├── LICENSE ├── README.md ├── code ├── GenGraph │ ├── _init_paths.py │ ├── bwmorph.py │ └── make_graph_db.py ├── _init_paths.py ├── check_pixel_mean_value.py ├── config.py ├── extract_subimages_HRF.py ├── model.py ├── test_CNN.py ├── test_CNN_HRF.py ├── test_VGN.py ├── test_VGN_HRF.py ├── train_CNN.py ├── train_VGN.py └── util.py ├── models └── empty.gitkeep ├── pretrained_model └── empty.gitkeep └── results ├── Image_12R.gif ├── empty.gitkeep └── im0239.gif /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Seung Yeon Shin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ******************************************** 24 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 25 | ******************************************** 26 | 27 | This project is based on or includes material from the project(s) listed below 28 | ("Third Party Code"). We are not the original author of the Third Party Code. 29 | The original copyright notice and license under which we received such 30 | Third Party Code are set out below. This Third Party Code is licensed to you 31 | under their original license terms set forth below. 32 | 33 | 1. GAT (https://github.com/PetarV-/GAT) 34 | 35 | MIT License 36 | 37 | Copyright (c) 2018 Petar Veličković 38 | 39 | Permission is hereby granted, free of charge, to any person obtaining a copy 40 | of this software and associated documentation files (the "Software"), to deal 41 | in the Software without restriction, including without limitation the rights 42 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 43 | copies of the Software, and to permit persons to whom the Software is 44 | furnished to do so, subject to the following conditions: 45 | 46 | The above copyright notice and this permission notice shall be included in all 47 | copies or substantial portions of the Software. 48 | 49 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 50 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 51 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 52 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 53 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 54 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 55 | SOFTWARE. 56 | 57 | *************************************************** 58 | END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 59 | *************************************************** 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vessel Graph Network (VGN) 2 | This is the code for ["Deep Vessel Segmentation by Learning Graphical Connectivity"](https://www.sciencedirect.com/science/article/pii/S1361841519300982). 3 | 4 | ## Dependency 5 | * Python 2.7.12 6 | * Tensorflow 1.12 7 | * networkx 2.0 8 | * scipy 1.1.0 9 | * mahotas 1.4.3 10 | * matplotlib 2.2.4 11 | * easydict 1.7 12 | * scikit-image 0.14.2 13 | * scikit-fmm 0.0.9 14 | * scikit-learn 0.19.0 15 | 16 | ## Datasets 17 | * The VGN is evaluated on four retinal image datasets, namely the [DRIVE](https://www.isi.uu.nl/Research/Databases/DRIVE/), [STARE](http://cecas.clemson.edu/~ahoover/stare/), [CHASE_DB1](https://blogs.kingston.ac.uk/retinal/chasedb1/), and [HRF](https://www5.cs.fau.de/research/data/fundus-images/) datasets, which all are publicly available. 18 | * The coronary artery X-ray angiography (CA-XRA) dataset we additionally used for evaluation can not be shared regrettably. 19 | 20 | ## Precomputed Results 21 | We provide precomputed results of the VGN on the four retinal image datasets. [[OneDrive]](https://1drv.ms/u/s!AmnLATyiwjphhZ0BquyksorE0YV7nA?e=OmHhGW) 22 | 23 | ## Testing a Model 24 | 1. Download available trained models. [[OneDrive]](https://1drv.ms/u/s!AmnLATyiwjphhZ0CYhSYOqHmnQw4UQ?e=eRgvcq) 25 | 2. Run a test script among `test_CNN.py`, `test_CNN_HRF.py`, `test_VGN.py`, or `test_VGN_HRF.py`, with appropriate input arguments including the path for the downloaded model. 26 | 27 | ## Training a Model 28 | We use a sequential training scheme composed of an initial pretraining of the CNN followed by joint training, including fine-tuning of the CNN module, of the whole VGN. Before the joint training, training graphs must be constructed from vessel probability maps inferred from the pretrained CNN. 29 | 30 | ### CNN Pretraining 31 | (This step can be skipped by using a pretrained model we share.) 32 | 1. Download an ImageNet pretrained model. [[OneDrive]](https://1drv.ms/u/s!AmnLATyiwjphhZ0AqBHI2Y0nALUdoQ?e=NG4kVS) 33 | 2. Run `train_CNN.py` with appropriate input arguments including the path for the downloaded pretrained model. 34 | 35 | ### Training Graph Construction 36 | 1. Run `GenGraph/make_graph_db.py`. 37 | 38 | ### VGN Training 39 | 1. Place the generated graphs ('.graph_res') and vessel probability images ('_prob.png') inferred from the pretrained CNN in a new directory 'args.save_root/graph' 40 | 2. Run `train_VGN.py` with appropriate input arguments including the path for the pretrained CNN model. 41 | 42 | ## Demo Videos 43 | Two example results, each of which is from the STARE and CHASE_DB1 datasets. The images in each row from left to right are, the original input image, GT, result. The last column is slowly changed from the baseline CNN result to the VGN result to better show the difference. 44 | ![](results/im0239.gif) 45 | ![](results/Image_12R.gif) 46 | 47 | ## Citation 48 | ``` 49 | @article{shin_media19, 50 | title = "Deep vessel segmentation by learning graphical connectivity", 51 | journal = "Medical Image Analysis", 52 | volume = "58", 53 | pages = "101556", 54 | year = "2019", 55 | issn = "1361-8415", 56 | doi = "https://doi.org/10.1016/j.media.2019.101556", 57 | url = "http://www.sciencedirect.com/science/article/pii/S1361841519300982", 58 | author = "Seung Yeon Shin and Soochahn Lee and Il Dong Yun and Kyoung Mu Lee", 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /code/GenGraph/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: 6 | sys.path.insert(0, path) 7 | 8 | this_dir = os.path.dirname(__file__) 9 | lib_path = os.path.join(this_dir, '..') 10 | add_path(lib_path) 11 | -------------------------------------------------------------------------------- /code/GenGraph/bwmorph.py: -------------------------------------------------------------------------------- 1 | # Duplication of 'bwmorph' in matlab 2 | # referred to 3 | # https://gist.github.com/joefutrelle/562f25bbcf20691217b8 4 | 5 | 6 | import numpy as np 7 | from scipy import ndimage as ndi 8 | 9 | 10 | OPS = ['dilate', 'fill', 'thin', 'branchpoints', 'endpoints'] 11 | 12 | 13 | # lookup tables 14 | LUT_THIN_1 = ~np.array([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 15 | 1,1,1,1,1,1,1,0,1,1,0,0,1,1,0,0, 16 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 17 | 1,1,0,0,1,1,0,0,1,1,0,0,1,1,0,0, 18 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 19 | 1,1,1,1,1,1,1,1,1,0,0,0,1,1,0,0, 20 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 21 | 1,1,1,1,1,1,1,1,1,1,0,0,1,1,0,0, 22 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 23 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 24 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 25 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 26 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 27 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 28 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 29 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 30 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 31 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 32 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 33 | 1,1,0,0,0,1,0,0,1,1,0,0,1,1,0,0, 34 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 35 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 36 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 37 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 38 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 39 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 40 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 41 | 0,1,0,0,0,1,0,0,1,1,1,1,1,1,1,1, 42 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 43 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 44 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 45 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], dtype=np.bool) 46 | 47 | LUT_THIN_2 = ~np.array([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 48 | 1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1, 49 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 50 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 51 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 52 | 1,1,1,1,1,1,1,1,1,0,1,0,1,1,1,1, 53 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 54 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 55 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 56 | 1,1,1,1,1,1,1,1,0,0,1,0,1,1,1,1, 57 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 58 | 0,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1, 59 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 60 | 1,1,1,1,1,1,1,1,0,0,1,0,1,1,1,1, 61 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 62 | 0,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1, 63 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 64 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 65 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 66 | 1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1, 67 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 68 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 69 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 70 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 71 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 72 | 1,1,1,1,1,1,1,1,0,0,1,0,1,1,1,1, 73 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 74 | 0,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1, 75 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 76 | 0,1,1,1,1,1,1,1,0,0,1,0,1,1,1,1, 77 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 78 | 0,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1], dtype=np.bool) 79 | 80 | LUT_ENDPOINTS = ~np.array([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 81 | 1,1,1,1,1,0,1,1,1,1,0,1,0,0,0,1, 82 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 83 | 1,0,0,0,1,0,1,1,0,0,0,0,0,0,0,1, 84 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 85 | 1,0,0,0,0,0,0,0,1,1,0,1,0,0,0,1, 86 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 87 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1, 88 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 89 | 1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 90 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 91 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 92 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 93 | 1,0,0,0,0,0,0,0,1,1,0,1,0,0,0,1, 94 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 95 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1, 96 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 97 | 1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 98 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 99 | 1,0,0,0,1,0,1,1,0,0,0,0,0,0,0,1, 100 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 101 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 102 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 103 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1, 104 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 105 | 1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 106 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 107 | 1,0,0,0,1,0,1,1,0,0,0,0,0,0,0,1, 108 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 109 | 1,0,0,0,0,0,0,0,1,1,0,1,0,0,0,1, 110 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 111 | 1,0,0,0,1,0,1,1,1,1,0,1,1,1,1,0], dtype=np.bool) 112 | 113 | LUT_BRANCHPOINTS = ~np.array([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 114 | 0,0,0,0,0,0,0,1,0,0,0,1,0,1,1,1, 115 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 116 | 0,0,0,1,0,1,1,1,0,1,1,1,1,1,1,1, 117 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 118 | 0,0,0,1,0,1,1,1,0,1,1,1,1,1,1,1, 119 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 120 | 0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 121 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 122 | 0,0,0,1,0,1,1,1,0,1,1,1,1,1,1,1, 123 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 124 | 0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 125 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 126 | 0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 127 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 128 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 129 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 130 | 0,0,0,1,0,1,1,1,0,1,1,1,1,1,1,1, 131 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 132 | 0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 133 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 134 | 0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 135 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 136 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 137 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 138 | 0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 139 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 140 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 141 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 142 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 143 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 144 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], dtype=np.bool) 145 | 146 | LUT_BACKCOUNT4 = np.array([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 147 | 1,1,1,1,1,2,1,1,1,1,2,1,2,2,2,1, 148 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 149 | 1,2,2,2,1,2,1,1,2,2,3,2,2,2,2,1, 150 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 151 | 1,2,2,2,2,3,2,2,1,1,2,1,2,2,2,1, 152 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 153 | 2,3,3,3,2,3,2,2,2,2,3,2,2,2,2,1, 154 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 155 | 1,2,2,2,2,3,2,2,2,2,3,2,3,3,3,2, 156 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 157 | 2,3,3,3,2,3,2,2,3,3,4,3,3,3,3,2, 158 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 159 | 1,2,2,2,2,3,2,2,1,1,2,1,2,2,2,1, 160 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 161 | 2,3,3,3,2,3,2,2,2,2,3,2,2,2,2,1, 162 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 163 | 1,2,2,2,2,3,2,2,2,2,3,2,3,3,3,2, 164 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 165 | 1,2,2,2,1,2,1,1,2,2,3,2,2,2,2,1, 166 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 167 | 2,3,3,3,3,4,3,3,2,2,3,2,3,3,3,2, 168 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 169 | 2,3,3,3,2,3,2,2,2,2,3,2,2,2,2,1, 170 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 171 | 1,2,2,2,2,3,2,2,2,2,3,2,3,3,3,2, 172 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 173 | 1,2,2,2,1,2,1,1,2,2,3,2,2,2,2,1, 174 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 175 | 1,2,2,2,2,3,2,2,1,1,2,1,2,2,2,1, 176 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 177 | 1,2,2,2,1,2,1,1,1,1,2,1,1,1,1,0], dtype=np.uint8) 178 | 179 | LUT_DILATE = np.array([0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 180 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 181 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 182 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 183 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 184 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 185 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 186 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 187 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 188 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 189 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 190 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 191 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 192 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 193 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 194 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 195 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 196 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 197 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 198 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 199 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 200 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 201 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 202 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 203 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 204 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 205 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 206 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 207 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 208 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 209 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 210 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], dtype=np.bool) 211 | 212 | LUT_FILL = np.array([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 213 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 214 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 215 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 216 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 217 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 218 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 219 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 220 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 221 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 222 | 0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1, 223 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 224 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 225 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 226 | 0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1, 227 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 228 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 229 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 230 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 231 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 232 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 233 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 234 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 235 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 236 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 237 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 238 | 0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1, 239 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 240 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 241 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 242 | 0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1, 243 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], dtype=np.bool) 244 | 245 | 246 | def bwmorph(image, op, n_iter=None): 247 | 248 | # check parameters 249 | if op not in OPS: 250 | raise ValueError('Undefined OP is used') 251 | 252 | if n_iter is None: 253 | n = -1 254 | elif n_iter <= 0: 255 | raise ValueError('n_iter must be > 0') 256 | else: 257 | n = n_iter 258 | 259 | # check that we have a 2d binary image, and convert it 260 | # to uint8 261 | bw = np.array(image).astype(np.uint16) 262 | 263 | if bw.ndim != 2: 264 | raise ValueError('2D array required') 265 | if not np.all(np.in1d(image.flat,(0,1))): 266 | raise ValueError('Image contains values other than 0 and 1') 267 | 268 | # neighborhood mask 269 | mask = np.array([[ 1, 8, 64], 270 | [ 2, 16,128], 271 | [ 4, 32,256]],dtype=np.uint16) 272 | 273 | # iterate either 1) indefinitely or 2) up to iteration limit 274 | while n != 0: 275 | before = np.sum(bw) # count points before thinning 276 | 277 | if op == 'dilate': 278 | 279 | bw[np.take(LUT_DILATE, ndi.correlate(bw, mask, mode='constant'))] = 1 280 | 281 | elif op == 'fill': 282 | 283 | bw[np.take(LUT_FILL, ndi.correlate(bw, mask, mode='constant'))] = 1 284 | 285 | elif op == 'thin': 286 | 287 | # for each subiteration 288 | for lut in [LUT_THIN_1, LUT_THIN_2]: 289 | # correlate image with neighborhood mask 290 | N = ndi.correlate(bw, mask, mode='constant') 291 | # take deletion decision from this subiteration's LUT 292 | D = np.take(lut, N) 293 | # perform deletion 294 | bw[D] = 0 295 | 296 | elif op == 'branchpoints': 297 | 298 | # Initial branch point candidates 299 | C = np.copy(bw) 300 | C[np.take(LUT_BRANCHPOINTS, ndi.correlate(bw, mask, mode='constant'))] = 0 301 | C = C.astype(np.bool) 302 | 303 | # Background 4-Connected Object Count (Vp) 304 | B = np.take(LUT_BACKCOUNT4, ndi.correlate(bw, mask, mode='constant')) 305 | 306 | # End Points (Vp = 1) 307 | E = (B == 1) 308 | 309 | # Final branch point candidates 310 | F = (~E)*C 311 | 312 | # Generate mask that defines pixels for which Vp = 2 and no 313 | # foreground neighbor q for which Vq > 2 314 | 315 | # Vp = 2 Mask 316 | Vp = ((B == 2) & (~E)) 317 | 318 | # Vq > 2 Mask 319 | Vq = ((B > 2) & (~E)) 320 | 321 | # Dilate Vq 322 | D = np.copy(Vq) 323 | D[np.take(LUT_DILATE, ndi.correlate(Vq, mask, mode='constant'))] = 1 324 | 325 | # Intersection between dilated Vq and final candidates w/ Vp = 2 326 | M = (F & Vp) & D 327 | 328 | # Final Branch Points 329 | bw = F & (~M) 330 | 331 | break 332 | 333 | elif op == 'endpoints': 334 | 335 | # correlate image with neighborhood mask 336 | N = ndi.correlate(bw, mask, mode='constant') 337 | # take deletion decision from the LUT 338 | D = np.take(LUT_ENDPOINTS, N) 339 | # perform deletion 340 | bw[D] = 0 341 | 342 | else: 343 | pass 344 | 345 | after = np.sum(bw) # count points after thinning 346 | 347 | if before == after: 348 | # iteration had no effect: finish 349 | break 350 | 351 | # count down to iteration limit (or endlessly negative) 352 | n -= 1 353 | 354 | return bw.astype(np.bool) -------------------------------------------------------------------------------- /code/GenGraph/make_graph_db.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.io 3 | import os 4 | import networkx as nx 5 | import pickle as pkl 6 | import scipy.sparse as sp 7 | from scipy import ndimage 8 | import mahotas as mh 9 | import multiprocessing 10 | import matplotlib.pyplot as plt 11 | import argparse 12 | import skfmm 13 | from scipy.ndimage.morphology import distance_transform_edt 14 | 15 | import _init_paths 16 | from bwmorph import bwmorph 17 | from config import cfg 18 | import util 19 | 20 | DEBUG = False 21 | 22 | 23 | def parse_args(): 24 | """ 25 | Parse input arguments 26 | """ 27 | parser = argparse.ArgumentParser(description='Make a graph db') 28 | parser.add_argument('--dataset', default='DRIVE', \ 29 | help='Dataset to use: Can be DRIVE or STARE or CHASE_DB1 or HRF', type=str) 30 | """parser.add_argument('--use_multiprocessing', action='store_true', \ 31 | default=False, help='Whether to use the python multiprocessing module')""" 32 | parser.add_argument('--use_multiprocessing', default=True, \ 33 | help='Whether to use the python multiprocessing module', type=bool) 34 | parser.add_argument('--source_type', default='result', \ 35 | help='Source to be used: Can be result or gt', type=str) 36 | parser.add_argument('--win_size', default=4, \ 37 | help='Window size for srns', type=int) # for srns # [4,8,16] 38 | parser.add_argument('--edge_method', default='geo_dist', \ 39 | help='Edge construction method: Can be geo_dist or eu_dist', type=str) 40 | parser.add_argument('--edge_dist_thresh', default=10, \ 41 | help='Distance threshold for edge construction', type=float) # [10,20,40] 42 | 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def generate_graph_using_srns((img_name, im_root_path, cnn_result_root_path, params)): 48 | 49 | win_size_str = '%.2d_%.2d'%(params.win_size,params.edge_dist_thresh) 50 | 51 | if params.source_type=='gt': 52 | win_size_str = win_size_str + '_gt' 53 | 54 | if 'DRIVE' in img_name: 55 | im_ext = '_image.tif' 56 | label_ext = '_label.gif' 57 | len_y = 592 58 | len_x = 592 59 | elif 'STARE' in img_name: 60 | im_ext = '.ppm' 61 | label_ext = '.ah.ppm' 62 | len_y = 704 63 | len_x = 704 64 | elif 'CHASE_DB1' in img_name: 65 | im_ext = '.jpg' 66 | label_ext = '_1stHO.png' 67 | len_y = 1024 68 | len_x = 1024 69 | elif 'HRF' in img_name: 70 | im_ext = '.bmp' 71 | label_ext = '.tif' 72 | len_y = 768 73 | len_x = 768 74 | 75 | cur_filename = img_name[util.find(img_name,'/')[-1]+1:] 76 | print 'processing '+cur_filename 77 | cur_im_path = os.path.join(im_root_path, cur_filename+im_ext) 78 | cur_gt_mask_path = os.path.join(im_root_path, cur_filename+label_ext) 79 | if params.source_type=='gt': 80 | cur_res_prob_path = cur_gt_mask_path 81 | else: 82 | cur_res_prob_path = os.path.join(cnn_result_root_path, cur_filename+'_prob.png') 83 | 84 | cur_vis_res_im_savepath = os.path.join(cnn_result_root_path, cur_filename+'_'+win_size_str+'_vis_graph_res_on_im.png') 85 | cur_vis_res_mask_savepath = os.path.join(cnn_result_root_path, cur_filename+'_'+win_size_str+'_vis_graph_res_on_mask.png') 86 | cur_res_graph_savepath = os.path.join(cnn_result_root_path, cur_filename+'_'+win_size_str+'.graph_res') 87 | # Note that there is no difference on above paths according to 'params.edge_method' 88 | 89 | im = skimage.io.imread(cur_im_path) 90 | 91 | gt_mask = skimage.io.imread(cur_gt_mask_path) 92 | gt_mask = gt_mask.astype(float)/255 93 | gt_mask = gt_mask>=0.5 94 | 95 | vesselness = skimage.io.imread(cur_res_prob_path) 96 | vesselness = vesselness.astype(float)/255 97 | 98 | temp = np.copy(im) 99 | im = np.zeros((len_y,len_x,3), dtype=temp.dtype) 100 | im[:temp.shape[0],:temp.shape[1],:] = temp 101 | temp = np.copy(gt_mask) 102 | gt_mask = np.zeros((len_y,len_x), dtype=temp.dtype) 103 | gt_mask[:temp.shape[0], :temp.shape[1]] = temp 104 | temp = np.copy(vesselness) 105 | vesselness = np.zeros((len_y,len_x), dtype=temp.dtype) 106 | vesselness[:temp.shape[0], :temp.shape[1]] = temp 107 | 108 | # find local maxima 109 | im_y = im.shape[0] 110 | im_x = im.shape[1] 111 | y_quan = range(0,im_y,args.win_size) 112 | y_quan = sorted(list(set(y_quan) | set([im_y]))) 113 | x_quan = range(0,im_x,args.win_size) 114 | x_quan = sorted(list(set(x_quan) | set([im_x]))) 115 | 116 | max_val = [] 117 | max_pos = [] 118 | for y_idx in xrange(len(y_quan)-1): 119 | for x_idx in xrange(len(x_quan)-1): 120 | cur_patch = vesselness[y_quan[y_idx]:y_quan[y_idx+1],x_quan[x_idx]:x_quan[x_idx+1]] 121 | if np.sum(cur_patch)==0: 122 | max_val.append(0) 123 | max_pos.append((y_quan[y_idx]+cur_patch.shape[0]/2,x_quan[x_idx]+cur_patch.shape[1]/2)) 124 | else: 125 | max_val.append(np.amax(cur_patch)) 126 | temp = np.unravel_index(cur_patch.argmax(), cur_patch.shape) 127 | max_pos.append((y_quan[y_idx]+temp[0],x_quan[x_idx]+temp[1])) 128 | 129 | graph = nx.Graph() 130 | 131 | # add nodes 132 | for node_idx, (node_y, node_x) in enumerate(max_pos): 133 | graph.add_node(node_idx, kind='MP', y=node_y, x=node_x, label=node_idx) 134 | print 'node label', node_idx, 'pos', (node_y,node_x), 'added' 135 | 136 | speed = vesselness 137 | if params.source_type=='gt': 138 | speed = bwmorph(speed, 'dilate', n_iter=1) 139 | speed = speed.astype(float) 140 | 141 | edge_dist_thresh_sq = params.edge_dist_thresh**2 142 | 143 | node_list = list(graph.nodes) 144 | for i, n in enumerate(node_list): 145 | 146 | if speed[graph.node[n]['y'],graph.node[n]['x']]==0: 147 | continue 148 | neighbor = speed[max(0,graph.node[n]['y']-1):min(im_y,graph.node[n]['y']+2), \ 149 | max(0,graph.node[n]['x']-1):min(im_x,graph.node[n]['x']+2)] 150 | 151 | if np.mean(neighbor)<0.1: 152 | continue 153 | 154 | if params.edge_method=='geo_dist': 155 | 156 | phi = np.ones_like(speed) 157 | phi[graph.node[n]['y'],graph.node[n]['x']] = -1 158 | tt = skfmm.travel_time(phi, speed, narrow=params.edge_dist_thresh) # travel time 159 | 160 | if DEBUG: 161 | plt.figure() 162 | plt.imshow(tt, interpolation='nearest') 163 | plt.show() 164 | 165 | plt.cla() 166 | plt.clf() 167 | plt.close() 168 | 169 | for n_comp in node_list[i+1:]: 170 | geo_dist = tt[graph.node[n_comp]['y'],graph.node[n_comp]['x']] # travel time 171 | if geo_dist < params.edge_dist_thresh: 172 | graph.add_edge(n, n_comp, weight=params.edge_dist_thresh/(params.edge_dist_thresh+geo_dist)) 173 | print 'An edge BTWN', 'node', n, '&', n_comp, 'is constructed' 174 | 175 | elif params.edge_method=='eu_dist': 176 | 177 | for n_comp in node_list[i+1:]: 178 | eu_dist = (graph.node[n_comp]['y']-graph.node[n]['y'])**2 + (graph.node[n_comp]['x']-graph.node[n]['x'])**2 179 | if eu_dist < edge_dist_thresh_sq: 180 | graph.add_edge(n, n_comp, weight=1.) 181 | print 'An edge BTWN', 'node', n, '&', n_comp, 'is constructed' 182 | 183 | else: 184 | raise NotImplementedError 185 | 186 | # visualize the constructed graph 187 | util.visualize_graph(im, graph, show_graph=False, \ 188 | save_graph=True, num_nodes_each_type=[0,graph.number_of_nodes()], save_path=cur_vis_res_im_savepath) 189 | util.visualize_graph(gt_mask, graph, show_graph=False, \ 190 | save_graph=True, num_nodes_each_type=[0,graph.number_of_nodes()], save_path=cur_vis_res_mask_savepath) 191 | 192 | # save as files 193 | nx.write_gpickle(graph, cur_res_graph_savepath, protocol=pkl.HIGHEST_PROTOCOL) 194 | 195 | graph.clear() 196 | 197 | 198 | if __name__ == '__main__': 199 | 200 | args = parse_args() 201 | 202 | print('Called with args:') 203 | print(args) 204 | 205 | if args.dataset=='DRIVE': 206 | train_set_txt_path = '../../DRIVE/train.txt' 207 | test_set_txt_path = '../../DRIVE/test.txt' 208 | im_root_path = '../../DRIVE/all' 209 | cnn_result_root_path = '../new_exp/DRIVE_cnn/test' 210 | elif args.dataset=='STARE': 211 | train_set_txt_path = '../../STARE/train.txt' 212 | test_set_txt_path = '../../STARE/test.txt' 213 | im_root_path = '../../STARE/all' 214 | cnn_result_root_path = '../STARE_cnn/res_resized' 215 | elif args.dataset=='CHASE_DB1': 216 | train_set_txt_path = '../../CHASE_DB1/train.txt' 217 | test_set_txt_path = '../../CHASE_DB1/test.txt' 218 | im_root_path = '../../CHASE_DB1/all' 219 | cnn_result_root_path = '../CHASE_cnn/test_resized_graph_gen' 220 | elif args.dataset=='HRF': 221 | train_set_txt_path = '../../HRF/train_768.txt' 222 | test_set_txt_path = '../../HRF/test_768.txt' 223 | im_root_path = '../../HRF/all_768' 224 | cnn_result_root_path = '../HRF_cnn/test' 225 | 226 | with open(train_set_txt_path) as f: 227 | train_img_names = [x.strip() for x in f.readlines()] 228 | with open(test_set_txt_path) as f: 229 | test_img_names = [x.strip() for x in f.readlines()] 230 | 231 | len_train = len(train_img_names) 232 | len_test = len(test_img_names) 233 | 234 | func = generate_graph_using_srns 235 | func_arg_train = map(lambda x: (train_img_names[x], im_root_path, cnn_result_root_path, args), xrange(len_train)) 236 | func_arg_test = map(lambda x: (test_img_names[x], im_root_path, cnn_result_root_path, args), xrange(len_test)) 237 | 238 | if args.use_multiprocessing: 239 | pool = multiprocessing.Pool(processes=20) 240 | 241 | pool.map(func, func_arg_train) 242 | pool.map(func, func_arg_test) 243 | 244 | pool.terminate() 245 | else: 246 | for x in func_arg_train: 247 | func(x) 248 | for x in func_arg_test: 249 | func(x) -------------------------------------------------------------------------------- /code/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: 6 | sys.path.insert(0, path) 7 | 8 | this_dir = os.path.dirname(__file__) 9 | lib_path = os.path.join(this_dir, 'GenGraph') 10 | add_path(lib_path) -------------------------------------------------------------------------------- /code/check_pixel_mean_value.py: -------------------------------------------------------------------------------- 1 | # To get the pixel mean value from training images 2 | # coded by syshin (170805) 3 | # DRIVE : [ 180.73042427 97.27726367 57.10662087] on mask 4 | # DRIVE : [ 126.83705873 69.0154593 41.42158474] 5 | # STARE : [ 150.29591358 83.55034309 27.50114876] 6 | # ALL : [ 136.00693286 74.69702627 35.98020038] # DRIVE+STARE 7 | # CHASE_DB1 : [ 113.95299712 39.80676852 6.88010585] 8 | # HRF : [164.41978937 51.82606062 27.12979025] 9 | 10 | 11 | import os 12 | import numpy as np 13 | import skimage.io 14 | 15 | from config import cfg 16 | 17 | 18 | DATASET='HRF' 19 | 20 | 21 | if DATASET=='DRIVE': 22 | 23 | dataset_root_path = '/mnt/hdd1/DRIVE' 24 | train_img_names = sorted(os.listdir(os.path.join(dataset_root_path, 'training/images'))) 25 | train_img_names = map(lambda x: x[:2], train_img_names) 26 | rgb_cum_sum = np.zeros((3,)) 27 | cum_num_pixels = 0. 28 | for cur_img_name in train_img_names: 29 | cur_img = skimage.io.imread(os.path.join(dataset_root_path,'training/images',cur_img_name+'_training.tif')) 30 | #cur_mask = skimage.io.imread(os.path.join(dataset_root_path,'training/mask',cur_img_name+'_training_mask.gif')) 31 | #cur_mask = cur_mask>100 32 | #cur_img = cur_img*np.dstack((cur_mask,cur_mask,cur_mask)) 33 | cur_rgb_sum = np.sum(cur_img, axis=(0,1)) 34 | rgb_cum_sum = rgb_cum_sum + cur_rgb_sum 35 | #cum_num_pixels += np.sum(cur_mask) 36 | cum_num_pixels += np.cumprod(cur_img.shape)[1] 37 | 38 | mean_rgb_val = rgb_cum_sum/cum_num_pixels 39 | print mean_rgb_val 40 | 41 | elif DATASET=='STARE': 42 | 43 | dataset_root_path = '/mnt/hdd1/STARE' 44 | train_img_names = sorted(os.listdir(os.path.join(dataset_root_path, 'stare-images'))) 45 | train_img_names = map(lambda x: x[:6], train_img_names[:10]) 46 | rgb_cum_sum = np.zeros((3,)) 47 | cum_num_pixels = 0. 48 | for cur_img_name in train_img_names: 49 | cur_img = skimage.io.imread(os.path.join(dataset_root_path,'stare-images',cur_img_name+'.ppm')) 50 | cur_rgb_sum = np.sum(cur_img, axis=(0,1)) 51 | rgb_cum_sum = rgb_cum_sum + cur_rgb_sum 52 | cum_num_pixels += np.cumprod(np.shape(cur_img))[1] 53 | 54 | mean_rgb_val = rgb_cum_sum/cum_num_pixels 55 | print mean_rgb_val 56 | 57 | elif DATASET=='CHASE_DB1': 58 | 59 | train_set_txt_path = '/home/syshin/Documents/CA/CHASE_DB1/train.txt' 60 | 61 | with open(train_set_txt_path) as f: 62 | train_img_names = [x.strip() for x in f.readlines()] 63 | 64 | rgb_cum_sum = np.zeros((3,)) 65 | cum_num_pixels = 0. 66 | for cur_img_name in train_img_names: 67 | cur_img = skimage.io.imread(cur_img_name+'.jpg') 68 | cur_rgb_sum = np.sum(cur_img, axis=(0,1)) 69 | rgb_cum_sum = rgb_cum_sum + cur_rgb_sum 70 | cum_num_pixels += np.cumprod(np.shape(cur_img))[1] 71 | 72 | mean_rgb_val = rgb_cum_sum/cum_num_pixels 73 | print mean_rgb_val 74 | 75 | elif DATASET=='HRF': 76 | 77 | train_set_txt_path = '/home/syshin/Documents/CA/HRF/train_fr.txt' 78 | 79 | with open(train_set_txt_path) as f: 80 | train_img_names = [x.strip() for x in f.readlines()] 81 | 82 | rgb_cum_sum = np.zeros((3,)) 83 | cum_num_pixels = 0. 84 | for cur_img_name in train_img_names: 85 | cur_img = skimage.io.imread(cur_img_name+'.jpg') 86 | cur_rgb_sum = np.sum(cur_img, axis=(0,1)) 87 | rgb_cum_sum = rgb_cum_sum + cur_rgb_sum 88 | cum_num_pixels += np.cumprod(np.shape(cur_img))[1] 89 | 90 | mean_rgb_val = rgb_cum_sum/cum_num_pixels 91 | print mean_rgb_val 92 | 93 | elif DATASET=='ALL': 94 | 95 | dataset_root_path = '/mnt/hdd1/DRIVE' 96 | train_img_names = sorted(os.listdir(os.path.join(dataset_root_path, 'training/images'))) 97 | train_img_names = map(lambda x: x[:2], train_img_names) 98 | rgb_cum_sum = np.zeros((3,)) 99 | cum_num_pixels = 0. 100 | for cur_img_name in train_img_names: 101 | cur_img = skimage.io.imread(os.path.join(dataset_root_path,'training/images',cur_img_name+'_training.tif')) 102 | #cur_mask = skimage.io.imread(os.path.join(dataset_root_path,'training/mask',cur_img_name+'_training_mask.gif')) 103 | #cur_mask = cur_mask>100 104 | #cur_img = cur_img*np.dstack((cur_mask,cur_mask,cur_mask)) 105 | cur_rgb_sum = np.sum(cur_img, axis=(0,1)) 106 | rgb_cum_sum = rgb_cum_sum + cur_rgb_sum 107 | #cum_num_pixels += np.sum(cur_mask) 108 | cum_num_pixels += np.cumprod(cur_img.shape)[1] 109 | 110 | dataset_root_path = '/mnt/hdd1/STARE' 111 | train_img_names = sorted(os.listdir(os.path.join(dataset_root_path, 'stare-images'))) 112 | train_img_names = map(lambda x: x[:6], train_img_names[:10]) 113 | for cur_img_name in train_img_names: 114 | cur_img = skimage.io.imread(os.path.join(dataset_root_path,'stare-images',cur_img_name+'.ppm')) 115 | cur_rgb_sum = np.sum(cur_img, axis=(0,1)) 116 | rgb_cum_sum = rgb_cum_sum + cur_rgb_sum 117 | cum_num_pixels += np.cumprod(np.shape(cur_img))[1] 118 | 119 | mean_rgb_val = rgb_cum_sum/cum_num_pixels 120 | print mean_rgb_val -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | """ Common config file 2 | """ 3 | 4 | 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | __C = edict() 9 | cfg = __C 10 | 11 | 12 | 13 | ##### Training (general) ##### 14 | 15 | __C.TRAIN = edict() 16 | 17 | __C.TRAIN.MODEL_SAVE_PATH = 'train' 18 | 19 | __C.TRAIN.DISPLAY = 10 20 | 21 | __C.TRAIN.TEST_ITERS = 500 22 | 23 | __C.TRAIN.SNAPSHOT_ITERS = 500 24 | 25 | __C.TRAIN.WEIGHT_DECAY_RATE = 0.0005 26 | 27 | __C.TRAIN.MOMENTUM = 0.9 28 | 29 | __C.TRAIN.BATCH_SIZE = 1 # for CNN 30 | 31 | __C.TRAIN.GRAPH_BATCH_SIZE = 1 # for VGN 32 | 33 | ##### Training (paths) ##### 34 | 35 | __C.TRAIN.DRIVE_SET_TXT_PATH = '../DRIVE/train.txt' 36 | 37 | __C.TRAIN.STARE_SET_TXT_PATH = '../STARE/train.txt' 38 | 39 | __C.TRAIN.CHASE_DB1_SET_TXT_PATH = '../CHASE_DB1/train.txt' 40 | 41 | __C.TRAIN.HRF_SET_TXT_PATH = '../HRF/train_768.txt' 42 | 43 | __C.TRAIN.TEMP_GRAPH_SAVE_PATH = 'graph' 44 | 45 | ##### Training (augmentation) ##### 46 | 47 | # horizontal flipping 48 | __C.TRAIN.USE_LR_FLIPPED = True 49 | 50 | # vertical flipping 51 | __C.TRAIN.USE_UD_FLIPPED = False 52 | 53 | # rotation 54 | __C.TRAIN.USE_ROTATION = False 55 | __C.TRAIN.ROTATION_MAX_ANGLE = 45 56 | 57 | # scaling 58 | __C.TRAIN.USE_SCALING = False 59 | __C.TRAIN.SCALING_RANGE = [1., 1.25] 60 | 61 | # cropping 62 | __C.TRAIN.USE_CROPPING = False 63 | __C.TRAIN.CROPPING_MAX_MARGIN = 0.05 # in ratio 64 | 65 | # brightness adjustment 66 | __C.TRAIN.USE_BRIGHTNESS_ADJUSTMENT = True 67 | __C.TRAIN.BRIGHTNESS_ADJUSTMENT_MAX_DELTA = 0.2 68 | 69 | # contrast adjustment 70 | __C.TRAIN.USE_CONTRAST_ADJUSTMENT = True 71 | __C.TRAIN.CONTRAST_ADJUSTMENT_LOWER_FACTOR = 0.5 72 | __C.TRAIN.CONTRAST_ADJUSTMENT_UPPER_FACTOR = 1.5 73 | 74 | 75 | 76 | ##### Test (general) ##### 77 | 78 | __C.TEST = edict() 79 | 80 | ##### Test (paths) ##### 81 | 82 | __C.TEST.DRIVE_SET_TXT_PATH = '../DRIVE/test.txt' 83 | 84 | __C.TEST.STARE_SET_TXT_PATH = '../STARE/test.txt' 85 | 86 | __C.TEST.CHASE_DB1_SET_TXT_PATH = '../CHASE_DB1/test.txt' 87 | 88 | __C.TEST.HRF_SET_TXT_PATH = '../HRF/test_768.txt' 89 | #__C.TEST.HRF_SET_TXT_PATH = '../HRF/test_fr.txt' 90 | 91 | __C.TEST.RES_SAVE_PATH = 'test' 92 | 93 | # especially for the HRF dataset 94 | __C.TEST.WHOLE_IMG_RES_SAVE_PATH = 'test_whole_img' 95 | 96 | 97 | 98 | ##### Misc. ##### 99 | 100 | __C.PIXEL_MEAN_DRIVE = [126.837, 69.015, 41.422] 101 | 102 | __C.PIXEL_MEAN_STARE = [150.296, 83.550, 27.501] 103 | 104 | __C.PIXEL_MEAN_CHASE_DB1 = [113.953, 39.807, 6.880] 105 | 106 | __C.PIXEL_MEAN_HRF = [164.420, 51.826, 27.130] 107 | 108 | __C.EPSILON = 1e-03 109 | 110 | 111 | 112 | ##### Feature normalization ##### 113 | 114 | __C.USE_BRN = True 115 | 116 | __C.GN_MIN_NUM_G = 8 117 | 118 | __C.GN_MIN_CHS_PER_G = 16 -------------------------------------------------------------------------------- /code/extract_subimages_HRF.py: -------------------------------------------------------------------------------- 1 | # to make multiple sub-images for training and test by cropping the whole images 2 | # coded by syshin (180428) 3 | 4 | import numpy as np 5 | import os 6 | import skimage.io 7 | 8 | IMG_SIZE = [2336,3504] 9 | SUB_IMG_SIZE = 768 10 | 11 | if __name__ == '__main__': 12 | 13 | save_root_path = '../HRF/all_768' 14 | if not os.path.isdir(save_root_path): 15 | os.mkdir(save_root_path) 16 | 17 | set_name = 'train' 18 | # we apply differnt degrees of overlapping for train & test sets 19 | # when extract sub-images 20 | if set_name=='train': 21 | fr_set_txt_path = '../HRF/train_fr.txt' 22 | crop_set_txt_path = '../HRF/train_768.txt' 23 | 24 | y_mins = range(0,IMG_SIZE[0]-SUB_IMG_SIZE+1,SUB_IMG_SIZE/2) 25 | x_mins = range(0,IMG_SIZE[1]-SUB_IMG_SIZE+1,SUB_IMG_SIZE/2) 26 | y_mins = sorted(list(set(y_mins + [IMG_SIZE[0]-SUB_IMG_SIZE]))) 27 | x_mins = sorted(list(set(x_mins + [IMG_SIZE[1]-SUB_IMG_SIZE]))) 28 | 29 | elif set_name=='test': 30 | fr_set_txt_path = '../HRF/test_fr.txt' 31 | crop_set_txt_path = '../HRF/test_768.txt' 32 | 33 | y_mins = range(0,IMG_SIZE[0]-SUB_IMG_SIZE+1,SUB_IMG_SIZE-50) 34 | x_mins = range(0,IMG_SIZE[1]-SUB_IMG_SIZE+1,SUB_IMG_SIZE-50) 35 | y_mins = sorted(list(set(y_mins + [IMG_SIZE[0]-SUB_IMG_SIZE]))) 36 | x_mins = sorted(list(set(x_mins + [IMG_SIZE[1]-SUB_IMG_SIZE]))) 37 | 38 | with open(fr_set_txt_path) as f: 39 | img_names = [x.strip() for x in f.readlines()] 40 | 41 | file_p = open(crop_set_txt_path, 'w') 42 | 43 | for cur_img_name in img_names: 44 | cur_img = skimage.io.imread(cur_img_name+'.jpg') 45 | cur_mask = skimage.io.imread(cur_img_name+'.tif') 46 | cur_fov_mask = skimage.io.imread(cur_img_name+'_mask.tif') 47 | 48 | for y_idx, cur_y_min in enumerate(y_mins): 49 | for x_idx, cur_x_min in enumerate(x_mins): 50 | cur_sub_img = cur_img[cur_y_min:cur_y_min+SUB_IMG_SIZE,cur_x_min:cur_x_min+SUB_IMG_SIZE,:] 51 | cur_sub_mask = cur_mask[cur_y_min:cur_y_min+SUB_IMG_SIZE,cur_x_min:cur_x_min+SUB_IMG_SIZE] 52 | cur_sub_fov_mask = cur_fov_mask[cur_y_min:cur_y_min+SUB_IMG_SIZE,cur_x_min:cur_x_min+SUB_IMG_SIZE,0] 53 | 54 | temp = cur_img_name[cur_img_name.rfind('/')+1:] 55 | 56 | cur_sub_img_name = os.path.join(save_root_path, temp + '_{:d}_{:d}'.format(y_idx,x_idx)) 57 | 58 | file_p.write(cur_sub_img_name+'\n') 59 | file_p.flush() 60 | 61 | cur_save_path = os.path.join(cur_sub_img_name + '.bmp') 62 | skimage.io.imsave(cur_save_path, cur_sub_img) 63 | cur_save_path = os.path.join(cur_sub_img_name + '.tif') 64 | skimage.io.imsave(cur_save_path, cur_sub_mask) 65 | cur_save_path = os.path.join(cur_sub_img_name + '_mask.tif') 66 | skimage.io.imsave(cur_save_path, cur_sub_fov_mask) 67 | 68 | file_p.close() -------------------------------------------------------------------------------- /code/test_CNN.py: -------------------------------------------------------------------------------- 1 | # coded by syshin 2 | 3 | import numpy as np 4 | import os 5 | import pdb 6 | import skimage.io 7 | import argparse 8 | import tensorflow as tf 9 | 10 | from config import cfg 11 | from model import vessel_segm_cnn 12 | import util 13 | 14 | 15 | def parse_args(): 16 | """ 17 | Parse input arguments 18 | """ 19 | parser = argparse.ArgumentParser(description='Test a vessel_segm_cnn network') 20 | parser.add_argument('--dataset', default='DRIVE', help='Dataset to use: Can be DRIVE or STARE or CHASE_DB1', type=str) 21 | parser.add_argument('--cnn_model', default='driu', help='CNN model to use', type=str) 22 | parser.add_argument('--use_fov_mask', default=True, help='Whether to use fov masks', type=bool) 23 | parser.add_argument('--opt', default='adam', help='Optimizer to use: Can be sgd or adam', type=str) # declared but not used 24 | parser.add_argument('--lr', default=1e-02, help='Learning rate to use: Can be any floating point number', type=float) # declared but not used 25 | parser.add_argument('--lr_decay', default='pc', help='Learning rate decay to use: Can be pc or exp', type=str) # declared but not used 26 | parser.add_argument('--max_iters', default=50000, help='Maximum number of iterations', type=int) # declared but not used 27 | parser.add_argument('--model_path', default='../models/DRIVE/DRIU*/DRIU_DRIVE.ckpt', help='path for a model(.ckpt) to load', type=str) 28 | parser.add_argument('--save_root', default='DRIU_DRIVE', help='root path to save test results', type=str) 29 | 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | args = parse_args() 37 | 38 | print('Called with args:') 39 | print(args) 40 | 41 | if args.dataset=='DRIVE': 42 | train_set_txt_path = cfg.TRAIN.DRIVE_SET_TXT_PATH 43 | test_set_txt_path = cfg.TEST.DRIVE_SET_TXT_PATH 44 | elif args.dataset=='STARE': 45 | train_set_txt_path = cfg.TRAIN.STARE_SET_TXT_PATH 46 | test_set_txt_path = cfg.TEST.STARE_SET_TXT_PATH 47 | elif args.dataset=='CHASE_DB1': 48 | train_set_txt_path = cfg.TRAIN.CHASE_DB1_SET_TXT_PATH 49 | test_set_txt_path = cfg.TEST.CHASE_DB1_SET_TXT_PATH 50 | 51 | with open(train_set_txt_path) as f: 52 | train_img_names = [x.strip() for x in f.readlines()] 53 | with open(test_set_txt_path) as f: 54 | test_img_names = [x.strip() for x in f.readlines()] 55 | 56 | len_train = len(train_img_names) 57 | len_test = len(test_img_names) 58 | 59 | data_layer_train = util.DataLayer(train_img_names, is_training=False) 60 | data_layer_test = util.DataLayer(test_img_names, is_training=False) 61 | 62 | res_save_path = args.save_root + '/' + cfg.TEST.RES_SAVE_PATH if len(args.save_root)>0 else cfg.TEST.RES_SAVE_PATH 63 | 64 | if len(args.save_root)>0 and not os.path.isdir(args.save_root): 65 | os.mkdir(args.save_root) 66 | if not os.path.isdir(res_save_path): 67 | os.mkdir(res_save_path) 68 | 69 | network = vessel_segm_cnn(args, None) 70 | 71 | config = tf.ConfigProto() 72 | config.gpu_options.allow_growth = True 73 | sess = tf.InteractiveSession(config=config) 74 | saver = tf.train.Saver() 75 | 76 | sess.run(tf.global_variables_initializer()) 77 | 78 | assert args.model_path, 'Model path is not available' 79 | print "Loading model..." 80 | saver.restore(sess, args.model_path) 81 | 82 | f_log = open(os.path.join(res_save_path,'log.txt'), 'w') 83 | f_log.write(args.model_path+'\n') 84 | timer = util.Timer() 85 | 86 | train_loss_list = [] 87 | for _ in xrange(int(np.ceil(float(len_train)/cfg.TRAIN.BATCH_SIZE))): 88 | 89 | timer.tic() 90 | 91 | # get one batch 92 | img_list, blobs_train = data_layer_train.forward() 93 | 94 | img = blobs_train['img'] 95 | label = blobs_train['label'] 96 | if args.use_fov_mask: 97 | fov_mask = blobs_train['fov'] 98 | else: 99 | fov_mask = np.ones(label.shape, dtype=label.dtype) 100 | 101 | loss_val, fg_prob_map = sess.run( 102 | [network.loss, network.fg_prob], 103 | feed_dict={ 104 | network.is_training: False, 105 | network.imgs: img, 106 | network.labels: label, 107 | network.fov_masks: fov_mask 108 | }) 109 | 110 | timer.toc() 111 | 112 | #fg_prob_map = fg_prob_map*fov_mask.astype(float) 113 | 114 | train_loss_list.append(loss_val) 115 | 116 | # save qualitative results 117 | cur_batch_size = len(img_list) 118 | reshaped_fg_prob_map = fg_prob_map.reshape((cur_batch_size,fg_prob_map.shape[1],fg_prob_map.shape[2])) 119 | if args.dataset=='DRIVE': 120 | if args.dataset=='DRIVE': 121 | mask = np.concatenate(map(lambda x: np.expand_dims(skimage.io.imread(x+'_mask.gif'), axis=0), img_list), axis=0) 122 | else: 123 | mask = np.concatenate(map(lambda x: np.expand_dims(skimage.io.imread(x+'_mask.tif'), axis=0), img_list), axis=0) 124 | mask = ((mask.astype(float)/255)>=0.5).astype(float) 125 | reshaped_fg_prob_map = reshaped_fg_prob_map*mask 126 | reshaped_output = reshaped_fg_prob_map>=0.5 127 | for img_idx in xrange(cur_batch_size): 128 | cur_test_img_path = img_list[img_idx] 129 | temp_name = cur_test_img_path[util.find(cur_test_img_path,'/')[-1]+1:] 130 | 131 | cur_reshaped_fg_prob_map = (reshaped_fg_prob_map[img_idx,:,:]*255).astype(int) 132 | cur_reshaped_output = reshaped_output[img_idx,:,:].astype(int)*255 133 | 134 | cur_fg_prob_save_path = os.path.join(res_save_path, temp_name + '_prob.png') 135 | cur_output_save_path = os.path.join(res_save_path, temp_name + '_output.png') 136 | 137 | skimage.io.imsave(cur_fg_prob_save_path, cur_reshaped_fg_prob_map) 138 | skimage.io.imsave(cur_output_save_path, cur_reshaped_output) 139 | 140 | test_loss_list = [] 141 | all_cnn_labels = np.zeros((0,)) 142 | all_cnn_preds = np.zeros((0,)) 143 | all_cnn_labels_roi = np.zeros((0,)) 144 | all_cnn_preds_roi = np.zeros((0,)) 145 | for _ in xrange(int(np.ceil(float(len_test)/cfg.TRAIN.BATCH_SIZE))): 146 | 147 | timer.tic() 148 | 149 | # get one batch 150 | img_list, blobs_test = data_layer_test.forward() 151 | 152 | img = blobs_test['img'] 153 | label = blobs_test['label'] 154 | if args.use_fov_mask: 155 | fov_mask = blobs_test['fov'] 156 | else: 157 | fov_mask = np.ones(label.shape, dtype=label.dtype) 158 | 159 | loss_val, fg_prob_map = sess.run( 160 | [network.loss, network.fg_prob], 161 | feed_dict={ 162 | network.is_training: False, 163 | network.imgs: img, 164 | network.labels: label, 165 | network.fov_masks: fov_mask 166 | }) 167 | 168 | timer.toc() 169 | 170 | #fg_prob_map = fg_prob_map*fov_mask.astype(float) 171 | 172 | test_loss_list.append(loss_val) 173 | 174 | all_cnn_labels = np.concatenate((all_cnn_labels,np.reshape(label, (-1)))) 175 | all_cnn_preds = np.concatenate((all_cnn_preds,np.reshape(fg_prob_map, (-1)))) 176 | 177 | # save qualitative results 178 | cur_batch_size = len(img_list) 179 | reshaped_fg_prob_map = fg_prob_map.reshape((cur_batch_size,fg_prob_map.shape[1],fg_prob_map.shape[2])) 180 | 181 | if args.dataset=='DRIVE': 182 | if args.dataset=='DRIVE': 183 | mask = np.concatenate(map(lambda x: np.expand_dims(skimage.io.imread(x+'_mask.gif'), axis=0), img_list), axis=0) 184 | else: 185 | mask = np.concatenate(map(lambda x: np.expand_dims(skimage.io.imread(x+'_mask.tif'), axis=0), img_list), axis=0) 186 | 187 | mask = ((mask.astype(float)/255)>=0.5).astype(float) 188 | label_roi = label[mask.astype(bool)] 189 | fg_prob_map_roi = fg_prob_map[mask.astype(bool)] 190 | all_cnn_labels_roi = np.concatenate((all_cnn_labels_roi,np.reshape(label_roi, (-1)))) 191 | all_cnn_preds_roi = np.concatenate((all_cnn_preds_roi,np.reshape(fg_prob_map_roi, (-1)))) 192 | reshaped_fg_prob_map = reshaped_fg_prob_map*mask 193 | label = np.squeeze(label.astype(float), axis=-1)*mask 194 | else: 195 | label = np.squeeze(label.astype(float), axis=-1) 196 | 197 | reshaped_output = reshaped_fg_prob_map>=0.5 198 | for img_idx in xrange(cur_batch_size): 199 | cur_test_img_path = img_list[img_idx] 200 | temp_name = cur_test_img_path[util.find(cur_test_img_path,'/')[-1]+1:] 201 | 202 | cur_reshaped_fg_prob_map = (reshaped_fg_prob_map[img_idx,:,:]*255).astype(int) 203 | cur_reshaped_fg_prob_map_inv = ((1.-reshaped_fg_prob_map[img_idx,:,:])*255).astype(int) 204 | cur_reshaped_output = reshaped_output[img_idx,:,:].astype(int)*255 205 | 206 | cur_fg_prob_save_path = os.path.join(res_save_path, temp_name + '_prob.png') 207 | cur_fg_prob_inv_save_path = os.path.join(res_save_path, temp_name + '_prob_inv.png') 208 | cur_output_save_path = os.path.join(res_save_path, temp_name + '_output.png') 209 | cur_numpy_save_path = os.path.join(res_save_path, temp_name + '.npy') 210 | 211 | skimage.io.imsave(cur_fg_prob_save_path, cur_reshaped_fg_prob_map) 212 | skimage.io.imsave(cur_fg_prob_inv_save_path, cur_reshaped_fg_prob_map_inv) 213 | skimage.io.imsave(cur_output_save_path, cur_reshaped_output) 214 | np.save(cur_numpy_save_path, reshaped_fg_prob_map[img_idx,:,:]) 215 | 216 | cnn_auc_test, cnn_ap_test = util.get_auc_ap_score(all_cnn_labels, all_cnn_preds) 217 | all_cnn_labels_bin = np.copy(all_cnn_labels).astype(np.bool) 218 | all_cnn_preds_bin = all_cnn_preds>=0.5 219 | all_cnn_correct = all_cnn_labels_bin==all_cnn_preds_bin 220 | cnn_acc_test = np.mean(all_cnn_correct.astype(np.float32)) 221 | 222 | if args.dataset=='DRIVE': 223 | cnn_auc_test_roi, cnn_ap_test_roi = util.get_auc_ap_score(all_cnn_labels_roi, all_cnn_preds_roi) 224 | all_cnn_labels_bin_roi = np.copy(all_cnn_labels_roi).astype(np.bool) 225 | all_cnn_preds_bin_roi = all_cnn_preds_roi>=0.5 226 | all_cnn_correct_roi = all_cnn_labels_bin_roi==all_cnn_preds_bin_roi 227 | cnn_acc_test_roi = np.mean(all_cnn_correct_roi.astype(np.float32)) 228 | 229 | #print 'train_loss: %.4f'%(np.mean(train_loss_list)) 230 | print 'test_loss: %.4f'%(np.mean(test_loss_list)) 231 | print 'test_cnn_acc: %.4f, test_cnn_auc: %.4f, test_cnn_ap: %.4f'%(cnn_acc_test, cnn_auc_test, cnn_ap_test) 232 | if args.dataset=='DRIVE': 233 | print 'test_cnn_acc_roi: %.4f, test_cnn_auc_roi: %.4f, test_cnn_ap_roi: %.4f'%(cnn_acc_test_roi, cnn_auc_test_roi, cnn_ap_test_roi) 234 | 235 | #f_log.write('train_loss '+str(np.mean(train_loss_list))+'\n') 236 | f_log.write('test_loss '+str(np.mean(test_loss_list))+'\n') 237 | f_log.write('test_cnn_acc '+str(cnn_acc_test)+'\n') 238 | f_log.write('test_cnn_auc '+str(cnn_auc_test)+'\n') 239 | f_log.write('test_cnn_ap '+str(cnn_ap_test)+'\n') 240 | if args.dataset=='DRIVE': 241 | f_log.write('test_cnn_acc_roi '+str(cnn_acc_test_roi)+'\n') 242 | f_log.write('test_cnn_auc_roi '+str(cnn_auc_test_roi)+'\n') 243 | f_log.write('test_cnn_ap_roi '+str(cnn_ap_test_roi)+'\n') 244 | 245 | f_log.flush() 246 | 247 | print 'speed: {:.3f}s'.format(timer.average_time) 248 | 249 | f_log.close() 250 | sess.close() 251 | print("Test complete.") -------------------------------------------------------------------------------- /code/test_CNN_HRF.py: -------------------------------------------------------------------------------- 1 | # a special script for testing images in the HRF dataset 2 | # here, multiple sub-images from a single image are independently tested 3 | # and tiled to make a result for the whole image 4 | # coded by syshin (180430) 5 | 6 | import numpy as np 7 | import os 8 | import pdb 9 | import skimage.io 10 | import argparse 11 | import tensorflow as tf 12 | 13 | from config import cfg 14 | from model import vessel_segm_cnn 15 | import util 16 | 17 | 18 | def parse_args(): 19 | """ 20 | Parse input arguments 21 | """ 22 | parser = argparse.ArgumentParser(description='Test a vessel_segm_cnn network') 23 | parser.add_argument('--dataset', default='HRF', help='Dataset to use', type=str) 24 | parser.add_argument('--cnn_model', default='driu_large', help='CNN model to use', type=str) 25 | parser.add_argument('--opt', default='sgd', help='Optimizer to use: Can be sgd or adam', type=str) # declared but not used 26 | parser.add_argument('--lr', default=1e-03, help='Learning rate to use: Can be any floating point number', type=float) # declared but not used 27 | parser.add_argument('--lr_decay', default='pc', help='Learning rate decay to use: Can be pc or exp', type=str) # declared but not used 28 | parser.add_argument('--max_iters', default=100000, help='Maximum number of iterations', type=int) # declared but not used 29 | parser.add_argument('--model_path', default='../models/HRF/DRIU*/DRIU_HRF.ckpt', help='path for a model(.ckpt) to load', type=str) 30 | parser.add_argument('--save_root', default='DRIU_HRF', help='root path to save test results', type=str) 31 | 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | if __name__ == '__main__': 37 | 38 | args = parse_args() 39 | 40 | print('Called with args:') 41 | print(args) 42 | 43 | # added for testing on a restricted test set 44 | test_type = ['dr', 'g', 'h'] 45 | # added for testing on a restricted test set 46 | 47 | with open('../HRF/test_fr.txt') as f: 48 | test_whole_img_paths = [x.strip() for x in f.readlines()] # different to 'test_img_names' 49 | 50 | # added for testing on a restricted test set 51 | temp = [] 52 | for i in xrange(len(test_whole_img_paths)): 53 | for j in xrange(len(test_type)): 54 | if test_type[j] in test_whole_img_paths[i][util.find(test_whole_img_paths[i],'/')[-1]+1:]: 55 | temp.append(test_whole_img_paths[i]) 56 | break 57 | test_whole_img_paths = temp 58 | # added for testing on a restricted test set 59 | 60 | test_whole_img_names = map(lambda x: x[util.find(x,'/')[-1]+1:], test_whole_img_paths) 61 | 62 | test_set_txt_path = cfg.TEST.HRF_SET_TXT_PATH 63 | IMG_SIZE = [2336,3504] 64 | SUB_IMG_SIZE = 768 65 | y_mins = range(0,IMG_SIZE[0]-SUB_IMG_SIZE+1,SUB_IMG_SIZE-50) 66 | x_mins = range(0,IMG_SIZE[1]-SUB_IMG_SIZE+1,SUB_IMG_SIZE-50) 67 | y_mins = sorted(list(set(y_mins + [IMG_SIZE[0]-SUB_IMG_SIZE]))) 68 | x_mins = sorted(list(set(x_mins + [IMG_SIZE[1]-SUB_IMG_SIZE]))) 69 | 70 | with open(test_set_txt_path) as f: 71 | test_img_names = [x.strip() for x in f.readlines()] 72 | 73 | # added for testing on a restricted test set 74 | temp = [] 75 | for i in xrange(len(test_img_names)): 76 | for j in xrange(len(test_type)): 77 | if test_type[j] in test_img_names[i][util.find(test_img_names[i],'/')[-1]+1:]: 78 | temp.append(test_img_names[i]) 79 | break 80 | test_img_names = temp 81 | # added for testing on a restricted test set 82 | 83 | len_test = len(test_img_names) 84 | 85 | data_layer_test = util.DataLayer(test_img_names, is_training=False) 86 | 87 | res_save_path = args.save_root + '/' + cfg.TEST.WHOLE_IMG_RES_SAVE_PATH if len(args.save_root)>0 else cfg.TEST.WHOLE_IMG_RES_SAVE_PATH 88 | if len(args.save_root)>0 and not os.path.isdir(args.save_root): 89 | os.mkdir(args.save_root) 90 | if not os.path.isdir(res_save_path): 91 | os.mkdir(res_save_path) 92 | 93 | network = vessel_segm_cnn(args, None) 94 | 95 | config = tf.ConfigProto() 96 | config.gpu_options.allow_growth = True 97 | sess = tf.InteractiveSession(config=config) 98 | saver = tf.train.Saver() 99 | 100 | sess.run(tf.global_variables_initializer()) 101 | 102 | assert args.model_path, 'Model path is not available' 103 | print "Loading model..." 104 | saver.restore(sess, args.model_path) 105 | 106 | f_log = open(os.path.join(res_save_path,'log.txt'), 'w') 107 | f_log.write(args.model_path+'\n') 108 | timer = util.Timer() 109 | 110 | all_cnn_labels = np.concatenate(map(lambda x: np.expand_dims(skimage.io.imread(x+'.tif'), axis=0), test_whole_img_paths), axis=0) 111 | all_cnn_masks = np.concatenate(map(lambda x: np.expand_dims(skimage.io.imread(x+'_mask.tif'), axis=0), test_whole_img_paths), axis=0) 112 | all_cnn_masks = all_cnn_masks[:,:,:,0] 113 | all_cnn_labels = ((all_cnn_labels.astype(float)/255)>=0.5).astype(float) 114 | all_cnn_masks = ((all_cnn_masks.astype(float)/255)>=0.5).astype(float) 115 | all_cnn_preds_cum_sum = np.zeros(all_cnn_labels.shape) 116 | all_cnn_preds_cum_num = np.zeros(all_cnn_labels.shape) 117 | 118 | for _ in xrange(int(np.ceil(float(len_test)/cfg.TRAIN.BATCH_SIZE))): 119 | 120 | timer.tic() 121 | 122 | # get one batch 123 | img_list, blobs_test = data_layer_test.forward() 124 | 125 | img = blobs_test['img'] 126 | label = blobs_test['label'] 127 | 128 | fg_prob_map = sess.run( 129 | [network.fg_prob], 130 | feed_dict={ 131 | network.is_training: False, 132 | network.imgs: img, 133 | network.labels: label, 134 | }) 135 | fg_prob_map = fg_prob_map[0] 136 | fg_prob_map = fg_prob_map.reshape((fg_prob_map.shape[1],fg_prob_map.shape[2])) 137 | 138 | timer.toc() 139 | 140 | cur_batch_size = len(img_list) 141 | for i in xrange(cur_batch_size): 142 | cur_test_img_path = img_list[i] 143 | temp_name = cur_test_img_path[util.find(cur_test_img_path,'/')[-1]+1:] 144 | 145 | temp_name_splits = temp_name.split('_') 146 | 147 | img_idx = test_whole_img_names.index(temp_name_splits[0]+'_'+temp_name_splits[1]) 148 | cur_y_min = y_mins[int(temp_name_splits[2])] 149 | cur_x_min = x_mins[int(temp_name_splits[3])] 150 | 151 | all_cnn_preds_cum_sum[img_idx,cur_y_min:cur_y_min+SUB_IMG_SIZE,cur_x_min:cur_x_min+SUB_IMG_SIZE] += fg_prob_map 152 | all_cnn_preds_cum_num[img_idx,cur_y_min:cur_y_min+SUB_IMG_SIZE,cur_x_min:cur_x_min+SUB_IMG_SIZE] += 1 153 | 154 | all_cnn_preds = np.divide(all_cnn_preds_cum_sum,all_cnn_preds_cum_num) 155 | 156 | all_cnn_labels_roi = all_cnn_labels[all_cnn_masks.astype(bool)] 157 | all_cnn_preds_roi = all_cnn_preds[all_cnn_masks.astype(bool)] 158 | 159 | # save qualitative results 160 | reshaped_fg_prob_map = all_cnn_preds*all_cnn_masks 161 | reshaped_output = reshaped_fg_prob_map>=0.5 162 | for img_idx, temp_name in enumerate(test_whole_img_names): 163 | 164 | cur_reshaped_fg_prob_map = (reshaped_fg_prob_map[img_idx,:,:]*255).astype(int) 165 | cur_reshaped_fg_prob_map_inv = ((1.-reshaped_fg_prob_map[img_idx,:,:])*255).astype(int) 166 | cur_reshaped_output = reshaped_output[img_idx,:,:].astype(int)*255 167 | 168 | cur_fg_prob_save_path = os.path.join(res_save_path, temp_name + '_prob.png') 169 | cur_fg_prob_inv_save_path = os.path.join(res_save_path, temp_name + '_prob_inv.png') 170 | cur_output_save_path = os.path.join(res_save_path, temp_name + '_output.png') 171 | cur_numpy_save_path = os.path.join(res_save_path, temp_name + '.npy') 172 | 173 | skimage.io.imsave(cur_fg_prob_save_path, cur_reshaped_fg_prob_map) 174 | skimage.io.imsave(cur_fg_prob_inv_save_path, cur_reshaped_fg_prob_map_inv) 175 | skimage.io.imsave(cur_output_save_path, cur_reshaped_output) 176 | np.save(cur_numpy_save_path, reshaped_fg_prob_map[img_idx,:,:]) 177 | 178 | all_cnn_labels = np.reshape(all_cnn_labels, (-1)) 179 | all_cnn_preds = np.reshape(all_cnn_preds, (-1)) 180 | all_cnn_labels_roi = np.reshape(all_cnn_labels_roi, (-1)) 181 | all_cnn_preds_roi = np.reshape(all_cnn_preds_roi, (-1)) 182 | 183 | cnn_auc_test, cnn_ap_test = util.get_auc_ap_score(all_cnn_labels, all_cnn_preds) 184 | all_cnn_labels_bin = np.copy(all_cnn_labels).astype(np.bool) 185 | all_cnn_preds_bin = all_cnn_preds>=0.5 186 | all_cnn_correct = all_cnn_labels_bin==all_cnn_preds_bin 187 | cnn_acc_test = np.mean(all_cnn_correct.astype(np.float32)) 188 | 189 | cnn_auc_test_roi, cnn_ap_test_roi = util.get_auc_ap_score(all_cnn_labels_roi, all_cnn_preds_roi) 190 | all_cnn_labels_bin_roi = np.copy(all_cnn_labels_roi).astype(np.bool) 191 | all_cnn_preds_bin_roi = all_cnn_preds_roi>=0.5 192 | all_cnn_correct_roi = all_cnn_labels_bin_roi==all_cnn_preds_bin_roi 193 | cnn_acc_test_roi = np.mean(all_cnn_correct_roi.astype(np.float32)) 194 | 195 | print 'test_cnn_acc: %.4f, test_cnn_auc: %.4f, test_cnn_ap: %.4f'%(cnn_acc_test, cnn_auc_test, cnn_ap_test) 196 | print 'test_cnn_acc_roi: %.4f, test_cnn_auc_roi: %.4f, test_cnn_ap_roi: %.4f'%(cnn_acc_test_roi, cnn_auc_test_roi, cnn_ap_test_roi) 197 | 198 | f_log.write('test_cnn_acc '+str(cnn_acc_test)+'\n') 199 | f_log.write('test_cnn_auc '+str(cnn_auc_test)+'\n') 200 | f_log.write('test_cnn_ap '+str(cnn_ap_test)+'\n') 201 | f_log.write('test_cnn_acc_roi '+str(cnn_acc_test_roi)+'\n') 202 | f_log.write('test_cnn_auc_roi '+str(cnn_auc_test_roi)+'\n') 203 | f_log.write('test_cnn_ap_roi '+str(cnn_ap_test_roi)+'\n') 204 | 205 | f_log.flush() 206 | 207 | print 'speed: {:.3f}s'.format(timer.average_time) 208 | 209 | f_log.close() 210 | sess.close() 211 | print("Test complete.") -------------------------------------------------------------------------------- /code/test_VGN.py: -------------------------------------------------------------------------------- 1 | # updated by syshin (180829) 2 | 3 | import numpy as np 4 | import os 5 | import pdb 6 | import argparse 7 | import skimage.io 8 | import networkx as nx 9 | import pickle as pkl 10 | import multiprocessing 11 | import skfmm 12 | import skimage.transform 13 | import tensorflow as tf 14 | 15 | from config import cfg 16 | from model import vessel_segm_vgn 17 | import util 18 | 19 | 20 | def parse_args(): 21 | """ 22 | Parse input arguments 23 | """ 24 | parser = argparse.ArgumentParser(description='Test a vessel_segm_vgn network') 25 | parser.add_argument('--dataset', default='CHASE_DB1', help='Dataset to use: Can be DRIVE or STARE or CHASE_DB1', type=str) 26 | #parser.add_argument('--use_multiprocessing', action='store_true', default=False, help='Whether to use the python multiprocessing module') 27 | parser.add_argument('--use_multiprocessing', default=True, help='Whether to use the python multiprocessing module', type=bool) 28 | parser.add_argument('--multiprocessing_num_proc', default=8, help='Number of CPU processes to use', type=int) 29 | parser.add_argument('--win_size', default=16, help='Window size for srns', type=int) # for srns # [4,8,16] 30 | parser.add_argument('--edge_type', default='srns_geo_dist_binary', \ 31 | help='Graph edge type: Can be srns_geo_dist_binary or srns_geo_dist_weighted', type=str) 32 | parser.add_argument('--edge_geo_dist_thresh', default=40, help='Threshold for geodesic distance', type=float) # [10,20,40] 33 | parser.add_argument('--model_path', default='../models/CHASE_DB1/VGN/win_size=16/VGN_CHASE.ckpt', \ 34 | help='Path for a trained model(.ckpt)', type=str) 35 | parser.add_argument('--save_root', default='../models/CHASE_DB1/VGN/win_size=16', \ 36 | help='Root path to save test results', type=str) 37 | 38 | ### cnn module related ### 39 | parser.add_argument('--cnn_model', default='driu', help='CNN model to use', type=str) 40 | parser.add_argument('--cnn_loss_on', default=True, help='Whether to use a cnn loss for training', type=bool) 41 | 42 | ### gnn module related ### 43 | parser.add_argument('--gnn_loss_on', default=True, help='Whether to use a gnn loss for training', type=bool) 44 | parser.add_argument('--gnn_loss_weight', default=1., help='Relative weight on the gnn loss', type=float) 45 | # gat # 46 | parser.add_argument('--gat_n_heads', default=[4,4], help='Numbers of heads in each layer', type=list) # [4,1] 47 | #parser.add_argument('--gat_n_heads', nargs='+', help='Numbers of heads in each layer', type=int) # [4,1] 48 | parser.add_argument('--gat_hid_units', default=[16], help='Numbers of hidden units per each attention head in each layer', type=list) 49 | #parser.add_argument('--gat_hid_units', nargs='+', help='Numbers of hidden units per each attention head in each layer', type=int) 50 | parser.add_argument('--gat_use_residual', action='store_true', default=False, help='Whether to use residual learning in GAT') 51 | 52 | ### inference module related ### 53 | parser.add_argument('--norm_type', default=None, help='Norm. type', type=str) 54 | parser.add_argument('--use_enc_layer', action='store_true', default=False, \ 55 | help='Whether to use additional conv. layers in the inference module') 56 | parser.add_argument('--infer_module_loss_masking_thresh', default=0.05, \ 57 | help='Threshold for loss masking', type=float) 58 | parser.add_argument('--infer_module_kernel_size', default=3, \ 59 | help='Conv. kernel size for the inference module', type=int) 60 | parser.add_argument('--infer_module_grad_weight', default=1., \ 61 | help='Relative weight of the grad. on the inference module', type=float) 62 | 63 | ### training (declared but not used) ### 64 | parser.add_argument('--do_simul_training', default=True, \ 65 | help='Whether to train the gnn and inference modules simultaneously or not', type=bool) 66 | parser.add_argument('--max_iters', default=50000, help='Maximum number of iterations', type=int) 67 | parser.add_argument('--old_net_ft_lr', default=0., help='Learnining rate for fine-tuning of old parts of network', type=float) 68 | parser.add_argument('--new_net_lr', default=1e-02, help='Learnining rate for a new part of network', type=float) 69 | parser.add_argument('--opt', default='adam', help='Optimizer to use: Can be sgd or adam', type=str) 70 | parser.add_argument('--lr_scheduling', default='pc', help='How to change the learning rate during training', type=str) 71 | parser.add_argument('--lr_decay_tp', default=1., help='When to decrease the lr during training', type=float) # for pc 72 | 73 | 74 | args = parser.parse_args() 75 | return args 76 | 77 | 78 | def make_graph_using_srns((fg_prob_map, edge_type, win_size, edge_geo_dist_thresh, img_path)): 79 | 80 | if 'srns' not in edge_type: 81 | raise NotImplementedError 82 | 83 | # find local maxima 84 | vesselness = fg_prob_map 85 | 86 | im_y = vesselness.shape[0] 87 | im_x = vesselness.shape[1] 88 | y_quan = range(0,im_y,win_size) 89 | y_quan = sorted(list(set(y_quan) | set([im_y]))) 90 | x_quan = range(0,im_x,win_size) 91 | x_quan = sorted(list(set(x_quan) | set([im_x]))) 92 | 93 | max_val = [] 94 | max_pos = [] 95 | for y_idx in xrange(len(y_quan)-1): 96 | for x_idx in xrange(len(x_quan)-1): 97 | cur_patch = vesselness[y_quan[y_idx]:y_quan[y_idx+1],x_quan[x_idx]:x_quan[x_idx+1]] 98 | if np.sum(cur_patch)==0: 99 | max_val.append(0) 100 | max_pos.append((y_quan[y_idx]+cur_patch.shape[0]/2,x_quan[x_idx]+cur_patch.shape[1]/2)) 101 | else: 102 | max_val.append(np.amax(cur_patch)) 103 | temp = np.unravel_index(cur_patch.argmax(), cur_patch.shape) 104 | max_pos.append((y_quan[y_idx]+temp[0],x_quan[x_idx]+temp[1])) 105 | 106 | graph = nx.Graph() 107 | 108 | # add nodes 109 | for node_idx, (node_y, node_x) in enumerate(max_pos): 110 | graph.add_node(node_idx, kind='MP', y=node_y, x=node_x, label=node_idx) 111 | print 'node label', node_idx, 'pos', (node_y,node_x), 'added' 112 | 113 | speed = vesselness 114 | 115 | node_list = list(graph.nodes) 116 | for i, n in enumerate(node_list): 117 | 118 | phi = np.ones_like(speed) 119 | phi[graph.node[n]['y'],graph.node[n]['x']] = -1 120 | if speed[graph.node[n]['y'],graph.node[n]['x']]==0: 121 | continue 122 | 123 | neighbor = speed[max(0,graph.node[n]['y']-1):min(im_y,graph.node[n]['y']+2), \ 124 | max(0,graph.node[n]['x']-1):min(im_x,graph.node[n]['x']+2)] 125 | if np.mean(neighbor)<0.1: 126 | continue 127 | 128 | tt = skfmm.travel_time(phi, speed, narrow=edge_geo_dist_thresh) # travel time 129 | 130 | for n_comp in node_list[i+1:]: 131 | geo_dist = tt[graph.node[n_comp]['y'],graph.node[n_comp]['x']] # travel time 132 | if geo_dist < edge_geo_dist_thresh: 133 | graph.add_edge(n, n_comp, weight=edge_geo_dist_thresh/(edge_geo_dist_thresh+geo_dist)) 134 | print 'An edge BTWN', 'node', n, '&', n_comp, 'is constructed' 135 | 136 | # save as a file 137 | savepath = img_path+'_%.2d_%.2d'%(win_size,edge_geo_dist_thresh)+'.graph_res' 138 | nx.write_gpickle(graph, savepath, protocol=pkl.HIGHEST_PROTOCOL) 139 | graph.clear() 140 | print 'generated a graph for '+img_path 141 | 142 | 143 | if __name__ == '__main__': 144 | 145 | args = parse_args() 146 | 147 | print('Called with args:') 148 | print(args) 149 | 150 | if args.dataset=='DRIVE': 151 | im_root_path = '../DRIVE/all' 152 | test_set_txt_path = cfg.TEST.DRIVE_SET_TXT_PATH 153 | im_ext = '_image.tif' 154 | label_ext = '_label.gif' 155 | elif args.dataset=='STARE': 156 | im_root_path = '../STARE/all' 157 | test_set_txt_path = cfg.TEST.STARE_SET_TXT_PATH 158 | im_ext = '.ppm' 159 | label_ext = '.ah.ppm' 160 | elif args.dataset=='CHASE_DB1': 161 | im_root_path = '../CHASE_DB1/all' 162 | test_set_txt_path = cfg.TEST.CHASE_DB1_SET_TXT_PATH 163 | im_ext = '.jpg' 164 | label_ext = '_1stHO.png' 165 | 166 | if args.use_multiprocessing: 167 | pool = multiprocessing.Pool(processes=args.multiprocessing_num_proc) 168 | 169 | res_save_path = args.save_root + '/' + cfg.TEST.RES_SAVE_PATH if len(args.save_root)>0 else cfg.TEST.RES_SAVE_PATH 170 | 171 | if len(args.save_root)>0 and not os.path.isdir(args.save_root): 172 | os.mkdir(args.save_root) 173 | if not os.path.isdir(res_save_path): 174 | os.mkdir(res_save_path) 175 | 176 | with open(test_set_txt_path) as f: 177 | test_img_names = [x.strip() for x in f.readlines()] 178 | 179 | len_test = len(test_img_names) 180 | 181 | data_layer_test = util.DataLayer(test_img_names, \ 182 | is_training=False, \ 183 | use_padding=True) 184 | 185 | network = vessel_segm_vgn(args, None) 186 | 187 | config = tf.ConfigProto() 188 | config.gpu_options.allow_growth = True 189 | sess = tf.InteractiveSession(config=config) 190 | 191 | saver = tf.train.Saver() 192 | 193 | sess.run(tf.global_variables_initializer()) 194 | if args.model_path is not None: 195 | print "Loading model..." 196 | saver.restore(sess, args.model_path) 197 | 198 | f_log = open(os.path.join(res_save_path,'log.txt'), 'w') 199 | f_log.write(str(args)+'\n') 200 | f_log.flush() 201 | timer = util.Timer() 202 | 203 | print("Testing the model...") 204 | 205 | ### make cnn results ### 206 | res_list = [] 207 | for _ in xrange(int(np.ceil(float(len_test)/cfg.TRAIN.GRAPH_BATCH_SIZE))): 208 | 209 | # get one batch 210 | img_list, blobs_test = data_layer_test.forward() 211 | 212 | img = blobs_test['img'] 213 | label = blobs_test['label'] 214 | fov = blobs_test['fov'] 215 | 216 | conv_feats, fg_prob_tensor, \ 217 | cnn_feat_dict, cnn_feat_spatial_sizes_dict = sess.run( 218 | [network.conv_feats, 219 | network.img_fg_prob, 220 | network.cnn_feat, 221 | network.cnn_feat_spatial_sizes], 222 | feed_dict={ 223 | network.imgs: img, 224 | network.labels: label 225 | }) 226 | 227 | cur_batch_size = len(img_list) 228 | for img_idx in xrange(cur_batch_size): 229 | cur_res = {} 230 | cur_res['img_path'] = img_list[img_idx] 231 | cur_res['img'] = img[[img_idx],:,:,:] 232 | cur_res['label'] = label[[img_idx],:,:,:] 233 | cur_res['conv_feats'] = conv_feats[[img_idx],:,:,:] 234 | cur_res['cnn_fg_prob_map'] = fg_prob_tensor[img_idx,:,:,0] 235 | cur_res['cnn_feat'] = {k: v[[img_idx],:,:,:] for k, v in zip(cnn_feat_dict.keys(), cnn_feat_dict.values())} 236 | cur_res['cnn_feat_spatial_sizes'] = cnn_feat_spatial_sizes_dict 237 | cur_res['graph'] = None # will be filled at the next step 238 | cur_res['final_fg_prob_map'] = cur_res['cnn_fg_prob_map'] 239 | cur_res['ap_list'] = [] 240 | 241 | if args.dataset=='DRIVE': 242 | """img_name = img_list[img_idx] 243 | temp = img_name[util.find(img_name,'/')[-1]:] 244 | if args.dataset=='DRIVE': 245 | mask = skimage.io.imread(im_root_path + temp +'_mask.gif') 246 | else: 247 | mask = skimage.io.imread(im_root_path + temp +'_mask.tif')""" 248 | mask = fov[img_idx,:,:,0] 249 | cur_res['mask'] = mask 250 | 251 | # compute the current AP 252 | cur_label = label[img_idx,:,:,0] 253 | label_roi = cur_label[mask.astype(bool)].reshape((-1)) 254 | fg_prob_map_roi = cur_res['cnn_fg_prob_map'][mask.astype(bool)].reshape((-1)) 255 | _, cur_cnn_ap = util.get_auc_ap_score(label_roi, fg_prob_map_roi) 256 | cur_res['ap'] = cur_cnn_ap 257 | cur_res['ap_list'].append(cur_cnn_ap) 258 | else: 259 | # compute the current AP 260 | cur_label = label[img_idx,:,:,0].reshape((-1)) 261 | fg_prob_map = cur_res['cnn_fg_prob_map'].reshape((-1)) 262 | _, cur_cnn_ap = util.get_auc_ap_score(cur_label, fg_prob_map) 263 | cur_res['ap'] = cur_cnn_ap 264 | cur_res['ap_list'].append(cur_cnn_ap) 265 | 266 | res_list.append(cur_res) 267 | 268 | ### make final results ### 269 | # make graphs 270 | func_arg = [] 271 | for img_idx in xrange(len(res_list)): 272 | temp_fg_prob_map = res_list[img_idx]['final_fg_prob_map'] 273 | func_arg.append((temp_fg_prob_map, args.edge_type, args.win_size, args.edge_geo_dist_thresh, res_list[img_idx]['img_path'])) 274 | if args.use_multiprocessing: 275 | pool.map(make_graph_using_srns, func_arg) 276 | else: 277 | for x in func_arg: 278 | make_graph_using_srns(x) 279 | 280 | # load graphs 281 | for img_idx in xrange(len(res_list)): 282 | loadpath = res_list[img_idx]['img_path']+'_%.2d_%.2d'%(args.win_size,args.edge_geo_dist_thresh)+'.graph_res' 283 | temp_graph = nx.read_gpickle(loadpath) 284 | res_list[img_idx]['graph'] = temp_graph 285 | 286 | # make final results 287 | for img_idx in xrange(len(res_list)): 288 | 289 | cur_img = res_list[img_idx]['img'] 290 | cur_conv_feats = res_list[img_idx]['conv_feats'] 291 | cur_cnn_feat = res_list[img_idx]['cnn_feat'] 292 | cur_cnn_feat_spatial_sizes = res_list[img_idx]['cnn_feat_spatial_sizes'] 293 | cur_graph = res_list[img_idx]['graph'] 294 | 295 | cur_graph = nx.convert_node_labels_to_integers(cur_graph) 296 | node_byxs = util.get_node_byx_from_graph(cur_graph, [cur_graph.number_of_nodes()]) 297 | 298 | if 'geo_dist_weighted' in args.edge_type: 299 | adj = nx.adjacency_matrix(cur_graph) 300 | else: 301 | adj = nx.adjacency_matrix(cur_graph,weight=None).astype(float) 302 | 303 | adj_norm = util.preprocess_graph_gat(adj) 304 | 305 | cur_feed_dict = \ 306 | { 307 | network.imgs: cur_img, 308 | network.conv_feats: cur_conv_feats, 309 | network.node_byxs: node_byxs, 310 | network.adj: adj_norm, 311 | network.is_lr_flipped: False, 312 | network.is_ud_flipped: False 313 | } 314 | cur_feed_dict.update({network.cnn_feat[cur_key]: cur_cnn_feat[cur_key] for cur_key in network.cnn_feat.keys()}) 315 | cur_feed_dict.update({network.cnn_feat_spatial_sizes[cur_key]: cur_cnn_feat_spatial_sizes[cur_key] for cur_key in network.cnn_feat_spatial_sizes.keys()}) 316 | 317 | res_prob_map = sess.run( 318 | [network.post_cnn_img_fg_prob], 319 | feed_dict=cur_feed_dict) 320 | res_prob_map = res_prob_map[0] 321 | 322 | res_prob_map = res_prob_map.reshape((res_prob_map.shape[1], res_prob_map.shape[2])) 323 | 324 | # compute the current AP 325 | if args.dataset=='DRIVE': 326 | cur_label = res_list[img_idx]['label'] 327 | cur_label = np.squeeze(cur_label) 328 | cur_mask = res_list[img_idx]['mask'] 329 | label_roi = cur_label[cur_mask.astype(bool)].reshape((-1)) 330 | fg_prob_map_roi = res_prob_map[cur_mask.astype(bool)].reshape((-1)) 331 | _, cur_ap = util.get_auc_ap_score(label_roi, fg_prob_map_roi) 332 | res_prob_map = res_prob_map*cur_mask 333 | else: 334 | cur_label = res_list[img_idx]['label'] 335 | cur_label = np.squeeze(cur_label) 336 | _, cur_ap = util.get_auc_ap_score(cur_label.reshape((-1)), res_prob_map.reshape((-1))) 337 | 338 | res_list[img_idx]['ap'] = cur_ap 339 | res_list[img_idx]['ap_list'].append(cur_ap) 340 | res_list[img_idx]['final_fg_prob_map'] = res_prob_map 341 | 342 | ### calculate performance measures ### 343 | all_labels = np.zeros((0,)) 344 | all_preds = np.zeros((0,)) 345 | for img_idx in xrange(len(res_list)): 346 | 347 | cur_label = res_list[img_idx]['label'] 348 | cur_label = np.squeeze(cur_label) 349 | cur_pred = res_list[img_idx]['final_fg_prob_map'] 350 | 351 | # save qualitative results 352 | img_path = res_list[img_idx]['img_path'] 353 | temp = img_path[util.find(img_path,'/')[-1]:] 354 | 355 | temp_output = (cur_pred*255).astype(int) 356 | cur_save_path = res_save_path + temp + '_prob_final.png' 357 | skimage.io.imsave(cur_save_path, temp_output) 358 | 359 | cur_save_path = res_save_path + temp + '.npy' 360 | np.save(cur_save_path, cur_pred) 361 | 362 | temp_output = ((1.-cur_pred)*255).astype(int) 363 | cur_save_path = res_save_path + temp + '_prob_final_inv.png' 364 | skimage.io.imsave(cur_save_path, temp_output) 365 | # save qualitative results 366 | 367 | if args.dataset=='DRIVE': 368 | cur_mask = res_list[img_idx]['mask'] 369 | cur_label = cur_label[cur_mask.astype(bool)] 370 | cur_pred = cur_pred[cur_mask.astype(bool)] 371 | 372 | all_labels = np.concatenate((all_labels,np.reshape(cur_label, (-1)))) 373 | all_preds = np.concatenate((all_preds,np.reshape(cur_pred, (-1)))) 374 | 375 | print 'AP list for ' + res_list[img_idx]['img_path'] + ' : ' + str(res_list[img_idx]['ap_list']) 376 | f_log.write('AP list for ' + res_list[img_idx]['img_path'] + ' : ' + str(res_list[img_idx]['ap_list']) + '\n') 377 | 378 | auc_test, ap_test = util.get_auc_ap_score(all_labels, all_preds) 379 | all_labels_bin = np.copy(all_labels).astype(np.bool) 380 | all_preds_bin = all_preds>=0.5 381 | all_correct = all_labels_bin==all_preds_bin 382 | acc_test = np.mean(all_correct.astype(np.float32)) 383 | 384 | print 'test_acc: %.4f, test_auc: %.4f, test_ap: %.4f'%(acc_test, auc_test, ap_test) 385 | 386 | f_log.write('test_acc '+str(acc_test)+'\n') 387 | f_log.write('test_auc '+str(auc_test)+'\n') 388 | f_log.write('test_ap '+str(ap_test)+'\n') 389 | f_log.flush() 390 | 391 | f_log.close() 392 | sess.close() 393 | if args.use_multiprocessing: 394 | pool.terminate() 395 | print("Testing complete.") -------------------------------------------------------------------------------- /code/test_VGN_HRF.py: -------------------------------------------------------------------------------- 1 | # a special script for testing images in the HRF dataset 2 | # here, multiple sub-images from a single image are independently tested 3 | # and tiled to make a result for the whole image 4 | # coded by syshin (180430) 5 | # updated by syshin (180903) 6 | 7 | import numpy as np 8 | import os 9 | import pdb 10 | import argparse 11 | import skimage.io 12 | import networkx as nx 13 | import pickle as pkl 14 | import multiprocessing 15 | import skfmm 16 | import tensorflow as tf 17 | 18 | from config import cfg 19 | from model import vessel_segm_vgn 20 | import util 21 | 22 | 23 | def parse_args(): 24 | """ 25 | Parse input arguments 26 | """ 27 | parser = argparse.ArgumentParser(description='Test a vessel_segm_vgn network') 28 | parser.add_argument('--dataset', default='HRF', help='Dataset to use', type=str) 29 | #parser.add_argument('--use_multiprocessing', action='store_true', default=False, help='Whether to use the python multiprocessing module') 30 | parser.add_argument('--use_multiprocessing', default=True, help='Whether to use the python multiprocessing module', type=bool) 31 | parser.add_argument('--multiprocessing_num_proc', default=8, help='Number of CPU processes to use', type=int) 32 | parser.add_argument('--win_size', default=32, help='Window size for srns', type=int) # for srns # [4,8,16,32] 33 | parser.add_argument('--edge_type', default='srns_geo_dist_binary', \ 34 | help='Graph edge type: Can be srns_geo_dist_binary or srns_geo_dist_weighted', type=str) 35 | parser.add_argument('--edge_geo_dist_thresh', default=80, help='Threshold for geodesic distance', type=float) # [10,20,40,80] 36 | parser.add_argument('--model_path', default='../models/HRF/VGN/VGN_HRF.ckpt', \ 37 | help='Path for a trained model(.ckpt)', type=str) 38 | parser.add_argument('--save_root', default='../models/HRF/VGN', \ 39 | help='Root path to save test results', type=str) 40 | 41 | ### cnn module related ### 42 | parser.add_argument('--cnn_model', default='driu_large', help='CNN model to use', type=str) 43 | parser.add_argument('--cnn_loss_on', default=True, help='Whether to use a cnn loss for training', type=bool) 44 | 45 | ### gnn module related ### 46 | parser.add_argument('--gnn_loss_on', default=True, help='Whether to use a gnn loss for training', type=bool) 47 | parser.add_argument('--gnn_loss_weight', default=1., help='Relative weight on the gnn loss', type=float) 48 | # gat # 49 | parser.add_argument('--gat_n_heads', default=[4,4], help='Numbers of heads in each layer', type=list) # [4,1] 50 | #parser.add_argument('--gat_n_heads', nargs='+', help='Numbers of heads in each layer', type=int) # [4,1] 51 | parser.add_argument('--gat_hid_units', default=[16], help='Numbers of hidden units per each attention head in each layer', type=list) 52 | #parser.add_argument('--gat_hid_units', nargs='+', help='Numbers of hidden units per each attention head in each layer', type=int) 53 | parser.add_argument('--gat_use_residual', action='store_true', default=False, help='Whether to use residual learning in GAT') 54 | 55 | ### inference module related ### 56 | parser.add_argument('--norm_type', default=None, help='Norm. type', type=str) 57 | parser.add_argument('--use_enc_layer', action='store_true', default=False, \ 58 | help='Whether to use additional conv. layers in a infer_module') 59 | parser.add_argument('--infer_module_loss_masking_thresh', default=0.05, \ 60 | help='Threshold for loss masking', type=float) 61 | parser.add_argument('--infer_module_kernel_size', default=3, \ 62 | help='Conv. kernel size for the inference module', type=int) 63 | parser.add_argument('--infer_module_grad_weight', default=1., \ 64 | help='Relative weight of the grad. on the infer_module', type=float) 65 | 66 | ### training (declared but not used) ### 67 | parser.add_argument('--do_simul_training', default=True, \ 68 | help='Whether to train the gnn and inference modules simultaneously or not', type=bool) 69 | parser.add_argument('--max_iters', default=100000, help='Maximum number of iterations', type=int) 70 | parser.add_argument('--old_net_ft_lr', default=1e-03, help='Learnining rate for fine-tuning of old parts of network', type=float) 71 | parser.add_argument('--new_net_lr', default=1e-03, help='Learnining rate for a new part of network', type=float) 72 | parser.add_argument('--opt', default='adam', help='Optimizer to use: Can be sgd or adam', type=str) 73 | parser.add_argument('--lr_scheduling', default='pc', help='How to change the learning rate during training', type=str) 74 | parser.add_argument('--lr_decay_tp', default=1., help='When to decrease the lr during training', type=float) # for pc 75 | 76 | 77 | args = parser.parse_args() 78 | return args 79 | 80 | 81 | def save_dict(dic, filename): 82 | with open(filename, 'wb') as f: 83 | pkl.dump(dic, f) 84 | 85 | 86 | def load_dict(filename): 87 | with open(filename, 'rb') as f: 88 | dic = pkl.load(f) 89 | return dic 90 | 91 | 92 | # This was modified to include loading CNN results due to a memory problem 93 | def make_graph_using_srns((res_file_path, edge_type, win_size, edge_geo_dist_thresh)): 94 | 95 | if 'srns' not in edge_type: 96 | raise NotImplementedError 97 | 98 | # loading 99 | cur_res = load_dict(res_file_path) 100 | fg_prob_map = cur_res['temp_fg_prob_map'] 101 | img_path = cur_res['img_path'] 102 | 103 | # find local maxima 104 | vesselness = fg_prob_map 105 | 106 | im_y = vesselness.shape[0] 107 | im_x = vesselness.shape[1] 108 | y_quan = range(0,im_y,win_size) 109 | y_quan = sorted(list(set(y_quan) | set([im_y]))) 110 | x_quan = range(0,im_x,win_size) 111 | x_quan = sorted(list(set(x_quan) | set([im_x]))) 112 | 113 | max_val = [] 114 | max_pos = [] 115 | for y_idx in xrange(len(y_quan)-1): 116 | for x_idx in xrange(len(x_quan)-1): 117 | cur_patch = vesselness[y_quan[y_idx]:y_quan[y_idx+1],x_quan[x_idx]:x_quan[x_idx+1]] 118 | if np.sum(cur_patch)==0: 119 | max_val.append(0) 120 | max_pos.append((y_quan[y_idx]+cur_patch.shape[0]/2,x_quan[x_idx]+cur_patch.shape[1]/2)) 121 | else: 122 | max_val.append(np.amax(cur_patch)) 123 | temp = np.unravel_index(cur_patch.argmax(), cur_patch.shape) 124 | max_pos.append((y_quan[y_idx]+temp[0],x_quan[x_idx]+temp[1])) 125 | 126 | graph = nx.Graph() 127 | 128 | # add nodes 129 | for node_idx, (node_y, node_x) in enumerate(max_pos): 130 | graph.add_node(node_idx, kind='MP', y=node_y, x=node_x, label=node_idx) 131 | print 'node label', node_idx, 'pos', (node_y,node_x), 'added' 132 | 133 | speed = vesselness 134 | 135 | node_list = list(graph.nodes) 136 | for i, n in enumerate(node_list): 137 | 138 | phi = np.ones_like(speed) 139 | phi[graph.node[n]['y'],graph.node[n]['x']] = -1 140 | if speed[graph.node[n]['y'],graph.node[n]['x']]==0: 141 | continue 142 | 143 | neighbor = speed[max(0,graph.node[n]['y']-1):min(im_y,graph.node[n]['y']+2), \ 144 | max(0,graph.node[n]['x']-1):min(im_x,graph.node[n]['x']+2)] 145 | if np.mean(neighbor)<0.1: 146 | continue 147 | 148 | tt = skfmm.travel_time(phi, speed, narrow=edge_geo_dist_thresh) # travel time 149 | 150 | for n_comp in node_list[i+1:]: 151 | geo_dist = tt[graph.node[n_comp]['y'],graph.node[n_comp]['x']] # travel time 152 | if geo_dist < edge_geo_dist_thresh: 153 | graph.add_edge(n, n_comp, weight=edge_geo_dist_thresh/(edge_geo_dist_thresh+geo_dist)) 154 | print 'An edge BTWN', 'node', n, '&', n_comp, 'is constructed' 155 | 156 | # (re-)save 157 | cur_res['graph'] = graph 158 | save_dict(cur_res, res_file_path) 159 | print 'generated a graph for '+img_path 160 | 161 | 162 | if __name__ == '__main__': 163 | 164 | args = parse_args() 165 | 166 | print('Called with args:') 167 | print(args) 168 | 169 | # added for testing on a restricted test set 170 | test_type = ['dr', 'g', 'h'] 171 | # added for testing on a restricted test set 172 | 173 | IMG_SIZE = [2336,3504] 174 | SUB_IMG_SIZE = 768 175 | SUB_IMG_ROOT_PATH = '../HRF/all_768' 176 | y_mins = range(0,IMG_SIZE[0]-SUB_IMG_SIZE+1,SUB_IMG_SIZE-50) 177 | x_mins = range(0,IMG_SIZE[1]-SUB_IMG_SIZE+1,SUB_IMG_SIZE-50) 178 | y_mins = sorted(list(set(y_mins + [IMG_SIZE[0]-SUB_IMG_SIZE]))) 179 | x_mins = sorted(list(set(x_mins + [IMG_SIZE[1]-SUB_IMG_SIZE]))) 180 | 181 | with open('../HRF/test_fr.txt') as f: 182 | test_whole_img_paths = [x.strip() for x in f.readlines()] 183 | 184 | # added for testing on a restricted test set 185 | temp = [] 186 | for i in xrange(len(test_whole_img_paths)): 187 | for j in xrange(len(test_type)): 188 | if test_type[j] in test_whole_img_paths[i][util.find(test_whole_img_paths[i],'/')[-1]+1:]: 189 | temp.append(test_whole_img_paths[i]) 190 | break 191 | test_whole_img_paths = temp 192 | # added for testing on a restricted test set 193 | 194 | test_whole_img_names = map(lambda x: x[util.find(x,'/')[-1]+1:], test_whole_img_paths) # different to 'test_img_names' 195 | 196 | if args.use_multiprocessing: 197 | pool = multiprocessing.Pool(processes=args.multiprocessing_num_proc) 198 | 199 | #temp_graph_save_path = args.save_root + '/' + cfg.TRAIN.TEMP_GRAPH_SAVE_PATH if len(args.save_root)>0 else cfg.TRAIN.TEMP_GRAPH_SAVE_PATH 200 | res_save_path = args.save_root + '/' + cfg.TEST.RES_SAVE_PATH if len(args.save_root)>0 else cfg.TEST.RES_SAVE_PATH 201 | 202 | if len(args.save_root)>0 and not os.path.isdir(args.save_root): 203 | os.mkdir(args.save_root) 204 | if not os.path.isdir(res_save_path): 205 | os.mkdir(res_save_path) 206 | 207 | with open('../HRF/test_768.txt') as f: 208 | test_img_names = [x.strip() for x in f.readlines()] 209 | 210 | # added for testing on a restricted test set 211 | temp = [] 212 | for i in xrange(len(test_img_names)): 213 | for j in xrange(len(test_type)): 214 | if test_type[j] in test_img_names[i][util.find(test_img_names[i],'/')[-1]+1:]: 215 | temp.append(test_img_names[i]) 216 | break 217 | test_img_names = temp 218 | # added for testing on a restricted test set 219 | 220 | len_test = len(test_img_names) 221 | 222 | data_layer_test = util.DataLayer(test_img_names, \ 223 | is_training=False, \ 224 | use_padding=False) 225 | 226 | network = vessel_segm_vgn(args, None) 227 | 228 | config = tf.ConfigProto() 229 | config.gpu_options.allow_growth = True 230 | sess = tf.InteractiveSession(config=config) 231 | 232 | saver = tf.train.Saver() 233 | 234 | sess.run(tf.global_variables_initializer()) 235 | assert args.model_path, 'Model path is not available' 236 | print "Loading model..." 237 | saver.restore(sess, args.model_path) 238 | 239 | f_log = open(os.path.join(res_save_path,'log.txt'), 'w') 240 | f_log.write(str(args)+'\n') 241 | f_log.flush() 242 | timer = util.Timer() 243 | 244 | all_labels = np.concatenate(map(lambda x: np.expand_dims(skimage.io.imread(x+'.tif'), axis=0), test_whole_img_paths), axis=0) 245 | all_masks = np.concatenate(map(lambda x: np.expand_dims(skimage.io.imread(x+'_mask.tif'), axis=0), test_whole_img_paths), axis=0) 246 | all_masks = all_masks[:,:,:,0] 247 | all_labels = ((all_labels.astype(float)/255)>=0.5).astype(float) 248 | all_masks = ((all_masks.astype(float)/255)>=0.5).astype(float) 249 | all_preds_cum_sum = np.zeros(all_labels.shape) 250 | all_preds_cum_num = np.zeros(all_labels.shape) 251 | 252 | print("Testing the model...") 253 | 254 | ### make cnn results (sub-image-wise) ### 255 | res_file_path_list = [] 256 | for _ in xrange(int(np.ceil(float(len_test)/cfg.TRAIN.BATCH_SIZE))): 257 | 258 | # get one batch 259 | img_list, blobs_test = data_layer_test.forward() 260 | 261 | img = blobs_test['img'] 262 | label = blobs_test['label'] 263 | 264 | conv_feats, fg_prob_tensor, \ 265 | cnn_feat_dict, cnn_feat_spatial_sizes_dict = sess.run( 266 | [network.conv_feats, 267 | network.img_fg_prob, 268 | network.cnn_feat, 269 | network.cnn_feat_spatial_sizes], 270 | feed_dict={ 271 | network.imgs: img, 272 | network.labels: label 273 | }) 274 | 275 | cur_batch_size = len(img_list) 276 | for img_idx in xrange(cur_batch_size): 277 | cur_res = {} 278 | cur_res['img_path'] = img_list[img_idx] 279 | cur_res['img'] = img[[img_idx],:,:,:] 280 | cur_res['label'] = label[[img_idx],:,:,:] 281 | cur_res['conv_feats'] = conv_feats[[img_idx],:,:,:] 282 | cur_res['temp_fg_prob_map'] = fg_prob_tensor[img_idx,:,:,0] 283 | cur_res['cnn_feat'] = {k: v[[img_idx],:,:,:] for k, v in zip(cnn_feat_dict.keys(), cnn_feat_dict.values())} 284 | cur_res['cnn_feat_spatial_sizes'] = cnn_feat_spatial_sizes_dict 285 | cur_res['graph'] = None # will be filled at the next step 286 | cur_res['final_fg_prob_map'] = cur_res['temp_fg_prob_map'] 287 | cur_res['ap_list'] = [] 288 | 289 | img_name = img_list[img_idx] 290 | temp = img_name[util.find(img_name,'/')[-1]:] 291 | mask = skimage.io.imread(SUB_IMG_ROOT_PATH + temp +'_mask.tif') 292 | mask = ((mask.astype(float)/255)>=0.5).astype(float) 293 | cur_res['mask'] = mask 294 | 295 | # compute the current AP 296 | cur_label = label[img_idx,:,:,0] 297 | label_roi = cur_label[mask.astype(bool)].reshape((-1)) 298 | fg_prob_map_roi = cur_res['temp_fg_prob_map'][mask.astype(bool)].reshape((-1)) 299 | _, cur_cnn_ap = util.get_auc_ap_score(label_roi, fg_prob_map_roi) 300 | cur_res['ap'] = cur_cnn_ap 301 | cur_res['ap_list'].append(cur_cnn_ap) 302 | 303 | # (initial) save 304 | cur_res_file_path = res_save_path + temp + '.pkl' 305 | save_dict(cur_res, cur_res_file_path) 306 | res_file_path_list.append(cur_res_file_path) 307 | 308 | ### make final results (sub-image-wise) ### 309 | # make graphs & append it to the existing pickle files # re-save 310 | func_arg = [] 311 | for img_idx in xrange(len(res_file_path_list)): 312 | func_arg.append((res_file_path_list[img_idx], args.edge_type, args.win_size, args.edge_geo_dist_thresh)) 313 | if args.use_multiprocessing: 314 | pool.map(make_graph_using_srns, func_arg) 315 | else: 316 | for x in func_arg: 317 | make_graph_using_srns(x) 318 | 319 | # make final results 320 | for img_idx in xrange(len(res_file_path_list)): 321 | 322 | # load 323 | cur_res = load_dict(res_file_path_list[img_idx]) 324 | 325 | cur_img = cur_res['img'] 326 | cur_conv_feats = cur_res['conv_feats'] 327 | cur_cnn_feat = cur_res['cnn_feat'] 328 | cur_cnn_feat_spatial_sizes = cur_res['cnn_feat_spatial_sizes'] 329 | cur_graph = cur_res['graph'] 330 | 331 | cur_graph = nx.convert_node_labels_to_integers(cur_graph) 332 | node_byxs = util.get_node_byx_from_graph(cur_graph, [cur_graph.number_of_nodes()]) 333 | 334 | if 'geo_dist_weighted' in args.edge_type: 335 | adj = nx.adjacency_matrix(cur_graph) 336 | else: 337 | adj = nx.adjacency_matrix(cur_graph,weight=None).astype(float) 338 | 339 | adj_norm = util.preprocess_graph_gat(adj) 340 | 341 | cur_feed_dict = \ 342 | { 343 | network.imgs: cur_img, 344 | network.conv_feats: cur_conv_feats, 345 | network.node_byxs: node_byxs, 346 | network.adj: adj_norm, 347 | network.is_lr_flipped: False, 348 | network.is_ud_flipped: False 349 | } 350 | cur_feed_dict.update({network.cnn_feat[cur_key]: cur_cnn_feat[cur_key] for cur_key in network.cnn_feat.keys()}) 351 | cur_feed_dict.update({network.cnn_feat_spatial_sizes[cur_key]: cur_cnn_feat_spatial_sizes[cur_key] for cur_key in network.cnn_feat_spatial_sizes.keys()}) 352 | 353 | res_prob_map = sess.run( 354 | [network.post_cnn_img_fg_prob], 355 | feed_dict=cur_feed_dict) 356 | 357 | res_prob_map = res_prob_map[0] 358 | res_prob_map = res_prob_map.reshape((res_prob_map.shape[1], res_prob_map.shape[2])) 359 | 360 | # compute the current AP 361 | cur_label = cur_res['label'] 362 | cur_label = np.squeeze(cur_label) 363 | cur_mask = cur_res['mask'] 364 | label_roi = cur_label[cur_mask.astype(bool)].reshape((-1)) 365 | fg_prob_map_roi = res_prob_map[cur_mask.astype(bool)].reshape((-1)) 366 | _, cur_ap = util.get_auc_ap_score(label_roi, fg_prob_map_roi) 367 | res_prob_map = res_prob_map*cur_mask 368 | 369 | cur_res['ap'] = cur_ap 370 | cur_res['ap_list'].append(cur_ap) 371 | cur_res['final_fg_prob_map'] = res_prob_map 372 | cur_res['temp_fg_prob_map'] = res_prob_map 373 | 374 | # (re-)save 375 | save_dict(cur_res, res_file_path_list[img_idx]) 376 | 377 | ### aggregate final results ### 378 | for cur_res_file_path in res_file_path_list: 379 | 380 | # load 381 | cur_res = load_dict(cur_res_file_path) 382 | 383 | res_prob_map = cur_res['final_fg_prob_map'] 384 | 385 | img_name = cur_res['img_path'] 386 | temp_name = img_name[util.find(img_name,'/')[-1]+1:] 387 | 388 | temp_name_splits = temp_name.split('_') 389 | 390 | img_idx = test_whole_img_names.index(temp_name_splits[0]+'_'+temp_name_splits[1]) 391 | cur_y_min = y_mins[int(temp_name_splits[2])] 392 | cur_x_min = x_mins[int(temp_name_splits[3])] 393 | 394 | all_preds_cum_sum[img_idx,cur_y_min:cur_y_min+SUB_IMG_SIZE,cur_x_min:cur_x_min+SUB_IMG_SIZE] += res_prob_map 395 | all_preds_cum_num[img_idx,cur_y_min:cur_y_min+SUB_IMG_SIZE,cur_x_min:cur_x_min+SUB_IMG_SIZE] += 1 396 | 397 | print 'AP list for ' + img_name + ' : ' + str(cur_res['ap_list']) 398 | f_log.write('AP list for ' + img_name + ' : ' + str(cur_res['ap_list']) + '\n') 399 | 400 | all_preds = np.divide(all_preds_cum_sum,all_preds_cum_num) 401 | 402 | all_labels_roi = all_labels[all_masks.astype(bool)] 403 | all_preds_roi = all_preds[all_masks.astype(bool)] 404 | 405 | # save qualitative results 406 | reshaped_fg_prob_map = all_preds*all_masks 407 | reshaped_output = reshaped_fg_prob_map>=0.5 408 | for img_idx, temp_name in enumerate(test_whole_img_names): 409 | 410 | cur_reshaped_fg_prob_map = (reshaped_fg_prob_map[img_idx,:,:]*255).astype(int) 411 | cur_reshaped_fg_prob_map_inv = ((1.-reshaped_fg_prob_map[img_idx,:,:])*255).astype(int) 412 | cur_reshaped_output = reshaped_output[img_idx,:,:].astype(int)*255 413 | 414 | cur_fg_prob_save_path = os.path.join(res_save_path, temp_name + '_prob.png') 415 | cur_fg_prob_inv_save_path = os.path.join(res_save_path, temp_name + '_prob_inv.png') 416 | cur_output_save_path = os.path.join(res_save_path, temp_name + '_output.png') 417 | cur_numpy_save_path = os.path.join(res_save_path, temp_name + '.npy') 418 | 419 | skimage.io.imsave(cur_fg_prob_save_path, cur_reshaped_fg_prob_map) 420 | skimage.io.imsave(cur_fg_prob_inv_save_path, cur_reshaped_fg_prob_map_inv) 421 | skimage.io.imsave(cur_output_save_path, cur_reshaped_output) 422 | np.save(cur_numpy_save_path, reshaped_fg_prob_map[img_idx,:,:]) 423 | 424 | all_labels = np.reshape(all_labels, (-1)) 425 | all_preds = np.reshape(all_preds, (-1)) 426 | all_labels_roi = np.reshape(all_labels_roi, (-1)) 427 | all_preds_roi = np.reshape(all_preds_roi, (-1)) 428 | 429 | auc_test, ap_test = util.get_auc_ap_score(all_labels, all_preds) 430 | all_labels_bin = np.copy(all_labels).astype(np.bool) 431 | all_preds_bin = all_preds>=0.5 432 | all_correct = all_labels_bin==all_preds_bin 433 | acc_test = np.mean(all_correct.astype(np.float32)) 434 | 435 | auc_test_roi, ap_test_roi = util.get_auc_ap_score(all_labels_roi, all_preds_roi) 436 | all_labels_bin_roi = np.copy(all_labels_roi).astype(np.bool) 437 | all_preds_bin_roi = all_preds_roi>=0.5 438 | all_correct_roi = all_labels_bin_roi==all_preds_bin_roi 439 | acc_test_roi = np.mean(all_correct_roi.astype(np.float32)) 440 | 441 | print 'test_acc: %.4f, test_auc: %.4f, test_ap: %.4f'%(acc_test, auc_test, ap_test) 442 | print 'test_acc_roi: %.4f, test_auc_roi: %.4f, test_ap_roi: %.4f'%(acc_test_roi, auc_test_roi, ap_test_roi) 443 | 444 | f_log.write('test_acc '+str(acc_test)+'\n') 445 | f_log.write('test_auc '+str(auc_test)+'\n') 446 | f_log.write('test_ap '+str(ap_test)+'\n') 447 | f_log.write('test_acc_roi '+str(acc_test_roi)+'\n') 448 | f_log.write('test_auc_roi '+str(auc_test_roi)+'\n') 449 | f_log.write('test_ap_roi '+str(ap_test_roi)+'\n') 450 | 451 | f_log.flush() 452 | 453 | f_log.close() 454 | sess.close() 455 | if args.use_multiprocessing: 456 | pool.terminate() 457 | print("Test complete.") -------------------------------------------------------------------------------- /code/train_CNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pdb 4 | import skimage.io 5 | import argparse 6 | import tensorflow as tf 7 | 8 | from config import cfg 9 | from model import vessel_segm_cnn 10 | import util 11 | 12 | 13 | def parse_args(): 14 | """ 15 | Parse input arguments 16 | """ 17 | parser = argparse.ArgumentParser(description='Train a vessel_segm_cnn network') 18 | parser.add_argument('--dataset', default='DRIVE', help='Dataset to use: Can be DRIVE or STARE or CHASE_DB1 or HRF', type=str) 19 | parser.add_argument('--cnn_model', default='driu', help='CNN model to use', type=str) 20 | parser.add_argument('--use_fov_mask', default=True, help='Whether to use fov masks', type=bool) 21 | parser.add_argument('--opt', default='adam', help='Optimizer to use: Can be sgd or adam', type=str) 22 | parser.add_argument('--lr', default=1e-02, help='Learning rate to use: Can be any floating point number', type=float) 23 | parser.add_argument('--lr_decay', default='pc', help='Learning rate decay to use: Can be const or pc or exp', type=str) 24 | parser.add_argument('--max_iters', default=50000, help='Maximum number of iterations', type=int) 25 | parser.add_argument('--pretrained_model', default='../pretrained_model/VGG_imagenet.npy', help='path for a pretrained model(.npy)', type=str) 26 | #parser.add_argument('--pretrained_model', default=None, help='path for a pretrained model(.ckpt)', type=str) 27 | parser.add_argument('--save_root', default='DRIU_DRIVE', help='root path to save trained models and test results', type=str) 28 | 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def load(data_path, session, ignore_missing=False): 34 | data_dict = np.load(data_path).item() 35 | for key in data_dict: 36 | with tf.variable_scope(key, reuse=True): 37 | for subkey in data_dict[key]: 38 | if subkey=='weights': 39 | target_subkey='W' 40 | elif subkey=='biases': 41 | target_subkey='b' 42 | try: 43 | var = tf.get_variable(target_subkey) 44 | session.run(var.assign(data_dict[key][subkey])) 45 | print "assign pretrain model "+subkey+ " to "+key 46 | except ValueError: 47 | print "ignore "+key+"/"+subkey 48 | #print "ignore "+key 49 | if not ignore_missing: 50 | raise 51 | 52 | 53 | if __name__ == '__main__': 54 | args = parse_args() 55 | 56 | print('Called with args:') 57 | print(args) 58 | 59 | if args.dataset=='DRIVE': 60 | train_set_txt_path = cfg.TRAIN.DRIVE_SET_TXT_PATH 61 | test_set_txt_path = cfg.TEST.DRIVE_SET_TXT_PATH 62 | elif args.dataset=='STARE': 63 | train_set_txt_path = cfg.TRAIN.STARE_SET_TXT_PATH 64 | test_set_txt_path = cfg.TEST.STARE_SET_TXT_PATH 65 | elif args.dataset=='CHASE_DB1': 66 | train_set_txt_path = cfg.TRAIN.CHASE_DB1_SET_TXT_PATH 67 | test_set_txt_path = cfg.TEST.CHASE_DB1_SET_TXT_PATH 68 | elif args.dataset=='HRF': 69 | train_set_txt_path = cfg.TRAIN.HRF_SET_TXT_PATH 70 | test_set_txt_path = cfg.TEST.HRF_SET_TXT_PATH 71 | 72 | with open(train_set_txt_path) as f: 73 | train_img_names = [x.strip() for x in f.readlines()] 74 | with open(test_set_txt_path) as f: 75 | test_img_names = [x.strip() for x in f.readlines()] 76 | 77 | if args.dataset=='HRF': 78 | test_img_names = map(lambda x: test_img_names[x], range(7,len(test_img_names),20)) 79 | 80 | len_train = len(train_img_names) 81 | len_test = len(test_img_names) 82 | 83 | data_layer_train = util.DataLayer(train_img_names, is_training=True) 84 | data_layer_test = util.DataLayer(test_img_names, is_training=False) 85 | 86 | model_save_path = args.save_root + '/' + cfg.TRAIN.MODEL_SAVE_PATH if len(args.save_root)>0 else cfg.TRAIN.MODEL_SAVE_PATH 87 | res_save_path = args.save_root + '/' + cfg.TEST.RES_SAVE_PATH if len(args.save_root)>0 else cfg.TEST.RES_SAVE_PATH 88 | if len(args.save_root)>0 and not os.path.isdir(args.save_root): 89 | os.mkdir(args.save_root) 90 | if not os.path.isdir(model_save_path): 91 | os.mkdir(model_save_path) 92 | if not os.path.isdir(res_save_path): 93 | os.mkdir(res_save_path) 94 | 95 | network = vessel_segm_cnn(args, None) 96 | 97 | config = tf.ConfigProto() 98 | config.gpu_options.allow_growth = True 99 | sess = tf.InteractiveSession(config=config) 100 | 101 | saver = tf.train.Saver(max_to_keep=100) 102 | summary_writer = tf.summary.FileWriter(model_save_path, sess.graph) 103 | 104 | sess.run(tf.global_variables_initializer()) 105 | if args.pretrained_model is not None: 106 | print "Loading model..." 107 | load(args.pretrained_model, sess, ignore_missing=True) 108 | 109 | f_log = open(os.path.join(model_save_path,'log.txt'), 'w') 110 | last_snapshot_iter = -1 111 | timer = util.Timer() 112 | 113 | train_loss_list = [] 114 | test_loss_list = [] 115 | print("Training the model...") 116 | for iter in xrange(args.max_iters): 117 | 118 | timer.tic() 119 | 120 | # get one batch 121 | _, blobs_train = data_layer_train.forward() 122 | 123 | if args.use_fov_mask: 124 | fov_masks = blobs_train['fov'] 125 | else: 126 | fov_masks = np.ones(blobs_train['label'].shape, dtype=blobs_train['label'].dtype) 127 | 128 | _, loss_val, accuracy_val, pre_val, rec_val = sess.run( 129 | [network.train_op, network.loss, network.accuracy, network.precision, network.recall], 130 | feed_dict={ 131 | network.is_training: True, 132 | network.imgs: blobs_train['img'], 133 | network.labels: blobs_train['label'], 134 | network.fov_masks: fov_masks 135 | }) 136 | 137 | timer.toc() 138 | train_loss_list.append(loss_val) 139 | 140 | if (iter+1) % (cfg.TRAIN.DISPLAY) == 0: 141 | print 'iter: %d / %d, loss: %.4f, accuracy: %.4f, precision: %.4f, recall: %.4f'\ 142 | %(iter+1, args.max_iters, loss_val, accuracy_val, pre_val, rec_val) 143 | print 'speed: {:.3f}s / iter'.format(timer.average_time) 144 | 145 | if (iter+1) % cfg.TRAIN.SNAPSHOT_ITERS == 0: 146 | last_snapshot_iter = iter 147 | filename = os.path.join(model_save_path,('iter_{:d}'.format(iter+1) + '.ckpt')) 148 | saver.save(sess, filename) 149 | print 'Wrote snapshot to: {:s}'.format(filename) 150 | 151 | if (iter+1) % cfg.TRAIN.TEST_ITERS == 0: 152 | 153 | all_labels = np.zeros((0,)) 154 | all_preds = np.zeros((0,)) 155 | 156 | for _ in xrange(int(np.ceil(float(len_test)/cfg.TRAIN.BATCH_SIZE))): 157 | 158 | # get one batch 159 | img_list, blobs_test = data_layer_test.forward() 160 | 161 | imgs = blobs_test['img'] 162 | labels = blobs_test['label'] 163 | if args.use_fov_mask: 164 | fov_masks = blobs_test['fov'] 165 | else: 166 | fov_masks = np.ones(labels.shape, dtype=labels.dtype) 167 | 168 | loss_val, fg_prob_map = sess.run( 169 | [network.loss, network.fg_prob], 170 | feed_dict={ 171 | network.is_training: False, 172 | network.imgs: imgs, 173 | network.labels: labels, 174 | network.fov_masks: fov_masks 175 | }) 176 | 177 | test_loss_list.append(loss_val) 178 | 179 | all_labels = np.concatenate((all_labels,np.reshape(labels, (-1)))) 180 | fg_prob_map = fg_prob_map*fov_masks.astype(float) 181 | all_preds = np.concatenate((all_preds,np.reshape(fg_prob_map, (-1)))) 182 | 183 | # save qualitative results 184 | cur_batch_size = len(img_list) 185 | reshaped_fg_prob_map = fg_prob_map.reshape((cur_batch_size,fg_prob_map.shape[1],fg_prob_map.shape[2])) 186 | reshaped_output = reshaped_fg_prob_map>=0.5 187 | for img_idx in xrange(cur_batch_size): 188 | cur_test_img_path = img_list[img_idx] 189 | temp_name = cur_test_img_path[util.find(cur_test_img_path,'/')[-1]+1:] 190 | 191 | cur_reshaped_fg_prob_map = (reshaped_fg_prob_map[img_idx,:,:]*255).astype(int) 192 | cur_reshaped_output = reshaped_output[img_idx,:,:].astype(int)*255 193 | 194 | cur_fg_prob_save_path = os.path.join(res_save_path, temp_name + '_prob.png') 195 | cur_output_save_path = os.path.join(res_save_path, temp_name + '_output.png') 196 | 197 | skimage.io.imsave(cur_fg_prob_save_path, cur_reshaped_fg_prob_map) 198 | skimage.io.imsave(cur_output_save_path, cur_reshaped_output) 199 | 200 | auc_test, ap_test = util.get_auc_ap_score(all_labels, all_preds) 201 | all_labels_bin = np.copy(all_labels).astype(np.bool) 202 | all_preds_bin = all_preds>=0.5 203 | all_correct = all_labels_bin==all_preds_bin 204 | acc_test = np.mean(all_correct.astype(np.float32)) 205 | 206 | summary = tf.Summary() 207 | summary.value.add(tag="train_loss", simple_value=float(np.mean(train_loss_list))) 208 | summary.value.add(tag="test_loss", simple_value=float(np.mean(test_loss_list))) 209 | summary.value.add(tag="test_acc", simple_value=float(acc_test)) 210 | summary.value.add(tag="test_auc", simple_value=float(auc_test)) 211 | summary.value.add(tag="test_ap", simple_value=float(ap_test)) 212 | summary_writer.add_summary(summary, global_step=iter+1) 213 | summary_writer.flush() 214 | 215 | print 'iter: %d / %d, train_loss: %.4f'%(iter+1, args.max_iters, np.mean(train_loss_list)) 216 | print 'iter: %d / %d, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f, test_ap: %.4f'\ 217 | %(iter+1, args.max_iters, np.mean(test_loss_list), acc_test, auc_test, ap_test) 218 | 219 | f_log.write('iter: '+str(iter+1)+' / '+str(args.max_iters)+'\n') 220 | f_log.write('train_loss '+str(np.mean(train_loss_list))+'\n') 221 | f_log.write('iter: '+str(iter+1)+' / '+str(args.max_iters)+'\n') 222 | f_log.write('test_loss '+str(np.mean(test_loss_list))+'\n') 223 | f_log.write('test_acc '+str(acc_test)+'\n') 224 | f_log.write('test_auc '+str(auc_test)+'\n') 225 | f_log.write('test_ap '+str(ap_test)+'\n') 226 | f_log.flush() 227 | 228 | train_loss_list = [] 229 | test_loss_list = [] 230 | 231 | if last_snapshot_iter != iter: 232 | filename = os.path.join(model_save_path,('iter_{:d}'.format(iter+1) + '.ckpt')) 233 | saver.save(sess, filename) 234 | print 'Wrote snapshot to: {:s}'.format(filename) 235 | 236 | f_log.close() 237 | sess.close() 238 | print("Training complete.") -------------------------------------------------------------------------------- /code/train_VGN.py: -------------------------------------------------------------------------------- 1 | # updated by syshin (180825) 2 | # do the following steps before running this script 3 | # (1) run the script 'GenGraph/make_graph_db.py' 4 | # to generate training/test graphs 5 | # (2) place the generated graphs ('.graph_res') 6 | # and cnn results ('_prob.png') in 7 | # a new directory 'args.save_root/graph' 8 | 9 | import numpy as np 10 | import os 11 | import pdb 12 | import argparse 13 | import skimage.io 14 | import networkx as nx 15 | import pickle as pkl 16 | import multiprocessing 17 | import sys 18 | import skfmm 19 | import tensorflow as tf 20 | 21 | import _init_paths 22 | from config import cfg 23 | from model import vessel_segm_vgn 24 | import util 25 | from train_CNN import load 26 | 27 | 28 | def parse_args(): 29 | """ 30 | Parse input arguments 31 | """ 32 | parser = argparse.ArgumentParser(description='Train a vessel_segm_vgn network') 33 | parser.add_argument('--dataset', default='DRIVE', help='Dataset to use: Can be DRIVE or STARE or CHASE_DB1 or HRF', type=str) 34 | #parser.add_argument('--use_multiprocessing', action='store_true', default=False, help='Whether to use the python multiprocessing module') 35 | parser.add_argument('--use_multiprocessing', default=True, help='Whether to use the python multiprocessing module', type=bool) 36 | parser.add_argument('--multiprocessing_num_proc', default=8, help='Number of CPU processes to use', type=int) 37 | parser.add_argument('--win_size', default=4, help='Window size for srns', type=int) # for srns # [4,8,16] 38 | parser.add_argument('--edge_type', default='srns_geo_dist_binary', \ 39 | help='Graph edge type: Can be srns_geo_dist_binary or srns_geo_dist_weighted', type=str) 40 | parser.add_argument('--edge_geo_dist_thresh', default=10, help='Threshold for geodesic distance', type=float) # [10,20,40] 41 | parser.add_argument('--pretrained_model', default='../models/DRIVE/DRIU*/DRIU_DRIVE.ckpt', \ 42 | help='Path for a pretrained model(.ckpt)', type=str) 43 | parser.add_argument('--save_root', default='../models/DRIVE/VGN_DRIVE', \ 44 | help='Root path to save trained models and test results', type=str) 45 | 46 | ### cnn module related ### 47 | parser.add_argument('--cnn_model', default='driu', help='CNN model to use', type=str) 48 | #parser.add_argument('--cnn_loss_on', action='store_true', default=False, help='Whether to use a cnn loss for training') 49 | parser.add_argument('--cnn_loss_on', default=True, help='Whether to use a cnn loss for training', type=bool) 50 | 51 | ### gnn module related ### 52 | #parser.add_argument('--gnn_loss_on', action='store_true', default=False, help='Whether to use a gnn loss for training') 53 | parser.add_argument('--gnn_loss_on', default=True, help='Whether to use a gnn loss for training', type=bool) 54 | parser.add_argument('--gnn_loss_weight', default=1., help='Relative weight on the gnn loss', type=float) 55 | parser.add_argument('--gnn_feat_dropout_prob', default=0.5, help='Dropout prob. for feat. in gnn layers', type=float) 56 | parser.add_argument('--gnn_att_dropout_prob', default=0.5, help='Dropout prob. for att. in gnn layers', type=float) 57 | # gat # 58 | parser.add_argument('--gat_n_heads', default=[4,4], help='Numbers of heads in each layer', type=list) # [4,1] 59 | #parser.add_argument('--gat_n_heads', nargs='+', help='Numbers of heads in each layer', type=int) # [4,1] 60 | parser.add_argument('--gat_hid_units', default=[16], help='Numbers of hidden units per each attention head in each layer', type=list) 61 | #parser.add_argument('--gat_hid_units', nargs='+', help='Numbers of hidden units per each attention head in each layer', type=int) 62 | parser.add_argument('--gat_use_residual', action='store_true', default=False, help='Whether to use residual learning in GAT') 63 | 64 | ### inference module related ### 65 | parser.add_argument('--norm_type', default=None, help='Norm. type', type=str) 66 | parser.add_argument('--use_enc_layer', action='store_true', default=False, \ 67 | help='Whether to use additional conv. layers in the inference module') 68 | parser.add_argument('--infer_module_loss_masking_thresh', default=0.05, \ 69 | help='Threshold for loss masking', type=float) 70 | parser.add_argument('--infer_module_kernel_size', default=3, \ 71 | help='Conv. kernel size for the inference module', type=int) 72 | parser.add_argument('--infer_module_grad_weight', default=1., \ 73 | help='Relative weight of the grad. on the inference module', type=float) 74 | parser.add_argument('--infer_module_dropout_prob', default=0.1, \ 75 | help='Dropout prob. for layers in the inference module', type=float) 76 | 77 | ### training (declared but not used) ### 78 | parser.add_argument('--do_simul_training', default=True, \ 79 | help='Whether to train the gnn and inference modules simultaneously or not', type=bool) 80 | parser.add_argument('--max_iters', default=50000, help='Maximum number of iterations', type=int) 81 | parser.add_argument('--old_net_ft_lr', default=1e-02, help='Learnining rate for fine-tuning of old parts of network', type=float) 82 | parser.add_argument('--new_net_lr', default=1e-02, help='Learnining rate for a new part of network', type=float) 83 | parser.add_argument('--opt', default='adam', help='Optimizer to use: Can be sgd or adam', type=str) 84 | parser.add_argument('--lr_scheduling', default='pc', help='How to change the learning rate during training', type=str) # [pc] 85 | parser.add_argument('--lr_decay_tp', default=1., help='When to decrease the lr during training', type=float) # for pc 86 | #parser.add_argument('--use_graph_update', action='store_true', default=False, help='Whether to update graphs during training') 87 | parser.add_argument('--use_graph_update', default=True, \ 88 | help='Whether to update graphs during training', type=bool) 89 | parser.add_argument('--graph_update_period', default=10000, help='Graph update period', type=int) 90 | parser.add_argument('--use_fov_mask', default=True, help='Whether to use fov masks', type=bool) 91 | 92 | 93 | args = parser.parse_args() 94 | return args 95 | 96 | 97 | def restore_from_pretrained_model(sess, saver, network, pretrained_model_path): 98 | splits = pretrained_model_path.split('/') 99 | if ('DRIU*' in splits) or ('DRIU*' in splits) or ('DRIU*' in splits) or ('DRIU*' in splits): 100 | var_dict = {} 101 | for v in tf.trainable_variables(): 102 | t_var_name = v.name 103 | highest_level_name = t_var_name[:t_var_name.find('/')] 104 | if highest_level_name in network.var_to_restore: 105 | if highest_level_name=='img_output': 106 | t_var_name_to_restore = 'output'+t_var_name[t_var_name.find('/'):t_var_name.rfind(':')] 107 | var_dict[t_var_name_to_restore] = v 108 | elif ('gat' in highest_level_name) or ('post_cnn' in highest_level_name): 109 | pass 110 | else: 111 | var_dict[t_var_name[:t_var_name.rfind(':')]] = v 112 | 113 | loader = tf.train.Saver(var_list=var_dict) 114 | loader.restore(sess, pretrained_model_path) 115 | else: 116 | saver.restore(sess, pretrained_model_path) 117 | 118 | 119 | def make_train_qual_res((img_name, fg_prob_map, temp_graph_save_path, args)): 120 | 121 | if 'srns' not in args.edge_type: 122 | raise NotImplementedError 123 | 124 | win_size_str = '%.2d_%.2d'%(args.win_size,args.edge_geo_dist_thresh) 125 | 126 | cur_filename = img_name[util.find(img_name,'/')[-1]+1:] 127 | 128 | print 'regenerating a graph for '+cur_filename 129 | 130 | temp = (fg_prob_map*255).astype(int) 131 | cur_save_path = os.path.join(temp_graph_save_path, cur_filename+'_prob.png') 132 | skimage.io.imsave(cur_save_path, temp) 133 | 134 | cur_res_graph_savepath = os.path.join(temp_graph_save_path, cur_filename+'_'+win_size_str+'.graph_res') 135 | 136 | # find local maxima 137 | vesselness = fg_prob_map 138 | 139 | im_y = vesselness.shape[0] 140 | im_x = vesselness.shape[1] 141 | y_quan = range(0,im_y,args.win_size) 142 | y_quan = sorted(list(set(y_quan) | set([im_y]))) 143 | x_quan = range(0,im_x,args.win_size) 144 | x_quan = sorted(list(set(x_quan) | set([im_x]))) 145 | 146 | max_val = [] 147 | max_pos = [] 148 | for y_idx in xrange(len(y_quan)-1): 149 | for x_idx in xrange(len(x_quan)-1): 150 | cur_patch = vesselness[y_quan[y_idx]:y_quan[y_idx+1],x_quan[x_idx]:x_quan[x_idx+1]] 151 | if np.sum(cur_patch)==0: 152 | max_val.append(0) 153 | max_pos.append((y_quan[y_idx]+cur_patch.shape[0]/2,x_quan[x_idx]+cur_patch.shape[1]/2)) 154 | else: 155 | max_val.append(np.amax(cur_patch)) 156 | temp = np.unravel_index(cur_patch.argmax(), cur_patch.shape) 157 | max_pos.append((y_quan[y_idx]+temp[0],x_quan[x_idx]+temp[1])) 158 | 159 | graph = nx.Graph() 160 | 161 | # add nodes 162 | for node_idx, (node_y, node_x) in enumerate(max_pos): 163 | graph.add_node(node_idx, kind='MP', y=node_y, x=node_x, label=node_idx) 164 | print 'node label', node_idx, 'pos', (node_y,node_x), 'added' 165 | 166 | speed = vesselness 167 | 168 | node_list = list(graph.nodes) 169 | for i, n in enumerate(node_list): 170 | 171 | phi = np.ones_like(speed) 172 | phi[graph.node[n]['y'],graph.node[n]['x']] = -1 173 | if speed[graph.node[n]['y'],graph.node[n]['x']]==0: 174 | continue 175 | 176 | neighbor = speed[max(0,graph.node[n]['y']-1):min(im_y,graph.node[n]['y']+2), \ 177 | max(0,graph.node[n]['x']-1):min(im_x,graph.node[n]['x']+2)] 178 | if np.mean(neighbor)<0.1: 179 | continue 180 | 181 | tt = skfmm.travel_time(phi, speed, narrow=args.edge_geo_dist_thresh) # travel time 182 | 183 | for n_comp in node_list[i+1:]: 184 | geo_dist = tt[graph.node[n_comp]['y'],graph.node[n_comp]['x']] # travel time 185 | if geo_dist < args.edge_geo_dist_thresh: 186 | graph.add_edge(n, n_comp, weight=args.edge_geo_dist_thresh/(args.edge_geo_dist_thresh+geo_dist)) 187 | print 'An edge BTWN', 'node', n, '&', n_comp, 'is constructed' 188 | 189 | # save as files 190 | nx.write_gpickle(graph, cur_res_graph_savepath, protocol=pkl.HIGHEST_PROTOCOL) 191 | 192 | graph.clear() 193 | 194 | 195 | if __name__ == '__main__': 196 | 197 | args = parse_args() 198 | 199 | print('Called with args:') 200 | print(args) 201 | 202 | if args.dataset=='DRIVE': 203 | im_root_path = '../DRIVE/all' 204 | train_set_txt_path = cfg.TRAIN.DRIVE_SET_TXT_PATH 205 | test_set_txt_path = cfg.TEST.DRIVE_SET_TXT_PATH 206 | elif args.dataset=='STARE': 207 | im_root_path = '../STARE/all' 208 | train_set_txt_path = cfg.TRAIN.STARE_SET_TXT_PATH 209 | test_set_txt_path = cfg.TEST.STARE_SET_TXT_PATH 210 | elif args.dataset=='CHASE_DB1': 211 | im_root_path = '../CHASE_DB1/all' 212 | train_set_txt_path = cfg.TRAIN.CHASE_DB1_SET_TXT_PATH 213 | test_set_txt_path = cfg.TEST.CHASE_DB1_SET_TXT_PATH 214 | elif args.dataset=='HRF': 215 | im_root_path = '../HRF/all_768' 216 | train_set_txt_path = cfg.TRAIN.HRF_SET_TXT_PATH 217 | test_set_txt_path = cfg.TEST.HRF_SET_TXT_PATH 218 | 219 | if args.use_multiprocessing: 220 | pool = multiprocessing.Pool(processes=args.multiprocessing_num_proc) 221 | 222 | model_save_path = args.save_root + '/' + cfg.TRAIN.MODEL_SAVE_PATH if len(args.save_root)>0 else cfg.TRAIN.MODEL_SAVE_PATH 223 | res_save_path = args.save_root + '/' + cfg.TEST.RES_SAVE_PATH if len(args.save_root)>0 else cfg.TEST.RES_SAVE_PATH 224 | temp_graph_save_path = args.save_root + '/' + cfg.TRAIN.TEMP_GRAPH_SAVE_PATH if len(args.save_root)>0 else cfg.TRAIN.TEMP_GRAPH_SAVE_PATH 225 | if len(args.save_root)>0 and not os.path.isdir(args.save_root): 226 | os.mkdir(args.save_root) 227 | if not os.path.isdir(model_save_path): 228 | os.mkdir(model_save_path) 229 | if not os.path.isdir(res_save_path): 230 | os.mkdir(res_save_path) 231 | 232 | with open(train_set_txt_path) as f: 233 | train_img_names = [x.strip() for x in f.readlines()] 234 | with open(test_set_txt_path) as f: 235 | test_img_names = [x.strip() for x in f.readlines()] 236 | 237 | if args.dataset=='HRF': 238 | test_img_names = map(lambda x: test_img_names[x], range(7,len(test_img_names),20)) 239 | 240 | len_train = len(train_img_names) 241 | len_test = len(test_img_names) 242 | 243 | # revise "train_img_names" and "test_img_names" 244 | for i in xrange(len_train): 245 | temp = train_img_names[i] 246 | train_img_names[i] = temp_graph_save_path + temp[util.find(temp,'/')[-1]:] 247 | for i in xrange(len_test): 248 | temp = test_img_names[i] 249 | test_img_names[i] = temp_graph_save_path + temp[util.find(temp,'/')[-1]:] 250 | 251 | data_layer_train = util.GraphDataLayer(train_img_names, is_training=True, \ 252 | edge_type=args.edge_type, \ 253 | win_size=args.win_size, edge_geo_dist_thresh=args.edge_geo_dist_thresh) 254 | data_layer_test = util.GraphDataLayer(test_img_names, is_training=False, \ 255 | edge_type=args.edge_type, \ 256 | win_size=args.win_size, edge_geo_dist_thresh=args.edge_geo_dist_thresh) 257 | 258 | network = vessel_segm_vgn(args, None) 259 | 260 | config = tf.ConfigProto() 261 | config.gpu_options.allow_growth = True 262 | sess = tf.InteractiveSession(config=config) 263 | 264 | saver = tf.train.Saver(max_to_keep=100) 265 | summary_writer = tf.summary.FileWriter(model_save_path, sess.graph) 266 | 267 | sess.run(tf.global_variables_initializer()) 268 | if args.pretrained_model is not None: 269 | print "Loading model..." 270 | ext_str = args.pretrained_model[args.pretrained_model.rfind('.')+1:] 271 | if ext_str=='ckpt': 272 | restore_from_pretrained_model(sess, saver, network, args.pretrained_model) 273 | elif ext_str=='npy': 274 | load(args.pretrained_model, sess, ignore_missing=True) 275 | 276 | f_log = open(os.path.join(model_save_path,'log.txt'), 'w') 277 | f_log.write(str(args)+'\n') 278 | f_log.flush() 279 | last_snapshot_iter = -1 280 | timer = util.Timer() 281 | 282 | # for graph update 283 | required_num_iters_for_train_set_update = int(np.ceil(float(len_train)/cfg.TRAIN.GRAPH_BATCH_SIZE)) 284 | required_num_iters_for_test_set_update = int(np.ceil(float(len_test)/cfg.TRAIN.GRAPH_BATCH_SIZE)) 285 | if args.use_graph_update: 286 | next_update_start = args.graph_update_period 287 | next_update_end = next_update_start+required_num_iters_for_train_set_update-1 288 | else: 289 | next_update_start = sys.maxint 290 | next_update_end = sys.maxint 291 | 292 | train_loss_list = [] 293 | train_cnn_loss_list = [] 294 | train_gnn_loss_list = [] 295 | train_infer_module_loss_list = [] 296 | test_loss_list = [] 297 | test_cnn_loss_list = [] 298 | test_gnn_loss_list = [] 299 | test_infer_module_loss_list = [] 300 | graph_update_func_arg = [] 301 | test_loss_logs = [] 302 | print("Training the model...") 303 | for iter in xrange(args.max_iters): 304 | 305 | timer.tic() 306 | 307 | # get one batch 308 | img_list, blobs_train = data_layer_train.forward() 309 | 310 | img = blobs_train['img'] 311 | label = blobs_train['label'] 312 | if args.use_fov_mask: 313 | fov_mask = blobs_train['fov'] 314 | else: 315 | fov_mask = np.ones(label.shape, dtype=label.dtype) 316 | 317 | graph = blobs_train['graph'] 318 | num_of_nodes_list = blobs_train['num_of_nodes_list'] 319 | 320 | node_byxs = util.get_node_byx_from_graph(graph, num_of_nodes_list) 321 | probmap = blobs_train['probmap'] 322 | pixel_weights = fov_mask*((probmap>=args.infer_module_loss_masking_thresh) | label) 323 | pixel_weights = pixel_weights.astype(float) 324 | 325 | if 'geo_dist_weighted' in args.edge_type: 326 | adj = nx.adjacency_matrix(graph) 327 | else: 328 | adj = nx.adjacency_matrix(graph,weight=None).astype(float) 329 | 330 | adj_norm = util.preprocess_graph_gat(adj) 331 | 332 | is_lr_flipped = False 333 | is_ud_flipped = False 334 | rot90_num = 0 335 | 336 | if blobs_train['vec_aug_on'][0]: 337 | is_lr_flipped = True 338 | if blobs_train['vec_aug_on'][1]: 339 | is_ud_flipped = True 340 | if blobs_train['vec_aug_on'][2]: 341 | rot90_num = blobs_train['rot_angle']/90 342 | 343 | if args.lr_scheduling=='pc': 344 | cur_lr = sess.run([network.lr_handler]) 345 | cur_lr = cur_lr[0] 346 | else: 347 | raise NotImplementedError 348 | 349 | _, loss_val, \ 350 | cnn_fg_prob_mat, \ 351 | cnn_loss_val, cnn_accuracy_val, cnn_precision_val, cnn_recall_val, \ 352 | gnn_loss_val, gnn_accuracy_val, \ 353 | infer_module_fg_prob_mat, \ 354 | infer_module_loss_val, infer_module_accuracy_val, infer_module_precision_val, infer_module_recall_val, \ 355 | node_logits, node_labels = sess.run( 356 | [network.train_op, network.loss, 357 | network.img_fg_prob, 358 | network.cnn_loss, network.cnn_accuracy, network.cnn_precision, network.cnn_recall, 359 | network.gnn_loss, network.gnn_accuracy, 360 | network.post_cnn_img_fg_prob, 361 | network.post_cnn_loss, network.post_cnn_accuracy, network.post_cnn_precision, network.post_cnn_recall, 362 | network.node_logits, network.node_labels 363 | ], 364 | feed_dict={ 365 | network.imgs: img, 366 | network.labels: label, 367 | network.fov_masks: fov_mask, 368 | network.node_byxs: node_byxs, 369 | network.adj: adj_norm, 370 | network.pixel_weights: pixel_weights, 371 | network.gnn_feat_dropout: args.gnn_feat_dropout_prob, 372 | network.gnn_att_dropout: args.gnn_att_dropout_prob, 373 | network.post_cnn_dropout: args.infer_module_dropout_prob, 374 | network.is_lr_flipped: is_lr_flipped, 375 | network.is_ud_flipped: is_ud_flipped, 376 | network.rot90_num: rot90_num, 377 | network.learning_rate: cur_lr 378 | }) 379 | 380 | timer.toc() 381 | train_loss_list.append(loss_val) 382 | train_cnn_loss_list.append(cnn_loss_val) 383 | train_gnn_loss_list.append(gnn_loss_val) 384 | train_infer_module_loss_list.append(infer_module_loss_val) 385 | 386 | if (iter+1) % (cfg.TRAIN.DISPLAY) == 0: 387 | print 'iter: %d / %d, loss: %.4f'\ 388 | %(iter+1, args.max_iters, loss_val) 389 | print 'cnn_loss: %.4f, cnn_accuracy: %.4f, cnn_precision: %.4f, cnn_recall: %.4f'\ 390 | %(cnn_loss_val, cnn_accuracy_val, cnn_precision_val, cnn_recall_val) 391 | print 'gnn_loss: %.4f, gnn_accuracy: %.4f'\ 392 | %(gnn_loss_val, gnn_accuracy_val) 393 | print 'infer_module_loss: %.4f, infer_module_accuracy: %.4f, infer_module_precision: %.4f, infer_module_recall: %.4f'\ 394 | %(infer_module_loss_val, infer_module_accuracy_val, infer_module_precision_val, infer_module_recall_val) 395 | print 'speed: {:.3f}s / iter'.format(timer.average_time) 396 | 397 | if (iter+1) % cfg.TRAIN.SNAPSHOT_ITERS == 0: 398 | last_snapshot_iter = iter 399 | filename = os.path.join(model_save_path,('iter_{:d}'.format(iter+1) + '.ckpt')) 400 | saver.save(sess, filename) 401 | print 'Wrote snapshot to: {:s}'.format(filename) 402 | 403 | if (iter+1)==next_update_start-1: 404 | data_layer_train.reinit(train_img_names, is_training=False, \ 405 | edge_type=args.edge_type, \ 406 | win_size=args.win_size, edge_geo_dist_thresh=args.edge_geo_dist_thresh) 407 | 408 | if ((iter+1)=next_update_start) and ((iter+1)<=next_update_end): 409 | 410 | # save qualitative results 411 | # here, we make (segm. and corresponding) graphs, 412 | # which will be used as GT graphs during a next training period, 413 | # from current estimated vesselnesses 414 | cur_batch_size = len(img_list) 415 | reshaped_fg_prob_map = infer_module_fg_prob_mat.reshape((cur_batch_size,infer_module_fg_prob_mat.shape[1],infer_module_fg_prob_mat.shape[2])) 416 | 417 | for j in xrange(cur_batch_size): 418 | graph_update_func_arg.append((img_list[j], reshaped_fg_prob_map[j,:,:], temp_graph_save_path, args)) 419 | 420 | if (iter+1)==next_update_end: 421 | if args.use_multiprocessing: 422 | pool.map(make_train_qual_res, graph_update_func_arg) 423 | else: 424 | for x in graph_update_func_arg: 425 | make_train_qual_res(x) 426 | graph_update_func_arg = [] 427 | 428 | data_layer_train.reinit(train_img_names, is_training=True, \ 429 | edge_type=args.edge_type, \ 430 | win_size=args.win_size, edge_geo_dist_thresh=args.edge_geo_dist_thresh) 431 | next_update_start = next_update_start + args.graph_update_period 432 | next_update_end = next_update_start+required_num_iters_for_train_set_update-1 433 | 434 | if (iter+1) % cfg.TRAIN.TEST_ITERS == 0: 435 | 436 | # cnn module related 437 | all_cnn_labels = np.zeros((0,)) 438 | all_cnn_preds = np.zeros((0,)) 439 | 440 | # gnn module related 441 | all_gnn_labels = np.zeros((0,)) 442 | all_gnn_preds = np.zeros((0,)) 443 | 444 | # inference module related 445 | all_infer_module_preds = np.zeros((0,)) 446 | 447 | for _ in xrange(int(np.ceil(float(len_test)/cfg.TRAIN.GRAPH_BATCH_SIZE))): 448 | 449 | # get one batch 450 | img_list, blobs_test = data_layer_test.forward() 451 | 452 | img = blobs_test['img'] 453 | label = blobs_test['label'] 454 | if args.use_fov_mask: 455 | fov_mask = blobs_test['fov'] 456 | else: 457 | fov_mask = np.ones(label.shape, dtype=label.dtype) 458 | 459 | graph = blobs_test['graph'] 460 | num_of_nodes_list = blobs_test['num_of_nodes_list'] 461 | 462 | node_byxs = util.get_node_byx_from_graph(graph, num_of_nodes_list) 463 | pixel_weights = fov_mask 464 | 465 | if 'geo_dist_weighted' in args.edge_type: 466 | adj = nx.adjacency_matrix(graph) 467 | else: 468 | adj = nx.adjacency_matrix(graph,weight=None).astype(float) 469 | 470 | adj_norm = util.preprocess_graph_gat(adj) 471 | 472 | loss_val, \ 473 | cnn_fg_prob_mat, cnn_loss_val, \ 474 | gnn_labels, gnn_prob_vec, gnn_loss_val, \ 475 | infer_module_fg_prob_mat, infer_module_loss_val = sess.run( 476 | [network.loss, 477 | network.img_fg_prob, network.cnn_loss, 478 | network.node_labels, network.gnn_prob, network.gnn_loss, 479 | network.post_cnn_img_fg_prob, network.post_cnn_loss], 480 | feed_dict={ 481 | network.imgs: img, 482 | network.labels: label, 483 | network.fov_masks: fov_mask, 484 | network.node_byxs: node_byxs, 485 | network.adj: adj_norm, 486 | network.pixel_weights: pixel_weights, 487 | network.is_lr_flipped: False, 488 | network.is_ud_flipped: False 489 | }) 490 | 491 | cnn_fg_prob_mat = cnn_fg_prob_mat*fov_mask.astype(float) 492 | infer_module_fg_prob_mat = infer_module_fg_prob_mat*fov_mask.astype(float) 493 | 494 | test_loss_list.append(loss_val) 495 | test_cnn_loss_list.append(cnn_loss_val) 496 | test_gnn_loss_list.append(gnn_loss_val) 497 | test_infer_module_loss_list.append(infer_module_loss_val) 498 | 499 | all_cnn_labels = np.concatenate((all_cnn_labels,np.reshape(label, (-1)))) 500 | all_cnn_preds = np.concatenate((all_cnn_preds,np.reshape(cnn_fg_prob_mat, (-1)))) 501 | 502 | all_gnn_labels = np.concatenate((all_gnn_labels,gnn_labels)) 503 | all_gnn_preds = np.concatenate((all_gnn_preds,gnn_prob_vec)) 504 | 505 | all_infer_module_preds = np.concatenate((all_infer_module_preds,np.reshape(infer_module_fg_prob_mat, (-1)))) 506 | 507 | # save qualitative results 508 | cur_batch_size = len(img_list) 509 | reshaped_cnn_fg_prob_map = cnn_fg_prob_mat.reshape((cur_batch_size,cnn_fg_prob_mat.shape[1],cnn_fg_prob_mat.shape[2])) 510 | reshaped_infer_module_fg_prob_mat = infer_module_fg_prob_mat.reshape((cur_batch_size,infer_module_fg_prob_mat.shape[1],infer_module_fg_prob_mat.shape[2])) 511 | for j in xrange(cur_batch_size): 512 | cur_img_name = img_list[j] 513 | cur_img_name = cur_img_name[util.find(cur_img_name,'/')[-1]:] 514 | 515 | cur_cnn_fg_prob_map = reshaped_cnn_fg_prob_map[j,:,:] 516 | cur_infer_module_fg_prob_map = reshaped_infer_module_fg_prob_mat[j,:,:] 517 | 518 | cur_map = (cur_cnn_fg_prob_map*255).astype(int) 519 | cur_save_path = res_save_path + cur_img_name + '_prob_cnn.png' 520 | skimage.io.imsave(cur_save_path, cur_map) 521 | 522 | cur_map = (cur_infer_module_fg_prob_map*255).astype(int) 523 | cur_save_path = res_save_path + cur_img_name + '_prob_infer_module.png' 524 | skimage.io.imsave(cur_save_path, cur_map) 525 | 526 | cnn_auc_test, cnn_ap_test = util.get_auc_ap_score(all_cnn_labels, all_cnn_preds) 527 | all_cnn_labels_bin = np.copy(all_cnn_labels).astype(np.bool) 528 | all_cnn_preds_bin = all_cnn_preds>=0.5 529 | all_cnn_correct = all_cnn_labels_bin==all_cnn_preds_bin 530 | cnn_acc_test = np.mean(all_cnn_correct.astype(np.float32)) 531 | 532 | gnn_auc_test, gnn_ap_test = util.get_auc_ap_score(all_gnn_labels, all_gnn_preds) 533 | all_gnn_labels_bin = np.copy(all_gnn_labels).astype(np.bool) 534 | all_gnn_preds_bin = all_gnn_preds>=0.5 535 | all_gnn_correct = all_gnn_labels_bin==all_gnn_preds_bin 536 | gnn_acc_test = np.mean(all_gnn_correct.astype(np.float32)) 537 | 538 | infer_module_auc_test, infer_module_ap_test = util.get_auc_ap_score(all_cnn_labels, all_infer_module_preds) 539 | all_infer_module_preds_bin = all_infer_module_preds>=0.5 540 | all_infer_module_correct = all_cnn_labels_bin==all_infer_module_preds_bin 541 | infer_module_acc_test = np.mean(all_infer_module_correct.astype(np.float32)) 542 | 543 | summary = tf.Summary() 544 | summary.value.add(tag="train_loss", simple_value=float(np.mean(train_loss_list))) 545 | summary.value.add(tag="train_cnn_loss", simple_value=float(np.mean(train_cnn_loss_list))) 546 | summary.value.add(tag="train_gnn_loss", simple_value=float(np.mean(train_gnn_loss_list))) 547 | summary.value.add(tag="train_infer_module_loss", simple_value=float(np.mean(train_infer_module_loss_list))) 548 | summary.value.add(tag="test_loss", simple_value=float(np.mean(test_loss_list))) 549 | summary.value.add(tag="test_cnn_loss", simple_value=float(np.mean(test_cnn_loss_list))) 550 | summary.value.add(tag="test_gnn_loss", simple_value=float(np.mean(test_gnn_loss_list))) 551 | summary.value.add(tag="test_infer_module_loss", simple_value=float(np.mean(test_infer_module_loss_list))) 552 | summary.value.add(tag="test_cnn_acc", simple_value=float(cnn_acc_test)) 553 | summary.value.add(tag="test_cnn_auc", simple_value=float(cnn_auc_test)) 554 | summary.value.add(tag="test_cnn_ap", simple_value=float(cnn_ap_test)) 555 | summary.value.add(tag="test_gnn_acc", simple_value=float(gnn_acc_test)) 556 | summary.value.add(tag="test_gnn_auc", simple_value=float(gnn_auc_test)) 557 | summary.value.add(tag="test_gnn_ap", simple_value=float(gnn_ap_test)) 558 | summary.value.add(tag="test_infer_module_acc", simple_value=float(infer_module_acc_test)) 559 | summary.value.add(tag="test_infer_module_auc", simple_value=float(infer_module_auc_test)) 560 | summary.value.add(tag="test_infer_module_ap", simple_value=float(infer_module_ap_test)) 561 | summary.value.add(tag="lr", simple_value=float(cur_lr)) 562 | summary_writer.add_summary(summary, global_step=iter+1) 563 | summary_writer.flush() 564 | 565 | print 'iter: %d / %d, train_loss: %.4f, train_cnn_loss: %.4f, train_gnn_loss: %.4f, train_infer_module_loss: %.4f'\ 566 | %(iter+1, args.max_iters, np.mean(train_loss_list), np.mean(train_cnn_loss_list), np.mean(train_gnn_loss_list), np.mean(train_infer_module_loss_list)) 567 | print 'iter: %d / %d, test_loss: %.4f, test_cnn_loss: %.4f, test_gnn_loss: %.4f, test_infer_module_loss: %.4f'\ 568 | %(iter+1, args.max_iters, np.mean(test_loss_list), np.mean(test_cnn_loss_list), np.mean(test_gnn_loss_list), np.mean(test_infer_module_loss_list)) 569 | print 'test_cnn_acc: %.4f, test_cnn_auc: %.4f, test_cnn_ap: %.4f'%(cnn_acc_test, cnn_auc_test, cnn_ap_test) 570 | print 'test_gnn_acc: %.4f, test_gnn_auc: %.4f, test_gnn_ap: %.4f'%(gnn_acc_test, gnn_auc_test, gnn_ap_test) 571 | print 'test_infer_module_acc: %.4f, test_infer_module_auc: %.4f, test_infer_module_ap: %.4f'%(infer_module_acc_test, infer_module_auc_test, infer_module_ap_test) 572 | print 'lr: %.8f'%(cur_lr) 573 | 574 | f_log.write('iter: '+str(iter+1)+' / '+str(args.max_iters)+'\n') 575 | f_log.write('train_loss '+str(np.mean(train_loss_list))+'\n') 576 | f_log.write('train_cnn_loss '+str(np.mean(train_cnn_loss_list))+'\n') 577 | f_log.write('train_gnn_loss '+str(np.mean(train_gnn_loss_list))+'\n') 578 | f_log.write('train_infer_module_loss '+str(np.mean(train_infer_module_loss_list))+'\n') 579 | f_log.write('iter: '+str(iter+1)+' / '+str(args.max_iters)+'\n') 580 | f_log.write('test_loss '+str(np.mean(test_loss_list))+'\n') 581 | f_log.write('test_cnn_loss '+str(np.mean(test_cnn_loss_list))+'\n') 582 | f_log.write('test_gnn_loss '+str(np.mean(test_gnn_loss_list))+'\n') 583 | f_log.write('test_infer_module_loss '+str(np.mean(test_infer_module_loss_list))+'\n') 584 | f_log.write('test_cnn_acc '+str(cnn_acc_test)+'\n') 585 | f_log.write('test_cnn_auc '+str(cnn_auc_test)+'\n') 586 | f_log.write('test_cnn_ap '+str(cnn_ap_test)+'\n') 587 | f_log.write('test_gnn_acc '+str(gnn_acc_test)+'\n') 588 | f_log.write('test_gnn_auc '+str(gnn_auc_test)+'\n') 589 | f_log.write('test_gnn_ap '+str(gnn_ap_test)+'\n') 590 | f_log.write('test_infer_module_acc '+str(infer_module_acc_test)+'\n') 591 | f_log.write('test_infer_module_auc '+str(infer_module_auc_test)+'\n') 592 | f_log.write('test_infer_module_ap '+str(infer_module_ap_test)+'\n') 593 | f_log.write('lr '+str(cur_lr)+'\n') 594 | f_log.flush() 595 | 596 | test_loss_logs.append(float(np.mean(test_loss_list))) 597 | 598 | train_loss_list = [] 599 | train_cnn_loss_list = [] 600 | train_gnn_loss_list = [] 601 | train_infer_module_loss_list = [] 602 | test_loss_list = [] 603 | test_cnn_loss_list = [] 604 | test_gnn_loss_list = [] 605 | test_infer_module_loss_list = [] 606 | 607 | # cnn module related 608 | all_cnn_labels = np.zeros((0,)) 609 | all_cnn_preds = np.zeros((0,)) 610 | 611 | # gnn module related 612 | all_gnn_labels = np.zeros((0,)) 613 | all_gnn_preds = np.zeros((0,)) 614 | 615 | # inference module related 616 | all_infer_module_preds = np.zeros((0,)) 617 | 618 | if last_snapshot_iter != iter: 619 | filename = os.path.join(model_save_path,('iter_{:d}'.format(iter+1) + '.ckpt')) 620 | saver.save(sess, filename) 621 | print 'Wrote snapshot to: {:s}'.format(filename) 622 | 623 | f_log.close() 624 | sess.close() 625 | if args.use_multiprocessing: 626 | pool.terminate() 627 | print("Training complete.") -------------------------------------------------------------------------------- /code/util.py: -------------------------------------------------------------------------------- 1 | """ Common util file 2 | """ 3 | 4 | 5 | import numpy as np 6 | import numpy.random as npr 7 | import os 8 | import skimage.io 9 | import skimage.transform 10 | import time 11 | import pdb 12 | import networkx as nx 13 | import scipy.sparse as sp 14 | from sklearn.metrics import roc_auc_score 15 | from sklearn.metrics import average_precision_score 16 | import matplotlib.pyplot as plt 17 | import random 18 | import skimage.draw 19 | 20 | import _init_paths 21 | from config import cfg 22 | 23 | STR_ELEM = np.array([[1,1,1],[1,1,0],[0,0,0]], dtype=np.bool) # eight-neighbors 24 | #STR_ELEM = np.array([[0,1,0],[1,1,0],[0,0,0]], dtype=np.bool) # four-neighbors 25 | 26 | # graph visualization 27 | VIS_FIG_SIZE = (10,10) 28 | VIS_NODE_SIZE = 50 29 | VIS_ALPHA = 0.5 # (both for nodes and edges) 30 | VIS_NODE_COLOR = ['b','r','y','g'] # tp/fp/fn(+tn)/tn 31 | VIS_EDGE_COLOR = ['b','g','r'] # tp/fn/fp 32 | 33 | DEBUG = False 34 | 35 | 36 | class DataLayer(object): 37 | 38 | def __init__(self, db, is_training, use_padding=False): 39 | """Set the db to be used by this layer.""" 40 | self._db = db 41 | self._is_training = is_training 42 | self._use_padding = use_padding 43 | if self._is_training: 44 | self._shuffle_db_inds() 45 | else: 46 | self._db_inds() 47 | 48 | def _shuffle_db_inds(self): 49 | """Randomly permute the db.""" 50 | self._perm = np.random.permutation(np.arange(len(self._db))) 51 | self._cur = 0 52 | 53 | def _db_inds(self): 54 | """Permute the db.""" 55 | self._perm = np.arange(len(self._db)) 56 | self._cur = 0 57 | 58 | def _get_next_minibatch_inds(self): 59 | """Return the db indices for the next minibatch.""" 60 | cur_batch_size = cfg.TRAIN.BATCH_SIZE 61 | if self._is_training: 62 | if self._cur + cfg.TRAIN.BATCH_SIZE > len(self._db): 63 | self._shuffle_db_inds() 64 | else: 65 | rem = len(self._db) - self._cur 66 | if rem >= cfg.TRAIN.BATCH_SIZE: 67 | cur_batch_size = cfg.TRAIN.BATCH_SIZE 68 | else: 69 | cur_batch_size = rem 70 | 71 | db_inds = self._perm[self._cur:self._cur + cur_batch_size] 72 | self._cur += cur_batch_size 73 | if (not self._is_training) and (self._cur>=len(self._db)): 74 | self._db_inds() 75 | 76 | return db_inds 77 | 78 | def _get_next_minibatch(self): 79 | """Return the blobs to be used for the next minibatch.""" 80 | db_inds = self._get_next_minibatch_inds() 81 | minibatch_db = [self._db[i] for i in db_inds] 82 | return minibatch_db, get_minibatch(minibatch_db, self._is_training, \ 83 | use_padding=self._use_padding) 84 | 85 | def forward(self): 86 | """Get blobs and copy them into this layer's top blob vector.""" 87 | img_list, blobs = self._get_next_minibatch() 88 | return img_list, blobs 89 | 90 | 91 | class GraphDataLayer(object): 92 | 93 | def __init__(self, db, is_training, \ 94 | edge_type='srns_geo_dist_binary', \ 95 | win_size=8, edge_geo_dist_thresh=20): 96 | """Set the db to be used by this layer.""" 97 | self._db = db 98 | self._is_training = is_training 99 | self._edge_type = edge_type 100 | self._win_size = win_size 101 | self._edge_geo_dist_thresh = edge_geo_dist_thresh 102 | if self._is_training: 103 | self._shuffle_db_inds() 104 | else: 105 | self._db_inds() 106 | 107 | def _shuffle_db_inds(self): 108 | """Randomly permute the db.""" 109 | self._perm = np.random.permutation(np.arange(len(self._db))) 110 | self._cur = 0 111 | 112 | def _db_inds(self): 113 | """Permute the db.""" 114 | self._perm = np.arange(len(self._db)) 115 | self._cur = 0 116 | 117 | def _get_next_minibatch_inds(self): 118 | """Return the db indices for the next minibatch.""" 119 | cur_batch_size = cfg.TRAIN.GRAPH_BATCH_SIZE 120 | if self._is_training: 121 | if self._cur + cfg.TRAIN.GRAPH_BATCH_SIZE > len(self._db): 122 | self._shuffle_db_inds() 123 | else: 124 | rem = len(self._db) - self._cur 125 | if rem >= cfg.TRAIN.GRAPH_BATCH_SIZE: 126 | cur_batch_size = cfg.TRAIN.GRAPH_BATCH_SIZE 127 | else: 128 | cur_batch_size = rem 129 | 130 | db_inds = self._perm[self._cur:self._cur + cur_batch_size] 131 | self._cur += cur_batch_size 132 | if (not self._is_training) and (self._cur>=len(self._db)): 133 | self._db_inds() 134 | 135 | return db_inds 136 | 137 | def _get_next_minibatch(self): 138 | """Return the blobs to be used for the next minibatch.""" 139 | db_inds = self._get_next_minibatch_inds() 140 | minibatch_db = [self._db[i] for i in db_inds] 141 | return minibatch_db, get_minibatch(minibatch_db, self._is_training, \ 142 | is_about_graph=True, \ 143 | edge_type=self._edge_type, \ 144 | win_size=self._win_size, \ 145 | edge_geo_dist_thresh=self._edge_geo_dist_thresh) 146 | 147 | def forward(self): 148 | """Get blobs and copy them into this layer's top blob vector.""" 149 | img_list, blobs = self._get_next_minibatch() 150 | return img_list, blobs 151 | 152 | def reinit(self, db, is_training, \ 153 | edge_type='srns_geo_dist_binary', \ 154 | win_size=8, edge_geo_dist_thresh=20): 155 | """Reinitialize with new arguments.""" 156 | self._db = db 157 | self._is_training = is_training 158 | self._edge_type = edge_type 159 | self._win_size = win_size 160 | self._edge_geo_dist_thresh = edge_geo_dist_thresh 161 | if self._is_training: 162 | self._shuffle_db_inds() 163 | else: 164 | self._db_inds() 165 | 166 | 167 | def get_minibatch(minibatch_db, is_training, \ 168 | is_about_graph=False, \ 169 | edge_type='srns_geo_dist_binary', \ 170 | win_size=8, edge_geo_dist_thresh=20, \ 171 | use_padding=False): 172 | """Given a minibatch_db, construct a blob.""" 173 | 174 | if not is_about_graph: 175 | im_blob, label_blob, fov_blob = _get_image_fov_blob(minibatch_db, is_training, use_padding=use_padding) 176 | blobs = {'img': im_blob, 'label': label_blob, 'fov': fov_blob} 177 | 178 | else: 179 | im_blob, label_blob, fov_blob, probmap_blob, \ 180 | all_union_graph, \ 181 | num_of_nodes_list, vec_aug_on, rot_angle = \ 182 | _get_graph_fov_blob(minibatch_db, is_training, edge_type, win_size, edge_geo_dist_thresh) 183 | 184 | blobs = {'img': im_blob, 'label': label_blob, 'fov': fov_blob, 'probmap': probmap_blob, 185 | 'graph': all_union_graph, 186 | 'num_of_nodes_list': num_of_nodes_list, 187 | 'vec_aug_on': vec_aug_on, 188 | 'rot_angle': rot_angle} 189 | 190 | return blobs 191 | 192 | 193 | def _get_image_fov_blob(minibatch_db, is_training, use_padding=False): 194 | """Builds an input blob from the images in the minibatch_db.""" 195 | 196 | num_images = len(minibatch_db) 197 | processed_ims = [] 198 | processed_labels = [] 199 | processed_fovs = [] 200 | if 'DRIVE' in minibatch_db[0]: 201 | im_ext = '_image.tif' 202 | label_ext = '_label.gif' 203 | fov_ext = '_mask.gif' 204 | pixel_mean = cfg.PIXEL_MEAN_DRIVE 205 | len_y = 592 206 | len_x = 592 207 | elif 'STARE' in minibatch_db[0]: 208 | im_ext = '.ppm' 209 | label_ext = '.ah.ppm' 210 | fov_ext = '_mask.png' 211 | pixel_mean = cfg.PIXEL_MEAN_STARE 212 | len_y = 704 213 | len_x = 704 214 | elif 'CHASE_DB1' in minibatch_db[0]: 215 | im_ext = '.jpg' 216 | label_ext = '_1stHO.png' 217 | fov_ext = '_mask.tif' 218 | pixel_mean = cfg.PIXEL_MEAN_CHASE_DB1 219 | len_y = 1024 220 | len_x = 1024 221 | elif 'HRF' in minibatch_db[0]: 222 | im_ext = '.bmp' 223 | label_ext = '.tif' 224 | fov_ext = '_mask.tif' 225 | pixel_mean = cfg.PIXEL_MEAN_HRF 226 | len_y = 768 227 | len_x = 768 228 | 229 | for i in xrange(num_images): 230 | im = skimage.io.imread(minibatch_db[i]+im_ext) 231 | label = skimage.io.imread(minibatch_db[i]+label_ext) 232 | label = label.reshape((label.shape[0],label.shape[1],1)) 233 | fov = skimage.io.imread(minibatch_db[i]+fov_ext) 234 | if fov.ndim==2: 235 | fov = fov.reshape((fov.shape[0],fov.shape[1],1)) 236 | else: 237 | fov = fov[:,:,[0]] 238 | 239 | if use_padding: 240 | temp = np.copy(im) 241 | im = np.zeros((len_y,len_x,3), dtype=temp.dtype) 242 | im[:temp.shape[0],:temp.shape[1],:] = temp 243 | temp = np.copy(label) 244 | label = np.zeros((len_y,len_x,1), dtype=temp.dtype) 245 | label[:temp.shape[0],:temp.shape[1],:] = temp 246 | temp = np.copy(fov) 247 | fov = np.zeros((len_y,len_x,1), dtype=temp.dtype) 248 | fov[:temp.shape[0],:temp.shape[1],:] = temp 249 | 250 | processed_im, processed_label, processed_fov, _ = \ 251 | prep_im_fov_for_blob(im, label, fov, pixel_mean, is_training) 252 | processed_ims.append(processed_im) 253 | processed_labels.append(processed_label) 254 | processed_fovs.append(processed_fov) 255 | 256 | # Create a blob to hold the input images & labels & fovs 257 | im_blob = im_list_to_blob(processed_ims) 258 | label_blob = im_list_to_blob(processed_labels) 259 | fov_blob = im_list_to_blob(processed_fovs) 260 | 261 | return im_blob, label_blob, fov_blob 262 | 263 | 264 | def _get_graph_fov_blob(minibatch_db, is_training, edge_type='srns_geo_dist_binary', \ 265 | win_size=8, edge_geo_dist_thresh=20): 266 | """Builds an input blob from the graphs in the minibatch_db.""" 267 | 268 | num_graphs = len(minibatch_db) 269 | processed_ims = [] # image related 270 | processed_labels = [] # image related 271 | processed_fovs = [] # image related 272 | processed_probmaps = [] # image related 273 | all_graphs = [] # graph related 274 | num_of_nodes_list = [] # graph related 275 | 276 | # to apply the same aug in a mini-batch # 277 | if num_graphs > 1: 278 | given_aug_vec = np.zeros((7,), dtype=np.bool) 279 | if cfg.TRAIN.USE_LR_FLIPPED and npr.random_sample() >= 0.5: 280 | given_aug_vec[0] = True 281 | if cfg.TRAIN.USE_UD_FLIPPED and npr.random_sample() >= 0.5: 282 | given_aug_vec[1] = True 283 | if cfg.TRAIN.USE_ROTATION: 284 | given_aug_vec[2] = True 285 | if cfg.TRAIN.USE_SCALING: 286 | given_aug_vec[3] = True 287 | if cfg.TRAIN.USE_CROPPING: 288 | given_aug_vec[4] = True 289 | if cfg.TRAIN.USE_BRIGHTNESS_ADJUSTMENT: 290 | given_aug_vec[5] = True 291 | if cfg.TRAIN.USE_CONTRAST_ADJUSTMENT: 292 | given_aug_vec[6] = True 293 | # to apply the same aug in a mini-batch # 294 | 295 | if 'DRIVE' in minibatch_db[0]: 296 | im_root_path = '../DRIVE/all' 297 | im_ext = '_image.tif' 298 | label_ext = '_label.gif' 299 | fov_ext = '_mask.gif' 300 | pixel_mean = cfg.PIXEL_MEAN_DRIVE 301 | len_y = 592 302 | len_x = 592 303 | elif 'STARE' in minibatch_db[0]: 304 | im_root_path = '../STARE/all' 305 | im_ext = '.ppm' 306 | label_ext = '.ah.ppm' 307 | fov_ext = '_mask.png' 308 | pixel_mean = cfg.PIXEL_MEAN_STARE 309 | len_y = 704 310 | len_x = 704 311 | elif 'CHASE' in minibatch_db[0]: 312 | im_root_path = '../CHASE_DB1/all' 313 | im_ext = '.jpg' 314 | label_ext = '_1stHO.png' 315 | fov_ext = '_mask.tif' 316 | pixel_mean = cfg.PIXEL_MEAN_CHASE_DB1 317 | len_y = 1024 318 | len_x = 1024 319 | elif 'HRF' in minibatch_db[0]: 320 | im_root_path = '../HRF/all_768' 321 | im_ext = '.bmp' 322 | label_ext = '.tif' 323 | fov_ext = '_mask.tif' 324 | pixel_mean = cfg.PIXEL_MEAN_HRF 325 | len_y = 768 326 | len_x = 768 327 | for i in xrange(num_graphs): 328 | 329 | # load images 330 | cur_path = minibatch_db[i] 331 | cur_name = cur_path[find(cur_path,'/')[-1]+1:] 332 | 333 | im = skimage.io.imread(os.path.join(im_root_path, cur_name+im_ext)) 334 | label = skimage.io.imread(os.path.join(im_root_path, cur_name+label_ext)) 335 | label = label.reshape((label.shape[0],label.shape[1],1)) 336 | fov = skimage.io.imread(os.path.join(im_root_path, cur_name+fov_ext)) 337 | if fov.ndim==2: 338 | fov = fov.reshape((fov.shape[0],fov.shape[1],1)) 339 | else: 340 | fov = fov[:,:,[0]] 341 | probmap = skimage.io.imread(cur_path+'_prob.png') # cnn results will be used for loss masking 342 | probmap = probmap.reshape((probmap.shape[0],probmap.shape[1],1)) 343 | 344 | temp = np.copy(im) 345 | im = np.zeros((len_y,len_x,3), dtype=temp.dtype) 346 | im[:temp.shape[0],:temp.shape[1],:] = temp 347 | temp = np.copy(label) 348 | label = np.zeros((len_y,len_x,1), dtype=temp.dtype) 349 | label[:temp.shape[0],:temp.shape[1],:] = temp 350 | temp = np.copy(fov) 351 | fov = np.zeros((len_y,len_x,1), dtype=temp.dtype) 352 | fov[:temp.shape[0],:temp.shape[1],:] = temp 353 | temp = np.copy(probmap) 354 | probmap = np.zeros((len_y,len_x,1), dtype=temp.dtype) 355 | probmap[:temp.shape[0],:temp.shape[1],:] = temp 356 | 357 | # load graphs 358 | if 'srns' not in edge_type: 359 | raise NotImplementedError 360 | else: 361 | win_size_str = '_%.2d_%.2d'%(win_size,edge_geo_dist_thresh) 362 | graph = nx.read_gpickle(cur_path+win_size_str+'.graph_res') 363 | 364 | union_graph = nx.convert_node_labels_to_integers(graph) 365 | n_nodes_in_graph = union_graph.number_of_nodes() 366 | node_idx_map = np.zeros(im.shape[:2]) 367 | for j in xrange(n_nodes_in_graph): 368 | node_idx_map[union_graph.nodes[j]['y'],union_graph.nodes[j]['x']] = j+1 369 | 370 | if num_graphs > 1: # not used 371 | raise NotImplementedError 372 | else: 373 | processed_im, processed_label, processed_fov, processed_probmap, processed_node_idx_map, \ 374 | vec_aug_on, (crop_y1,crop_y2,crop_x1,crop_x2), rot_angle = \ 375 | prep_im_label_fov_probmap_for_blob(im, label, fov, probmap, node_idx_map, pixel_mean, is_training, win_size) 376 | 377 | processed_ims.append(processed_im) 378 | processed_labels.append(processed_label) 379 | processed_fovs.append(processed_fov) 380 | processed_probmaps.append(processed_probmap) 381 | 382 | node_ys, node_xs = np.where(processed_node_idx_map) 383 | for j in xrange(len(node_ys)): 384 | cur_node_idx = processed_node_idx_map[node_ys[j],node_xs[j]] 385 | union_graph.nodes[cur_node_idx-1]['y'] = node_ys[j] 386 | union_graph.nodes[cur_node_idx-1]['x'] = node_xs[j] 387 | union_graph = nx.convert_node_labels_to_integers(union_graph) 388 | n_nodes_in_graph = union_graph.number_of_nodes() 389 | 390 | """if vec_aug_on[0]: 391 | for j in xrange(n_nodes_in_graph): 392 | union_graph.nodes[j]['x'] = label.shape[1]-union_graph.nodes[j]['x']-1 393 | 394 | if vec_aug_on[1]: 395 | for j in xrange(n_nodes_in_graph): 396 | union_graph.nodes[j]['y'] = label.shape[0]-union_graph.nodes[j]['y']-1""" 397 | 398 | if vec_aug_on[4]: 399 | del_node_list = [] 400 | for j in xrange(n_nodes_in_graph): 401 | if (union_graph.nodes[j]['y']>=crop_y1 and \ 402 | union_graph.nodes[j]['y']=crop_x1 and \ 404 | union_graph.nodes[j]['x']= 0.5: 453 | vec_aug_on[0] = True 454 | im = im[:, ::-1, :] 455 | label = label[:, ::-1, :] 456 | fov = fov[:, ::-1, :] 457 | 458 | if cfg.TRAIN.USE_UD_FLIPPED and npr.random_sample() >= 0.5: 459 | vec_aug_on[1] = True 460 | im = im[::-1, :, :] 461 | label = label[::-1, :, :] 462 | fov = fov[::-1, :, :] 463 | 464 | if cfg.TRAIN.USE_ROTATION: 465 | vec_aug_on[2] = True 466 | rot_angle = np.random.uniform(-cfg.TRAIN.ROTATION_MAX_ANGLE,cfg.TRAIN.ROTATION_MAX_ANGLE) 467 | """im_r = skimage.transform.rotate(im[:,:,0], rot_angle, cval=0.) 468 | im_g = skimage.transform.rotate(im[:,:,1], rot_angle, cval=0.) 469 | im_b = skimage.transform.rotate(im[:,:,2], rot_angle, cval=0.)""" 470 | im_r = skimage.transform.rotate(im[:,:,0], rot_angle, cval=pixel_mean[0]/255.) 471 | im_g = skimage.transform.rotate(im[:,:,1], rot_angle, cval=pixel_mean[1]/255.) 472 | im_b = skimage.transform.rotate(im[:,:,2], rot_angle, cval=pixel_mean[2]/255.) 473 | im = np.dstack((im_r,im_g,im_b)) 474 | label = skimage.transform.rotate(label, rot_angle, cval=0., order=0) 475 | fov = skimage.transform.rotate(fov, rot_angle, cval=0., order=0) 476 | 477 | if cfg.TRAIN.USE_SCALING: 478 | vec_aug_on[3] = True 479 | scale = np.random.uniform(cfg.TRAIN.SCALING_RANGE[0],cfg.TRAIN.SCALING_RANGE[1]) 480 | im = skimage.transform.rescale(im, scale) 481 | label = skimage.transform.rescale(label, scale, order=0) 482 | fov = skimage.transform.rescale(fov, scale, order=0) 483 | 484 | if cfg.TRAIN.USE_CROPPING: 485 | vec_aug_on[4] = True 486 | cur_h = np.random.random_integers(im.shape[0]*0.5,im.shape[0]*0.8) 487 | cur_w = np.random.random_integers(im.shape[1]*0.5,im.shape[1]*0.8) 488 | cur_y1 = np.random.random_integers(0,im.shape[0]-cur_h) 489 | cur_x1 = np.random.random_integers(0,im.shape[1]-cur_w) 490 | cur_y2 = cur_y1 + cur_h 491 | cur_x2 = cur_x1 + cur_w 492 | im = im[cur_y1:cur_y2,cur_x1:cur_x2,:] 493 | label = label[cur_y1:cur_y2,cur_x1:cur_x2,:] 494 | fov = fov[cur_y1:cur_y2,cur_x1:cur_x2,:] 495 | 496 | if cfg.TRAIN.USE_BRIGHTNESS_ADJUSTMENT: 497 | vec_aug_on[5] = True 498 | im += np.random.uniform(-cfg.TRAIN.BRIGHTNESS_ADJUSTMENT_MAX_DELTA,cfg.TRAIN.BRIGHTNESS_ADJUSTMENT_MAX_DELTA) 499 | im = np.clip(im, 0, 1) 500 | 501 | if cfg.TRAIN.USE_CONTRAST_ADJUSTMENT: 502 | vec_aug_on[6] = True 503 | mm = np.mean(im) 504 | im = (im-mm)*np.random.uniform(cfg.TRAIN.CONTRAST_ADJUSTMENT_LOWER_FACTOR,cfg.TRAIN.CONTRAST_ADJUSTMENT_UPPER_FACTOR) + mm 505 | im = np.clip(im, 0, 1) 506 | 507 | #skimage.io.imsave('img_1.bmp', (im*255).astype(int)) 508 | 509 | # original 510 | im -= np.array(pixel_mean)/255. 511 | im = im*255. 512 | 513 | """# contrast enhancement 514 | im_f = skimage.filters.gaussian(im, sigma=10, multichannel=True) 515 | im = im-im_f 516 | im = im*255.""" 517 | 518 | label = label>=0.5 519 | fov = fov>=0.5 520 | 521 | #skimage.io.imsave('label_1.bmp', (label.reshape(label.shape[0],label.shape[1])*255).astype(int)) 522 | 523 | return im, label, fov, vec_aug_on 524 | 525 | 526 | def prep_im_label_fov_probmap_for_blob(im, label, fov, probmap, node_idx_map, pixel_mean, is_training, win_size): 527 | """Preprocess images for use in a blob.""" 528 | 529 | im = im.astype(np.float32, copy=False)/255. 530 | label = label.astype(np.float32, copy=False)/255. 531 | fov = fov.astype(np.float32, copy=False)/255. 532 | probmap = probmap.astype(np.float32, copy=False)/255. 533 | 534 | vec_aug_on = np.zeros((7,), dtype=np.bool) 535 | 536 | cur_y1 = 0 537 | cur_y2 = 0 538 | cur_x1 = 0 539 | cur_x2 = 0 540 | rot_angle = 0 541 | if is_training: 542 | if cfg.TRAIN.USE_LR_FLIPPED and npr.random_sample() >= 0.5: 543 | vec_aug_on[0] = True 544 | im = im[:, ::-1, :] 545 | label = label[:, ::-1, :] 546 | fov = fov[:, ::-1, :] 547 | probmap = probmap[:, ::-1, :] 548 | node_idx_map = node_idx_map[:, ::-1] 549 | 550 | if cfg.TRAIN.USE_UD_FLIPPED and npr.random_sample() >= 0.5: 551 | vec_aug_on[1] = True 552 | im = im[::-1, :, :] 553 | label = label[::-1, :, :] 554 | fov = fov[::-1, :, :] 555 | probmap = probmap[::-1, :, :] 556 | node_idx_map = node_idx_map[::-1, :] 557 | 558 | if cfg.TRAIN.USE_ROTATION: 559 | vec_aug_on[2] = True 560 | 561 | len_ori_y,len_ori_x = im.shape[:2] 562 | 563 | rot_angle = np.random.choice([0,90,180,270]) 564 | im_r = skimage.transform.rotate(im[:,:,0], rot_angle, cval=pixel_mean[0]/255., resize=True) 565 | im_g = skimage.transform.rotate(im[:,:,1], rot_angle, cval=pixel_mean[1]/255., resize=True) 566 | im_b = skimage.transform.rotate(im[:,:,2], rot_angle, cval=pixel_mean[2]/255., resize=True) 567 | im = np.dstack((im_r,im_g,im_b)) 568 | label = skimage.transform.rotate(label, rot_angle, cval=0., order=0, resize=True) 569 | fov = skimage.transform.rotate(fov, rot_angle, cval=0., order=0, resize=True) 570 | probmap = skimage.transform.rotate(probmap, rot_angle, cval=0., resize=True) 571 | node_idx_map = skimage.transform.rotate(node_idx_map, rot_angle, cval=0., order=0, resize=True) 572 | 573 | im = im[:len_ori_y,:len_ori_x,:] 574 | label = label[:len_ori_y,:len_ori_x,:] 575 | fov = fov[:len_ori_y,:len_ori_x,:] 576 | probmap = probmap[:len_ori_y,:len_ori_x,:] 577 | node_idx_map = node_idx_map[:len_ori_y,:len_ori_x] 578 | 579 | if cfg.TRAIN.USE_SCALING: 580 | vec_aug_on[3] = True 581 | scale = np.random.uniform(cfg.TRAIN.SCALING_RANGE[0],cfg.TRAIN.SCALING_RANGE[1]) 582 | im = skimage.transform.rescale(im, scale) 583 | label = skimage.transform.rescale(label, scale, order=0) 584 | fov = skimage.transform.rescale(fov, scale, order=0) 585 | probmap = skimage.transform.rescale(probmap, scale) 586 | node_idx_map = skimage.transform.rescale(node_idx_map, scale, order=0) 587 | 588 | if cfg.TRAIN.USE_CROPPING: 589 | vec_aug_on[4] = True 590 | 591 | # cropping dependent on 'win_size' 592 | cur_h = (np.random.random_integers(im.shape[0]*0.5,im.shape[0]*0.8)//win_size)*win_size 593 | cur_w = (np.random.random_integers(im.shape[1]*0.5,im.shape[1]*0.8)//win_size)*win_size 594 | if vec_aug_on[0]: 595 | cur_y1 = np.random.choice(range(im.shape[0]%win_size,im.shape[0]-cur_h,win_size)) 596 | cur_x1 = np.random.choice(range(im.shape[1]%win_size,im.shape[1]-cur_w,win_size)) 597 | else: 598 | cur_y1 = np.random.choice(range(0,im.shape[0]-cur_h,win_size)) 599 | cur_x1 = np.random.choice(range(0,im.shape[1]-cur_w,win_size)) 600 | cur_y2 = cur_y1 + cur_h 601 | cur_x2 = cur_x1 + cur_w 602 | 603 | im = im[cur_y1:cur_y2,cur_x1:cur_x2,:] 604 | label = label[cur_y1:cur_y2,cur_x1:cur_x2,:] 605 | fov = fov[cur_y1:cur_y2,cur_x1:cur_x2,:] 606 | probmap = probmap[cur_y1:cur_y2,cur_x1:cur_x2,:] 607 | node_idx_map = node_idx_map[cur_y1:cur_y2,cur_x1:cur_x2] 608 | 609 | if cfg.TRAIN.USE_BRIGHTNESS_ADJUSTMENT: 610 | vec_aug_on[5] = True 611 | im += np.random.uniform(-cfg.TRAIN.BRIGHTNESS_ADJUSTMENT_MAX_DELTA,cfg.TRAIN.BRIGHTNESS_ADJUSTMENT_MAX_DELTA) 612 | im = np.clip(im, 0, 1) 613 | 614 | if cfg.TRAIN.USE_CONTRAST_ADJUSTMENT: 615 | vec_aug_on[6] = True 616 | mm = np.mean(im) 617 | im = (im-mm)*np.random.uniform(cfg.TRAIN.CONTRAST_ADJUSTMENT_LOWER_FACTOR,cfg.TRAIN.CONTRAST_ADJUSTMENT_UPPER_FACTOR) + mm 618 | im = np.clip(im, 0, 1) 619 | 620 | # original 621 | im -= np.array(pixel_mean)/255. 622 | im = im*255. 623 | 624 | """# contrast enhancement 625 | im_f = skimage.filters.gaussian(im, sigma=10, multichannel=True) 626 | im = im-im_f 627 | im = im*255.""" 628 | 629 | label = label>=0.5 630 | fov = fov>=0.5 631 | 632 | return im, label, fov, probmap, node_idx_map, vec_aug_on, (cur_y1,cur_y2,cur_x1,cur_x2), rot_angle 633 | 634 | 635 | def im_list_to_blob(ims): 636 | """Convert a list of images into a network input.""" 637 | 638 | max_shape = np.array([im.shape for im in ims]).max(axis=0) 639 | num_images = len(ims) 640 | blob = np.zeros((num_images, max_shape[0], max_shape[1], max_shape[2]), 641 | dtype=ims[0].dtype) 642 | for i in xrange(num_images): 643 | im = ims[i] 644 | blob[i, 0:im.shape[0], 0:im.shape[1], :] = im 645 | 646 | return blob 647 | 648 | 649 | def find(s, ch): 650 | return [i for i, ltr in enumerate(s) if ltr == ch] 651 | 652 | 653 | class Timer(object): 654 | """A simple timer.""" 655 | def __init__(self): 656 | self.total_time = 0. 657 | self.calls = 0 658 | self.start_time = 0. 659 | self.diff = 0. 660 | self.average_time = 0. 661 | 662 | def tic(self): 663 | # using time.time instead of time.clock because time time.clock 664 | # does not normalize for multithreading 665 | self.start_time = time.time() 666 | 667 | def toc(self, average=True): 668 | self.diff = time.time() - self.start_time 669 | self.total_time += self.diff 670 | self.calls += 1 671 | self.average_time = self.total_time / self.calls 672 | if average: 673 | return self.average_time 674 | else: 675 | return self.diff 676 | 677 | 678 | def sparse_to_tuple(sparse_mx): 679 | if not sp.isspmatrix_coo(sparse_mx): 680 | sparse_mx = sparse_mx.tocoo() 681 | coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose() 682 | values = sparse_mx.data 683 | shape = sparse_mx.shape 684 | return coords, values, shape 685 | 686 | 687 | # append self connections (diagonal terms in adjacency matrix) and binarize 688 | def preprocess_graph_gat(adj): 689 | adj = sp.coo_matrix(adj) 690 | adj = adj + sp.eye(adj.shape[0]) # self-loop 691 | adj[adj > 0.0] = 1.0 692 | if not sp.isspmatrix_coo(adj): 693 | adj = adj.tocoo() 694 | adj = adj.astype(np.float32) 695 | indices = np.vstack((adj.col, adj.row)).transpose() 696 | return indices, adj.data, adj.shape 697 | 698 | 699 | def visualize_graph(im, graph, show_graph=False, save_graph=True, \ 700 | num_nodes_each_type=None, custom_node_color=None, \ 701 | tp_edges=None, fn_edges=None, fp_edges=None, \ 702 | save_path='graph.png'): 703 | 704 | plt.figure(figsize=VIS_FIG_SIZE) 705 | if im.dtype==np.bool: 706 | bg = im.astype(int)*255 707 | else: 708 | bg = im 709 | 710 | if len(bg.shape)==2: 711 | plt.imshow(bg, cmap='gray', vmin=0, vmax=255) 712 | elif len(bg.shape)==3: 713 | plt.imshow(bg) 714 | #plt.imshow(bg, cmap='gray', vmin=0, vmax=255) 715 | plt.axis('off') 716 | pos = {} 717 | node_list = list(graph.nodes) 718 | for i in node_list: 719 | pos[i] = [graph.nodes[i]['x'],graph.nodes[i]['y']] 720 | 721 | if custom_node_color is not None: 722 | node_color = custom_node_color 723 | else: 724 | if num_nodes_each_type is None: 725 | node_color = 'b' 726 | else: 727 | if not (graph.number_of_nodes()==np.sum(num_nodes_each_type)): 728 | raise ValueError('Wrong number of nodes') 729 | node_color = [VIS_NODE_COLOR[0]]*num_nodes_each_type[0] + [VIS_NODE_COLOR[1]]*num_nodes_each_type[1] 730 | 731 | nx.draw(graph, pos, node_color='green', edge_color='blue', width=1, node_size=10, alpha=VIS_ALPHA) 732 | #nx.draw(graph, pos, node_color='darkgreen', edge_color='black', width=3, node_size=30, alpha=VIS_ALPHA) 733 | #nx.draw(graph, pos, node_color=node_color, node_size=VIS_NODE_SIZE, alpha=VIS_ALPHA) 734 | 735 | if tp_edges is not None: 736 | nx.draw_networkx_edges(graph, pos, 737 | edgelist=tp_edges, 738 | width=3, alpha=VIS_ALPHA, edge_color=VIS_EDGE_COLOR[0]) 739 | if fn_edges is not None: 740 | nx.draw_networkx_edges(graph, pos, 741 | edgelist=fn_edges, 742 | width=3, alpha=VIS_ALPHA, edge_color=VIS_EDGE_COLOR[1]) 743 | if fp_edges is not None: 744 | nx.draw_networkx_edges(graph, pos, 745 | edgelist=fp_edges, 746 | width=3, alpha=VIS_ALPHA, edge_color=VIS_EDGE_COLOR[2]) 747 | 748 | if save_graph: 749 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 750 | if show_graph: 751 | plt.show() 752 | 753 | plt.cla() 754 | plt.clf() 755 | plt.close() 756 | 757 | 758 | def get_auc_ap_score(labels, preds): 759 | 760 | auc_score = roc_auc_score(labels, preds) 761 | ap_score = average_precision_score(labels, preds) 762 | 763 | return auc_score, ap_score 764 | 765 | 766 | def check_symmetric(a, tol=1e-8): 767 | return np.allclose(a, a.T, atol=tol) 768 | 769 | 770 | def get_node_byx_from_graph(graph, num_of_nodes_list): 771 | node_byxs = np.zeros((graph.number_of_nodes(),3), dtype=np.int32) 772 | node_idx = 0 773 | for sub_graph_idx, cur_num_nodes in enumerate(num_of_nodes_list): 774 | for i in range(node_idx,node_idx+cur_num_nodes): 775 | node_byxs[i,:] = [sub_graph_idx,graph.nodes[i]['y'],graph.nodes[i]['x']] 776 | node_idx = node_idx+cur_num_nodes 777 | 778 | return node_byxs -------------------------------------------------------------------------------- /models/empty.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pretrained_model/empty.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /results/Image_12R.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syshin1014/VGN/fa54c8b9d30db8d4273a4eb315ab1857e610746b/results/Image_12R.gif -------------------------------------------------------------------------------- /results/empty.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /results/im0239.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syshin1014/VGN/fa54c8b9d30db8d4273a4eb315ab1857e610746b/results/im0239.gif --------------------------------------------------------------------------------