├── .gitignore ├── README.md ├── best_weights.h5 ├── demo.py ├── environment.yml ├── images ├── messidor_test.tif ├── messidor_test_mask.tif ├── messidor_test_prediction.png └── unet_fod.png ├── predict_od_fov.py ├── requirements.txt ├── script-train-messidor.py └── util ├── __init__.py ├── data_generator0218.py ├── od_coords.py ├── unet_triclass_whole_image.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Specific folders 107 | results/ 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Joint Retinal Optical Disc and Fovea Detection 2 | 3 | Here you can find the implementation of a new strategy for the task of simultaneously locating the optic disc and the 4 | fovea in eye fundus images. This method was presented at MICCAI 2018 in Granada (oral presentation). If you find this code useful 5 | for your research, please consider citing our paper: 6 | 7 | > Meyer M.I., Galdran A., Mendonça A.M., Campilho A.. A Pixel-Wise Distance Regression Approach for Joint Retinal Optical Disc and Fovea Detection. In: Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, LNCS, vol 11071, pp 39-47, 2018. doi:10.1007/978-3-030-00934-2_5 8 | 9 | ### Update (April 2020) 10 | I have added an **environment.yml** file with the updated requirements such that the demo works as expected. Note that the training and prediction scripts were not tested with the new updated libraries, since I have stopped working on this project. Please take this into account when running the code! 11 | 12 | Introduction 13 | ------------ 14 | In contrast with previous techniques, the proposed method does not attempt to directly detect only OD and fovea centers. 15 | Instead, the distance to both locations is regressed for every pixel in a retinal image. 16 | This regression problem can be solved by means of a Fully-Convolutional Neural Network. 17 | This strategy poses a multi-task-like problem, on which information of every pixel contributes to generate a globally 18 | consistent prediction map where likelihood of OD and fovea locations are maximized. 19 | 20 | In particular, we make use of the a U-net architecture, while using a loss function suitable to perform pixel-wise 21 | distance regression (L2 Loss). 22 | 23 | ![](images/unet_fod.png) 24 | 25 | Installation 26 | ------------ 27 | To install all the requirements to run this code, first clone this repository to your local machine: 28 | ``` 29 | git clone https://github.com/minesmeyer/od-fovea-regression.git 30 | ``` 31 | 32 | If you have an Anaconda installation, you can use the `conda` package manager as follows: 33 | ``` 34 | conda env create -f environment.yml 35 | ``` 36 | where `env_name` is the name you want to use for creating a `conda` environment. 37 | 38 | Training 39 | -------- 40 | 41 | The method was trained and validated on the Messidor dataset. If you wish to replicate the training, the first step is 42 | to exclude the Messidor images that do not contain OD and Fovea location information, as provided by [1] 43 | (http://www.uhu.es/retinopathy/eng/bd.php). 44 | 45 | Split the remaining 1136 images in two (*half_1* and *half_2*) and train your model in two different splits: 46 | * **Split 1** will use *half_1* for training and *half_2* for testing. 47 | * **Split 2** will use *half_2* for training and *half_1* for testing. 48 | 49 | To start training, run the `script-train-messidor.py` file, specifying which split is being used and the directory where you wish 50 | to save the model's weights. 51 | 52 | Evaluating 53 | ---------- 54 | 55 | To obtain the resulting predictions on Messidor, use the `predict_od_fov.py` file, specifying which split is being used 56 | and the directory holding the weights of the model: 57 | 58 | 59 | Predicting OD and Fovea Location on a Single Image 60 | ------------------------ 61 | 62 | The pre-trained weights of the model are in the file `best_weights.h5`. Please be aware that this model was trained on 63 | a particular random split of the data as specified above. If you use it on a Messidor subset without re-training, some of 64 | the images may belong to the original training set. The example image in the `images/` folder is from the test set of the 65 | particular split this model was trained on. 66 | 67 | You can run the `demo.py` script, which takes a retinal image and returns the location of the OD and Fovea: 68 | ``` 69 | python demo.py --img_dir images/messidor_test.tif --mask_dir images/messidor_test_mask.tif 70 | ``` 71 | This script saves an image to a `results/` folder, which will be created automatically if it does not exist. 72 | 73 | ##### NOTE: 74 | When possible, input images should be cropped around the FOV, or a mask of the FOV provided as specified above. 75 | If no FOV mask is provided, this implementation will still return optic disc and fovea locations, although results may 76 | be suboptimal. 77 | 78 | Alternatively, you can call the above script with the flag `--estimate_fov`, which will perform a FOV 79 | segmentation based on a simple thresholding: 80 | ``` 81 | python demo.py --img_dir images/messidor_test.tif --estimate_fov 82 | ``` 83 | . 84 | 85 | ![](images/messidor_test_prediction.png) 86 | 87 | ---------------------------------- 88 | 89 | 90 | -------- 91 | 92 | ### References 93 | 1. Gegundez-Arias, M.E., Marin, D., Bravo, J.M., Suero, A.: Locating the fovea 94 | center position in digital fundus images using thresholding and feature extraction 95 | techniques. Computerized Medical Imaging and Graphics 37, 386–393 (2013) 96 | 97 | 98 | -------------------------------------------------------------------------------- /best_weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minesmeyer/od-fovea-regression/8d1aec38c98d8e95a79b0a6d143e3a8796fe53e0/best_weights.h5 -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ Example run od the OD and Fovea detector """ 2 | 3 | #!/usr/bin/python 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from skimage import io, img_as_float, transform 7 | 8 | from skimage import filters, color 9 | from skimage.measure import regionprops, label 10 | from skimage.morphology import erosion, dilation, disk 11 | 12 | 13 | from util.unet_triclass_whole_image import unet 14 | import util.od_coords as odc 15 | import util.util as ut 16 | 17 | import argparse 18 | 19 | parser = argparse.ArgumentParser(description='OD Fovea detection') 20 | 21 | parser.add_argument('-i', 22 | '--img_dir', 23 | help='Image_dir', 24 | type=str, 25 | default='images/messidor_test.tif') 26 | 27 | parser.add_argument('-m', 28 | '--mask_dir', 29 | help='Mask dir', 30 | type=str, 31 | default= 'images/messidor_test_mask.tif') 32 | 33 | parser.add_argument('-e', 34 | '--estimate_fov', 35 | action='store_true') 36 | 37 | # Define auxiliary functions 38 | def crop_image(img, mask): 39 | """ Crops the image to the edges of a given FOV mask. """ 40 | b = np.nonzero(mask) 41 | up, bot = np.amin(b[0]), np.amax(b[0]) 42 | left, right = np.amin(b[1]), np.amax(b[1]) 43 | cr_img = img[up:bot, left:right] 44 | return cr_img 45 | 46 | def get_original_coords(coords, mask): 47 | """ Get the landmark coordinates for the original resolution. """ 48 | b = np.nonzero(mask) 49 | up, bot = np.amin(b[0]), np.amax(b[0]) 50 | left, right = np.amin(b[1]), np.amax(b[1]) 51 | 52 | new_coords = (coords[0] + up, coords[1] + left) 53 | return new_coords 54 | 55 | def get_mask_fov(image): 56 | """ Estimate a FOV mask from the original image by simple thresholding. """ 57 | 58 | im = color.rgb2gray(image) 59 | im = filters.rank.enhance_contrast(im, disk(5)) 60 | im = filters.gaussian(im, 2) 61 | 62 | val = filters.threshold_otsu(im) 63 | 64 | mask = np.zeros(im.shape) 65 | mask[np.where(im < val)] = 1 66 | 67 | # apply closing operation to minimize the contamination with single pixels 68 | mask = dilation(mask, selem=disk(5)) 69 | mask = erosion(mask, selem=disk(5)) 70 | 71 | label_image = label(mask) 72 | 73 | # select the largest connected component as the background 74 | props = regionprops(label_image) 75 | pp = [props[i].area for i in range(len(props))] 76 | 77 | mask_new = np.ones(label_image.shape) 78 | # np.argmax(pp)+1 because label_image also considers 0, but props does not 79 | mask_new[np.where(label_image == np.unique(label_image)[np.argmax(pp)+1])] = 0 80 | 81 | return mask_new 82 | 83 | def demo_od_fovea_detection(args): 84 | 85 | ## Define the model and load the pre-trained weights 86 | weights_file = 'best_weights.h5' 87 | 88 | model = unet(3, 512, drop=0.) 89 | 90 | m1 = model.get_unet(nf=8) # u net using upsample 91 | m1.load_weights(weights_file) 92 | 93 | ## Load the image 94 | 95 | img = img_as_float(io.imread(args.img_dir)) 96 | 97 | if args.mask_dir is not None: 98 | mask = img_as_float(io.imread(args.mask_dir)) 99 | img_crop = crop_image(img, mask) 100 | img_to_pred = transform.resize(img_crop, (512, 512), order=0, mode='constant') 101 | elif args.estimate_fov is not False: 102 | mask = get_mask_fov(img) 103 | img_crop = crop_image(img, mask) 104 | img_to_pred = transform.resize(img_crop, (512, 512), order=0, mode='constant') 105 | else: 106 | img_to_pred = transform.resize(img, (512,512), order=0, mode='constant') 107 | 108 | img_to_pred = (img_to_pred - img_to_pred.mean(axis=(0,1))) / (img_to_pred.std(axis=(0,1))) 109 | 110 | ## Get the location prediction 111 | 112 | dist_map_pred = m1.predict(img_to_pred[np.newaxis, : :, :]) 113 | pred_map = dist_map_pred[0,:,:,0] 114 | 115 | ## Get the OD and Fovea locations from this distance map 116 | 117 | peak_coords = odc.get_peak_coordinates(pred_map, threshold=0.2) 118 | od_coords, fov_coords = odc.determine_od(img_to_pred, peak_coords, neigh=12) 119 | 120 | ## Get the coordinates in the original resolution 121 | if (args.mask_dir is not None) or (args.estimate_fov is not False): 122 | od_resh = odc.get_new_peaks(od_coords, img_crop.shape[:2]) 123 | f_resh = odc.get_new_peaks(fov_coords, img_crop.shape[:2]) 124 | 125 | od_resh = get_original_coords(od_resh, mask) 126 | f_resh = get_original_coords(f_resh, mask) 127 | else: 128 | od_resh = odc.get_new_peaks(od_coords, img.shape[:2]) 129 | f_resh = odc.get_new_peaks(fov_coords, img.shape[:2]) 130 | 131 | print('===> OD coordinates: ', od_resh) 132 | print('===> FOVEA coordinates: ', f_resh) 133 | 134 | fig, ax = plt.subplots(1,2, figsize=(15,10)) 135 | # plt.figure(figsize=(10, 10)) 136 | ax[0].imshow(img) 137 | ax[0].plot(od_resh[1], od_resh[0], 'b.') 138 | ax[0].plot(f_resh[1], f_resh[0], 'r.') 139 | 140 | ax[1].imshow(pred_map) 141 | ax[1].plot(od_coords[1], od_coords[0], 'b.') 142 | ax[1].plot(fov_coords[1], fov_coords[0], 'r.') 143 | plt.show() 144 | 145 | plt.figure(figsize=(10, 10)) 146 | plt.imshow(img) 147 | plt.plot(od_resh[1], od_resh[0], 'b.') 148 | plt.plot(f_resh[1], f_resh[0], 'r.') 149 | plt.title('Predicted location of OD (blue) and Fovea (red)') 150 | plt.xlabel('OD: ( {0}, {1}) Fovea: ({2}, {3}) '.format(od_resh[0], od_resh[1], 151 | f_resh[0], f_resh[1])) 152 | 153 | ut.create_dir('results/') 154 | plt.savefig('results/demo.png') 155 | 156 | 157 | if __name__ == "__main__": 158 | demo_od_fovea_detection(parser.parse_args()) 159 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: od-fovea 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=4.5 8 | - _tflow_select=2.3.0 9 | - absl-py=0.9.0 10 | - astor=0.7.1 11 | - binutils_impl_linux-64=2.34 12 | - binutils_linux-64=2.34 13 | - c-ares=1.15.0 14 | - ca-certificates=2019.11.28 15 | - certifi=2019.11.28 16 | - cloudpickle=1.3.0 17 | - cycler=0.10.0 18 | - cytoolz=0.10.1 19 | - dask-core=2.13.0 20 | - decorator=4.4.2 21 | - freetype=2.10.1 22 | - gast=0.2.2 23 | - gcc_impl_linux-64=7.3.0 24 | - gcc_linux-64=7.3.0 25 | - google-pasta=0.2.0 26 | - grpcio=1.23.0 27 | - gxx_impl_linux-64=7.3.0 28 | - gxx_linux-64=7.3.0 29 | - h5py=2.10.0 30 | - hdf5=1.10.5 31 | - icu=64.2 32 | - imageio=2.8.0 33 | - joblib=0.14.1 34 | - jpeg=9c 35 | - keras=2.3.1 36 | - keras-applications=1.0.8 37 | - keras-preprocessing=1.1.0 38 | - kiwisolver=1.1.0 39 | - ld_impl_linux-64=2.34 40 | - libblas=3.8.0 41 | - libcblas=3.8.0 42 | - libffi=3.2.1 43 | - libgcc-ng=9.2.0 44 | - libgfortran-ng=7.3.0 45 | - libgomp=9.2.0 46 | - libgpuarray=0.7.6 47 | - liblapack=3.8.0 48 | - libopenblas=0.3.7 49 | - libpng=1.6.37 50 | - libprotobuf=3.11.3 51 | - libstdcxx-ng=9.2.0 52 | - libtiff=4.1.0 53 | - libwebp-base=1.1.0 54 | - lz4-c=1.8.3 55 | - mako=1.1.0 56 | - markdown=3.2.1 57 | - markupsafe=1.1.1 58 | - matplotlib-base=3.2.1 59 | - ncurses=6.1 60 | - networkx=2.4 61 | - numpy=1.18.1 62 | - olefile=0.46 63 | - openssl=1.1.1f 64 | - opt_einsum=3.2.0 65 | - pandas=1.0.3 66 | - pillow=7.0.0 67 | - pip=20.0.2 68 | - protobuf=3.11.3 69 | - pygpu=0.7.6 70 | - pyparsing=2.4.6 71 | - python=3.6.10 72 | - python-dateutil=2.8.1 73 | - python_abi=3.6 74 | - pytz=2019.3 75 | - pywavelets=1.1.1 76 | - pyyaml=5.3.1 77 | - readline=8.0 78 | - scikit-image=0.16.2 79 | - scikit-learn=0.22.2.post1 80 | - scipy=1.4.1 81 | - setuptools=46.1.3 82 | - six=1.14.0 83 | - sqlite=3.30.1 84 | - tensorboard=1.15.0 85 | - tensorflow=1.15.0 86 | - tensorflow-base=1.15.0 87 | - tensorflow-estimator=1.15.1 88 | - termcolor=1.1.0 89 | - theano=1.0.4 90 | - tk=8.6.10 91 | - toolz=0.10.0 92 | - tornado=6.0.4 93 | - werkzeug=0.16.1 94 | - wheel=0.34.2 95 | - wrapt=1.12.1 96 | - xz=5.2.4 97 | - yaml=0.2.2 98 | - zlib=1.2.11 99 | - zstd=1.4.4 100 | 101 | -------------------------------------------------------------------------------- /images/messidor_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minesmeyer/od-fovea-regression/8d1aec38c98d8e95a79b0a6d143e3a8796fe53e0/images/messidor_test.tif -------------------------------------------------------------------------------- /images/messidor_test_mask.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minesmeyer/od-fovea-regression/8d1aec38c98d8e95a79b0a6d143e3a8796fe53e0/images/messidor_test_mask.tif -------------------------------------------------------------------------------- /images/messidor_test_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minesmeyer/od-fovea-regression/8d1aec38c98d8e95a79b0a6d143e3a8796fe53e0/images/messidor_test_prediction.png -------------------------------------------------------------------------------- /images/unet_fod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minesmeyer/od-fovea-regression/8d1aec38c98d8e95a79b0a6d143e3a8796fe53e0/images/unet_fod.png -------------------------------------------------------------------------------- /predict_od_fov.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import skimage.io as io 4 | import os 5 | import pandas as pd 6 | 7 | from scipy.ndimage.morphology import binary_erosion 8 | 9 | from util.unet_triclass_whole_image import unet 10 | import util.od_coords as odc 11 | import util.util as ut 12 | 13 | # Turn interactive plotting off 14 | plt.ioff() 15 | 16 | 17 | # #### TESTING ON SPLIT 1 OF MESSIDOR 18 | data_dir = '../messidor_od_fovea/split1/resized/test/' 19 | 20 | masks_dir = '../messidor_od_fovea/split1/test/' 21 | 22 | test_dir = os.path.join(data_dir, 'images/') 23 | 24 | model_dir = 'results/Messidor_split1/messidor_split1/' 25 | 26 | weights_file = 'best_weights.h5' 27 | 28 | # # Define model and load weights 29 | model = unet(3, 512, drop=0.) 30 | m1 = model.get_unet(nf=8) # u net using upsample 31 | m1.load_weights(weights_file) 32 | 33 | resized_coordinates = {'Image_File_Name': [], 34 | 'fovea-x-coordinate': [], 'fovea-y-coordinate': [], 35 | 'od_x_coords': [], 'od_y_coords': []} 36 | 37 | ut.create_dir('results/Preds_test_decay7/') 38 | 39 | for fl in sorted(os.listdir(test_dir)): 40 | 41 | resized_coordinates['Image_File_Name'].append(fl) 42 | 43 | # Predict peak location (in 512x512) 44 | img = io.imread(test_dir + fl) 45 | img_to_pred = ((img_to_pred - img_to_pred.mean(axis=(0,1))) / 46 | (img_to_pred.std(axis=(0,1)))) 47 | 48 | dist_map_pred = m1.predict(img_to_pred[np.newaxis,:,:,:]) 49 | pred_map = dist_map_pred[0,:,:,0] 50 | 51 | peak_coords = odc.get_peak_coordinates(pred_map, threshold=0.2) 52 | 53 | od_coords, fov_coords = odc.determine_od(img_to_pred, peak_coords, neigh=9) 54 | 55 | print(od_coords, fov_coords) 56 | 57 | # uncomment to save the predictions to disk 58 | # io.imsave(model_dir + 'Preds_maps_decay7/' + fl, pred_map, cmap='gray') 59 | 60 | plt.imshow(pred_map) 61 | plt.plot(od_coords[1], od_coords[0], 'b.') 62 | plt.plot(fov_coords[1], fov_coords[0], 'r.') 63 | 64 | # Get the locations in the original coordinates (1488x2240) 65 | fl_img = fl[:-4] + '.tif' 66 | fl_msk = fl[:-4] + '_test_mask.gif' 67 | img = io.imread(masks_dir + 'images/' + fl_img)/255. 68 | mask = io.imread(masks_dir + 'masks/' + fl_msk)/255. 69 | mask = binary_erosion(mask, structure=np.ones((10, 10))) 70 | 71 | # crop the images to FOV 72 | 73 | b = np.nonzero(mask) 74 | up, bot = np.min(b[0]), np.max(b[0]) 75 | left, right = np.min(b[1]), np.max(b[1]) 76 | cr_img = img[up:bot, left:right] 77 | 78 | sh_crop = cr_img.shape 79 | 80 | # Resize the predictions to the shape of the crop 81 | # cr_img_rz = resize(img, img_to_pred) 82 | 83 | od_resh = odc.get_new_peaks(od_coords, sh_crop[:2]) 84 | f_resh = odc.get_new_peaks(fov_coords, sh_crop[:2]) 85 | 86 | od_resh = (od_resh[0] + up, od_resh[1] + left) 87 | f_resh = (f_resh[0] + up, f_resh[1] + left) 88 | 89 | resized_coordinates['fovea-y-coordinate'].append(f_resh[0]) 90 | resized_coordinates['fovea-x-coordinate'].append(f_resh[1]) 91 | resized_coordinates['od_y_coords'].append(od_resh[0]) 92 | resized_coordinates['od_x_coords'].append(od_resh[1]) 93 | 94 | fig, ax = plt.subplots(1,2, figsize=(15,10)) 95 | ax[0].imshow(img) 96 | ax[0].plot(od_resh[1], od_resh[0], 'b.') 97 | ax[0].plot(f_resh[1], f_resh[0], 'r.') 98 | 99 | ax[1].imshow(pred_map) 100 | ax[1].plot(od_coords[1], od_coords[0], 'b.') 101 | ax[1].plot(fov_coords[1], fov_coords[0], 'r.') 102 | 103 | # uncomment to save the predictions to disk 104 | # plt.savefig(model_dir + 'Preds_test_decay7/' + fl) 105 | # plt.close('all') 106 | 107 | coords_df = pd.DataFrame(data=resized_coordinates) 108 | 109 | coords_df.to_csv(model_dir + 'Preds_test_decay7/fovea_od_preds.csv') 110 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # # This file may be used to create an environment using: 2 | # Original requirements, for reference. Please use the file environment.yml instead! 3 | # # $ conda create --name --file 4 | # # platform: linux-64 5 | # backports=1.0=py27h63c9359_1 6 | # backports.functools_lru_cache=1.4=py27he8db605_1 7 | # backports.shutil_get_terminal_size=1.0.0=py27h5bc021e_2 8 | # backports_abc=0.5=py27h7b3c97b_0 9 | # bleach=1.5.0=py27_0 10 | # bokeh=0.12.13=py27h5233db4_0 11 | # bzip2=1.0.6=h9a117a8_4 12 | # ca-certificates=2018.03.07=0 13 | # cairo=1.14.12=h77bcde2_0 14 | # certifi=2018.4.16=py27_0 15 | # click=6.7=py27h4225b90_0 16 | # cloudpickle=0.5.2=py27_1 17 | # configparser=3.5.0=py27h5117587_0 18 | # cudatoolkit=8.0=3 19 | # cudnn=7.0.5=cuda8.0_0 20 | # cycler=0.10.0=py27hc7354d3_0 21 | # dask=0.16.1=py27_0 22 | # dask-core=0.16.1=py27_0 23 | # dbus=1.12.2=hd3e2b69_0 24 | # decorator=4.2.1=py27_0 25 | # distributed=1.20.2=py27_0 26 | # entrypoints=0.2.3=py27h502b47d_2 27 | # enum34=1.1.6=py27h99a27e9_1 28 | # expat=2.2.5=he0dffb1_0 29 | # ffmpeg=3.4=h7264315_0 30 | # fontconfig=2.12.4=h88586e7_1 31 | # freetype=2.8=hab7d2ae_1 32 | # funcsigs=1.0.2=py27h83f16ab_0 33 | # functools32=3.2.3.2=py27h4ead58f_1 34 | # futures=3.2.0=py27h7b459c0_0 35 | # glib=2.53.6=h5d9569c_2 36 | # gmp=6.1.2=h6c8ec71_1 37 | # graphite2=1.3.10=hf63cedd_1 38 | # gst-plugins-base=1.12.4=h33fb286_0 39 | # gstreamer=1.12.4=hb53b477_0 40 | # h5py=2.7.1=py27h2697762_0 41 | # harfbuzz=1.7.4=hc5b324e_0 42 | # hdf5=1.10.1=h9caa474_1 43 | # heapdict=1.0.0=py27h33770af_0 44 | # html5lib=0.9999999=py27_0 45 | # icu=58.2=h9c2bf20_1 46 | # imageio=2.2.0=py27hf108a7f_0 47 | # intel-openmp=2018.0.0=hc7b2577_8 48 | # ipykernel=4.8.0=py27_0 49 | # ipython=5.4.1=py27_2 50 | # ipython_genutils=0.2.0=py27h89fb69b_0 51 | # ipywidgets=7.1.1=py27_0 52 | # jasper=1.900.1=hd497a04_4 53 | # jinja2=2.10=py27h4114e70_0 54 | # jpeg=9b=h024ee3a_2 55 | # jsonschema=2.6.0=py27h7ed5aa4_0 56 | # jupyter=1.0.0=py27_4 57 | # jupyter_client=5.2.2=py27_0 58 | # jupyter_console=5.2.0=py27hc6bee7e_1 59 | # jupyter_core=4.4.0=py27h345911c_0 60 | # keras=2.1.3=py27_0 61 | # libedit=3.1=heed3624_0 62 | # libffi=3.2.1=hd88cf55_4 63 | # libgcc-ng=7.2.0=h7cc24e2_2 64 | # libgfortran-ng=7.2.0=h9f7466a_2 65 | # libopus=1.2.1=hb9ed12e_0 66 | # libpng=1.6.34=hb9fc6fc_0 67 | # libprotobuf=3.4.1=h5b8497f_0 68 | # libsodium=1.0.15=hf101ebd_0 69 | # libstdcxx-ng=7.2.0=h7a57d05_2 70 | # libtiff=4.0.9=h28f6b97_0 71 | # libvpx=1.6.1=h888fd40_0 72 | # libxcb=1.12=hcd93eb1_4 73 | # libxml2=2.9.7=h26e45fe_0 74 | # locket=0.2.0=py27h73929a2_1 75 | # markdown=2.6.9=py27_0 76 | # markupsafe=1.0=py27h97b2822_1 77 | # matplotlib=2.1.2=py27h0e671d2_0 78 | # mistune=0.8.3=py27_0 79 | # mkl=2018.0.1=h19d6760_4 80 | # mock=2.0.0=py27h0c0c831_0 81 | # msgpack-python=0.5.1=py27h6bb024c_0 82 | # nbconvert=5.3.1=py27he041f76_0 83 | # nbformat=4.4.0=py27hed7f2b2_0 84 | # ncurses=6.0=h9df7e31_2 85 | # networkx=2.1=py27_0 86 | # notebook>=5.7.2=py27_0 87 | # #notebook=5.4.1=py27_0 88 | # #notebook=5.4.0=py27_0 89 | # numpy=1.14.0=py27h3dfced4_1 90 | # olefile=0.45.1=py27_0 91 | # opencv=3.3.1=py27h6cbbc71_1 92 | # openssl=1.0.2o=h20670df_0 93 | # pandas=0.22.0=py27hf484d3e_0 94 | # pandoc=1.19.2.1=hea2e7c5_1 95 | # pandocfilters=1.4.2=py27h428e1e5_1 96 | # partd=0.3.8=py27h4e55004_0 97 | # pathlib2=2.3.0=py27h6e9d198_0 98 | # patsy=0.5.0=py27_0 99 | # pbr=3.1.1=py27hf64632f_0 100 | # pcre=8.41=hc27e229_1 101 | # pexpect=4.3.1=py27_0 102 | # pickleshare=0.7.4=py27h09770e1_0 103 | # pillow=5.0.0=py27h3deb7b8_0 104 | # pip=9.0.1=py27ha730c48_4 105 | # pixman=0.34.0=hceecf20_3 106 | # prompt_toolkit=1.0.15=py27h1b593e1_0 107 | # protobuf=3.4.1=py27h2ba6a9c_0 108 | # psutil=5.4.3=py27h14c3975_0 109 | # ptyprocess=0.5.2=py27h4ccb14c_0 110 | # pygments=2.2.0=py27h4a8b6f5_0 111 | # pyparsing=2.2.0=py27hf1513f8_1 112 | # pyqt=5.6.0=py27h4b1e83c_5 113 | # python=2.7.14=h1571d57_29 114 | # python-dateutil=2.6.1=py27h4ca5741_1 115 | # pytz=2017.3=py27h001bace_0 116 | # pywavelets=0.5.2=py27hecda097_0 117 | # pyyaml>=4.2b1=py27_0 118 | # pyzmq=16.0.3=py27hc579512_0 119 | # qt=5.6.2=h974d657_12 120 | # qtconsole=4.3.1=py27hc444b0d_0 121 | # readline=7.0=ha6073c6_4 122 | # scandir=1.6=py27hf7388dc_0 123 | # scikit-image=0.13.1=py27h14c3975_1 124 | # scikit-learn=0.19.1=py27h445a80a_0 125 | # scipy=1.0.0=py27hf5f0f52_0 126 | # seaborn=0.8.1=py27h633ea1e_0 127 | # send2trash=1.4.2=py27_0 128 | # setuptools=38.4.0=py27_0 129 | # simplegeneric=0.8.1=py27h19e43cd_0 130 | # singledispatch=3.4.0.3=py27h9bcb476_0 131 | # sip=4.18.1=py27he9ba0ab_2 132 | # six=1.11.0=py27h5f960f1_1 133 | # sortedcontainers=1.5.9=py27_0 134 | # sqlite=3.22.0=h1bed415_0 135 | # ssl_match_hostname=3.5.0.1=py27h4ec10b9_2 136 | # statsmodels=0.8.0=py27hc87d62d_0 137 | # subprocess32=3.2.7=py27h373dbce_0 138 | # tblib=1.3.2=py27h51fe5ba_0 139 | # tensorflow=1.4.1=0 140 | # tensorflow-base=1.4.1=py27hd00c003_2 141 | # tensorflow-gpu=1.4.1=0 142 | # tensorflow-gpu-base=1.4.1=py27h01caf0a_0 143 | # tensorflow-tensorboard=0.4.0=py27hf484d3e_0 144 | # terminado=0.8.1=py27_1 145 | # testpath=0.3.1=py27hc38d2c4_0 146 | # tk=8.6.7=hc745277_3 147 | # toolz=0.9.0=py27_0 148 | # tornado=4.5.3=py27_0 149 | # traitlets=4.3.2=py27hd6ce930_0 150 | # wcwidth=0.1.7=py27h9e3e1ab_0 151 | # werkzeug=0.14.1=py27_0 152 | # wheel=0.30.0=py27h2bc6bb2_1 153 | # widgetsnbextension=3.1.0=py27_0 154 | # xz=5.2.3=h55aa19d_2 155 | # yaml=0.1.7=had09818_2 156 | # zeromq=4.2.2=hbedb6e5_2 157 | # zict=0.1.3=py27h12c336c_0 158 | # zlib=1.2.11=ha838bed_2 159 | -------------------------------------------------------------------------------- /script-train-messidor.py: -------------------------------------------------------------------------------- 1 | """ Main script meant for training on Messidor or IDRiD datasets """ 2 | 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import pickle 8 | import os 9 | 10 | from util.data_generator0218 import TwoImageIterator 11 | import util.util as ut 12 | from util.unet_triclass_whole_image import unet 13 | 14 | from keras.optimizers import Adam 15 | from keras import backend as K 16 | 17 | 18 | def check_EarlyStop(vloss, tr_loss, patience=5): 19 | 20 | Eopt = np.min(vloss[:-1]) 21 | GL = ((vloss[-1] / Eopt) - 1) 22 | Pk = (np.sum(tr_loss[-patience:]) / ( 23 | patience * (np.min(tr_loss[-patience:])))) 24 | PQ = GL / Pk 25 | 26 | # return [PQ, GL, Pk] 27 | if (GL > 0.): 28 | if PQ > 0.5: 29 | return 'early_stop' 30 | elif Pk < 1.1: 31 | return 'early_stop' 32 | else: 33 | return 'pass' 34 | else: 35 | return 'pass' 36 | 37 | 38 | # #### TRAINING ON SPLIT 1 OF MESSIDOR 39 | data_dir = '../messidor_od_fovea/split1/resized/' 40 | 41 | a_tr, b_tr = 'train/images/', 'train/gdt/' 42 | a_val, b_val = 'val/images/', 'val/gdt/' 43 | 44 | # ## For Messidor 45 | nsamples = 455 46 | nsamples_val = 113 47 | 48 | # ## For IDRiD 49 | # nsamples = 373. 50 | # nsamples_val = 40. 51 | 52 | nb_epochs = 601 53 | batch_size = 16 54 | 55 | b_iter = int(np.ceil(nsamples / batch_size)) 56 | 57 | val_batch_size = int(np.ceil(nsamples_val/b_iter)) 58 | 59 | print(b_iter) 60 | 61 | decay = 7 62 | 63 | # #### Define the Model 64 | model = unet(3, 512, drop=0.5) 65 | m1 = model.get_unet(nf=8) # u net using upsample 66 | lr = 0.005 67 | adam = Adam(lr=lr) 68 | m1.compile(optimizer=adam, loss='mse') 69 | path_model = '../results/Messidor_split1_NEW/2804_messidor_split1_nf8_decay{0}/'.format(decay) 70 | 71 | # #### Set the iterators 72 | train_it = TwoImageIterator(data_dir, a_dir_name=a_tr, 73 | b_dir_name=b_tr, N=-1, 74 | batch_size=batch_size, shuffle=True, seed=None, 75 | target_size=(512,512), nch_gdt=1, 76 | rotation_range=0.2, height_shift_range = 0., shear_range = 0., 77 | width_shift_range = 0., zoom_range = 0., fill_mode='constant', 78 | cval = 0., horizontal_flip=True, vertical_flip=True, 79 | cspace='rgb', 80 | normalize_tanh=False, zscore=True, decay=decay, dataset='messidor') 81 | 82 | val_it = TwoImageIterator(data_dir, a_dir_name=a_val, 83 | b_dir_name=b_val, N=-1, 84 | batch_size=val_batch_size, shuffle=True, seed=None, 85 | target_size=(512,512), nch_gdt=1, 86 | cspace='rgb', 87 | normalize_tanh=False, zscore=True, decay=decay, dataset='messidor') 88 | 89 | # Create the folder where intermediate results will be saved for verification 90 | tr_folder = 'train_continued/' 91 | ut.create_dir(path_model + tr_folder) 92 | 93 | # #### Training Loop 94 | l_ep = {} 95 | acc_ep = {} 96 | 97 | losses = {'train': [], 'val': []} 98 | epoch = 0 99 | for e in range(epoch, nb_epochs): 100 | 101 | print('Epoch %d' % (e + 1)) 102 | l_ep['train'] = np.zeros(b_iter) 103 | l_ep['val'] = np.zeros(b_iter) 104 | 105 | acc_ep['train'] = np.zeros(b_iter) 106 | acc_ep['val'] = np.zeros(b_iter) 107 | 108 | for it in range(b_iter): 109 | x_batch, y_batch = next(train_it) 110 | xval, yval = next(val_it) 111 | 112 | tmp = m1.fit(x_batch, y_batch, batch_size=batch_size, epochs=1, 113 | shuffle=True, validation_data=(xval, yval), 114 | verbose=0) 115 | 116 | # save loss for further inspection 117 | l_ep['train'][it] = tmp.history['loss'][0] 118 | l_ep['val'][it] = tmp.history['val_loss'][0] 119 | 120 | print('batch' + str(it) + ': ') 121 | print(tmp.history) 122 | if np.isnan(tmp.history['loss']): 123 | raise Exception('Loss is NaN') 124 | 125 | val_loss, tr_loss = np.median(l_ep['val']), np.median(l_ep['train']) 126 | 127 | print('loss: [%.6f], val_loss: [%0.6f]' % (tr_loss, val_loss)) 128 | 129 | losses['train'].append(np.mean(l_ep['train'])) 130 | losses['val'].append(np.mean(l_ep['val'])) 131 | 132 | # Save best model 133 | if e > epoch+1: 134 | 135 | Eopt = np.min(losses['val'][:-1]) 136 | 137 | if losses['val'][-1] < Eopt: 138 | m1.save((path_model +'best_model.h5'), 139 | overwrite=True) 140 | m1.save_weights((path_model + 'best_weights.h5'), 141 | overwrite=True) 142 | 143 | # save intermediate to folder results every 10 epochs 144 | if e % 10 == 0: 145 | 146 | ypred = m1.predict(xval) 147 | x_plt = (xval[0] - xval[0].min()) / (xval[0].max() - xval[0].min()) 148 | fix, ax = plt.subplots(1,3, figsize=(10,10)) 149 | 150 | ax[0].imshow(x_plt) 151 | ax[1].imshow(yval[0, :, :, 0], cmap='gray') 152 | ax[2].imshow(ypred[0, :, :, 0], cmap='gray') 153 | 154 | plt.savefig((path_model + tr_folder + 'val_pred_e_' + str(e) + '.png')) 155 | 156 | sv_path = os.path.join(path_model, tr_folder) 157 | pickle.dump(losses, open(sv_path + 'losses.pkl', 'wb')) 158 | 159 | # plot the training and validation losses 160 | ut.plot_loss(losses['train'], 'Train Loss', 'unet_loss.png', 161 | sv_path, title='Training Loss', ylim=(0, 0.05)) 162 | ut.plot_loss(losses['val'], 'Validation Loss', 'unet_loss_val.png', 163 | sv_path, title='Validation_Loss', ylim=(0, 0.05)) 164 | 165 | if (e >= epoch+50) and (e % 50 == 0): 166 | if check_EarlyStop(losses['val'], losses['train'], patience=50) == 'early_stop': 167 | print('Decreasing learning rate to lr= %f' % (lr/2)) 168 | K.set_value(m1.optimizer.lr, lr/2) 169 | else: 170 | pass 171 | 172 | if (e >= 100) and (e % 100 == 0): 173 | if check_EarlyStop(losses['val'], losses['train'], patience=100) == 'early_stop': 174 | print('Early Stopping Model...') 175 | break 176 | else: 177 | pass 178 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minesmeyer/od-fovea-regression/8d1aec38c98d8e95a79b0a6d143e3a8796fe53e0/util/__init__.py -------------------------------------------------------------------------------- /util/data_generator0218.py: -------------------------------------------------------------------------------- 1 | """ 2 | Iterator to load images from the datasets, and related functions. 3 | """ 4 | import os 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from skimage import io 9 | from skimage import img_as_float 10 | from skimage.transform import resize 11 | from scipy.ndimage import distance_transform_edt 12 | 13 | import pickle 14 | 15 | from keras import backend as K 16 | from keras.preprocessing.image import Iterator 17 | from keras.preprocessing.image import apply_transform, flip_axis 18 | from keras.preprocessing.image import transform_matrix_offset_center 19 | 20 | def normalize_for_tanh(batch): 21 | """Make input image values lie between -1 and 1.""" 22 | tanh_batch = batch - np.max(batch)/2. 23 | tanh_batch /= np.max(batch)/2. 24 | return tanh_batch 25 | 26 | class TwoImageIterator(Iterator): 27 | """Class to iterate A and B images at the same time, while applying desired 28 | transformations online.""" 29 | 30 | def __init__(self, directory, a_dir_name='A', b_dir_name=None, N=-1, 31 | batch_size=32, shuffle=True, seed=None, target_size=(512,512), 32 | cspace='rgb', nch_gdt=1, 33 | zscore=True, normalize_tanh=False, 34 | return_mode='normal', decay=5, dataset='idrid', 35 | rotation_range=0., height_shift_range=0., shear_range=0., 36 | width_shift_range=0., zoom_range=0., fill_mode='constant', 37 | cval=0., horizontal_flip=False, vertical_flip=False): 38 | 39 | """ 40 | Iterate through the image directoriy, apply transformations and return 41 | distance map calculated on the fly. If b_dir_name is not None, it will 42 | retrieve the ground truth from the directory. 43 | 44 | Files under the directory A and B will be returned at the same time. 45 | Parameters: 46 | - directory: base directory of the dataset. Should contain two 47 | directories with name a_dir_name and b_dir_name; 48 | - a_dir_name: name of directory under directory that contains the A 49 | images; 50 | - b_dir_name: name of directory under directory that contains the B 51 | images; 52 | - N: if -1 uses the entire dataset. Otherwise only uses a subset; 53 | - batch_size: the size of the batches to create; 54 | - shuffle: if True the order of the images in X will be shuffled; 55 | - seed: seed for a random number generator; 56 | - return_mode: 'normal', 'fnames'. Default: 'normal' 57 | - 'normal' returns: [batch_a, batch_b] 58 | - 'fnames' returns: [batch_a, batch_b, files] 59 | - decay: decay at which to compute de distance map. Default: 5 60 | - dataset: dataset to load. Can handle Messidor and Idrid. Default: Idrid 61 | 62 | """ 63 | self.directory = directory 64 | 65 | self.a_dir = os.path.join(directory, a_dir_name) 66 | self.a_fnames = sorted(os.listdir(self.a_dir)) 67 | 68 | self.b_dir_name = b_dir_name 69 | if b_dir_name is not None: 70 | self.b_dir = os.path.join(directory, b_dir_name) 71 | self.b_fnames = sorted(os.listdir(self.b_dir)) 72 | 73 | # Use only a subset of the files. Good to easily overfit the model 74 | if N > 0: 75 | self.filenames = self.a_fnames[:N] 76 | self.N = len(self.a_fnames) 77 | 78 | self.ch_order = K.image_dim_ordering() 79 | 80 | # Preprocess images 81 | self.cspace = cspace #colorspace 82 | 83 | # Image shape 84 | self.target_size = target_size 85 | self.nch_gdt = nch_gdt 86 | 87 | self.nch = len(self.cspace) # for example if grayscale 88 | 89 | self.select_vessels = select_vessels 90 | 91 | self.img_shape_a = self._get_img_shape(self.target_size, ch=self.nch) 92 | self.img_shape_b = self._get_img_shape(self.target_size, ch=self.nch_gdt) 93 | 94 | if self.ch_order == 'tf': 95 | self.channel_index = 3 96 | self.row_index = 1 97 | self.col_index = 2 98 | else: 99 | self.channel_index = 1 100 | self.row_index = 2 101 | self.col_index = 3 102 | 103 | #Normalizations 104 | self.normalize_tanh = normalize_tanh 105 | self.zscore = zscore 106 | 107 | # Transformations 108 | self.rotation_range = rotation_range 109 | self.height_shift_range = height_shift_range 110 | self.width_shift_range = width_shift_range 111 | self.shear_range = shear_range 112 | self.fill_mode = fill_mode 113 | self.cval = cval 114 | self.horizontal_flip = horizontal_flip 115 | self.vertical_flip = vertical_flip 116 | if np.isscalar(zoom_range): 117 | self.zoom_range = [1 - zoom_range, 1 + zoom_range] 118 | elif len(zoom_range) == 2: 119 | self.zoom_range = [zoom_range[0], zoom_range[1]] 120 | 121 | 122 | self.return_mode = return_mode 123 | 124 | self.decay=decay 125 | self.dataset = dataset 126 | 127 | super(TwoImageIterator, self).__init__(len(self.a_fnames), batch_size, 128 | shuffle, seed) 129 | 130 | def _get_img_shape(self, size, ch=3): 131 | 132 | if self.ch_order == 'tf': 133 | img_shape = size + (ch,) 134 | else: 135 | img_shape = (ch,) + size 136 | 137 | return img_shape 138 | 139 | def _load_img_pair(self, idx): 140 | """ 141 | Load images and apply pre-processing 142 | :param idx: index of file to load in the list of names 143 | :return: aa: image 144 | bb: ground truth 145 | """ 146 | aa = img_as_float(io.imread(os.path.join(self.a_dir, self.a_fnames[idx]))) 147 | bb = img_as_float(io.imread(os.path.join(self.b_dir, self.b_fnames[idx]))) 148 | 149 | if self.nch_gdt == 3: 150 | # fix for the case when the .png has an alpha channel 151 | if bb.shape[-1] == 4: 152 | bb = bb[:,:,:3] 153 | elif self.nch_gdt == 1: 154 | # fix for the case when the .png has an alpha channel 155 | if len(bb.shape) == 2: 156 | bb = bb[:,:,np.newaxis] 157 | 158 | if self.select_vessels is True: 159 | bb = self.select_vessel_width(bb) 160 | 161 | return aa, bb 162 | 163 | def _random_transform(self, a, b, is_batch=False): 164 | """ 165 | Random dataset augmentation. 166 | 167 | Adapted from https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py 168 | """ 169 | if is_batch is False: 170 | # a and b are single images, so they don't have image number at index 0 171 | img_row_index = self.row_index - 1 172 | img_col_index = self.col_index - 1 173 | img_channel_index = self.channel_index - 1 174 | else: 175 | img_row_index = self.row_index 176 | img_col_index = self.col_index 177 | img_channel_index = self.channel_index 178 | # use composition of homographies to generate final transform that needs to be applied 179 | if self.rotation_range: 180 | theta = np.pi / 180 * np.random.uniform(-self.rotation_range, 181 | self.rotation_range) 182 | else: 183 | theta = 0 184 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 185 | [np.sin(theta), np.cos(theta), 0], 186 | [0, 0, 1]]) 187 | if self.height_shift_range: 188 | tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) \ 189 | * a.shape[img_row_index] 190 | else: 191 | tx = 0 192 | 193 | if self.width_shift_range: 194 | ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) \ 195 | * a.shape[img_col_index] 196 | else: 197 | ty = 0 198 | 199 | translation_matrix = np.array([[1, 0, tx], 200 | [0, 1, ty], 201 | [0, 0, 1]]) 202 | 203 | if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: 204 | zx, zy = 1, 1 205 | else: 206 | zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2) 207 | zoom_matrix = np.array([[zx, 0, 0], 208 | [0, zy, 0], 209 | [0, 0, 1]]) 210 | 211 | if self.shear_range: 212 | shear = np.random.uniform(-self.shear_range, self.shear_range) 213 | else: 214 | shear = 0 215 | shear_matrix = np.array([[1, -np.sin(shear), 0], 216 | [0, np.cos(shear), 0], 217 | [0, 0, 1]]) 218 | 219 | transform_matrix = np.dot(np.dot(np.dot(rotation_matrix, translation_matrix), shear_matrix), 220 | zoom_matrix) 221 | 222 | h, w = a.shape[img_row_index], a.shape[img_col_index] 223 | transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) 224 | a = apply_transform(a, transform_matrix, img_channel_index, 225 | fill_mode=self.fill_mode, cval=self.cval) 226 | b = apply_transform(b, transform_matrix, img_channel_index, 227 | fill_mode=self.fill_mode, cval=self.cval) 228 | 229 | if self.horizontal_flip: 230 | if np.random.random() < 0.5: 231 | a = flip_axis(a, img_col_index) 232 | b = flip_axis(b, img_col_index) 233 | 234 | 235 | if self.vertical_flip: 236 | if np.random.random() < 0.5: 237 | a = flip_axis(a, img_row_index) 238 | b = flip_axis(b, img_row_index) 239 | 240 | return a, b 241 | 242 | def get_dist_maps(self, coords, shp=(512,512)): 243 | fx, fy = coords[0] 244 | odx, ody = coords[1] 245 | 246 | distance = np.ones(shp) 247 | distance[fy, fx] = 0 248 | distance[ody, odx] = 0 249 | distance = distance_transform_edt(distance) 250 | distance = distance[:,:,np.newaxis] 251 | if shp != (512,512): 252 | distance=resize(1 - distance / np.max(distance), (512,512,1)) ** self.decay 253 | else: 254 | distance = (1 - distance / np.max(distance)) ** self.decay 255 | return distance 256 | 257 | def next(self): 258 | """Get the next pair of the sequence.""" 259 | 260 | # Lock the iterator when the index is changed. 261 | with self.lock: 262 | index_array = next(self.index_generator) 263 | current_batch_size = len(index_array) 264 | 265 | # Initialize the arrays according to the size of the output images 266 | batch_a = np.zeros((current_batch_size,) + self.img_shape_a) 267 | batch_b = np.zeros((current_batch_size,) + self.img_shape_b[:-1] 268 | + (self.nch_gdt,)) 269 | 270 | files = [] 271 | ind = [] 272 | 273 | if self.b_dir_name is None: 274 | if self.dataset == 'messidor': 275 | ### For Messidor 276 | all_coords = pickle.load(open(os.path.join(self.directory + 'resized_coords.pkl'), 'r')) 277 | 278 | elif self.dataset == 'idrid': 279 | ### For IDRiD 280 | file_csv_od = os.path.join(self.directory + 'IDRiD_OD_Center_Training_set.csv') 281 | file_csv_fov = os.path.join(self.directory + 'IDRiD_Fovea_Center_Training_set.csv') 282 | 283 | gt_fovea = pd.read_csv(file_csv_fov) 284 | # get rid of garbage data 285 | gt_fovea.drop(gt_fovea.columns[3:], axis=1, inplace=True) 286 | gt_fovea.drop(gt_fovea.index[413:], inplace=True) 287 | 288 | gt_od = pd.read_csv(file_csv_od) 289 | # get rid of garbage data 290 | gt_od.drop(gt_od.columns[3:], axis=1, inplace=True) 291 | gt_od.drop(gt_od.index[413:], inplace=True) 292 | 293 | # Load images and apply transformations 294 | for i, j in enumerate(index_array): 295 | im_id = self.a_fnames[j][:-4] 296 | 297 | if self.b_dir_name is not None: 298 | a_img, b_img = self._load_img_pair(j) 299 | 300 | else: 301 | a_img = img_as_float(io.imread(os.path.join(self.a_dir, self.a_fnames[j]))) 302 | 303 | if self.dataset == 'messidor': 304 | ### For Messidor 305 | a_idx = np.where(np.array(all_coords['Image']) == im_id + '.tif')[0][0] 306 | coords = [all_coords['fovea'][a_idx], all_coords['od'][a_idx]] 307 | # get the distance maps 308 | b_img = self.get_dist_maps(coords) 309 | 310 | elif self.dataset == 'idrid': 311 | ### For IDRiD 312 | fovea_coords = gt_fovea[gt_fovea['Image No'] == im_id] 313 | fx, fy = int(fovea_coords['X- Coordinate']), int(fovea_coords['Y - Coordinate']) 314 | od_coords = gt_od[gt_od['Image No'] == im_id] 315 | odx, ody = int(od_coords['X- Coordinate']), int(od_coords['Y - Coordinate']) 316 | coords = [(fx,fy), (odx, ody)] 317 | b_img = self.get_dist_maps(coords, shp=(2848, 4288)) 318 | 319 | a_img, b_img = self._random_transform(a_img, b_img) 320 | if self.zscore is True: 321 | a_img = (a_img - a_img.mean()) / (a_img.std()) 322 | 323 | batch_a[i] = a_img 324 | batch_b[i] = b_img 325 | 326 | files.append(self.a_fnames[j]) 327 | 328 | # when using tanh activation the inputs must be between [-1 1] 329 | if self.normalize_tanh is True and self.zscore is False: 330 | batch_a = normalize_for_tanh(batch_a) 331 | batch_b = normalize_for_tanh(batch_b) 332 | 333 | if self.return_mode == 'normal': 334 | return [batch_a, batch_b] 335 | 336 | elif self.return_mode == 'fnames': 337 | return [batch_a, batch_b, files] 338 | -------------------------------------------------------------------------------- /util/od_coords.py: -------------------------------------------------------------------------------- 1 | """ Functions to deal with OD and Fovea localization. """ 2 | 3 | import numpy as np 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | from skimage.feature import blob_log 8 | from skimage.color import rgb2gray 9 | from skimage.feature import peak_local_max 10 | 11 | from scipy.ndimage.filters import maximum_filter 12 | from scipy.ndimage.morphology import binary_fill_holes 13 | 14 | def find_od_f(pred): 15 | 16 | local_max = maximum_filter(pred, size=5, mode='constant') 17 | coordinates = peak_local_max(pred, min_distance=50, num_peaks=2) 18 | 19 | return coordinates 20 | 21 | def plot_coords(img, coords): 22 | 23 | plt.imshow(img) 24 | plt.plot(coords[:, 1], coords[:, 0], 'r.') 25 | 26 | 27 | def get_new_peaks(coords, shp): 28 | 29 | xo, yo = shp 30 | xp, yp = coords 31 | 32 | xn = (xp * xo) / 512 33 | yn = (yp * yo) / 512 34 | coord_new = (xn, yn) 35 | 36 | return coord_new 37 | 38 | 39 | def distance_metric(pred_coords, orig_coords): 40 | 41 | xp, yp = pred_coords 42 | xo, yo = orig_coords 43 | 44 | dist = np.sqrt((xo - xp) ** 2 + (yo - yp) ** 2) 45 | 46 | return dist 47 | 48 | 49 | def distance_error(pred_coords, orig_coords, od_radius=88., r=1): 50 | xp, yp = pred_coords 51 | xo, yo = orig_coords 52 | 53 | dist = np.sqrt((xo - xp) ** 2 + (yo - yp) ** 2) 54 | 55 | error_od_radius = dist / (od_radius * r) 56 | 57 | return dist, error_od_radius 58 | 59 | 60 | def determine_od(image, coords, neigh=3): 61 | """ Determines which peak corresponds to the OD and to the Fovea. 62 | input params: 63 | image: the RGB image 64 | coords: the coordinates of the two selected peak_coords 65 | neigh: the neighbourhood to consider for evaluation 66 | returns: 67 | od_coords: the coordinates of the peak selected as OD 68 | fov_coords: the coordinates of the peak selected as Fovea 69 | """ 70 | # create a special case for the border, in case the peak is located close 71 | # to it, it must always have neighbours 72 | coords[np.where(coords < neigh)] = neigh 73 | coords[np.where(coords > (511-neigh))] = (511-neigh) 74 | 75 | coord_new1, coord_new2 = coords[0], coords[1] 76 | 77 | # Calculate the mean intensity of each peak and its neighbohood 78 | i1 = np.mean(image[:,:,1][coord_new1[0]-neigh:coord_new1[0]+neigh, 79 | coord_new1[1]-neigh:coord_new1[1]+neigh]) 80 | i2 = np.mean(image[:,:,1][coord_new2[0]-neigh:coord_new2[0]+neigh, 81 | coord_new2[1]-neigh:coord_new2[1]+neigh]) 82 | 83 | # The OD is expected to have higher intensity 84 | if i1 >= i2: 85 | od_coords = coord_new1 86 | fov_coords = coord_new2 87 | 88 | elif i10)[0] 104 | cmin = indices[0] 105 | cmax = indices[-1] 106 | 107 | # Same for rows 108 | collapsedr = np.sum(od_mask, axis=1) 109 | # These indices will be already sorted 110 | indices = np.where(collapsedr == collapsedr.max())[0] 111 | r = indices[int(round((len(indices) - 1) / 2))] 112 | indices = np.where(collapsedr>0)[0] 113 | rmin = indices[0] 114 | rmax = indices[-1] 115 | 116 | dc = cmax - cmin 117 | dr = rmax - rmin 118 | return dc, dr 119 | 120 | def get_centroid(mask, fill=True): 121 | """ 122 | Function that retuns the coordinates of the centroid of the OD or fovea 123 | """ 124 | if fill is True: 125 | mask = binary_fill_holes(mask) 126 | 127 | collapsedc = np.sum(mask, axis=0) 128 | indices = np.where(collapsedc == collapsedc.max())[0] 129 | c = indices[int(round((len(indices) - 1) / 2))] 130 | 131 | collapsedr = np.sum(mask, axis=1) 132 | indices = np.where(collapsedr == collapsedr.max())[0] 133 | r = indices[int(round((len(indices) - 1) / 2))] 134 | 135 | return c, r 136 | 137 | 138 | def get_peak_coordinates(image, threshold=0.2): 139 | image_gray = rgb2gray(image) 140 | image_gray = np.pad(image_gray, (15, 15), 'constant') 141 | blobs = blob_log(image_gray, min_sigma=10, max_sigma=50, threshold=threshold) 142 | 143 | bb = blobs[:, :2].astype('int') 144 | 145 | if blobs.shape[0] < 2: 146 | new_blobs = np.copy(blobs) 147 | 148 | while new_blobs.shape[0] < 2: 149 | 150 | threshold = 0.8 * threshold 151 | print(threshold) 152 | if threshold < 0.001: 153 | print('Threshold too low! Passing...') 154 | break 155 | else: 156 | new_blobs = blob_log(image, min_sigma=10, max_sigma=50, 157 | threshold=threshold) 158 | 159 | blobs = new_blobs 160 | print(blobs.shape) 161 | if blobs.shape[0] < 2: 162 | np.concatenate((blobs, [[256, 256, 0]]), axis=0) 163 | 164 | blobs = blobs - 15 # to account for to the initial padding 165 | blobs[np.where(blobs > 512)] = 0 166 | blobs[np.where(blobs < 0)] = 0 167 | 168 | 169 | blobs = blobs[:, :2].astype('int') 170 | 171 | bb2 = blobs[:, :2].astype('int') 172 | 173 | #if blobs.shape[0] > 2: 174 | # sorted_indx = np.argsort(image[bb2[:, 0], bb2[:, 1]], axis=None)[::-1] 175 | # print sorted_indx 176 | # blobs = bb2[sorted_indx[:2]] 177 | return blobs 178 | -------------------------------------------------------------------------------- /util/unet_triclass_whole_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | The U-Net scaled for 512x512 inputs 3 | """ 4 | 5 | from keras.layers import Input, Concatenate, Dense 6 | from keras.layers.core import Activation, Reshape, Dropout 7 | from keras.layers.convolutional import Conv2D 8 | from keras.layers.normalization import BatchNormalization 9 | from keras.layers.advanced_activations import LeakyReLU 10 | from keras.models import Model 11 | from keras.layers import UpSampling2D 12 | 13 | from keras import backend as K 14 | from keras.optimizers import Adam 15 | 16 | import numpy as np 17 | 18 | 19 | class unet(): 20 | 21 | def __init__(self, nch, sz=512, drop = 0.): 22 | 23 | self.sz = sz 24 | self.nch = nch 25 | self.nlayers = int(np.floor(np.log(sz)/np.log(2)))+1 26 | self.drop = drop 27 | 28 | # Define the neural network 29 | def get_unet(self, nf): 30 | 31 | input_n = Input((self.sz, self.sz, self.nch)) 32 | 33 | ##### Encoder part 34 | conv1 = [] 35 | # nch x 512 x 512 36 | conv = Conv2D(nf, (3, 3), padding='same')(input_n) 37 | conv = LeakyReLU(alpha=0.01)(conv) 38 | conv = BatchNormalization()(conv) 39 | conv = Dropout(self.drop)(conv) 40 | 41 | # nfxnch x 512 x 512 42 | conv1 = Conv2D(nf*2, (3, 3), padding='same', strides=(2, 2))(conv) 43 | # nfxnch x 256 x 256 44 | conv1 = LeakyReLU(alpha=0.01)(conv1) 45 | conv1 = BatchNormalization()(conv1) 46 | conv1 = Dropout(self.drop)(conv1) 47 | 48 | # nfxnch x 256 x 256 49 | conv2 = Conv2D(nf*4, (3, 3), padding='same', strides=(2, 2))(conv1) 50 | # nfxnch x 128 x 128 51 | conv2 = LeakyReLU(alpha=0.01)(conv2) 52 | conv2 = BatchNormalization()(conv2) 53 | conv2 = Dropout(self.drop)(conv2) 54 | 55 | # nfxnch x 128 x 128 56 | conv3 = Conv2D(nf*8, (3, 3), padding='same', strides=(2, 2))(conv2) 57 | # nfxnch x 64 x 64 58 | conv3 = LeakyReLU(alpha=0.01)(conv3) 59 | conv3 = BatchNormalization()(conv3) 60 | conv3 = Dropout(self.drop)(conv3) 61 | 62 | # nfxnch x 64 x 64 63 | conv4 = Conv2D(nf*8, (3, 3), padding='same', strides=(2, 2))(conv3) 64 | # nfxnch x 32 x 32 65 | conv4 = LeakyReLU(alpha=0.01)(conv4) 66 | conv4 = BatchNormalization()(conv4) 67 | conv4 = Dropout(self.drop)(conv4) 68 | 69 | ##### Upsample path 70 | up4 = Concatenate(axis=-1)([UpSampling2D(size=(2, 2))(conv4), conv3]) 71 | # nfxnch x 64 x 64 72 | upconv1 = Conv2D(nf*8, (3, 3), padding='same')(up4) 73 | upconv1 = LeakyReLU(alpha=0.01)(upconv1) 74 | upconv1 = BatchNormalization()(upconv1) 75 | upconv1 = Dropout(self.drop)(upconv1) 76 | 77 | upconv1 = Conv2D(nf*8, (3, 3), padding='same')(upconv1) 78 | upconv1 = LeakyReLU(alpha=0.01)(upconv1) 79 | upconv1 = BatchNormalization()(upconv1) 80 | upconv1 = Dropout(self.drop)(upconv1) 81 | 82 | up3 = Concatenate(axis=-1)([UpSampling2D(size=(2, 2))(upconv1), conv2]) 83 | # nfxnch x 128 x 128 84 | upconv2 = Conv2D(nf, (3, 3), padding='same')(up3) 85 | upconv2 = LeakyReLU(alpha=0.01)(upconv2) 86 | upconv2 = BatchNormalization()(upconv2) 87 | upconv2 = Dropout(self.drop)(upconv2) 88 | 89 | upconv2 = Conv2D(nf, (3, 3), padding='same')(upconv2) 90 | upconv2 = LeakyReLU(alpha=0.01)(upconv2) 91 | upconv2 = BatchNormalization()(upconv2) 92 | upconv2 = Dropout(self.drop)(upconv2) 93 | 94 | up2 = Concatenate(axis=-1)([UpSampling2D(size=(2, 2))(upconv2), conv1]) 95 | # nfxnch x 256 x 256 96 | upconv3 = Conv2D(nf, (3, 3), padding='same')(up2) 97 | upconv3 = LeakyReLU(alpha=0.01)(upconv3) 98 | upconv3 = BatchNormalization()(upconv3) 99 | upconv3 = Dropout(self.drop)(upconv3) 100 | 101 | upconv3 = Conv2D(nf, (3, 3), padding='same')(upconv3) 102 | upconv3 = LeakyReLU(alpha=0.01)(upconv3) 103 | upconv3 = BatchNormalization()(upconv3) 104 | upconv3 = Dropout(self.drop)(upconv3) 105 | 106 | up1= Concatenate(axis=-1)([UpSampling2D(size=(2, 2))(upconv3), conv]) 107 | # nfxnch x 512 x 512 108 | upconv4 = Conv2D(nf, (3, 3), padding='same')(up1) 109 | upconv4 = LeakyReLU(alpha=0.01)(upconv4) 110 | upconv4 = BatchNormalization()(upconv4) 111 | upconv4 = Dropout(self.drop)(upconv4) 112 | 113 | upconv4 = Conv2D(nf, (3, 3), padding='same')(upconv4) 114 | upconv4 = LeakyReLU(alpha=0.01)(upconv4) 115 | upconv4 = BatchNormalization()(upconv4) 116 | upconv4 = Dropout(self.drop)(upconv4) 117 | 118 | conv_final = Conv2D(1, (3,3), padding='same')(upconv4) 119 | act = 'sigmoid' 120 | out_distances= Activation(act)(conv_final) 121 | 122 | model_distances = Model(inputs=input_n, outputs=out_distances) 123 | 124 | return model_distances 125 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """Auxiliary methods. """ 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import os 6 | import pickle 7 | 8 | 9 | def create_dir(mypath): 10 | """Create a directory if it does not exist.""" 11 | try: 12 | os.makedirs(mypath) 13 | except OSError as exc: 14 | if os.path.isdir(mypath): 15 | pass 16 | else: 17 | raise 18 | 19 | 20 | def plot_loss(loss, label, filename, log_dir, acc=None, title='', ylim=None): 21 | """Plot a loss function and save it in a file.""" 22 | loss = np.array(loss) 23 | plt.figure(figsize=(5, 4)) 24 | plt.plot(loss, label=label) 25 | if ylim is not None: 26 | plt.ylim(ylim) 27 | else: 28 | if acc is None: 29 | plt.ylim((0, 0.5)) 30 | else: 31 | plt.ylim((0,1.)) 32 | 33 | plt.title(title) 34 | plt.savefig(os.path.join(log_dir, filename)) 35 | plt.clf() 36 | plt.close('all') 37 | --------------------------------------------------------------------------------