├── .gitignore ├── LICENSE ├── README.md ├── pics ├── cls.png ├── correction.png ├── raw.jpg ├── seg.png └── structure.jpg ├── predict.py ├── requirements-conda.txt ├── requirements-pip.txt ├── train.py └── utils ├── __init__.py ├── crop_prediction.py ├── data_augmentation.py ├── define_model.py ├── prepare_dataset.py └── process_data_for_ALL_dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | .vscode/* 3 | output/* 4 | logs/* 5 | trained_model/* 6 | 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | *.DS_Store 113 | 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 conscienceli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeqNet 2 | 3 | ## Joint Learning of Vessel Segmentation and Artery/Vein Classification 4 | 5 | Retinal imaging serves as a valuable tool for diagnosis of various diseases. However, reading retinal images is a difficult and time-consuming task even for experienced specialists. The fundamental step towards automated retinal image analysis is vessel segmentation and artery/vein classification, which provide various information on potential disorders. To improve the performance of the existing automated methods for retinal image analysis, we propose a two-step vessel classification. We adopt a UNet-based model, SeqNet, to accurately segment vessels from the background and make prediction on the vessel type. Our model does segmentation and classification sequentially, which alleviates the problem of label distribution bias and facilitates training. 6 | 7 | ## Model 8 | 9 | ![Network Structure](./pics/structure.jpg) 10 | 11 | Fig.1 The network architecture of SeqNet. 12 | 13 | ## Usage 14 | 15 | When training, datasets should be placed at `./data/ALL`, following the data structure defined in `./utils/prepare_dataset.py`. 16 | 17 | Training: 18 | 19 | ```bash 20 | python train.py 21 | ``` 22 | 23 | Models will be saved in `./trained_model/` and results will be saved at `./output/`. 24 | 25 | Prediction: 26 | 27 | ```bash 28 | python predict.py -i ./data/test_images/ -o ./output/ 29 | ``` 30 | 31 | ## Pretrained Weights 32 | 33 | Here is a model trained with multiple datasets (all images in DRIVE, LES-AV, and HRF are used for training). Now I am using it for universal retinal vessel extraction and classification. In my test, it works well on new data even with very different brightness, color, etc. In my case, no fine-tunning is needed. 34 | 35 | [Download from Google Drive](https://drive.google.com/file/d/1OYjzu0gixtga6e7Rvb2mZoSSYJkXWRNB/view?usp=sharing) 36 | 37 | Please put it under `trained_model/ALL/`. 38 | 39 | The classification results of retinal images from some other datasets. 40 | 41 | ![Raw File](./pics/raw.jpg) 42 | ![Segmentation Result](./pics/seg.png) 43 | ![Classification Result](./pics/cls.png) 44 | 45 | This result seems nice for me. We asked a clinician to validate this result and below is the correction. 46 | ![Correction Result](./pics/correction.png) 47 | 48 | ## Publication 49 | 50 | If you want to use this work, please consider citing the following paper. 51 | 52 | ```bib 53 | @inproceedings{li2020joint, 54 | title={Joint Learning of Vessel Segmentation and Artery/Vein Classification with Post-processing}, 55 | author={Li, Liangzhi and Verma, Manisha and Nakashima, Yuta and Kawasaki, Ryo and Nagahara, Hajime}, 56 | booktitle={Medical Imaging with Deep Learning}, 57 | year={2020} 58 | } 59 | ``` 60 | 61 | You can find PDF, poster, and talk video (later) of this paper [here](https://www.liangzhili.com/publication/li-2020-joint/). 62 | 63 | ## Acknowledgements 64 | 65 | This work was supported by Council for Science, Technology and Innovation (CSTI), cross-ministerial Strategic Innovation Promotion Program (SIP), "Innovative AI Hospital System" (Funding Agency: National Institute of Biomedical Innovation, Health and Nutrition (NIBIOHN)). 66 | 67 | ## License 68 | 69 | This project is licensed under the MIT License. 70 | -------------------------------------------------------------------------------- /pics/cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conscienceli/SeqNet/30a805996d3590dfd675ebe5505a3cf0e61de9ad/pics/cls.png -------------------------------------------------------------------------------- /pics/correction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conscienceli/SeqNet/30a805996d3590dfd675ebe5505a3cf0e61de9ad/pics/correction.png -------------------------------------------------------------------------------- /pics/raw.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conscienceli/SeqNet/30a805996d3590dfd675ebe5505a3cf0e61de9ad/pics/raw.jpg -------------------------------------------------------------------------------- /pics/seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conscienceli/SeqNet/30a805996d3590dfd675ebe5505a3cf0e61de9ad/pics/seg.png -------------------------------------------------------------------------------- /pics/structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conscienceli/SeqNet/30a805996d3590dfd675ebe5505a3cf0e61de9ad/pics/structure.jpg -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | ############Test 2 | import argparse 3 | import os 4 | import tensorflow as tf 5 | from keras.backend import tensorflow_backend 6 | 7 | from utils import define_model, crop_prediction 8 | from keras.layers import ReLU 9 | from tqdm import tqdm 10 | import numpy as np 11 | from skimage.transform import resize 12 | import cv2 13 | 14 | from PIL import Image 15 | 16 | 17 | def predict(ACTIVATION='ReLU', dropout=0.1, batch_size=32, repeat=4, minimum_kernel=32, 18 | epochs=200, iteration=3, crop_size=128, stride_size=3, 19 | input_path='', output_path='', DATASET='ALL'): 20 | exts = ['png', 'jpg', 'tif', 'bmp', 'gif'] 21 | 22 | if not input_path.endswith('/'): 23 | input_path += '/' 24 | paths = [input_path + i for i in sorted(os.listdir(input_path)) if i.split('.')[-1] in exts] 25 | 26 | gt_list_out = {} 27 | pred_list_out = {} 28 | 29 | os.makedirs(f"{output_path}/out_seg/", exist_ok=True) 30 | os.makedirs(f"{output_path}/out_art/", exist_ok=True) 31 | os.makedirs(f"{output_path}/out_vei/", exist_ok=True) 32 | os.makedirs(f"{output_path}/out_final/", exist_ok=True) 33 | 34 | activation = globals()[ACTIVATION] 35 | model = define_model.get_unet(minimum_kernel=minimum_kernel, do=dropout, activation=activation, iteration=iteration) 36 | model_name = f"Final_Emer_Iteration_{iteration}_cropsize_{crop_size}_epochs_{epochs}" 37 | print("Model : %s" % model_name) 38 | load_path = f"trained_model/{DATASET}/{model_name}.hdf5" 39 | model.load_weights(load_path, by_name=False) 40 | 41 | for i in tqdm(range(len(paths))): 42 | filename = '.'.join(paths[i].split('/')[-1].split('.')[:-1]) 43 | img = Image.open(paths[i]) 44 | image_size = img.size 45 | img = np.array(img) / 255. 46 | img = resize(img, [576, 576]) 47 | 48 | patches_pred, new_height, new_width, adjustImg = crop_prediction.get_test_patches(img, crop_size, stride_size) 49 | preds = model.predict(patches_pred) 50 | 51 | #for segmentation 52 | pred = preds[iteration] 53 | pred_patches = crop_prediction.pred_to_patches(pred, crop_size, stride_size) 54 | pred_imgs = crop_prediction.recompone_overlap(pred_patches, crop_size, stride_size, new_height, new_width) 55 | pred_imgs = pred_imgs[:, 0:576, 0:576, :] 56 | probResult = pred_imgs[0, :, :, 0] 57 | pred_ = probResult 58 | pred_ = 255. * (pred_ - np.min(pred_)) / (np.max(pred_) - np.min(pred_)) 59 | pred_seg = pred_ 60 | pred_ = resize(pred_, image_size[::-1]) 61 | cv2.imwrite(f"{output_path}/out_seg/{filename}.png", pred_) 62 | 63 | #for artery 64 | pred = preds[2*iteration + 1] 65 | pred_patches = crop_prediction.pred_to_patches(pred, crop_size, stride_size) 66 | pred_imgs = crop_prediction.recompone_overlap(pred_patches, crop_size, stride_size, new_height, new_width) 67 | pred_imgs = pred_imgs[:, 0:576, 0:576, :] 68 | probResult = pred_imgs[0, :, :, 0] 69 | pred_ = probResult 70 | pred_ = 255. * (pred_ - np.min(pred_)) / (np.max(pred_) - np.min(pred_)) 71 | pred_art = pred_ 72 | pred_ = resize(pred_, image_size[::-1]) 73 | cv2.imwrite(f"{output_path}/out_art/{filename}.png", pred_) 74 | 75 | #for vein 76 | pred = preds[3*iteration + 2] 77 | pred_patches = crop_prediction.pred_to_patches(pred, crop_size, stride_size) 78 | pred_imgs = crop_prediction.recompone_overlap(pred_patches, crop_size, stride_size, new_height, new_width) 79 | pred_imgs = pred_imgs[:, 0:576, 0:576, :] 80 | probResult = pred_imgs[0, :, :, 0] 81 | pred_ = probResult 82 | pred_ = 255. * (pred_ - np.min(pred_)) / (np.max(pred_) - np.min(pred_)) 83 | pred_vei = pred_ 84 | pred_ = resize(pred_, image_size[::-1]) 85 | cv2.imwrite(f"{output_path}/out_vei/{filename}.png", pred_) 86 | 87 | #for final 88 | pred_final = np.zeros((*list(pred_seg.shape), 3), dtype=pred_seg.dtype) 89 | art_temp = pred_final[pred_art >= pred_vei] 90 | art_temp[:,2] = pred_seg[pred_art >= pred_vei] 91 | pred_final[pred_art >= pred_vei] = art_temp 92 | vei_temp = pred_final[pred_art < pred_vei] 93 | vei_temp[:,0] = pred_seg[pred_art < pred_vei] 94 | pred_final[pred_art < pred_vei] = vei_temp 95 | pred_ = pred_final 96 | pred_ = resize(pred_, image_size[::-1]) 97 | cv2.imwrite(f"{output_path}/out_final/{filename}.png", pred_) 98 | 99 | 100 | 101 | 102 | if __name__ == "__main__": 103 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 104 | config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) 105 | session = tf.Session(config=config) 106 | tensorflow_backend.set_session(session) 107 | 108 | 109 | # define the program description 110 | des_text = 'Please use -i to specify the input dir and -o to specify the output dir.' 111 | 112 | # initiate the parser 113 | parser = argparse.ArgumentParser(description=des_text) 114 | parser.add_argument('--input', '-i', help="(Required) Path of input dir") 115 | parser.add_argument('--output', '-o', help="(Optional) Path of output dir") 116 | args = parser.parse_args() 117 | 118 | if not args.input: 119 | print('Please specify the input dir with -i') 120 | exit(1) 121 | 122 | input_path = args.input 123 | 124 | if not args.output: 125 | output_path = './output/' 126 | else: 127 | output_path = args.output 128 | if output_path.endswith('/'): 129 | output_path = output_path[:-1] 130 | 131 | 132 | #stride_size = 3 will be better, but slower 133 | predict(batch_size=24, epochs=200, iteration=3, stride_size=3, crop_size=128, 134 | input_path=input_path, output_path=output_path) -------------------------------------------------------------------------------- /requirements-conda.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=conda_forge 5 | _openmp_mutex=4.5=0_gnu 6 | _tflow_select=2.1.0=gpu 7 | absl-py=0.9.0=py37hc8dfbb8_1 8 | appdirs=1.4.3=py_1 9 | astor=0.7.1=py_0 10 | binutils_linux-64=2.34=hc952b39_18 11 | blas=2.14=openblas 12 | blinker=1.4=py_1 13 | brotlipy=0.7.0=py37h8f50634_1000 14 | bzip2=1.0.8=h516909a_2 15 | c-ares=1.15.0=h516909a_1001 16 | ca-certificates=2020.4.5.1=hecc5488_0 17 | cachetools=3.1.1=py_0 18 | cairo=1.16.0=hcf35c78_1003 19 | certifi=2020.4.5.1=py37hc8dfbb8_0 20 | cffi=1.14.0=py37hd463f26_0 21 | chardet=3.0.4=py37hc8dfbb8_1006 22 | click=7.1.2=pyh9f0ad1d_0 23 | cloudpickle=1.4.1=py_0 24 | cryptography=2.9.2=py37hb09aad4_0 25 | cudatoolkit=10.0.130=0 26 | cudnn=7.6.5=cuda10.0_0 27 | cupti=10.0.130=0 28 | cycler=0.10.0=py_2 29 | cytoolz=0.10.1=py37h516909a_0 30 | dask-core=2.16.0=py_0 31 | dbus=1.13.6=he372182_0 32 | decorator=4.4.2=py_0 33 | expat=2.2.9=he1b5a44_2 34 | ffmpeg=4.2=h167e202_0 35 | fontconfig=2.13.1=h86ecdb6_1001 36 | freetype=2.10.1=he06d7ca_0 37 | gast=0.2.2=py_0 38 | gcc_impl_linux-64=7.3.0=hd420e75_5 39 | gcc_linux-64=7.3.0=h553295d_18 40 | gettext=0.19.8.1=hc5be6a0_1002 41 | giflib=5.2.1=h516909a_2 42 | glib=2.64.2=h6f030ca_0 43 | gmp=6.2.0=he1b5a44_2 44 | gnutls=3.6.5=hd3a4fd2_1002 45 | google-auth=1.14.2=pyh9f0ad1d_0 46 | google-auth-oauthlib=0.4.1=py_2 47 | google-pasta=0.2.0=pyh8c360ce_0 48 | graphite2=1.3.13=he1b5a44_1001 49 | grpcio=1.27.2=py37hf8bcb03_0 50 | gst-plugins-base=1.14.5=h0935bb2_2 51 | gstreamer=1.14.5=h36ae1b5_2 52 | gxx_impl_linux-64=7.3.0=hdf63c60_5 53 | gxx_linux-64=7.3.0=h553295d_18 54 | h5py=2.10.0=nompi_py37h513d04c_102 55 | harfbuzz=2.4.0=h9f30f68_3 56 | hdf5=1.10.5=nompi_h3c11f04_1104 57 | icu=64.2=he1b5a44_1 58 | idna=2.9=py_1 59 | imagecodecs-lite=2019.12.3=py37h8f50634_0 60 | imageio=2.8.0=py_0 61 | jasper=1.900.1=h07fcdf6_1006 62 | joblib=0.14.1=py_0 63 | jpeg=9c=h14c3975_1001 64 | keras=2.3.1=py37_0 65 | keras-applications=1.0.8=py_1 66 | keras-preprocessing=1.1.0=py_0 67 | kiwisolver=1.2.0=py37h99015e2_0 68 | lame=3.100=h14c3975_1001 69 | libblas=3.8.0=14_openblas 70 | libcblas=3.8.0=14_openblas 71 | libclang=9.0.1=default_hde54327_0 72 | libffi=3.2.1=he1b5a44_1007 73 | libgcc-ng=9.2.0=h24d8f2e_2 74 | libgfortran-ng=7.3.0=hdf63c60_5 75 | libgomp=9.2.0=h24d8f2e_2 76 | libgpuarray=0.7.6=h14c3975_1003 77 | libiconv=1.15=h516909a_1006 78 | liblapack=3.8.0=14_openblas 79 | liblapacke=3.8.0=14_openblas 80 | libllvm9=9.0.1=he513fc3_1 81 | libopenblas=0.3.7=h5ec1e0e_6 82 | libopencv=4.2.0=py37_5 83 | libpng=1.6.37=hed695b0_1 84 | libprotobuf=3.11.4=h8b12597_0 85 | libstdcxx-ng=9.2.0=hdf63c60_2 86 | libtiff=4.1.0=hc3755c2_3 87 | libuuid=2.32.1=h14c3975_1000 88 | libwebp=1.0.2=h56121f0_5 89 | libxcb=1.13=h14c3975_1002 90 | libxkbcommon=0.10.0=he1b5a44_0 91 | libxml2=2.9.10=hee79883_0 92 | lz4-c=1.9.2=he1b5a44_1 93 | mako=1.1.0=py_0 94 | markdown=3.2.1=py_0 95 | markupsafe=1.1.1=py37h8f50634_1 96 | matplotlib-base=3.2.1=py37h30547a4_0 97 | ncurses=6.1=hf484d3e_1002 98 | nettle=3.4.1=h1bed415_1002 99 | networkx=2.4=py_1 100 | nspr=4.25=he1b5a44_0 101 | nss=3.47=he751ad9_0 102 | numpy=1.18.4=py37h8960a57_0 103 | oauthlib=3.0.1=py_0 104 | olefile=0.46=py_0 105 | opencv=4.2.0=py37_5 106 | openh264=1.8.0=hdbcaa40_1000 107 | openssl=1.1.1g=h516909a_0 108 | opt_einsum=3.2.1=py_0 109 | packaging=20.1=py_0 110 | pcre=8.44=he1b5a44_0 111 | pillow=7.1.2=py37h718be6c_0 112 | pip=20.1=pyh9f0ad1d_0 113 | pixman=0.38.0=h516909a_1003 114 | pooch=1.1.0=py_0 115 | protobuf=3.11.4=py37h3340039_1 116 | pthread-stubs=0.4=h14c3975_1001 117 | py-opencv=4.2.0=py37h43977f1_5 118 | pyasn1=0.4.8=py_0 119 | pyasn1-modules=0.2.7=py_0 120 | pycparser=2.20=py_0 121 | pygpu=0.7.6=py37hc1659b7_1000 122 | pyjwt=1.7.1=py_0 123 | pyopenssl=19.1.0=py_1 124 | pyparsing=2.4.7=pyh9f0ad1d_0 125 | pysocks=1.7.1=py37hc8dfbb8_1 126 | python=3.7.6=h8356626_5_cpython 127 | python-dateutil=2.8.1=py_0 128 | python_abi=3.7=1_cp37m 129 | pywavelets=1.1.1=py37h03ebfcd_1 130 | pyyaml=5.3.1=py37h8f50634_0 131 | qt=5.12.5=hd8c4c69_1 132 | readline=8.0=hf8c457e_0 133 | requests=2.23.0=pyh8c360ce_2 134 | requests-oauthlib=1.2.0=py_0 135 | rsa=4.0=py_0 136 | scikit-image=0.17.1=py37h0da4684_0 137 | scikit-learn=0.22.1=py37h22eb022_0 138 | scipy=1.4.1=py37ha3d9a3c_3 139 | setuptools=46.1.3=py37hc8dfbb8_0 140 | six=1.14.0=py_1 141 | sqlite=3.30.1=hcee41ef_0 142 | tensorboard=1.15.0=pyhb230dea_0 143 | tensorflow=1.15.0=gpu_py37h0f0df58_0 144 | tensorflow-base=1.15.0=gpu_py37h9dcbed7_0 145 | tensorflow-estimator=1.15.1=pyh2649769_0 146 | tensorflow-gpu=1.15.0=h0d30ee6_0 147 | termcolor=1.1.0=py_2 148 | theano=1.0.4=py37he1b5a44_1001 149 | tifffile=2020.5.7=py_0 150 | tk=8.6.10=hed695b0_0 151 | toolz=0.10.0=py_0 152 | tornado=6.0.4=py37h8f50634_1 153 | tqdm=4.46.0=py_0 154 | urllib3=1.25.9=py_0 155 | webencodings=0.5.1=py37_1 156 | werkzeug=0.16.1=py_0 157 | wheel=0.34.2=py_1 158 | wrapt=1.12.1=py37h8f50634_1 159 | x264=1!152.20180806=h14c3975_0 160 | xorg-kbproto=1.0.7=h14c3975_1002 161 | xorg-libice=1.0.10=h516909a_0 162 | xorg-libsm=1.2.3=h84519dc_1000 163 | xorg-libx11=1.6.9=h516909a_0 164 | xorg-libxau=1.0.9=h14c3975_0 165 | xorg-libxdmcp=1.1.3=h516909a_0 166 | xorg-libxext=1.3.4=h516909a_0 167 | xorg-libxrender=0.9.10=h516909a_1002 168 | xorg-renderproto=0.11.1=h14c3975_1002 169 | xorg-xextproto=7.3.0=h14c3975_1002 170 | xorg-xproto=7.0.31=h14c3975_1007 171 | xz=5.2.5=h516909a_0 172 | yaml=0.2.4=h516909a_0 173 | zlib=1.2.11=h516909a_1006 174 | zstd=1.4.4=h6597ccf_3 175 | -------------------------------------------------------------------------------- /requirements-pip.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | appdirs==1.4.3 3 | astor==0.7.1 4 | blinker==1.4 5 | brotlipy==0.7.0 6 | cachetools==3.1.1 7 | certifi==2020.4.5.1 8 | cffi==1.14.0 9 | chardet==3.0.4 10 | click==7.1.2 11 | cloudpickle==1.4.1 12 | cryptography==2.9.2 13 | cycler==0.10.0 14 | cytoolz==0.10.1 15 | dask==2.16.0 16 | decorator==4.4.2 17 | gast==0.2.2 18 | google-auth==1.14.2 19 | google-auth-oauthlib==0.4.1 20 | google-pasta==0.2.0 21 | grpcio==1.27.2 22 | h5py==2.10.0 23 | idna==2.9 24 | imagecodecs-lite==2019.12.3 25 | imageio==2.8.0 26 | joblib==0.14.1 27 | Keras==2.3.1 28 | Keras-Applications==1.0.8 29 | Keras-Preprocessing==1.1.0 30 | kiwisolver==1.2.0 31 | Mako==1.1.0 32 | Markdown==3.2.1 33 | MarkupSafe==1.1.1 34 | matplotlib==3.2.1 35 | networkx==2.4 36 | numpy==1.18.4 37 | oauthlib==3.0.1 38 | olefile==0.46 39 | opt-einsum==3.2.1 40 | packaging==20.1 41 | Pillow==7.1.2 42 | pooch==1.1.0 43 | protobuf==3.11.4 44 | pyasn1==0.4.8 45 | pyasn1-modules==0.2.7 46 | pycparser==2.20 47 | pygpu==0.7.6 48 | PyJWT==1.7.1 49 | pyOpenSSL==19.1.0 50 | pyparsing==2.4.7 51 | PySocks==1.7.1 52 | python-dateutil==2.8.1 53 | PyWavelets==1.1.1 54 | PyYAML==5.3.1 55 | requests==2.23.0 56 | requests-oauthlib==1.2.0 57 | rsa==4.0 58 | scikit-image==0.17.1 59 | scikit-learn==0.22.1 60 | scipy==1.4.1 61 | six==1.14.0 62 | tensorboard==1.15.0 63 | tensorflow==1.15.0 64 | tensorflow-estimator==1.15.1 65 | termcolor==1.1.0 66 | Theano==1.0.4 67 | tifffile==2020.5.7 68 | toolz==0.10.0 69 | tornado==6.0.4 70 | tqdm==4.46.0 71 | urllib3==1.25.9 72 | webencodings==0.5.1 73 | Werkzeug==0.16.1 74 | wrapt==1.12.1 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | from datetime import datetime 5 | from keras.callbacks import TensorBoard, ModelCheckpoint 6 | from keras.layers import ReLU 7 | from keras.utils import plot_model 8 | from utils import define_model, prepare_dataset 9 | import tensorflow as tf 10 | from keras.backend import tensorflow_backend 11 | 12 | def train(iteration=3, DATASET='ALL', crop_size=128, need_au=True, ACTIVATION='ReLU', dropout=0.1, batch_size=32, 13 | repeat=4, minimum_kernel=32, epochs=200): 14 | model_name = f"Final_Emer_Iteration_{iteration}_cropsize_{crop_size}_epochs_{epochs}" 15 | 16 | print("Model : %s" % model_name) 17 | 18 | prepare_dataset.prepareDataset(DATASET) 19 | 20 | activation = globals()[ACTIVATION] 21 | model = define_model.get_unet(minimum_kernel=minimum_kernel, do=dropout, activation=activation, iteration=iteration) 22 | 23 | try: 24 | os.makedirs(f"trained_model/{DATASET}/", exist_ok=True) 25 | os.makedirs(f"logs/{DATASET}/", exist_ok=True) 26 | except: 27 | pass 28 | 29 | load_path = f"trained_model/{DATASET}/{model_name}_weights.best.hdf5" 30 | try: 31 | model.load_weights(load_path, by_name=True) 32 | except: 33 | pass 34 | 35 | now = datetime.now() # current date and time 36 | date_time = now.strftime("%Y-%m-%d---%H-%M-%S") 37 | 38 | tensorboard = TensorBoard( 39 | log_dir=f"logs/{DATASET}/Final_Emer-Iteration_{iteration}-Cropsize_{crop_size}-Epochs_{epochs}---{date_time}", 40 | histogram_freq=0, batch_size=32, write_graph=True, write_grads=True, 41 | write_images=True, embeddings_freq=0, embeddings_layer_names=None, 42 | embeddings_metadata=None, embeddings_data=None, update_freq='epoch') 43 | 44 | save_path = f"trained_model/{DATASET}/{model_name}.hdf5" 45 | checkpoint = ModelCheckpoint(save_path, monitor='seg_final_out_loss', verbose=1, save_best_only=True, mode='min') 46 | 47 | data_generator = define_model.Generator(batch_size, repeat, DATASET) 48 | 49 | history = model.fit_generator(data_generator.gen(au=need_au, crop_size=crop_size, iteration=iteration), 50 | epochs=epochs, verbose=1, 51 | steps_per_epoch=100 * data_generator.n // batch_size, 52 | use_multiprocessing=True, workers=8, 53 | callbacks=[tensorboard, checkpoint]) 54 | 55 | 56 | if __name__ == "__main__": 57 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 58 | config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) 59 | session = tf.Session(config=config) 60 | tensorflow_backend.set_session(session) 61 | 62 | #epochs>100 will be enough, but slower 63 | train(batch_size=24, iteration=3, 64 | epochs=200, crop_size=128) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conscienceli/SeqNet/30a805996d3590dfd675ebe5505a3cf0e61de9ad/utils/__init__.py -------------------------------------------------------------------------------- /utils/crop_prediction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_test_patches(img, crop_size, stride_size, rl=False): 5 | test_img = [] 6 | 7 | test_img.append(img) 8 | test_img = np.asarray(test_img) 9 | 10 | # test_img_adjust=img_process(test_img,rl=rl) 11 | test_img_adjust = test_img 12 | test_imgs = paint_border(test_img_adjust, crop_size, stride_size) 13 | 14 | test_img_patch = extract_patches(test_imgs, crop_size, stride_size) 15 | 16 | return test_img_patch, test_imgs.shape[1], test_imgs.shape[2], test_img_adjust 17 | 18 | 19 | def extract_patches(full_imgs, crop_size, stride_size): 20 | patch_height = crop_size 21 | patch_width = crop_size 22 | stride_height = stride_size 23 | stride_width = stride_size 24 | 25 | assert (len(full_imgs.shape) == 4) # 4D arrays 26 | img_h = full_imgs.shape[1] # height of the full image 27 | img_w = full_imgs.shape[2] # width of the full image 28 | 29 | assert ((img_h - patch_height) % stride_height == 0 and (img_w - patch_width) % stride_width == 0) 30 | N_patches_img = ((img_h - patch_height) // stride_height + 1) * ( 31 | (img_w - patch_width) // stride_width + 1) # // --> division between integers 32 | N_patches_tot = N_patches_img * full_imgs.shape[0] 33 | 34 | patches = np.empty((N_patches_tot, patch_height, patch_width, full_imgs.shape[3])) 35 | iter_tot = 0 # iter over the total number of patches (N_patches) 36 | for i in range(full_imgs.shape[0]): # loop over the full images 37 | for h in range((img_h - patch_height) // stride_height + 1): 38 | for w in range((img_w - patch_width) // stride_width + 1): 39 | patch = full_imgs[i, h * stride_height:(h * stride_height) + patch_height, 40 | w * stride_width:(w * stride_width) + patch_width, :] 41 | patches[iter_tot] = patch 42 | iter_tot += 1 # total 43 | assert (iter_tot == N_patches_tot) 44 | return patches 45 | 46 | 47 | def paint_border(imgs, crop_size, stride_size): 48 | patch_height = crop_size 49 | patch_width = crop_size 50 | stride_height = stride_size 51 | stride_width = stride_size 52 | 53 | assert (len(imgs.shape) == 4) 54 | img_h = imgs.shape[1] # height of the full image 55 | img_w = imgs.shape[2] # width of the full image 56 | leftover_h = (img_h - patch_height) % stride_height # leftover on the h dim 57 | leftover_w = (img_w - patch_width) % stride_width # leftover on the w dim 58 | full_imgs = None 59 | if (leftover_h != 0): # change dimension of img_h 60 | tmp_imgs = np.zeros((imgs.shape[0], img_h + (stride_height - leftover_h), img_w, imgs.shape[3])) 61 | tmp_imgs[0:imgs.shape[0], 0:img_h, 0:img_w, 0:imgs.shape[3]] = imgs 62 | full_imgs = tmp_imgs 63 | if (leftover_w != 0): # change dimension of img_w 64 | tmp_imgs = np.zeros( 65 | (full_imgs.shape[0], full_imgs.shape[1], img_w + (stride_width - leftover_w), full_imgs.shape[3])) 66 | tmp_imgs[0:imgs.shape[0], 0:imgs.shape[1], 0:img_w, 0:full_imgs.shape[3]] = imgs 67 | full_imgs = tmp_imgs 68 | # print("new full images shape: \n" +str(full_imgs.shape)) 69 | return full_imgs 70 | else: 71 | return imgs 72 | 73 | 74 | def pred_to_patches(pred, crop_size, stride_size): 75 | return pred 76 | patch_height = crop_size 77 | patch_width = crop_size 78 | 79 | seg_num = 0 80 | # print(pred.shape) 81 | 82 | assert (len(pred.shape) == 3) # 3D array: (Npatches,height*width,2) 83 | 84 | pred_images = np.empty((pred.shape[0], pred.shape[1], seg_num + 1)) # (Npatches,height*width) 85 | pred_images[:, :, 0:seg_num + 1] = pred[:, :, 0:seg_num + 1] 86 | pred_images = np.reshape(pred_images, (pred_images.shape[0], patch_height, patch_width, seg_num + 1)) 87 | return pred_images 88 | 89 | 90 | def recompone_overlap(preds, crop_size, stride_size, img_h, img_w): 91 | assert (len(preds.shape) == 4) # 4D arrays 92 | 93 | patch_h = crop_size 94 | patch_w = crop_size 95 | stride_height = stride_size 96 | stride_width = stride_size 97 | 98 | N_patches_h = (img_h - patch_h) // stride_height + 1 99 | N_patches_w = (img_w - patch_w) // stride_width + 1 100 | N_patches_img = N_patches_h * N_patches_w 101 | # print("N_patches_h: " +str(N_patches_h)) 102 | # print("N_patches_w: " +str(N_patches_w)) 103 | # print("N_patches_img: " +str(N_patches_img)) 104 | # assert (preds.shape[0]%N_patches_img==0) 105 | N_full_imgs = preds.shape[0] // N_patches_img 106 | # print("According to the dimension inserted, there are " +str(N_full_imgs) +" full images (of " +str(img_h)+"x" +str(img_w) +" each)") 107 | full_prob = np.zeros( 108 | (N_full_imgs, img_h, img_w, preds.shape[3])) # itialize to zero mega array with sum of Probabilities 109 | full_sum = np.zeros((N_full_imgs, img_h, img_w, preds.shape[3])) 110 | 111 | k = 0 # iterator over all the patches 112 | for i in range(N_full_imgs): 113 | for h in range((img_h - patch_h) // stride_height + 1): 114 | for w in range((img_w - patch_w) // stride_width + 1): 115 | full_prob[i, h * stride_height:(h * stride_height) + patch_h, 116 | w * stride_width:(w * stride_width) + patch_w, :] += preds[k] 117 | full_sum[i, h * stride_height:(h * stride_height) + patch_h, 118 | w * stride_width:(w * stride_width) + patch_w, :] += 1 119 | k += 1 120 | # print(k,preds.shape[0]) 121 | assert (k == preds.shape[0]) 122 | assert (np.min(full_sum) >= 1.0) # at least one 123 | final_avg = full_prob / full_sum 124 | # print('using avg') 125 | return final_avg -------------------------------------------------------------------------------- /utils/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras_preprocessing import image 3 | 4 | 5 | def random_flip(img, masks, masks2, u=0.5): 6 | if np.random.random() < u: 7 | img = image.flip_axis(img, 1) 8 | for i in range(masks.shape[0]): 9 | masks[i] = image.flip_axis(masks[i], 1) 10 | for i in range(masks2.shape[0]): 11 | masks2[i] = image.flip_axis(masks2[i], 1) 12 | if np.random.random() < u: 13 | img = image.flip_axis(img, 0) 14 | for i in range(masks.shape[0]): 15 | masks[i] = image.flip_axis(masks[i], 0) 16 | for i in range(masks2.shape[0]): 17 | masks2[i] = image.flip_axis(masks2[i], 0) 18 | return img, masks, masks2 19 | 20 | 21 | def random_rotate(img, masks, masks2, rotate_limit=(-20, 20), u=0.5): 22 | if np.random.random() < u: 23 | theta = np.random.uniform(rotate_limit[0], rotate_limit[1]) 24 | img = image.apply_affine_transform(img, theta=theta) 25 | for i in range(masks.shape[0]): 26 | masks[i] = image.apply_affine_transform(masks[i], theta=theta) 27 | for i in range(masks2.shape[0]): 28 | masks2[i] = image.apply_affine_transform(masks2[i], theta=theta) 29 | return img, masks, masks2 30 | 31 | 32 | def shift(x, wshift, hshift, row_axis=0, col_axis=1, channel_axis=2, fill_mode='nearest', cval=0.): 33 | h, w = x.shape[row_axis], x.shape[col_axis] 34 | tx = hshift * h 35 | ty = wshift * w 36 | x = image.apply_affine_transform(x, ty=ty, tx=tx) 37 | return x 38 | 39 | 40 | def random_shift(img, masks, masks2, w_limit=(-0.1, 0.1), h_limit=(-0.1, 0.1), u=0.5): 41 | if np.random.random() < u: 42 | wshift = np.random.uniform(w_limit[0], w_limit[1]) 43 | hshift = np.random.uniform(h_limit[0], h_limit[1]) 44 | img = shift(img, wshift, hshift) 45 | for i in range(masks.shape[0]): 46 | masks[i] = shift(masks[i], wshift, hshift) 47 | for i in range(masks2.shape[0]): 48 | masks2[i] = shift(masks2[i], wshift, hshift) 49 | return img, masks, masks2 50 | 51 | 52 | def random_zoom(img, masks, masks2, zoom_range=(0.8, 1), u=0.5): 53 | if np.random.random() < u: 54 | zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) 55 | img = image.apply_affine_transform(img, zx=zx, zy=zy) 56 | for i in range(masks.shape[0]): 57 | masks[i] = image.apply_affine_transform(masks[i], zx=zx, zy=zy) 58 | for i in range(masks2.shape[0]): 59 | masks2[i] = image.apply_affine_transform(masks2[i], zx=zx, zy=zy) 60 | return img, masks, masks2 61 | 62 | 63 | def random_shear(img, masks, masks2, intensity_range=(-0.5, 0.5), u=0.5): 64 | if np.random.random() < u: 65 | sh = np.random.uniform(-intensity_range[0], intensity_range[1]) 66 | img = image.apply_affine_transform(img, shear=sh) 67 | for i in range(masks.shape[0]): 68 | masks[i] = image.apply_affine_transform(masks[i], shear=sh) 69 | for i in range(masks2.shape[0]): 70 | masks2[i] = image.apply_affine_transform(masks2[i], shear=sh) 71 | return img, masks, masks2 72 | 73 | 74 | def random_gray(img, u=0.5): 75 | if np.random.random() < u: 76 | coef = np.array([[[0.114, 0.587, 0.299]]]) # rgb to gray (YCbCr) 77 | gray = np.sum(img * coef, axis=2) 78 | img = np.dstack((gray, gray, gray)) 79 | return img 80 | 81 | 82 | def random_contrast(img, limit=(-0.3, 0.3), u=0.5): 83 | if np.random.random() < u: 84 | alpha = 1.0 + np.random.uniform(limit[0], limit[1]) 85 | coef = np.array([[[0.114, 0.587, 0.299]]]) # rgb to gray (YCbCr) 86 | gray = img * coef 87 | gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray) 88 | img = alpha * img + gray 89 | img = np.clip(img, 0., 1.) 90 | return img 91 | 92 | 93 | def random_brightness(img, limit=(-0.3, 0.3), u=0.5): 94 | if np.random.random() < u: 95 | alpha = 1.0 + np.random.uniform(limit[0], limit[1]) 96 | img = alpha * img 97 | img = np.clip(img, 0., 1.) 98 | return img 99 | 100 | 101 | def random_saturation(img, limit=(-0.3, 0.3), u=0.5): 102 | if np.random.random() < u: 103 | alpha = 1.0 + np.random.uniform(limit[0], limit[1]) 104 | coef = np.array([[[0.114, 0.587, 0.299]]]) 105 | gray = img * coef 106 | gray = np.sum(gray, axis=2, keepdims=True) 107 | img = alpha * img + (1. - alpha) * gray 108 | img = np.clip(img, 0., 1.) 109 | return img 110 | 111 | 112 | def random_channel_shift(x, limit, channel_axis=2): 113 | x = np.rollaxis(x, channel_axis, 0) 114 | min_x, max_x = np.min(x), np.max(x) 115 | channel_images = [np.clip(x_ch + np.random.uniform(-limit, limit), min_x, max_x) for x_ch in x] 116 | x = np.stack(channel_images, axis=0) 117 | x = np.rollaxis(x, 0, channel_axis + 1) 118 | return x 119 | 120 | 121 | def random_augmentation(img, masks, masks2=None): 122 | img = random_brightness(img, limit=(-0.2, 0.2), u=0.5) 123 | img = random_contrast(img, limit=(-0.2, 0.2), u=0.5) 124 | img = random_saturation(img, limit=(-0.2, 0.2), u=0.5) 125 | img, masks, masks2 = random_rotate(img, masks, masks2, rotate_limit=(-180, 180), u=0.5) 126 | img, masks, masks2 = random_shear(img, masks, masks2, intensity_range=(-5, 5), u=0.05) 127 | img, masks, masks2 = random_flip(img, masks, masks2, u=0.5) 128 | img, masks, masks2 = random_shift(img, masks, masks2, w_limit=(-0.1, 0.1), h_limit=(-0.1, 0.1), u=0.05) 129 | img, masks, masks2 = random_zoom(img, masks, masks2, zoom_range=(0.8, 1.2), u=0.05) 130 | return img, masks, masks2 -------------------------------------------------------------------------------- /utils/define_model.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import keras.backend as K 3 | import numpy as np 4 | import os 5 | import os.path 6 | import tensorflow as tf 7 | import threading 8 | from PIL import Image 9 | from keras import backend as K 10 | from keras import losses 11 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau 12 | from keras.layers import Input, MaxPooling2D, Lambda 13 | from keras.layers import concatenate, Conv2D, Conv2DTranspose, Dropout, ReLU, BatchNormalization, Activation 14 | from keras.layers.merge import add, multiply 15 | from keras.models import Model 16 | from keras.optimizers import Adam 17 | from numpy import random 18 | from random import randint 19 | from utils import data_augmentation, prepare_dataset 20 | 21 | 22 | def get_unet(minimum_kernel=32, do=0, activation=ReLU, iteration=1): 23 | inputs = Input((None, None, 3)) 24 | conv1 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(inputs))) 25 | conv1 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(conv1))) 26 | a = conv1 27 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 28 | 29 | conv2 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(pool1))) 30 | conv2 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(conv2))) 31 | b = conv2 32 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 33 | 34 | conv3 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(pool2))) 35 | conv3 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(conv3))) 36 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 37 | 38 | conv4 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(pool3))) 39 | conv4 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(conv4))) 40 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 41 | 42 | conv5 = Dropout(do)(activation()(Conv2D(minimum_kernel * 16, (3, 3), padding='same')(pool4))) 43 | conv5 = Dropout(do)(activation()(Conv2D(minimum_kernel * 16, (3, 3), padding='same')(conv5))) 44 | 45 | up6 = concatenate([Conv2DTranspose(minimum_kernel * 8, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], 46 | axis=3) 47 | conv6 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(up6))) 48 | conv6 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(conv6))) 49 | 50 | up7 = concatenate([Conv2DTranspose(minimum_kernel * 4, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], 51 | axis=3) 52 | conv7 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(up7))) 53 | conv7 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(conv7))) 54 | 55 | up8 = concatenate([Conv2DTranspose(minimum_kernel * 2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], 56 | axis=3) 57 | conv8 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(up8))) 58 | conv8 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(conv8))) 59 | 60 | up9 = concatenate([Conv2DTranspose(minimum_kernel, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3) 61 | conv9 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(up9))) 62 | conv9 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(conv9))) 63 | 64 | pt_conv1a = Conv2D(minimum_kernel, (3, 3), padding='same') 65 | pt_activation1a = activation() 66 | pt_dropout1a = Dropout(do) 67 | pt_conv1b = Conv2D(minimum_kernel, (3, 3), padding='same') 68 | pt_activation1b = activation() 69 | pt_dropout1b = Dropout(do) 70 | pt_pooling1 = MaxPooling2D(pool_size=(2, 2)) 71 | 72 | pt_conv2a = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 73 | pt_activation2a = activation() 74 | pt_dropout2a = Dropout(do) 75 | pt_conv2b = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 76 | pt_activation2b = activation() 77 | pt_dropout2b = Dropout(do) 78 | pt_pooling2 = MaxPooling2D(pool_size=(2, 2)) 79 | 80 | pt_conv3a = Conv2D(minimum_kernel * 4, (3, 3), padding='same') 81 | pt_activation3a = activation() 82 | pt_dropout3a = Dropout(do) 83 | pt_conv3b = Conv2D(minimum_kernel * 4, (3, 3), padding='same') 84 | pt_activation3b = activation() 85 | pt_dropout3b = Dropout(do) 86 | 87 | pt_tranconv8 = Conv2DTranspose(minimum_kernel * 2, (2, 2), strides=(2, 2), padding='same') 88 | pt_conv8a = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 89 | pt_activation8a = activation() 90 | pt_dropout8a = Dropout(do) 91 | pt_conv8b = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 92 | pt_activation8b = activation() 93 | pt_dropout8b = Dropout(do) 94 | 95 | pt_tranconv9 = Conv2DTranspose(minimum_kernel, (2, 2), strides=(2, 2), padding='same') 96 | pt_conv9a = Conv2D(minimum_kernel, (3, 3), padding='same') 97 | pt_activation9a = activation() 98 | pt_dropout9a = Dropout(do) 99 | pt_conv9b = Conv2D(minimum_kernel, (3, 3), padding='same') 100 | pt_activation9b = activation() 101 | pt_dropout9b = Dropout(do) 102 | 103 | conv9s = [conv9] 104 | outs = [] 105 | a_layers = [a] 106 | for iteration_id in range(iteration): 107 | out = Conv2D(1, (1, 1), activation='sigmoid', name=f'out1{iteration_id + 1}')(conv9s[-1]) 108 | outs.append(out) 109 | 110 | conv1 = pt_dropout1a(pt_activation1a(pt_conv1a(conv9s[-1]))) 111 | conv1 = pt_dropout1b(pt_activation1b(pt_conv1b(conv1))) 112 | a_layers.append(conv1) 113 | conv1 = concatenate(a_layers, axis=3) 114 | conv1 = Conv2D(minimum_kernel, (1, 1), padding='same')(conv1) 115 | pool1 = pt_pooling1(conv1) 116 | 117 | conv2 = pt_dropout2a(pt_activation2a(pt_conv2a(pool1))) 118 | conv2 = pt_dropout2b(pt_activation2b(pt_conv2b(conv2))) 119 | pool2 = pt_pooling2(conv2) 120 | 121 | conv3 = pt_dropout3a(pt_activation3a(pt_conv3a(pool2))) 122 | conv3 = pt_dropout3b(pt_activation3b(pt_conv3b(conv3))) 123 | 124 | up8 = concatenate([pt_tranconv8(conv3), conv2], axis=3) 125 | conv8 = pt_dropout8a(pt_activation8a(pt_conv8a(up8))) 126 | conv8 = pt_dropout8b(pt_activation8b(pt_conv8b(conv8))) 127 | 128 | up9 = concatenate([pt_tranconv9(conv8), conv1], axis=3) 129 | conv9 = pt_dropout9a(pt_activation9a(pt_conv9a(up9))) 130 | conv9 = pt_dropout9b(pt_activation9b(pt_conv9b(conv9))) 131 | 132 | conv9s.append(conv9) 133 | 134 | seg_final_out = Conv2D(1, (1, 1), activation='sigmoid', name='seg_final_out')(conv9) 135 | outs.append(seg_final_out) 136 | 137 | # to cls 138 | def masked_input(args): 139 | x, inputs = args 140 | return x * inputs 141 | cls_in = Lambda(masked_input)([seg_final_out, inputs]) 142 | # cls_in = concatenate([cls_in, crossing_final_out], axis=3) 143 | cls_in = Lambda(lambda x: K.stop_gradient(x))(cls_in) 144 | 145 | # to cls (artery) 146 | conv1 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(cls_in))) 147 | conv1 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(conv1))) 148 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 149 | 150 | conv2 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(pool1))) 151 | conv2 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(conv2))) 152 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 153 | 154 | conv3 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(pool2))) 155 | conv3 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(conv3))) 156 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 157 | 158 | conv4 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(pool3))) 159 | conv4 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(conv4))) 160 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 161 | 162 | conv5 = Dropout(do)(activation()(Conv2D(minimum_kernel * 16, (3, 3), padding='same')(pool4))) 163 | conv5 = Dropout(do)(activation()(Conv2D(minimum_kernel * 16, (3, 3), padding='same')(conv5))) 164 | 165 | up6 = concatenate([Conv2DTranspose(minimum_kernel * 8, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], 166 | axis=3) 167 | conv6 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(up6))) 168 | conv6 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(conv6))) 169 | 170 | up7 = concatenate([Conv2DTranspose(minimum_kernel * 4, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], 171 | axis=3) 172 | conv7 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(up7))) 173 | conv7 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(conv7))) 174 | 175 | up8 = concatenate([Conv2DTranspose(minimum_kernel * 2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], 176 | axis=3) 177 | conv8 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(up8))) 178 | conv8 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(conv8))) 179 | 180 | up9 = concatenate([Conv2DTranspose(minimum_kernel, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3) 181 | conv9 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(up9))) 182 | conv9 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(conv9))) 183 | 184 | 185 | pt_cls_art_conv1a = Conv2D(minimum_kernel, (3, 3), padding='same') 186 | pt_cls_art_activation1a = activation() 187 | pt_cls_art_dropout1a = Dropout(do) 188 | pt_cls_art_conv1b = Conv2D(minimum_kernel, (3, 3), padding='same') 189 | pt_cls_art_activation1b = activation() 190 | pt_cls_art_dropout1b = Dropout(do) 191 | pt_cls_art_pooling1 = MaxPooling2D(pool_size=(2, 2)) 192 | 193 | pt_cls_art_conv2a = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 194 | pt_cls_art_activation2a = activation() 195 | pt_cls_art_dropout2a = Dropout(do) 196 | pt_cls_art_conv2b = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 197 | pt_cls_art_activation2b = activation() 198 | pt_cls_art_dropout2b = Dropout(do) 199 | pt_cls_art_pooling2 = MaxPooling2D(pool_size=(2, 2)) 200 | 201 | pt_cls_art_conv3a = Conv2D(minimum_kernel * 4, (3, 3), padding='same') 202 | pt_cls_art_activation3a = activation() 203 | pt_cls_art_dropout3a = Dropout(do) 204 | pt_cls_art_conv3b = Conv2D(minimum_kernel * 4, (3, 3), padding='same') 205 | pt_cls_art_activation3b = activation() 206 | pt_cls_art_dropout3b = Dropout(do) 207 | 208 | pt_cls_art_tranconv8 = Conv2DTranspose(minimum_kernel * 2, (2, 2), strides=(2, 2), padding='same') 209 | pt_cls_art_conv8a = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 210 | pt_cls_art_activation8a = activation() 211 | pt_cls_art_dropout8a = Dropout(do) 212 | pt_cls_art_conv8b = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 213 | pt_cls_art_activation8b = activation() 214 | pt_cls_art_dropout8b = Dropout(do) 215 | 216 | pt_cls_art_tranconv9 = Conv2DTranspose(minimum_kernel, (2, 2), strides=(2, 2), padding='same') 217 | pt_cls_art_conv9a = Conv2D(minimum_kernel, (3, 3), padding='same') 218 | pt_cls_art_activation9a = activation() 219 | pt_cls_art_dropout9a = Dropout(do) 220 | pt_cls_art_conv9b = Conv2D(minimum_kernel, (3, 3), padding='same') 221 | pt_cls_art_activation9b = activation() 222 | pt_cls_art_dropout9b = Dropout(do) 223 | 224 | conv9s_cls_art = [conv9] 225 | a_layers = [a] 226 | for iteration_id in range(iteration): 227 | out = Conv2D(1, (1, 1), activation='sigmoid', name=f'out1_cls_art{iteration_id + 1}')(conv9s_cls_art[-1]) 228 | outs.append(out) 229 | 230 | conv1 = pt_cls_art_dropout1a(pt_cls_art_activation1a(pt_cls_art_conv1a(conv9s_cls_art[-1]))) 231 | conv1 = pt_cls_art_dropout1b(pt_cls_art_activation1b(pt_cls_art_conv1b(conv1))) 232 | a_layers.append(conv1) 233 | conv1 = concatenate(a_layers, axis=3) 234 | conv1 = Conv2D(minimum_kernel, (1, 1), padding='same')(conv1) 235 | pool1 = pt_cls_art_pooling1(conv1) 236 | 237 | conv2 = pt_cls_art_dropout2a(pt_cls_art_activation2a(pt_cls_art_conv2a(pool1))) 238 | conv2 = pt_cls_art_dropout2b(pt_cls_art_activation2b(pt_cls_art_conv2b(conv2))) 239 | pool2 = pt_cls_art_pooling2(conv2) 240 | 241 | conv3 = pt_cls_art_dropout3a(pt_cls_art_activation3a(pt_cls_art_conv3a(pool2))) 242 | conv3 = pt_cls_art_dropout3b(pt_cls_art_activation3b(pt_cls_art_conv3b(conv3))) 243 | 244 | up8 = concatenate([pt_cls_art_tranconv8(conv3), conv2], axis=3) 245 | conv8 = pt_cls_art_dropout8a(pt_cls_art_activation8a(pt_cls_art_conv8a(up8))) 246 | conv8 = pt_cls_art_dropout8b(pt_cls_art_activation8b(pt_cls_art_conv8b(conv8))) 247 | 248 | up9 = concatenate([pt_cls_art_tranconv9(conv8), conv1], axis=3) 249 | conv9 = pt_cls_art_dropout9a(pt_cls_art_activation9a(pt_cls_art_conv9a(up9))) 250 | conv9 = pt_cls_art_dropout9b(pt_cls_art_activation9b(pt_cls_art_conv9b(conv9))) 251 | 252 | conv9s_cls_art.append(conv9) 253 | 254 | 255 | cls_art_final_out = Conv2D(1, (1, 1), activation='sigmoid', name='cls_art_final_out')(conv9) 256 | 257 | outs.append(cls_art_final_out) 258 | 259 | 260 | # to cls (vein) 261 | conv1 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(cls_in))) 262 | conv1 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(conv1))) 263 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 264 | 265 | conv2 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(pool1))) 266 | conv2 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(conv2))) 267 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 268 | 269 | conv3 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(pool2))) 270 | conv3 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(conv3))) 271 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 272 | 273 | conv4 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(pool3))) 274 | conv4 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(conv4))) 275 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 276 | 277 | conv5 = Dropout(do)(activation()(Conv2D(minimum_kernel * 16, (3, 3), padding='same')(pool4))) 278 | conv5 = Dropout(do)(activation()(Conv2D(minimum_kernel * 16, (3, 3), padding='same')(conv5))) 279 | 280 | up6 = concatenate([Conv2DTranspose(minimum_kernel * 8, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], 281 | axis=3) 282 | conv6 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(up6))) 283 | conv6 = Dropout(do)(activation()(Conv2D(minimum_kernel * 8, (3, 3), padding='same')(conv6))) 284 | 285 | up7 = concatenate([Conv2DTranspose(minimum_kernel * 4, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], 286 | axis=3) 287 | conv7 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(up7))) 288 | conv7 = Dropout(do)(activation()(Conv2D(minimum_kernel * 4, (3, 3), padding='same')(conv7))) 289 | 290 | up8 = concatenate([Conv2DTranspose(minimum_kernel * 2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], 291 | axis=3) 292 | conv8 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(up8))) 293 | conv8 = Dropout(do)(activation()(Conv2D(minimum_kernel * 2, (3, 3), padding='same')(conv8))) 294 | 295 | up9 = concatenate([Conv2DTranspose(minimum_kernel, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3) 296 | conv9 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(up9))) 297 | conv9 = Dropout(do)(activation()(Conv2D(minimum_kernel, (3, 3), padding='same')(conv9))) 298 | 299 | 300 | pt_cls_vei_conv1a = Conv2D(minimum_kernel, (3, 3), padding='same') 301 | pt_cls_vei_activation1a = activation() 302 | pt_cls_vei_dropout1a = Dropout(do) 303 | pt_cls_vei_conv1b = Conv2D(minimum_kernel, (3, 3), padding='same') 304 | pt_cls_vei_activation1b = activation() 305 | pt_cls_vei_dropout1b = Dropout(do) 306 | pt_cls_vei_pooling1 = MaxPooling2D(pool_size=(2, 2)) 307 | 308 | pt_cls_vei_conv2a = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 309 | pt_cls_vei_activation2a = activation() 310 | pt_cls_vei_dropout2a = Dropout(do) 311 | pt_cls_vei_conv2b = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 312 | pt_cls_vei_activation2b = activation() 313 | pt_cls_vei_dropout2b = Dropout(do) 314 | pt_cls_vei_pooling2 = MaxPooling2D(pool_size=(2, 2)) 315 | 316 | pt_cls_vei_conv3a = Conv2D(minimum_kernel * 4, (3, 3), padding='same') 317 | pt_cls_vei_activation3a = activation() 318 | pt_cls_vei_dropout3a = Dropout(do) 319 | pt_cls_vei_conv3b = Conv2D(minimum_kernel * 4, (3, 3), padding='same') 320 | pt_cls_vei_activation3b = activation() 321 | pt_cls_vei_dropout3b = Dropout(do) 322 | 323 | pt_cls_vei_tranconv8 = Conv2DTranspose(minimum_kernel * 2, (2, 2), strides=(2, 2), padding='same') 324 | pt_cls_vei_conv8a = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 325 | pt_cls_vei_activation8a = activation() 326 | pt_cls_vei_dropout8a = Dropout(do) 327 | pt_cls_vei_conv8b = Conv2D(minimum_kernel * 2, (3, 3), padding='same') 328 | pt_cls_vei_activation8b = activation() 329 | pt_cls_vei_dropout8b = Dropout(do) 330 | 331 | pt_cls_vei_tranconv9 = Conv2DTranspose(minimum_kernel, (2, 2), strides=(2, 2), padding='same') 332 | pt_cls_vei_conv9a = Conv2D(minimum_kernel, (3, 3), padding='same') 333 | pt_cls_vei_activation9a = activation() 334 | pt_cls_vei_dropout9a = Dropout(do) 335 | pt_cls_vei_conv9b = Conv2D(minimum_kernel, (3, 3), padding='same') 336 | pt_cls_vei_activation9b = activation() 337 | pt_cls_vei_dropout9b = Dropout(do) 338 | 339 | conv9s_cls_vei = [conv9] 340 | a_layers = [a] 341 | for iteration_id in range(iteration): 342 | out = Conv2D(1, (1, 1), activation='sigmoid', name=f'out1_cls_vei{iteration_id + 1}')(conv9s_cls_vei[-1]) 343 | outs.append(out) 344 | 345 | conv1 = pt_cls_vei_dropout1a(pt_cls_vei_activation1a(pt_cls_vei_conv1a(conv9s_cls_vei[-1]))) 346 | conv1 = pt_cls_vei_dropout1b(pt_cls_vei_activation1b(pt_cls_vei_conv1b(conv1))) 347 | a_layers.append(conv1) 348 | conv1 = concatenate(a_layers, axis=3) 349 | conv1 = Conv2D(minimum_kernel, (1, 1), padding='same')(conv1) 350 | pool1 = pt_cls_vei_pooling1(conv1) 351 | 352 | conv2 = pt_cls_vei_dropout2a(pt_cls_vei_activation2a(pt_cls_vei_conv2a(pool1))) 353 | conv2 = pt_cls_vei_dropout2b(pt_cls_vei_activation2b(pt_cls_vei_conv2b(conv2))) 354 | pool2 = pt_cls_vei_pooling2(conv2) 355 | 356 | conv3 = pt_cls_vei_dropout3a(pt_cls_vei_activation3a(pt_cls_vei_conv3a(pool2))) 357 | conv3 = pt_cls_vei_dropout3b(pt_cls_vei_activation3b(pt_cls_vei_conv3b(conv3))) 358 | 359 | up8 = concatenate([pt_cls_vei_tranconv8(conv3), conv2], axis=3) 360 | conv8 = pt_cls_vei_dropout8a(pt_cls_vei_activation8a(pt_cls_vei_conv8a(up8))) 361 | conv8 = pt_cls_vei_dropout8b(pt_cls_vei_activation8b(pt_cls_vei_conv8b(conv8))) 362 | 363 | up9 = concatenate([pt_cls_vei_tranconv9(conv8), conv1], axis=3) 364 | conv9 = pt_cls_vei_dropout9a(pt_cls_vei_activation9a(pt_cls_vei_conv9a(up9))) 365 | conv9 = pt_cls_vei_dropout9b(pt_cls_vei_activation9b(pt_cls_vei_conv9b(conv9))) 366 | 367 | conv9s_cls_vei.append(conv9) 368 | 369 | 370 | cls_vei_final_out = Conv2D(1, (1, 1), activation='sigmoid', name='cls_vei_final_out')(conv9) 371 | 372 | outs.append(cls_vei_final_out) 373 | 374 | 375 | model = Model(inputs=[inputs], outputs=outs) 376 | 377 | 378 | loss_funcs = {} 379 | for iteration_id in range(iteration): 380 | loss_funcs.update({f'out1{iteration_id + 1}': losses.binary_crossentropy}) 381 | loss_funcs.update({'seg_final_out': losses.binary_crossentropy}) 382 | loss_funcs.update({'cls_art_final_out': losses.binary_crossentropy}) 383 | loss_funcs.update({'cls_vei_final_out': losses.binary_crossentropy}) 384 | for iteration_id in range(iteration): 385 | loss_funcs.update({f'out1_cls_art{iteration_id + 1}': losses.binary_crossentropy}) 386 | for iteration_id in range(iteration): 387 | loss_funcs.update({f'out1_cls_vei{iteration_id + 1}': losses.binary_crossentropy}) 388 | 389 | metrics = { 390 | "seg_final_out": ['accuracy'], 391 | "cls_art_final_out": ['accuracy'], 392 | "cls_vei_final_out": ['accuracy'], 393 | } 394 | 395 | model.compile(optimizer=Adam(lr=1e-3), loss=loss_funcs, metrics=metrics) 396 | 397 | return model 398 | 399 | 400 | def random_crop(img, mask, mask_onehot, crop_size): 401 | imgheight = img.shape[0] 402 | imgwidth = img.shape[1] 403 | 404 | i = randint(0, imgheight - crop_size) 405 | j = randint(0, imgwidth - crop_size) 406 | 407 | return img[i:(i + crop_size), j:(j + crop_size), :]\ 408 | , np.array(mask)[:, i:(i + crop_size), j:(j + crop_size)]\ 409 | , np.array(mask_onehot)[:, i:(i + crop_size), j:(j + crop_size)] 410 | 411 | 412 | class Generator(): 413 | def __init__(self, batch_size, repeat, dataset): 414 | self.lock = threading.Lock() 415 | self.dataset = dataset 416 | with self.lock: 417 | self.list_images_all = prepare_dataset.getTrainingData(0, self.dataset) 418 | self.list_gt_all = prepare_dataset.getTrainingData(1, self.dataset) 419 | self.list_gt_all_onehot = prepare_dataset.getTrainingData(1, self.dataset, need_one_hot=True) 420 | self.n = len(self.list_images_all) 421 | self.index = 0 422 | self.repeat = repeat 423 | self.batch_size = batch_size 424 | self.step = self.batch_size // self.repeat 425 | 426 | if self.repeat >= self.batch_size: 427 | self.repeat = self.batch_size 428 | self.step = 1 429 | 430 | def gen(self, au=True, crop_size=48, iteration=None): 431 | 432 | while True: 433 | data_yield = [self.index % self.n, 434 | (self.index + self.step) % self.n if (self.index + self.step) < self.n else self.n] 435 | self.index = (self.index + self.step) % self.n 436 | 437 | list_images_base = self.list_images_all[data_yield[0]:data_yield[1]] 438 | list_gt_base = self.list_gt_all[data_yield[0]:data_yield[1]] 439 | list_gt_onehot_base = self.list_gt_all_onehot[data_yield[0]:data_yield[1]] 440 | 441 | list_images_aug = [] 442 | list_gt_aug = [] 443 | list_gt_onehot_aug = [] 444 | image_id = -1 445 | for image, gt in zip(list_images_base, list_gt_base): 446 | image_id += 1 447 | gt2 = list_gt_onehot_base[image_id] 448 | if au: 449 | if crop_size == prepare_dataset.DESIRED_DATA_SHAPE[0]: 450 | for _ in range(self.repeat): 451 | image, gt, gt2 = data_augmentation.random_augmentation(image, gt, gt2) 452 | list_images_aug.append(image) 453 | list_gt_aug.append(gt) 454 | list_gt_onehot_aug.append(gt2) 455 | else: 456 | image, gt, gt2 = data_augmentation.random_augmentation(image, gt, gt2) 457 | list_images_aug.append(image) 458 | list_gt_aug.append(gt) 459 | list_gt_onehot_aug.append(gt2) 460 | else: 461 | list_images_aug.append(image) 462 | list_gt_aug.append(gt) 463 | list_gt_onehot_aug.append(gt2) 464 | 465 | list_images = [] 466 | list_gt = [] 467 | list_gt_onehot = [] 468 | image_id = -1 469 | if crop_size == prepare_dataset.DESIRED_DATA_SHAPE[0]: 470 | list_images = list_images_aug 471 | list_gt = list_gt_aug 472 | list_gt_onehot = list_gt_onehot_aug 473 | else: 474 | for image, gt in zip(list_images_aug, list_gt_aug): 475 | image_id += 1 476 | for _ in range(self.repeat): 477 | image_, gt_, gt_onehot_ = random_crop(image, gt, list_gt_onehot_aug[image_id], crop_size) 478 | 479 | list_images.append(image_) 480 | list_gt.append(gt_) 481 | list_gt_onehot.append(gt_onehot_) 482 | 483 | outs = {} 484 | for iteration_id in range(iteration): 485 | outs.update({f'out1{iteration_id + 1}': np.array(list_gt)[:,0]}) 486 | outs.update({'seg_final_out': np.array(list_gt)[:,0]}) 487 | # outs.update({'crossing_final_out': np.array(list_gt)[:,1]}) 488 | outs.update({'cls_art_final_out': np.array(list_gt)[:,2]}) 489 | outs.update({'cls_vei_final_out': np.array(list_gt)[:,3]}) 490 | for iteration_id in range(iteration): 491 | outs.update({f'out1_cls_art{iteration_id + 1}': np.array(list_gt)[:,2]}) 492 | for iteration_id in range(iteration): 493 | outs.update({f'out1_cls_vei{iteration_id + 1}': np.array(list_gt)[:,3]}) 494 | yield np.array(list_images), outs -------------------------------------------------------------------------------- /utils/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os.path 4 | from PIL import Image 5 | from glob import glob 6 | from skimage.transform import resize 7 | 8 | raw_training_x_path = './data/ALL/training/images/*.png' 9 | raw_training_y_path= './data/ALL/training/av/*.png' 10 | 11 | raw_data_path = [raw_training_x_path, raw_training_y_path] 12 | 13 | HDF5_data_path = './data/HDF5/' 14 | 15 | DESIRED_DATA_SHAPE = [576, 576] 16 | 17 | 18 | def isHDF5exists(raw_data_path, HDF5_data_path): 19 | for raw in raw_data_path: 20 | if not raw: 21 | continue 22 | 23 | raw_splited = raw.split('/') 24 | HDF5 = ''.join([HDF5_data_path, '/'.join(raw_splited[2:-1]), '/*.hdf5']) 25 | 26 | if len(glob(HDF5)) == 0: 27 | return False 28 | 29 | return True 30 | 31 | 32 | def read_input(path): 33 | x = np.array(Image.open(path)) / 255. 34 | if x.shape[-1] == 3: 35 | return x 36 | else: 37 | return x[..., np.newaxis] 38 | 39 | 40 | def read_input_av_label(path, need_one_hot=False): 41 | global DESIRED_DATA_SHAPE 42 | x = Image.open(path) 43 | x = np.array(x) / 255. 44 | x = resize(x, DESIRED_DATA_SHAPE) 45 | 46 | new_x_whole = np.zeros((x.shape[0], x.shape[1])) 47 | new_x_a = np.zeros((x.shape[0], x.shape[1])) 48 | new_x_v = np.zeros((x.shape[0], x.shape[1])) 49 | new_x_cross= np.zeros((x.shape[0], x.shape[1])) 50 | new_x_unknown = np.zeros((x.shape[0], x.shape[1])) 51 | new_x_onehot = np.zeros((x.shape[0], x.shape[1], 3)) #bg, a, v 52 | 53 | for row_id, row in enumerate(x): 54 | for col_id, elem in enumerate(row): 55 | if elem[0] > 0.5 or elem[1] > 0.5 or elem[2] > 0.5: 56 | new_x_whole[row_id, col_id] = 1.0 57 | else: 58 | new_x_onehot[row_id, col_id] = np.array((1.0, 0.0, 0.0)) 59 | if elem[0] > 0.5 and elem[1] > 0.5 and elem[2] > 0.5: 60 | new_x_unknown[row_id, col_id] = 1.0 61 | new_x_a[row_id, col_id] = 1.0 62 | new_x_onehot[row_id, col_id] = np.array((0.0, 1.0, 0.0)) 63 | elif elem[0] > 0.5: 64 | new_x_a[row_id, col_id] = 1.0 65 | new_x_onehot[row_id, col_id] = np.array((0.0, 1.0, 0.0)) 66 | elif elem[1] > 0.5: 67 | new_x_cross[row_id, col_id] = 1.0 68 | new_x_a[row_id, col_id] = 1.0 69 | # new_x_v[row_id, col_id] = 1.0 70 | new_x_onehot[row_id, col_id] = np.array((0.0, 1.0, 0.0)) 71 | elif elem[2] > 0.5: 72 | new_x_v[row_id, col_id] = 1.0 73 | new_x_onehot[row_id, col_id] = np.array((0.0, 0.0, 1.0)) 74 | 75 | new_x_whole = resize(new_x_whole[..., np.newaxis], DESIRED_DATA_SHAPE) 76 | new_x_a = resize(new_x_a[..., np.newaxis], DESIRED_DATA_SHAPE) 77 | new_x_v = resize(new_x_v[..., np.newaxis], DESIRED_DATA_SHAPE) 78 | new_x_cross = resize(new_x_cross[..., np.newaxis], DESIRED_DATA_SHAPE) 79 | new_x_unknown = resize(new_x_unknown[..., np.newaxis], DESIRED_DATA_SHAPE) 80 | new_x_onehot = resize(new_x_onehot, DESIRED_DATA_SHAPE) 81 | 82 | if not need_one_hot: 83 | return [new_x_whole, new_x_cross, new_x_a, new_x_v, new_x_unknown] 84 | else: 85 | return [new_x_onehot] 86 | 87 | 88 | def preprocessData(data_path, dataset, need_one_hot=False): 89 | global DESIRED_DATA_SHAPE 90 | 91 | data_path = list(sorted(glob(data_path))) 92 | 93 | if data_path[0].find('mask') > 0: 94 | return np.array([read_input(image_path) for image_path in data_path]) 95 | elif data_path[0].find('/av/') > 0 or data_path[0].find('/arteries-and-veins/') > 0 : 96 | return np.array([read_input_av_label(image_path, need_one_hot) for image_path in data_path]) 97 | else: 98 | return np.array([resize(read_input(image_path), DESIRED_DATA_SHAPE) for image_path in data_path]) 99 | 100 | 101 | def createHDF5(data, HDF5_data_path, one_hot=False): 102 | try: 103 | os.makedirs(HDF5_data_path, exist_ok=True) 104 | except: 105 | pass 106 | if not one_hot: 107 | f = h5py.File(HDF5_data_path + 'data.hdf5', 'w') 108 | else: 109 | f = h5py.File(HDF5_data_path + 'data_onehot.hdf5', 'w') 110 | f.create_dataset('data', data=data) 111 | return 112 | 113 | 114 | def prepareDataset(dataset): 115 | global raw_data_path, HDF5_data_path 116 | global DESIRED_DATA_SHAPE 117 | 118 | 119 | if isHDF5exists(raw_data_path, HDF5_data_path): 120 | return 121 | 122 | for raw in raw_data_path: 123 | if not raw: 124 | continue 125 | 126 | raw_splited = raw.split('/') 127 | HDF5 = ''.join([HDF5_data_path, '/'.join(raw_splited[2:-1]), '/']) 128 | 129 | preprocessed = preprocessData(raw, dataset) 130 | createHDF5(preprocessed, HDF5) 131 | 132 | if raw.find('/av/') > 0 or raw.find('/arteries-and-veins/') > 0: 133 | raw_splited = raw.split('/') 134 | HDF5 = ''.join([HDF5_data_path, '/'.join(raw_splited[2:-1]), '/']) 135 | 136 | preprocessed = preprocessData(raw, dataset, need_one_hot=True) 137 | createHDF5(preprocessed, HDF5, one_hot=True) 138 | 139 | 140 | def getTrainingData(XorY, dataset, need_one_hot=False): 141 | global HDF5_data_path 142 | 143 | raw_training_x_path, raw_training_y_path = raw_data_path[:2] 144 | 145 | if XorY == 0: 146 | raw_splited = raw_training_x_path.split('/') 147 | else: 148 | raw_splited = raw_training_y_path.split('/') 149 | 150 | if not need_one_hot: 151 | data_path = ''.join([HDF5_data_path, dataset, '/', '/'.join(raw_splited[3:-1]), '/data.hdf5']) 152 | else: 153 | data_path = ''.join([HDF5_data_path, dataset, '/', '/'.join(raw_splited[3:-1]), '/data_onehot.hdf5']) 154 | f = h5py.File(data_path, 'r') 155 | data = f['data'] 156 | 157 | return data 158 | 159 | 160 | def getTestData(XorYorMask, dataset, need_one_hot=False): 161 | global HDF5_data_path 162 | 163 | raw_test_x_path, raw_test_y_path = raw_data_path[2:] 164 | 165 | if XorYorMask == 0: 166 | raw_splited = raw_test_x_path.split('/') 167 | elif XorYorMask == 1: 168 | raw_splited = raw_test_y_path.split('/') 169 | else: 170 | if not raw_test_mask_path: 171 | return None 172 | raw_splited = raw_test_mask_path.split('/') 173 | 174 | if not need_one_hot: 175 | data_path = ''.join([HDF5_data_path, dataset, '/', '/'.join(raw_splited[3:-1]), '/data.hdf5']) 176 | else: 177 | data_path = ''.join([HDF5_data_path, dataset, '/', '/'.join(raw_splited[3:-1]), '/data_onehot.hdf5']) 178 | f = h5py.File(data_path, 'r') 179 | data = f['data'] 180 | 181 | return data -------------------------------------------------------------------------------- /utils/process_data_for_ALL_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | yourpath = './data/ALL' 5 | 6 | # change formats into '.png' 7 | formats = ['.tif', '.jpg', '.tiff'] 8 | for root, dirs, files in os.walk(yourpath, topdown=False): 9 | for name in files: 10 | if os.path.splitext(os.path.join(root, name))[1].lower() in formats: 11 | if os.path.isfile(os.path.splitext(os.path.join(root, name))[0] + ".png"): 12 | # print("A png file already exists for %s" % name) 13 | None 14 | # If a png is *NOT* present, create one from other formats. 15 | else: 16 | outfile = os.path.splitext(os.path.join(root, name))[0] + ".png" 17 | try: 18 | print(os.path.join(root, name)) 19 | im = Image.open(os.path.join(root, name)) 20 | print("Generating png for %s" % name) 21 | im.thumbnail(im.size) 22 | im.save(outfile, "PNG", quality=100) 23 | os.remove(os.path.join(root, name)) 24 | except Exception as e: 25 | print(e) 26 | 27 | # rename av files for HRF dataset 28 | for root, dirs, files in os.walk(yourpath, topdown=False): 29 | for name in files: 30 | if name.endswith('_AVmanual.png'): 31 | print(os.path.join(root, name)) 32 | os.rename(os.path.join(root, name), 33 | os.path.join(root, name).replace('_AVmanual','')) 34 | --------------------------------------------------------------------------------