├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── experiments ├── Example │ ├── Precision_recall.png │ ├── ROC.png │ ├── args.pkl │ ├── args.txt │ ├── best_model.pth │ ├── events.out.tfevents.1611322639.yx-Server1 │ ├── latest_model.pth │ ├── log_2021-01-22-21-37.csv │ ├── performance.txt │ ├── result.npy │ ├── result_img │ │ ├── Result_01_test.png │ │ ├── Result_02_test.png │ │ ├── Result_03_test.png │ │ ├── Result_04_test.png │ │ └── Result_05_test.png │ ├── sample_input_imgs.png │ ├── sample_input_masks.png │ ├── test_log.txt │ └── train_log.txt └── README.md ├── function.py ├── lib ├── __init__.py ├── common.py ├── dataset.py ├── datasetV2.py ├── extract_patches.py ├── logger.py ├── losses │ ├── __init__.py │ ├── loss.py │ └── loss_lab.py ├── metrics.py ├── pre_processing.py └── visualize.py ├── models ├── DenseUnet.py ├── LadderNet.py ├── UNetFamily.py ├── __init__.py └── nn │ ├── __init__.py │ └── attention.py ├── prepare_dataset ├── chasedb1.py ├── data_path_list │ └── .gitignore ├── drive.py └── stare.py ├── test.py ├── tools ├── README.md ├── ablation │ ├── ablation_plot.py │ └── ablation_plot_with_detail.py ├── merge_k-flod_plot.py └── visualization │ ├── detail_comparison.py │ ├── detail_comparison2.py │ └── preprocess_visualization.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # my options 2 | data/ 3 | !/data/README.md 4 | experiments/* 5 | !experiments/README.md 6 | !experiments/Example 7 | 8 | # tools/ablation 9 | # tools/visualization 10 | 11 | .vscode/ 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## __VesselSeg-Pytorch__ : _Retinal vessel segmentation toolkit based on pytorch_ 2 | ### Introduction 3 | This project is a retinal blood vessel segmentation code based on python and pytorch framework, including data preprocessing, model training and testing, visualization, etc. This project is suitable for researchers who study retinal vessel segmentation. 4 | ![Segmentation results](http://ww1.sinaimg.cn/mw690/a2d5ce76ly1gn16jnugz6j20ri08w465.jpg) 5 | ### Requirements 6 | The main package and version of the python environment are as follows 7 | ``` 8 | # Name Version 9 | python 3.7.9 10 | pytorch 1.7.0 11 | torchvision 0.8.0 12 | cudatoolkit 10.2.89 13 | cudnn 7.6.5 14 | matplotlib 3.3.2 15 | numpy 1.19.2 16 | opencv 3.4.2 17 | pandas 1.1.3 18 | pillow 8.0.1 19 | scikit-learn 0.23.2 20 | scipy 1.5.2 21 | tensorboardX 2.1 22 | tqdm 4.54.1 23 | ``` 24 | The above environment is successful when running the code of the project. In addition, it is well known that pytorch has very good compatibility (version>=1.0). Thus, __I suggest you try to use the existing pytorch environment firstly.__ 25 | 26 | The current version has problems reading the `.tif` format image in the DRIVE dataset on Windows OS. __It is recommended that you use Linux for training and testing__ 27 | 28 | --- 29 | ## Usage 30 | ### 0) Download Project 31 | 32 | Running```git clone https://github.com/lee-zq/VesselSeg-Pytorch.git``` 33 | The project structure and intention are as follows : 34 | ``` 35 | VesselSeg-Pytorch # Source code 36 | ├── config.py # Configuration information 37 | ├── lib # Function library 38 | │ ├── common.py 39 | │ ├── dataset.py # Dataset class to load training data 40 | │ ├── datasetV2.py # Dataset class to load training data with lower memory 41 | │ ├── extract_patches.py # Extract training and test samples 42 | │ ├── help_functions.py # 43 | │ ├── __init__.py 44 | │ ├── logger.py # To create log 45 | │ ├── losses 46 | │ ├── metrics.py # Evaluation metrics 47 | │ └── pre_processing.py # Data preprocessing 48 | ├── models # All models are created in this folder 49 | │ ├── denseunet.py 50 | │ ├── __init__.py 51 | │ ├── LadderNet.py 52 | │ ├── nn 53 | │ └── UNetFamily.py 54 | ├── prepare_dataset # Prepare the dataset (organize the image path of the dataset) 55 | │ ├── chasedb1.py 56 | │ ├── data_path_list # image path of dataset 57 | │ ├── drive.py 58 | │ └── stare.py 59 | ├── tools # some tools 60 | │ ├── ablation_plot.py 61 | │ ├── ablation_plot_with_detail.py 62 | │ ├── merge_k-flod_plot.py 63 | │ └── visualization 64 | ├── function.py # Creating dataloader, training and validation functions 65 | ├── test.py # Test file 66 | └── train.py # Train file 67 | ``` 68 | ### 1) Datasets preparation 69 | 1. Please download the retina image datasets(DRIVE, STARE and CHASE_DB1) from [TianYi Cloud](https://cloud.189.cn/t/UJrmYrFZBzIn). Otherwise, you can download three data sets from the official address: [DRIVE](http://www.isi.uu.nl/Research/Databases/DRIVE/),[STARE](http://www.ces.clemson.edu/ahoover/stare/) and [CHASE_DB1](). 70 | 2. Unzip the downloaded `datasets.rar` file. The results are as follows: 71 | ``` 72 | datasets 73 | ├── CHASEDB1 74 | │ ├── 1st_label 75 | │ ├── 2nd_label 76 | │ ├── images 77 | │ └── mask 78 | ├── DRIVE 79 | │ ├── test 80 | │ └── training 81 | └── STARE 82 | ├── 1st_labels_ah 83 | ├── images 84 | ├── mask 85 | └── snd_label_vk 86 | ``` 87 | 3. Create data path index file(.txt). running: 88 | Please modify the data folder path:`data_root_path`(in the [`drive.py`](https://github.com/lee-zq/VesselSeg-Pytorch/blob/master/prepare_dataset/drive.py), [`stare.py`](https://github.com/lee-zq/VesselSeg-Pytorch/blob/master/prepare_dataset/stare.py) and [`chasedb1.py`](https://github.com/lee-zq/VesselSeg-Pytorch/blob/master/prepare_dataset/chasedb1.py)) to the absolute path of the datasets downloaded above 89 | ``` 90 | python ./prepare_dataset/drive.py 91 | ``` 92 | In the same way, the data path files of the three datasets can be obtained, and the results are saved in the [`./prepare_dataset/data_path_list`](https://github.com/lee-zq/VesselSeg-Pytorch/tree/master/prepare_dataset/data_path_list) folder 93 | ### 2) Training model 94 | Please confirm the configuration information in the [`config.py`](https://github.com/lee-zq/VesselSeg-Pytorch/blob/master/config.py). Pay special attention to the `train_data_path_list` and `test_data_path_list`. Then, running: 95 | ``` 96 | CUDA_VISIBLE_DEVICES=1 python train.py --save UNet_vessel_seg --batch_size 64 97 | ``` 98 | You can configure the training information in config, or modify the configuration parameters using the command line. The training results will be saved to the corresponding directory(save name) in the `experiments` folder. 99 | ### 3) Testing model 100 | The test process also needs to specify parameters in [`config.py`](https://github.com/lee-zq/VesselSeg-Pytorch/blob/master/config.py). You can also modify the parameters through the command line, running: 101 | ``` 102 | CUDA_VISIBLE_DEVICES=1 python test.py --save UNet_vessel_seg 103 | ``` 104 | The above command loads the `best_model.pth` in `./experiments/UNet_vessel_seg` and performs a performance test on the testset, and its test results are saved in the same folder. 105 | 106 | ## Visualization 107 | 0. Training sample visualization 108 | ![train data](http://ww1.sinaimg.cn/mw690/a2d5ce76ly1gn1710u3s4j20hg06y0tt.jpg) 109 | 1. Segmentation results 110 | The original image, predicted probability image, predicted binary image and groundtruth 111 | DRIVE: 112 | ![results drive](http://ww1.sinaimg.cn/mw690/a2d5ce76ly1gn16yw1u1qj21qs0g8tou.jpg) 113 | STARE: 114 | ![results stare](http://ww1.sinaimg.cn/mw690/a2d5ce76ly1gn1ewrd5luj225s0gt4mc.jpg) 115 | CHASE_DB1: 116 | ![results chasedb1](http://ww1.sinaimg.cn/mw690/a2d5ce76ly1gn1fqwmy8dj23300qo7wh.jpg) 117 | 118 | ## To DO 119 | * [ ] Add other retinal vessel segmentation models and performances. 120 | * [ ] Add SOTA loss function. 121 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # in/out 7 | parser.add_argument('--outf', default='./experiments', 8 | help='trained model will be saved at here') 9 | parser.add_argument('--save', default='UNet_vessel_seg', 10 | help='save name of experiment in args.outf directory') 11 | 12 | # data 13 | parser.add_argument('--train_data_path_list', 14 | default='./prepare_dataset/data_path_list/STARE/train.txt') 15 | parser.add_argument('--test_data_path_list', 16 | default='./prepare_dataset/data_path_list/STARE/test.txt') 17 | parser.add_argument('--train_patch_height', default=64) 18 | parser.add_argument('--train_patch_width', default=64) 19 | parser.add_argument('--N_patches', default=150000, 20 | help='Number of training image patches') 21 | parser.add_argument('--inside_FOV', default='center', 22 | help='Choose from [not,center,all]') 23 | parser.add_argument('--val_ratio', default=0.1, 24 | help='The ratio of the validation set in the training set') 25 | parser.add_argument('--sample_visualization', default=True, 26 | help='Visualization of training samples') 27 | # model parameters 28 | parser.add_argument('--in_channels', default=1,type=int, 29 | help='input channels of model') 30 | parser.add_argument('--classes', default=2,type=int, 31 | help='output channels of model') 32 | 33 | # training 34 | parser.add_argument('--N_epochs', default=50, type=int, 35 | help='number of total epochs to run') 36 | parser.add_argument('--batch_size', default=64, 37 | type=int, help='batch size') 38 | parser.add_argument('--early-stop', default=6, type=int, 39 | help='early stopping') 40 | parser.add_argument('--lr', default=0.0005, type=float, 41 | help='initial learning rate') 42 | parser.add_argument('--val_on_test', default=False, type=bool, 43 | help='Validation on testset') 44 | 45 | # for pre_trained checkpoint 46 | parser.add_argument('--start_epoch', default=1, 47 | help='Start epoch') 48 | parser.add_argument('--pre_trained', default=None, 49 | help='(path of trained _model)load trained model to continue train') 50 | 51 | # testing 52 | parser.add_argument('--test_patch_height', default=96) 53 | parser.add_argument('--test_patch_width', default=96) 54 | parser.add_argument('--stride_height', default=16) 55 | parser.add_argument('--stride_width', default=16) 56 | 57 | # hardware setting 58 | parser.add_argument('--cuda', default=True, type=bool, 59 | help='Use GPU calculating') 60 | 61 | args = parser.parse_args() 62 | 63 | return args 64 | -------------------------------------------------------------------------------- /experiments/Example/Precision_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/Precision_recall.png -------------------------------------------------------------------------------- /experiments/Example/ROC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/ROC.png -------------------------------------------------------------------------------- /experiments/Example/args.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/args.pkl -------------------------------------------------------------------------------- /experiments/Example/args.txt: -------------------------------------------------------------------------------- 1 | outf: /ssd/lzq/projects/vesselseg/experiments 2 | save: UNet_vessel_seg 3 | train_data_path_list: /ssd/lzq/projects/vesselseg/prepare_dataset/data_path_list/DRIVE/train.txt 4 | test_data_path_list: /ssd/lzq/projects/vesselseg/prepare_dataset/data_path_list/DRIVE/test.txt 5 | train_patch_height: 48 6 | train_patch_width: 48 7 | N_patches: 100000 8 | inside_FOV: True 9 | val_ratio: 0.1 10 | in_channels: 1 11 | classes: 2 12 | N_epochs: 50 13 | batch_size: 64 14 | early_stop: 6 15 | lr: 0.0001 16 | val_on_test: True 17 | start_epoch: 1 18 | pre_trained: None 19 | test_patch_height: 96 20 | test_patch_width: 96 21 | stride_height: 16 22 | stride_width: 16 23 | cuda: True 24 | -------------------------------------------------------------------------------- /experiments/Example/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/best_model.pth -------------------------------------------------------------------------------- /experiments/Example/events.out.tfevents.1611322639.yx-Server1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/events.out.tfevents.1611322639.yx-Server1 -------------------------------------------------------------------------------- /experiments/Example/latest_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/latest_model.pth -------------------------------------------------------------------------------- /experiments/Example/log_2021-01-22-21-37.csv: -------------------------------------------------------------------------------- 1 | epoch,train_loss,val_auc_roc,val_acc,val_f1 2 | 1,0.149005,0.978863,0.954974,0.822566 3 | 2,0.112972,0.980005,0.956922,0.826443 4 | 3,0.107827,0.980281,0.957215,0.82673 5 | 4,0.104817,0.980439,0.956755,0.828352 6 | 5,0.102211,0.980667,0.957071,0.826338 7 | 6,0.1,0.980501,0.956719,0.827644 8 | 7,0.097532,0.97973,0.95682,0.826485 9 | 8,0.095253,0.979032,0.955755,0.822463 10 | 9,0.09293,0.979556,0.956356,0.824103 11 | 10,0.09074,0.978833,0.956238,0.822186 12 | 11,0.088608,0.978919,0.95645,0.822891 13 | -------------------------------------------------------------------------------- /experiments/Example/performance.txt: -------------------------------------------------------------------------------- 1 | AUC ROC curve: 0.9806667939202203 2 | AUC PR curve: 0.9148753238140122 3 | F1 score: 0.8263400354419006 4 | Accuracy: 0.9570709869653733 5 | Sensitivity(SE): 0.8024059593282425 6 | Specificity(SP): 0.9796293088690451 7 | Precision: 0.8517458153171612 8 | 9 | Confusion matrix:[[3879816 80678] 10 | [ 114140 463509]] -------------------------------------------------------------------------------- /experiments/Example/result.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/result.npy -------------------------------------------------------------------------------- /experiments/Example/result_img/Result_01_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/result_img/Result_01_test.png -------------------------------------------------------------------------------- /experiments/Example/result_img/Result_02_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/result_img/Result_02_test.png -------------------------------------------------------------------------------- /experiments/Example/result_img/Result_03_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/result_img/Result_03_test.png -------------------------------------------------------------------------------- /experiments/Example/result_img/Result_04_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/result_img/Result_04_test.png -------------------------------------------------------------------------------- /experiments/Example/result_img/Result_05_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/result_img/Result_05_test.png -------------------------------------------------------------------------------- /experiments/Example/sample_input_imgs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/sample_input_imgs.png -------------------------------------------------------------------------------- /experiments/Example/sample_input_masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/experiments/Example/sample_input_masks.png -------------------------------------------------------------------------------- /experiments/Example/test_log.txt: -------------------------------------------------------------------------------- 1 | ==> Resuming from checkpoint.. 2 | load data from ./prepare_dataset/data_path_list/DRIVE/test.txt  3 | ori data shape < ori_imgs:(20, 3, 584, 565) GTs:(20, 1, 584, 565) FOVs:(20, 1, 584, 565) 4 | imgs pixel range 0-255: 5 | GTs pixel range 0-255: 6 | FOVs pixel range 0-255: 7 | ==================data have loaded====================== 8 | 9 | the side H is not compatible with the selected stride of 16 10 | img_h 584, patch_h 96, stride_h 16 11 | (img_h - patch_h) MOD stride_h: 8 12 | So the H dim will be padded with additional 8 pixels 13 | the side W is not compatible with the selected stride of 16 14 | img_w 565, patch_w 96, stride_w 16 15 | (img_w - patch_w) MOD stride_w: 5 16 | So the W dim will be padded with additional 11 pixels 17 | new full images shape: (20, 1, 592, 576) 18 | 19 | Test images shape: (20, 1, 592, 576), vaule range (0.0 - 1.0): 20 | Number of patches on h : 32 21 | Number of patches on w : 31 22 | number of patches per image: 992, totally for this dataset: 19840 23 | test patches shape: (19840, 1, 96, 96), value range (0.0 - 1.0) 24 | N_patches_h: 32 25 | N_patches_w: 31 26 | N_patches_img: 992 27 | According to the dimension inserted, there are 20 images 28 | OrderedDict([('AUC_ROC', 0.980667), ('AUC_PR', 0.914875), ('f1-score', 0.82634), ('Acc', 0.957071), ('SE', 0.802406), ('SP', 0.979629), ('precision', 0.851746)]) 29 | -------------------------------------------------------------------------------- /experiments/Example/train_log.txt: -------------------------------------------------------------------------------- 1 | The computing device used is: GPU 2 | Total number of parameters: 34525954 3 | Architecture of Model have saved in Tensorboard! 4 | load data from /ssd/lzq/projects/vesselseg/prepare_dataset/data_path_list/DRIVE/train.txt  5 | ori data shape < ori_imgs:(20, 3, 584, 565) GTs:(20, 1, 584, 565) FOVs:(20, 1, 584, 565) 6 | imgs pixel range 0-255: 7 | GTs pixel range 0-255: 8 | FOVs pixel range 0-255: 9 | ==================data have loaded====================== 10 | 11 | Train images shape: (20, 1, 584, 565), vaule range (0.0 - 1.0): 12 | patches per image: 5000 13 | train patches shape: (100000, 1, 48, 48), value range (0.0 - 1.0) 14 | Validation on Testset!!! 15 | load data from /ssd/lzq/projects/vesselseg/prepare_dataset/data_path_list/DRIVE/test.txt  16 | ori data shape < ori_imgs:(20, 3, 584, 565) GTs:(20, 1, 584, 565) FOVs:(20, 1, 584, 565) 17 | imgs pixel range 0-255: 18 | GTs pixel range 0-255: 19 | FOVs pixel range 0-255: 20 | ==================data have loaded====================== 21 | 22 | the side H is not compatible with the selected stride of 16 23 | img_h 584, patch_h 96, stride_h 16 24 | (img_h - patch_h) MOD stride_h: 8 25 | So the H dim will be padded with additional 8 pixels 26 | the side W is not compatible with the selected stride of 16 27 | img_w 565, patch_w 96, stride_w 16 28 | (img_w - patch_w) MOD stride_w: 5 29 | So the W dim will be padded with additional 11 pixels 30 | new full images shape: (20, 1, 592, 576) 31 | 32 | Test images shape: (20, 1, 592, 576), vaule range (0.0 - 1.0): 33 | Number of patches on h : 32 34 | Number of patches on w : 31 35 | number of patches per image: 992, totally for this dataset: 19840 36 | test patches shape: (19840, 1, 96, 96), value range (0.0 - 1.0) 37 | 38 | EPOCH: 1/50 --(learn_rate:0.000100) | Time: Fri Jan 22 21:37:23 2021 39 | N_patches_h: 32 40 | N_patches_w: 31 41 | N_patches_img: 992 42 | According to the dimension inserted, there are 20 full images (of 592x576 each) 43 | (20, 1, 592, 576) 44 | OrderedDict([('epoch', 1), ('train_loss', 0.149005), ('val_auc_roc', 0.978863), ('val_acc', 0.954974), ('val_f1', 0.822566)]) 45 | Saving best model! 46 | Best performance at Epoch: 1 | AUC_roc: 0.978863 47 | 48 | EPOCH: 2/50 --(learn_rate:0.000100) | Time: Fri Jan 22 21:39:56 2021 49 | N_patches_h: 32 50 | N_patches_w: 31 51 | N_patches_img: 992 52 | According to the dimension inserted, there are 20 full images (of 592x576 each) 53 | (20, 1, 592, 576) 54 | OrderedDict([('epoch', 2), ('train_loss', 0.112972), ('val_auc_roc', 0.980005), ('val_acc', 0.956922), ('val_f1', 0.826443)]) 55 | Saving best model! 56 | Best performance at Epoch: 2 | AUC_roc: 0.980005 57 | 58 | EPOCH: 3/50 --(learn_rate:0.000100) | Time: Fri Jan 22 21:42:30 2021 59 | N_patches_h: 32 60 | N_patches_w: 31 61 | N_patches_img: 992 62 | According to the dimension inserted, there are 20 full images (of 592x576 each) 63 | (20, 1, 592, 576) 64 | OrderedDict([('epoch', 3), ('train_loss', 0.107827), ('val_auc_roc', 0.980281), ('val_acc', 0.957215), ('val_f1', 0.82673)]) 65 | Saving best model! 66 | Best performance at Epoch: 3 | AUC_roc: 0.980281 67 | 68 | EPOCH: 4/50 --(learn_rate:0.000099) | Time: Fri Jan 22 21:45:05 2021 69 | N_patches_h: 32 70 | N_patches_w: 31 71 | N_patches_img: 992 72 | According to the dimension inserted, there are 20 full images (of 592x576 each) 73 | (20, 1, 592, 576) 74 | OrderedDict([('epoch', 4), ('train_loss', 0.104817), ('val_auc_roc', 0.980439), ('val_acc', 0.956755), ('val_f1', 0.828352)]) 75 | Saving best model! 76 | Best performance at Epoch: 4 | AUC_roc: 0.980439 77 | 78 | EPOCH: 5/50 --(learn_rate:0.000098) | Time: Fri Jan 22 21:47:39 2021 79 | N_patches_h: 32 80 | N_patches_w: 31 81 | N_patches_img: 992 82 | According to the dimension inserted, there are 20 full images (of 592x576 each) 83 | (20, 1, 592, 576) 84 | OrderedDict([('epoch', 5), ('train_loss', 0.102211), ('val_auc_roc', 0.980667), ('val_acc', 0.957071), ('val_f1', 0.826338)]) 85 | Saving best model! 86 | Best performance at Epoch: 5 | AUC_roc: 0.980667 87 | 88 | EPOCH: 6/50 --(learn_rate:0.000098) | Time: Fri Jan 22 21:50:14 2021 89 | N_patches_h: 32 90 | N_patches_w: 31 91 | N_patches_img: 992 92 | According to the dimension inserted, there are 20 full images (of 592x576 each) 93 | (20, 1, 592, 576) 94 | OrderedDict([('epoch', 6), ('train_loss', 0.1), ('val_auc_roc', 0.980501), ('val_acc', 0.956719), ('val_f1', 0.827644)]) 95 | Best performance at Epoch: 5 | AUC_roc: 0.980667 96 | 97 | EPOCH: 7/50 --(learn_rate:0.000096) | Time: Fri Jan 22 21:52:47 2021 98 | N_patches_h: 32 99 | N_patches_w: 31 100 | N_patches_img: 992 101 | According to the dimension inserted, there are 20 full images (of 592x576 each) 102 | (20, 1, 592, 576) 103 | OrderedDict([('epoch', 7), ('train_loss', 0.097532), ('val_auc_roc', 0.97973), ('val_acc', 0.95682), ('val_f1', 0.826485)]) 104 | Best performance at Epoch: 5 | AUC_roc: 0.980667 105 | 106 | EPOCH: 8/50 --(learn_rate:0.000095) | Time: Fri Jan 22 21:55:20 2021 107 | N_patches_h: 32 108 | N_patches_w: 31 109 | N_patches_img: 992 110 | According to the dimension inserted, there are 20 full images (of 592x576 each) 111 | (20, 1, 592, 576) 112 | OrderedDict([('epoch', 8), ('train_loss', 0.095253), ('val_auc_roc', 0.979032), ('val_acc', 0.955755), ('val_f1', 0.822463)]) 113 | Best performance at Epoch: 5 | AUC_roc: 0.980667 114 | 115 | EPOCH: 9/50 --(learn_rate:0.000094) | Time: Fri Jan 22 21:57:53 2021 116 | N_patches_h: 32 117 | N_patches_w: 31 118 | N_patches_img: 992 119 | According to the dimension inserted, there are 20 full images (of 592x576 each) 120 | (20, 1, 592, 576) 121 | OrderedDict([('epoch', 9), ('train_loss', 0.09293), ('val_auc_roc', 0.979556), ('val_acc', 0.956356), ('val_f1', 0.824103)]) 122 | Best performance at Epoch: 5 | AUC_roc: 0.980667 123 | 124 | EPOCH: 10/50 --(learn_rate:0.000092) | Time: Fri Jan 22 22:00:26 2021 125 | N_patches_h: 32 126 | N_patches_w: 31 127 | N_patches_img: 992 128 | According to the dimension inserted, there are 20 full images (of 592x576 each) 129 | (20, 1, 592, 576) 130 | OrderedDict([('epoch', 10), ('train_loss', 0.09074), ('val_auc_roc', 0.978833), ('val_acc', 0.956238), ('val_f1', 0.822186)]) 131 | Best performance at Epoch: 5 | AUC_roc: 0.980667 132 | 133 | EPOCH: 11/50 --(learn_rate:0.000090) | Time: Fri Jan 22 22:02:59 2021 134 | N_patches_h: 32 135 | N_patches_w: 31 136 | N_patches_img: 992 137 | According to the dimension inserted, there are 20 full images (of 592x576 each) 138 | (20, 1, 592, 576) 139 | OrderedDict([('epoch', 11), ('train_loss', 0.088608), ('val_auc_roc', 0.978919), ('val_acc', 0.95645), ('val_f1', 0.822891)]) 140 | Best performance at Epoch: 5 | AUC_roc: 0.980667 141 | => early stopping 142 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | All experiments are saved in the `./experiments` folder and saved separately according to the name. 2 | The directory structure of each experiment is as follows: 3 | ``` 4 | ├── experiments # Experiment results folder 5 | │ ├── Example # An experiment 6 | │ ├── args.pkl # Saved configuration file (pkl format) 7 | │ ├── args.txt # Saved configuration file (txt format) 8 | │ ├── best_model.pth # Best performance model saved 9 | │ ├── events.out.tfevents.00.Server # Tensorboard log files (including loss, acc and auc, etc.) 10 | │ ├── latest_model.pth # Latest model saved 11 | │ ├── log_2021-01-14-23-09.csv # csv log files (including val loss, acc and auc, etc.) 12 | │ ├── performances.txt # Performance on the testset 13 | │ ├── Precision_recall.png # P-R curve on the testset 14 | │ ├── result_img # Visualized results of the testset 15 | │ ├── result.npy # Pixel probability prediction result 16 | │ ├── ROC.png # ROC curve on the testset 17 | │ ├── sample_input_imgs.png # Input image patches example 18 | │ ├── sample_input_masks.png # Input label example 19 | │ ├── test_log.txt # Training process log 20 | │ └── train_log.txt # Test process log 21 | ``` 22 | PS: The `./experiments/Example` is just an example, which `best_model.pth`、`latest_model.pth` and `result.npy` are empty files and cannot be used directly. You need to retrain the model. -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | import random 2 | from os.path import join 3 | from lib.extract_patches import get_data_train 4 | from lib.losses.loss import * 5 | from lib.visualize import group_images, save_img 6 | from lib.common import * 7 | from lib.dataset import TrainDataset 8 | from torch.utils.data import DataLoader 9 | from collections import OrderedDict 10 | from lib.metrics import Evaluate 11 | from lib.visualize import group_images, save_img 12 | from lib.extract_patches import get_data_train 13 | from lib.datasetV2 import data_preprocess,create_patch_idx,TrainDatasetV2 14 | from tqdm import tqdm 15 | 16 | # ========================get dataloader============================== 17 | def get_dataloader(args): 18 | """ 19 | 该函数将数据集加载并直接提取所有训练样本图像块到内存,所以内存占用率较高,容易导致内存溢出 20 | """ 21 | patches_imgs_train, patches_masks_train = get_data_train( 22 | data_path_list = args.train_data_path_list, 23 | patch_height = args.train_patch_height, 24 | patch_width = args.train_patch_width, 25 | N_patches = args.N_patches, 26 | inside_FOV = args.inside_FOV #select the patches only inside the FOV (default == False) 27 | ) 28 | val_ind = random.sample(range(patches_masks_train.shape[0]),int(np.floor(args.val_ratio*patches_masks_train.shape[0]))) 29 | train_ind = set(range(patches_masks_train.shape[0])) - set(val_ind) 30 | train_ind = list(train_ind) 31 | 32 | train_set = TrainDataset(patches_imgs_train[train_ind,...],patches_masks_train[train_ind,...],mode="train") 33 | train_loader = DataLoader(train_set, batch_size=args.batch_size, 34 | shuffle=True, num_workers=6) 35 | 36 | val_set = TrainDataset(patches_imgs_train[val_ind,...],patches_masks_train[val_ind,...],mode="val") 37 | val_loader = DataLoader(val_set, batch_size=args.batch_size, 38 | shuffle=False, num_workers=6) 39 | # Save some samples of feeding to the neural network 40 | if args.sample_visualization: 41 | N_sample = min(patches_imgs_train.shape[0], 50) 42 | save_img(group_images((patches_imgs_train[0:N_sample, :, :, :]*255).astype(np.uint8), 10), 43 | join(args.outf, args.save, "sample_input_imgs.png")) 44 | save_img(group_images((patches_masks_train[0:N_sample, :, :, :]*255).astype(np.uint8), 10), 45 | join(args.outf, args.save,"sample_input_masks.png")) 46 | return train_loader,val_loader 47 | 48 | 49 | def get_dataloaderV2(args): 50 | """ 51 | 该函数加载数据集所有图像到内存,并创建训练样本提取位置的索引,所以占用内存量较少, 52 | 测试结果表明,相比于上述原始的get_dataloader方法并不会降低训练效率 53 | """ 54 | imgs_train, masks_train, fovs_train = data_preprocess(data_path_list = args.train_data_path_list) 55 | 56 | patches_idx = create_patch_idx(fovs_train, args) 57 | 58 | train_idx,val_idx = np.vsplit(patches_idx, (int(np.floor((1-args.val_ratio)*patches_idx.shape[0])),)) 59 | 60 | train_set = TrainDatasetV2(imgs_train, masks_train, fovs_train,train_idx,mode="train",args=args) 61 | train_loader = DataLoader(train_set, batch_size=args.batch_size, 62 | shuffle=True, num_workers=0) 63 | 64 | val_set = TrainDatasetV2(imgs_train, masks_train, fovs_train,val_idx,mode="val",args=args) 65 | val_loader = DataLoader(val_set, batch_size=args.batch_size, 66 | shuffle=False, num_workers=0) 67 | 68 | # Save some samples of feeding to the neural network 69 | if args.sample_visualization: 70 | visual_set = TrainDatasetV2(imgs_train, masks_train, fovs_train,val_idx,mode="val",args=args) 71 | visual_loader = DataLoader(visual_set, batch_size=1,shuffle=True, num_workers=0) 72 | N_sample = 50 73 | visual_imgs = np.empty((N_sample,1,args.train_patch_height, args.train_patch_width)) 74 | visual_masks = np.empty((N_sample,1,args.train_patch_height, args.train_patch_width)) 75 | 76 | for i, (img, mask) in tqdm(enumerate(visual_loader)): 77 | visual_imgs[i] = np.squeeze(img.numpy(),axis=0) 78 | visual_masks[i,0] = np.squeeze(mask.numpy(),axis=0) 79 | if i>=N_sample-1: 80 | break 81 | save_img(group_images((visual_imgs[0:N_sample, :, :, :]*255).astype(np.uint8), 10), 82 | join(args.outf, args.save, "sample_input_imgs.png")) 83 | save_img(group_images((visual_masks[0:N_sample, :, :, :]*255).astype(np.uint8), 10), 84 | join(args.outf, args.save,"sample_input_masks.png")) 85 | return train_loader,val_loader 86 | 87 | # =======================train======================== 88 | def train(train_loader,net,criterion,optimizer,device): 89 | net.train() 90 | train_loss = AverageMeter() 91 | 92 | for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_loader)): 93 | inputs, targets = inputs.to(device), targets.to(device) 94 | optimizer.zero_grad() 95 | 96 | outputs = net(inputs) 97 | loss = criterion(outputs, targets) 98 | loss.backward() 99 | optimizer.step() 100 | 101 | train_loss.update(loss.item(), inputs.size(0)) 102 | log = OrderedDict([('train_loss',train_loss.avg)]) 103 | return log 104 | 105 | # ========================val=============================== 106 | def val(val_loader,net,criterion,device): 107 | net.eval() 108 | val_loss = AverageMeter() 109 | evaluater = Evaluate() 110 | with torch.no_grad(): 111 | for batch_idx, (inputs, targets) in tqdm(enumerate(val_loader), total=len(val_loader)): 112 | inputs, targets = inputs.to(device), targets.to(device) 113 | outputs = net(inputs) 114 | loss = criterion(outputs, targets) 115 | val_loss.update(loss.item(), inputs.size(0)) 116 | 117 | outputs = outputs.data.cpu().numpy() 118 | targets = targets.data.cpu().numpy() 119 | evaluater.add_batch(targets,outputs[:,1]) 120 | log = OrderedDict([('val_loss', val_loss.avg), 121 | ('val_acc', evaluater.confusion_matrix()[1]), 122 | ('val_f1', evaluater.f1_score()), 123 | ('val_auc_roc', evaluater.auc_roc())]) 124 | return log -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/lib/__init__.py -------------------------------------------------------------------------------- /lib/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,joblib 3 | import torch,random 4 | import torch.nn as nn 5 | import cv2,imageio,PIL 6 | from libtiff import TIFFfile 7 | 8 | def readImg(img_path): 9 | """ 10 | When reading local image data, because the format of the data set is not uniform, 11 | the reading method needs to be considered. 12 | Default using pillow to read the desired RGB format img 13 | """ 14 | img_format = img_path.split(".")[-1] 15 | try: 16 | #在win下读取tif格式图像在转np的时候异常终止,暂时没找到合适的读取方式,Linux下直接用PIl读取无问题 17 | img = PIL.Image.open(img_path) 18 | except Exception as e: 19 | ValueError("Reading failed, please check path of dataset,",img_path) 20 | return img 21 | 22 | def count_parameters(model): 23 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value for calculate average loss""" 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | # print(self.val) 42 | 43 | # formulate a learning rate decay strategy 44 | def make_lr_schedule(lr_epoch,lr_value): 45 | lr_schedule = np.zeros(lr_epoch[-1]) 46 | for l in range(len(lr_epoch)): 47 | if l == 0: 48 | lr_schedule[0:lr_epoch[l]] = lr_value[l] 49 | else: 50 | lr_schedule[lr_epoch[l - 1]:lr_epoch[l]] = lr_value[l] 51 | return lr_schedule 52 | 53 | # Save configuration information 54 | def save_args(args,save_path): 55 | if not os.path.exists(save_path): 56 | os.makedirs('%s' % save_path) 57 | 58 | print('Config info -----') 59 | for arg in vars(args): 60 | print('%s: %s' % (arg, getattr(args, arg))) 61 | with open('%s/args.txt' % save_path, 'w') as f: 62 | for arg in vars(args): 63 | print('%s: %s' % (arg, getattr(args, arg)), file=f) 64 | joblib.dump(args, '%s/args.pkl' % save_path) 65 | print('\033[0;33m================config infomation has been saved=================\033[0m') 66 | 67 | # Seed for repeatability 68 | def setpu_seed(seed): 69 | torch.manual_seed(seed) 70 | torch.cuda.manual_seed_all(seed) 71 | np.random.seed(seed) 72 | torch.backends.cudnn.deterministic=True 73 | random.seed(seed) 74 | 75 | # Round off 76 | def dict_round(dic,num): 77 | for key,value in dic.items(): 78 | dic[key] = round(value,num) 79 | return dic 80 | 81 | # params initialization 82 | def weight_initV1(m): 83 | if isinstance(m, nn.Conv2d): 84 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 85 | if m.bias is not None: 86 | nn.init.constant_(m.bias, 0) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | nn.init.constant_(m.weight, 1) 89 | nn.init.constant_(m.bias, 0) 90 | elif isinstance(m, nn.Linear): 91 | nn.init.normal_(m.weight, 0, 0.01) 92 | if m.bias is not None: 93 | nn.init.constant_(m.bias, 0) 94 | 95 | def weight_initV2(m): 96 | if isinstance(m, nn.Linear): 97 | nn.init.xavier_normal_(m.weight) 98 | nn.init.constant_(m.bias, 0) 99 | elif isinstance(m, nn.Conv2d): 100 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 101 | elif isinstance(m, nn.BatchNorm2d): 102 | nn.init.constant_(m.weight, 1) 103 | nn.init.constant_(m.bias, 0) 104 | 105 | def weight_initV3(net, init_type='normal', gain=0.02): 106 | def init_func(m): 107 | classname = m.__class__.__name__ 108 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 109 | if init_type == 'normal': 110 | init.normal_(m.weight.data, 0.0, gain) 111 | elif init_type == 'xavier': 112 | init.xavier_normal_(m.weight.data, gain=gain) 113 | elif init_type == 'kaiming': 114 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 115 | elif init_type == 'orthogonal': 116 | init.orthogonal_(m.weight.data, gain=gain) 117 | else: 118 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 119 | if hasattr(m, 'bias') and m.bias is not None: 120 | init.constant_(m.bias.data, 0.0) 121 | elif classname.find('BatchNorm2d') != -1: 122 | init.normal_(m.weight.data, 1.0, gain) 123 | init.constant_(m.bias.data, 0.0) 124 | 125 | print('initialize network with %s' % init_type) 126 | net.apply(init_func) 127 | -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part is based on the dataset class implemented by pytorch, 3 | including train_dataset and test_dataset, as well as data augmentation 4 | """ 5 | from torch.utils.data import Dataset 6 | import torch 7 | import numpy as np 8 | import random 9 | import torch.nn.functional as F 10 | from torchvision import transforms 11 | from torchvision.transforms.functional import normalize 12 | 13 | class TrainDataset(Dataset): 14 | def __init__(self, patches_imgs,patches_masks,mode="train"): 15 | self.imgs = patches_imgs 16 | self.masks = patches_masks 17 | self.transforms = None 18 | if mode == "train": 19 | self.transforms = Compose([ 20 | # RandomResize([56,72],[56,72]), 21 | RandomCrop((48, 48)), 22 | RandomFlip_LR(prob=0.5), 23 | RandomFlip_UD(prob=0.5), 24 | RandomRotate() 25 | ]) 26 | 27 | def __len__(self): 28 | return self.imgs.shape[0] 29 | 30 | def __getitem__(self, idx): 31 | mask = self.masks[idx] 32 | data = self.imgs[idx] 33 | 34 | data = torch.from_numpy(data).float() 35 | mask = torch.from_numpy(mask).long() 36 | 37 | if self.transforms: 38 | data, mask = self.transforms(data, mask) 39 | return data, mask.squeeze(0) 40 | 41 | #----------------------data augment------------------------------------------- 42 | class Resize: 43 | def __init__(self, shape): 44 | self.shape = [shape, shape] if isinstance(shape, int) else shape 45 | 46 | def __call__(self, img, mask): 47 | img, mask = img.unsqueeze(0), mask.unsqueeze(0).float() 48 | img = F.interpolate(img, size=self.shape, mode="bilinear", align_corners=False) 49 | mask = F.interpolate(mask, size=self.shape, mode="nearest") 50 | return img[0], mask[0].byte() 51 | 52 | class RandomResize: 53 | def __init__(self, w_rank,h_rank): 54 | self.w_rank = w_rank 55 | self.h_rank = h_rank 56 | 57 | def __call__(self, img, mask): 58 | random_w = random.randint(self.w_rank[0],self.w_rank[1]) 59 | random_h = random.randint(self.h_rank[0],self.h_rank[1]) 60 | self.shape = [random_w,random_h] 61 | img, mask = img.unsqueeze(0), mask.unsqueeze(0).float() 62 | img = F.interpolate(img, size=self.shape, mode="bilinear", align_corners=False) 63 | mask = F.interpolate(mask, size=self.shape, mode="nearest") 64 | return img[0], mask[0].long() 65 | 66 | class RandomCrop: 67 | def __init__(self, shape): 68 | self.shape = [shape, shape] if isinstance(shape, int) else shape 69 | self.fill = 0 70 | self.padding_mode = 'constant' 71 | 72 | def _get_range(self, shape, crop_shape): 73 | if shape == crop_shape: 74 | start = 0 75 | else: 76 | start = random.randint(0, shape - crop_shape) 77 | end = start + crop_shape 78 | return start, end 79 | 80 | def __call__(self, img, mask): 81 | _, h, w = img.shape 82 | sh, eh = self._get_range(h, self.shape[0]) 83 | sw, ew = self._get_range(w, self.shape[1]) 84 | return img[:, sh:eh, sw:ew], mask[:, sh:eh, sw:ew] 85 | 86 | class RandomFlip_LR: 87 | def __init__(self, prob=0.5): 88 | self.prob = prob 89 | 90 | def _flip(self, img, prob): 91 | if prob[0] <= self.prob: 92 | img = img.flip(2) 93 | return img 94 | 95 | def __call__(self, img, mask): 96 | prob = (random.uniform(0, 1), random.uniform(0, 1)) 97 | return self._flip(img, prob), self._flip(mask, prob) 98 | 99 | class RandomFlip_UD: 100 | def __init__(self, prob=0.5): 101 | self.prob = prob 102 | 103 | def _flip(self, img, prob): 104 | if prob[1] <= self.prob: 105 | img = img.flip(1) 106 | return img 107 | 108 | def __call__(self, img, mask): 109 | prob = (random.uniform(0, 1), random.uniform(0, 1)) 110 | return self._flip(img, prob), self._flip(mask, prob) 111 | 112 | class RandomRotate: 113 | def __init__(self, max_cnt=3): 114 | self.max_cnt = max_cnt 115 | 116 | def _rotate(self, img, cnt): 117 | img = torch.rot90(img,cnt,[1,2]) 118 | return img 119 | 120 | def __call__(self, img, mask): 121 | cnt = random.randint(0,self.max_cnt) 122 | return self._rotate(img, cnt), self._rotate(mask, cnt) 123 | 124 | 125 | class ToTensor: 126 | def __init__(self): 127 | self.to_tensor = transforms.ToTensor() 128 | 129 | def __call__(self, img, mask): 130 | img = self.to_tensor(img) 131 | mask = torch.from_numpy(np.array(mask)) 132 | return img, mask[None] 133 | 134 | 135 | class Normalize: 136 | def __init__(self, mean, std): 137 | self.mean = mean 138 | self.std = std 139 | 140 | def __call__(self, img, mask): 141 | return normalize(img, self.mean, self.std, False), mask 142 | 143 | 144 | class Compose: 145 | def __init__(self, transforms): 146 | self.transforms = transforms 147 | 148 | def __call__(self, img, mask): 149 | for t in self.transforms: 150 | img, mask = t(img, mask) 151 | return img, mask 152 | 153 | 154 | class TestDataset(Dataset): 155 | """Endovis 2018 dataset.""" 156 | 157 | def __init__(self, patches_imgs): 158 | self.imgs = patches_imgs 159 | 160 | def __len__(self): 161 | return self.imgs.shape[0] 162 | 163 | def __getitem__(self, idx): 164 | return torch.from_numpy(self.imgs[idx,...]).float() 165 | 166 | #----------------------image aug-------------------------------------- 167 | class TrainDataset_imgaug(Dataset): 168 | """Endovis 2018 dataset.""" 169 | 170 | def __init__(self, patches_imgs,patches_masks_train): 171 | self.imgs = patches_imgs 172 | self.masks = patches_masks_train 173 | self.seq = iaa.Sequential([ 174 | # iaa.Sharpen((0.1, 0.5)), 175 | iaa.flip.Fliplr(p=0.5), 176 | iaa.flip.Flipud(p=0.5), 177 | # sharpen the image 178 | # iaa.GaussianBlur(sigma=(0.0, 0.1)), # apply water effect (affects heatmaps) 179 | # iaa.Affine(rotate=(-20, 20)), 180 | # iaa.ElasticTransformation(alpha=16, sigma=8), # water-like effect 181 | ], random_order=True) 182 | 183 | def __len__(self): 184 | return self.imgs.shape[0] 185 | 186 | def __getitem__(self, idx): 187 | mask = self.masks[idx,0] 188 | data = self.imgs[idx] 189 | data = data.transpose((1,2,0)) 190 | mask = ia.SegmentationMapsOnImage(mask, shape=data.shape) 191 | 192 | # 这里可以通过加入循环的方式,对多张图进行数据增强。 193 | seq_det = self.seq.to_deterministic() # 确定一个数据增强的序列 194 | data = seq_det.augment_image(data).transpose((2,0,1))/255.0 # 将方法应用在原图像上 195 | mask = seq_det.augment_segmentation_maps([mask])[0].get_arr().astype(np.uint8) 196 | 197 | return torch.from_numpy(data).float(), torch.from_numpy(mask).long() -------------------------------------------------------------------------------- /lib/datasetV2.py: -------------------------------------------------------------------------------- 1 | """ 2 | This dataset is used reduce memory usage during training 3 | """ 4 | from torch.utils.data import Dataset 5 | import torch 6 | import numpy as np 7 | import random 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | from torchvision.transforms.functional import normalize 11 | 12 | from .extract_patches import load_data, my_PreProc, is_patch_inside_FOV 13 | from .dataset import RandomCrop, RandomFlip_LR, RandomFlip_UD, RandomRotate, Compose 14 | 15 | class TrainDatasetV2(Dataset): 16 | def __init__(self, imgs,masks,fovs,patches_idx,mode,args): 17 | self.imgs = imgs 18 | 19 | self.masks = masks 20 | self.fovs = fovs 21 | self.patch_h, self.patch_w = args.train_patch_height, args.train_patch_width 22 | self.patches_idx = patches_idx 23 | self.inside_FOV = args.inside_FOV 24 | self.transforms = None 25 | if mode == "train": 26 | self.transforms = Compose([ 27 | # RandomResize([56,72],[56,72]), 28 | RandomCrop((48, 48)), 29 | RandomFlip_LR(prob=0.5), 30 | RandomFlip_UD(prob=0.5), 31 | RandomRotate() 32 | ]) 33 | 34 | def __len__(self): 35 | return len(self.patches_idx) 36 | 37 | def __getitem__(self, idx): 38 | n, x_center, y_center = self.patches_idx[idx] 39 | 40 | data = self.imgs[n,:,y_center-int(self.patch_h/2):y_center+int(self.patch_h/2),x_center-int(self.patch_w/2):x_center+int(self.patch_w/2)] 41 | mask = self.masks[n,:,y_center-int(self.patch_h/2):y_center+int(self.patch_h/2),x_center-int(self.patch_w/2):x_center+int(self.patch_w/2)] 42 | 43 | data = torch.from_numpy(data).float() 44 | mask = torch.from_numpy(mask).long() 45 | 46 | if self.transforms: 47 | data, mask = self.transforms(data, mask) 48 | return data, mask.squeeze(0) 49 | 50 | 51 | #----------------------Related Methon-------------------------------------- 52 | def data_preprocess(data_path_list): 53 | train_imgs_original, train_masks, train_FOVs = load_data(data_path_list) 54 | # save_img(group_images(train_imgs_original[0:20,:,:,:],5),'imgs_train.png')#.show() #check original train imgs 55 | 56 | train_imgs = my_PreProc(train_imgs_original) 57 | train_masks = train_masks//255 58 | train_FOVs = train_FOVs//255 59 | return train_imgs, train_masks, train_FOVs 60 | 61 | def create_patch_idx(img_fovs, args): 62 | assert len(img_fovs.shape)==4 63 | N,C,img_h,img_w = img_fovs.shape 64 | res = np.empty((args.N_patches,3),dtype=int) 65 | print("") 66 | 67 | seed=2021 68 | count = 0 69 | while count < args.N_patches: 70 | random.seed(seed) # fuxian 71 | seed+=1 72 | n = random.randint(0,N-1) 73 | x_center = random.randint(0+int(args.train_patch_width/2),img_w-int(args.train_patch_width/2)) 74 | y_center = random.randint(0+int(args.train_patch_height/2),img_h-int(args.train_patch_height/2)) 75 | 76 | #check whether the patch is contained in the FOV 77 | if args.inside_FOV=='center' or args.inside_FOV == 'all': 78 | if not is_patch_inside_FOV(x_center,y_center,img_fovs[n,0],args.train_patch_height,args.train_patch_width,mode=args.inside_FOV): 79 | continue 80 | res[count] = np.asarray([n,x_center,y_center]) 81 | count+=1 82 | 83 | return res 84 | 85 | -------------------------------------------------------------------------------- /lib/extract_patches.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part mainly contains functions related to extracting image patches. 3 | The image patches are randomly extracted in the fov(optional) during the training phase, 4 | and the test phase needs to be spliced after splitting 5 | """ 6 | import numpy as np 7 | import random 8 | import configparser 9 | 10 | from .visualize import save_img, group_images 11 | from .common import readImg 12 | from .pre_processing import my_PreProc 13 | 14 | #=================Load imgs from disk with txt files===================================== 15 | # Load data path index file 16 | def load_file_path_txt(file_path): 17 | img_list = [] 18 | gt_list = [] 19 | fov_list = [] 20 | with open(file_path, 'r') as file_to_read: 21 | while True: 22 | lines = file_to_read.readline().strip() # read a line 23 | if not lines: 24 | break 25 | img,gt,fov = lines.split(' ') 26 | img_list.append(img) 27 | gt_list.append(gt) 28 | fov_list.append(fov) 29 | return img_list,gt_list,fov_list 30 | 31 | # Load the original image, grroundtruth and FOV of the data set in order, and check the dimensions 32 | def load_data(data_path_list_file): 33 | print('\033[0;33mload data from {} \033[0m'.format(data_path_list_file)) 34 | img_list, gt_list, fov_list = load_file_path_txt(data_path_list_file) 35 | imgs = None 36 | groundTruth = None 37 | FOVs = None 38 | for i in range(len(img_list)): 39 | img = np.asarray(readImg(img_list[i])) 40 | gt = np.asarray(readImg(gt_list[i])) 41 | if len(gt.shape)==3: 42 | gt = gt[:,:,0] 43 | fov = np.asarray(readImg(fov_list[i])) 44 | if len(fov.shape)==3: 45 | fov = fov[:,:,0] 46 | 47 | imgs = np.expand_dims(img,0) if imgs is None else np.concatenate((imgs,np.expand_dims(img,0))) 48 | groundTruth = np.expand_dims(gt,0) if groundTruth is None else np.concatenate((groundTruth,np.expand_dims(gt,0))) 49 | FOVs = np.expand_dims(fov,0) if FOVs is None else np.concatenate((FOVs,np.expand_dims(fov,0))) 50 | 51 | assert(np.min(FOVs)==0 and np.max(FOVs)==255) 52 | assert((np.min(groundTruth)==0 and (np.max(groundTruth)==255 or np.max(groundTruth)==1))) # CHASE_DB1数据集GT图像为单通道二值(0和1)图像 53 | if np.max(groundTruth)==1: 54 | print("\033[0;31m Single channel binary image is multiplied by 255 \033[0m") 55 | groundTruth = groundTruth * 255 56 | 57 | #Convert the dimension of imgs to [N,C,H,W] 58 | imgs = np.transpose(imgs,(0,3,1,2)) 59 | groundTruth = np.expand_dims(groundTruth,1) 60 | FOVs = np.expand_dims(FOVs,1) 61 | print('ori data shape < ori_imgs:{} GTs:{} FOVs:{}'.format(imgs.shape,groundTruth.shape,FOVs.shape)) 62 | print("imgs pixel range %s-%s: " %(str(np.min(imgs)),str(np.max(imgs)))) 63 | print("GTs pixel range %s-%s: " %(str(np.min(groundTruth)),str(np.max(groundTruth)))) 64 | print("FOVs pixel range %s-%s: " %(str(np.min(FOVs)),str(np.max(FOVs)))) 65 | print("==================data have loaded======================") 66 | return imgs, groundTruth, FOVs 67 | 68 | #==============================Load train data============================================== 69 | #Load the original data and return the extracted patches for training 70 | def get_data_train(data_path_list,patch_height,patch_width,N_patches,inside_FOV): 71 | train_imgs_original, train_masks, train_FOVs = load_data(data_path_list) 72 | # save_img(group_images(train_imgs_original[0:20,:,:,:],5),'imgs_train.png')#.show() #check original train imgs 73 | 74 | train_imgs = my_PreProc(train_imgs_original) 75 | train_masks = train_masks/255. 76 | train_FOVs = train_FOVs//255 77 | 78 | # Crop edge (optional) 79 | # train_imgs = train_imgs[:,:,9:-9,9:-9] 80 | # train_masks = train_masks[:,:,9:-9,9:-9] 81 | # train_FOVs = train_FOVs[:,:,9:-9,9:-9] 82 | 83 | # Check dimensions 84 | data_dim_check(train_imgs,train_masks) 85 | assert(np.min(train_masks)==0 and np.max(train_masks)==1) 86 | assert(np.min(train_FOVs)==0 and np.max(train_FOVs)==1) 87 | #check masks are within 0-1 88 | print("\nTrain images shape: {}, vaule range ({} - {}):"\ 89 | .format(train_imgs.shape, str(np.min(train_imgs)), str(np.max(train_imgs)))) 90 | 91 | #extract the train patches from all images 92 | patches_imgs_train, patches_masks_train = extract_random(train_imgs,train_masks,train_FOVs,patch_height,patch_width,N_patches,inside_FOV) 93 | data_dim_check(patches_imgs_train, patches_masks_train) 94 | 95 | print("train patches shape: {}, value range ({} - {})"\ 96 | .format(patches_imgs_train.shape, str(np.min(patches_imgs_train)), str(np.max(patches_imgs_train)))) 97 | 98 | return patches_imgs_train, patches_masks_train 99 | 100 | # extract patches randomly in the training images 101 | def extract_random(full_imgs,full_masks,full_FOVs, patch_h,patch_w, N_patches, inside='not'): 102 | patch_per_img = int(N_patches/full_imgs.shape[0]) 103 | if (N_patches%full_imgs.shape[0] != 0): 104 | print("\033[0;31mRecommended N_patches be set as a multiple of train img numbers\033[0m") 105 | N_patches = patch_per_img * full_imgs.shape[0] 106 | print("patches per image: " +str(patch_per_img), " Total number of patches:", N_patches) 107 | patches = np.empty((N_patches,full_imgs.shape[1],patch_h,patch_w)) 108 | patches_masks = np.empty((N_patches,full_masks.shape[1],patch_h,patch_w), dtype=np.uint8) 109 | img_h = full_imgs.shape[2] #height of the image 110 | img_w = full_imgs.shape[3] #width of the image 111 | 112 | iter_tot = 0 #iter over the total numbe rof patches (N_patches) 113 | for i in range(full_imgs.shape[0]): #loop over the all images 114 | k=0 115 | while k division between integers 212 | N_patches_tot = N_patches_img*full_imgs.shape[0] 213 | print("Number of patches on h : " +str(((img_h-patch_h)//stride_h+1))) 214 | print("Number of patches on w : " +str(((img_w-patch_w)//stride_w+1))) 215 | print("number of patches per image: " +str(N_patches_img) +", totally for testset: " +str(N_patches_tot)) 216 | patches = np.empty((N_patches_tot,full_imgs.shape[1],patch_h,patch_w)) 217 | iter_tot = 0 #iter over the total number of patches (N_patches) 218 | for i in range(full_imgs.shape[0]): #loop over the full images 219 | for h in range((img_h-patch_h)//stride_h+1): 220 | for w in range((img_w-patch_w)//stride_w+1): 221 | patch = full_imgs[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w] 222 | patches[iter_tot]=patch 223 | iter_tot +=1 #total 224 | assert (iter_tot==N_patches_tot) 225 | return patches #array with all the full_imgs divided in patches 226 | 227 | # recompone the prediction result patches to images 228 | def recompone_overlap(preds, img_h, img_w, stride_h, stride_w): 229 | assert (len(preds.shape)==4) #4D arrays 230 | assert (preds.shape[1]==1 or preds.shape[1]==3) #check the channel is 1 or 3 231 | patch_h = preds.shape[2] 232 | patch_w = preds.shape[3] 233 | N_patches_h = (img_h-patch_h)//stride_h+1 234 | N_patches_w = (img_w-patch_w)//stride_w+1 235 | N_patches_img = N_patches_h * N_patches_w 236 | # print("N_patches_h: " + str(N_patches_h)) 237 | # print("N_patches_w: " + str(N_patches_w)) 238 | # print("N_patches_img: " + str(N_patches_img)) 239 | assert (preds.shape[0]%N_patches_img==0) 240 | N_full_imgs = preds.shape[0]//N_patches_img 241 | print("There are " +str(N_full_imgs) +" images in Testset") 242 | full_prob = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w)) 243 | full_sum = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w)) 244 | 245 | k = 0 #iterator over all the patches 246 | for i in range(N_full_imgs): 247 | for h in range((img_h-patch_h)//stride_h+1): 248 | for w in range((img_w-patch_w)//stride_w+1): 249 | full_prob[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=preds[k] # Accumulate predicted values 250 | full_sum[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=1 # Accumulate the number of predictions 251 | k+=1 252 | assert(k==preds.shape[0]) 253 | assert(np.min(full_sum)>=1.0) 254 | final_avg = full_prob/full_sum # Take the average 255 | # print(final_avg.shape) 256 | assert(np.max(final_avg)<=1.0) # max value for a pixel is 1.0 257 | assert(np.min(final_avg)>=0.0) # min value for a pixel is 0.0 258 | return final_avg 259 | 260 | #return only the predicted pixels contained in the FOV, for both images and masks 261 | def pred_only_in_FOV(data_imgs,data_masks,FOVs): 262 | assert (len(data_imgs.shape)==4 and len(data_masks.shape)==4) #4D arrays 263 | height = data_imgs.shape[2] 264 | width = data_imgs.shape[3] 265 | new_pred_imgs = [] 266 | new_pred_masks = [] 267 | for i in range(data_imgs.shape[0]): #loop over the all test images 268 | for x in range(width): 269 | for y in range(height): 270 | if pixel_inside_FOV(i,x,y,FOVs): 271 | new_pred_imgs.append(data_imgs[i,:,y,x]) 272 | new_pred_masks.append(data_masks[i,:,y,x]) 273 | new_pred_imgs = np.asarray(new_pred_imgs) 274 | new_pred_masks = np.asarray(new_pred_masks) 275 | return new_pred_imgs, new_pred_masks 276 | 277 | # Set the pixel value outside FOV to 0, only for visualization 278 | def kill_border(data, FOVs): 279 | assert (len(data.shape)==4) #4D arrays 280 | assert (data.shape[1]==1 or data.shape[1]==3) #check the channel is 1 or 3 281 | height = data.shape[2] 282 | width = data.shape[3] 283 | for i in range(data.shape[0]): #loop over the full images 284 | for x in range(width): 285 | for y in range(height): 286 | if not pixel_inside_FOV(i,x,y,FOVs): 287 | data[i,:,y,x]=0.0 288 | 289 | # function to judge pixel(x,y) in FOV or not 290 | def pixel_inside_FOV(i, x, y, FOVs): 291 | assert (len(FOVs.shape)==4) #4D arrays 292 | assert (FOVs.shape[1]==1) 293 | if (x >= FOVs.shape[3] or y >= FOVs.shape[2]): # Pixel position is out of range 294 | return False 295 | return FOVs[i,0,y,x]>0 #0==black pixels 296 | 297 | -------------------------------------------------------------------------------- /lib/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import os 4 | import pandas as pd 5 | from tensorboardX import SummaryWriter 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import time 9 | from collections import OrderedDict 10 | from .common import dict_round 11 | 12 | # Record data in tensorboard and .csv files during training stage 13 | class Logger(): 14 | def __init__(self,save_name): 15 | self.log = None 16 | self.summary = None 17 | self.name = save_name 18 | self.time_now = time.strftime('_%Y-%m-%d-%H-%M', time.localtime()) 19 | 20 | def update(self,epoch,train_log,val_log): 21 | item = OrderedDict({'epoch':epoch}) 22 | item.update(train_log) 23 | item.update(val_log) 24 | item = dict_round(item,6) # 保留小数点后6位有效数字 25 | print(item) 26 | self.update_csv(item) 27 | self.update_tensorboard(item) 28 | 29 | def update_csv(self,item): 30 | tmp = pd.DataFrame(item,index=[0]) 31 | if self.log is not None: 32 | self.log = self.log.append(tmp, ignore_index=True) 33 | else: 34 | self.log = tmp 35 | self.log.to_csv('%s/log%s.csv' %(self.name,self.time_now), index=False) 36 | 37 | def update_tensorboard(self,item): 38 | if self.summary is None: 39 | self.summary = SummaryWriter('%s/' % self.name) 40 | epoch = item['epoch'] 41 | for key,value in item.items(): 42 | if key != 'epoch': self.summary.add_scalar(key, value, epoch) 43 | def save_graph(self,model,input): 44 | if self.summary is None: 45 | self.summary = SummaryWriter('%s/' % self.name) 46 | self.summary.add_graph(model, (input,)) 47 | print("Architecture of Model have saved in Tensorboard!") 48 | 49 | # Record the information printed in the terminal 50 | class Print_Logger(object): 51 | def __init__(self, filename="Default.log"): 52 | self.terminal = sys.stdout 53 | self.log = open(filename, "a") 54 | 55 | def write(self, message): 56 | self.terminal.write(message) 57 | self.log.write(message) 58 | 59 | def flush(self): 60 | pass 61 | # call by 62 | # sys.stdout = Logger(os.path.join(save_path,'test_log.txt')) 63 | 64 | -------------------------------------------------------------------------------- /lib/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lee-zq/VesselSeg-Pytorch/b4f6571fc1fb1fbdaad60ff9282a54a1f1c455fa/lib/losses/__init__.py -------------------------------------------------------------------------------- /lib/losses/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part is the available loss function 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class LossMulti: 11 | def __init__(self, jaccard_weight=0, class_weights=None, num_classes=1): 12 | if class_weights is not None: 13 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).cuda() 14 | else: 15 | nll_weight = None 16 | self.nll_loss = nn.NLLLoss(weight=nll_weight) # Not include softmax 17 | self.jaccard_weight = jaccard_weight 18 | self.num_classes = num_classes 19 | 20 | def __call__(self, outputs, targets): 21 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets) 22 | 23 | if self.jaccard_weight: 24 | eps = 1e-15 25 | for cls in range(self.num_classes): 26 | jaccard_target = (targets == cls).float() 27 | jaccard_output = outputs[:, cls].exp() 28 | intersection = (jaccard_output * jaccard_target).sum() 29 | 30 | union = jaccard_output.sum() + jaccard_target.sum() 31 | loss -= torch.log((intersection + eps) / (union - intersection + eps)) * self.jaccard_weight 32 | return loss 33 | 34 | class CrossEntropyLoss2d(nn.Module): 35 | def __init__(self, weight=None, size_average=True, ignore_index=255): 36 | super(CrossEntropyLoss2d, self).__init__() 37 | self.nll_loss = nn.NLLLoss(weight, size_average, ignore_index) 38 | 39 | def forward(self, inputs, targets): 40 | return self.nll_loss(torch.log(inputs), targets) 41 | 42 | class FocalLoss2d(nn.Module): 43 | def __init__(self, gamma=2, weight=None, size_average=True, ignore_index=255): 44 | super(FocalLoss2d, self).__init__() 45 | self.gamma = gamma 46 | self.nll_loss = nn.NLLLoss(weight, size_average, ignore_index) 47 | 48 | def forward(self, inputs, targets): # 包含了log_softmax函数,调用时网络输出层不需要加log_softmax 49 | return self.nll_loss((1 - F.softmax(inputs,1)) ** self.gamma * F.log_softmax(inputs,1), targets) -------------------------------------------------------------------------------- /lib/losses/loss_lab.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part is mainly the important loss function in semantic segmentation. 3 | These functions are still in beta. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from torch import mean 11 | 12 | # one-hot encode of target(label) 13 | def to_one_hot(tensor, n_classes): 14 | n, h, w = tensor.size() 15 | one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1) 16 | return one_hot 17 | 18 | # ---------------------------Dice Loss------------------------- 19 | class DiceLoss(nn.Module): 20 | def __init__(self): 21 | super(DiceLoss, self).__init__() 22 | 23 | def forward(self, input, target): 24 | N = target.size(0) 25 | smooth = 1 26 | 27 | input_flat = input.view(N, -1) 28 | target_flat = target.view(N, -1) 29 | 30 | intersection = input_flat * target_flat 31 | 32 | loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) 33 | loss = 1 - loss.sum() / N 34 | 35 | return loss 36 | 37 | class MulticlassDiceLoss(nn.Module): 38 | """ 39 | requires one hot encoded target. Applies DiceLoss on each class iteratively. 40 | requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is 41 | batch size and C is number of classes 42 | """ 43 | 44 | def __init__(self): 45 | super(MulticlassDiceLoss, self).__init__() 46 | 47 | def forward(self, input, target, weights=None): 48 | 49 | C = target.shape[1] 50 | 51 | if weights is None: 52 | weights = torch.ones(C) #uniform weights for all classes 53 | 54 | dice = DiceLoss() 55 | totalLoss = 0 56 | 57 | for i in range(C): 58 | diceLoss = dice(input[:, i], target[:, i]) 59 | if weights is not None: 60 | diceLoss *= weights[i] 61 | totalLoss += diceLoss 62 | 63 | return totalLoss 64 | 65 | # ---------------------------IoU Loss------------------------- 66 | class SoftIoULoss(nn.Module): 67 | def __init__(self, n_classes): 68 | super(SoftIoULoss, self).__init__() 69 | self.n_classes = n_classes 70 | 71 | @staticmethod 72 | def to_one_hot(tensor, n_classes): 73 | n, h, w = tensor.size() 74 | one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1) 75 | return one_hot 76 | 77 | def forward(self, input, target): 78 | # logit => N x Classes x H x W 79 | # target => N x H x W 80 | 81 | N = len(input) 82 | 83 | pred = F.softmax(input, dim=1) 84 | target_onehot = self.to_one_hot(target, self.n_classes) 85 | 86 | # Numerator Product 87 | inter = pred * target_onehot 88 | # Sum over all pixels N x C x H x W => N x C 89 | inter = inter.view(N, self.n_classes, -1).sum(2) 90 | 91 | # Denominator 92 | union = pred + target_onehot - (pred * target_onehot) 93 | # Sum over all pixels N x C x H x W => N x C 94 | union = union.view(N, self.n_classes, -1).sum(2) 95 | 96 | loss = inter / (union + 1e-16) 97 | 98 | # Return average loss over classes and batch 99 | return -loss.mean() 100 | 101 | # ---------------------------lovasz Loss and it's func------------------ 102 | def lovasz_grad(gt_sorted): 103 | """ 104 | Computes gradient of the Lovasz extension w.r.t sorted errors 105 | See Alg. 1 in paper 106 | """ 107 | p = len(gt_sorted) 108 | gts = gt_sorted.sum() 109 | intersection = gts - gt_sorted.float().cumsum(0) 110 | union = gts + (1 - gt_sorted).float().cumsum(0) 111 | jaccard = 1. - intersection / union 112 | if p > 1: # cover 1-pixel case 113 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 114 | return jaccard 115 | 116 | # --------------------------- BINARY lovasz LOSSES --------------------------- 117 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 118 | """ 119 | Binary Lovasz hinge loss 120 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 121 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 122 | per_image: compute the loss per image instead of per batch 123 | ignore: void class id 124 | """ 125 | if per_image: 126 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 127 | for log, lab in zip(logits, labels)) 128 | else: 129 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 130 | return loss 131 | 132 | def lovasz_hinge_flat(logits, labels): 133 | """ 134 | Binary Lovasz hinge loss 135 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 136 | labels: [P] Tensor, binary ground truth labels (0 or 1) 137 | ignore: label to ignore 138 | """ 139 | if len(labels) == 0: 140 | # only void pixels, the gradients should be 0 141 | return logits.sum() * 0. 142 | signs = 2. * labels.float() - 1. 143 | errors = (1. - logits * Variable(signs)) 144 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 145 | perm = perm.data 146 | gt_sorted = labels[perm] 147 | grad = lovasz_grad(gt_sorted) 148 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 149 | return loss 150 | 151 | def flatten_binary_scores(scores, labels, ignore=None): 152 | """ 153 | Flattens predictions in the batch (binary case) 154 | Remove labels equal to 'ignore' 155 | """ 156 | scores = scores.view(-1) 157 | labels = labels.view(-1) 158 | if ignore is None: 159 | return scores, labels 160 | valid = (labels != ignore) 161 | vscores = scores[valid] 162 | vlabels = labels[valid] 163 | return vscores, vlabels 164 | 165 | # --------------------------- MULTICLASS lovasz LOSSES --------------------------- 166 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 167 | """ 168 | Multi-class Lovasz-Softmax loss 169 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 170 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 171 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 172 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 173 | per_image: compute the loss per image instead of per batch 174 | ignore: void class labels 175 | """ 176 | if per_image: 177 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 178 | for prob, lab in zip(probas, labels)) 179 | else: 180 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 181 | return loss 182 | 183 | def lovasz_softmax_flat(probas, labels, classes='present'): 184 | """ 185 | Multi-class Lovasz-Softmax loss 186 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 187 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 188 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 189 | """ 190 | if probas.numel() == 0: 191 | # only void pixels, the gradients should be 0 192 | return probas * 0. 193 | C = probas.size(1) 194 | losses = [] 195 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 196 | for c in class_to_sum: 197 | fg = (labels == c).float() # foreground for class c 198 | if (classes is 'present' and fg.sum() == 0): 199 | continue 200 | if C == 1: 201 | if len(classes) > 1: 202 | raise ValueError('Sigmoid output possible only with 1 class') 203 | class_pred = probas[:, 0] 204 | else: 205 | class_pred = probas[:, c] 206 | errors = (Variable(fg) - class_pred).abs() 207 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 208 | perm = perm.data 209 | fg_sorted = fg[perm] 210 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 211 | return mean(torch.Tensor(losses)) 212 | 213 | def flatten_probas(probas, labels, ignore=None): 214 | """ 215 | Flattens predictions in the batch 216 | """ 217 | if probas.dim() == 3: 218 | # assumes output of a sigmoid layer 219 | B, H, W = probas.size() 220 | probas = probas.view(B, 1, H, W) 221 | B, C, H, W = probas.size() 222 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 223 | labels = labels.view(-1) 224 | if ignore is None: 225 | return probas, labels 226 | valid = (labels != ignore) 227 | vprobas = probas[valid.nonzero().squeeze()] 228 | vlabels = labels[valid] 229 | return vprobas, vlabels 230 | 231 | # --------------------------- Focal LOSS --------------------------- 232 | class FocalLoss(nn.Module): 233 | def __init__(self, alpha=0.25, gamma=2, weight=None, ignore_index=255): 234 | super(FocalLoss, self).__init__() 235 | self.alpha = alpha 236 | self.gamma = gamma 237 | self.weight = weight 238 | self.ignore_index = ignore_index 239 | self.bce_fn = nn.BCEWithLogitsLoss(weight=self.weight) 240 | 241 | def forward(self, preds, labels): 242 | if self.ignore_index is not None: 243 | mask = labels != self.ignore 244 | labels = labels[mask] 245 | preds = preds[mask] 246 | 247 | logpt = -self.bce_fn(preds, labels) 248 | pt = torch.exp(logpt) 249 | loss = -((1 - pt) ** self.gamma) * self.alpha * logpt 250 | return loss 251 | 252 | class MUlticlassFocalLoss(nn.Module): 253 | def __init__(self, alpha=0.5, gamma=2, weight=None, ignore_index=255): 254 | super().__init__() 255 | self.alpha = alpha 256 | self.gamma = gamma 257 | self.weight = weight 258 | self.ignore_index = ignore_index 259 | self.ce_fn = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index) 260 | 261 | def forward(self, preds, labels): 262 | logpt = -self.ce_fn(preds, labels) 263 | pt = torch.exp(logpt) 264 | loss = -((1 - pt) ** self.gamma) * self.alpha * logpt 265 | return loss 266 | 267 | 268 | def OHEM(output, target, alpha, gamma, OHEM_percent): # online hard example mining 269 | output = output.contiguous().view(-1) 270 | target = target.contiguous().view(-1) 271 | 272 | max_val = (-output).clamp(min=0) 273 | loss = output - output * target + max_val + ((-max_val).exp() + (-output - max_val).exp()).log() 274 | 275 | # This formula gives us the log sigmoid of 1-p if y is 0 and of p if y is 1 276 | invprobs = F.logsigmoid(-output * (target * 2 - 1)) 277 | focal_loss = alpha * (invprobs * gamma).exp() * loss 278 | 279 | # Online Hard Example Mining: top x% losses (pixel-wise). Refer to http://www.robots.ox.ac.uk/~tvg/publications/2017/0026.pdf 280 | OHEM, _ = focal_loss.topk(k=int(OHEM_percent * [*focal_loss.shape][0])) 281 | return OHEM.mean() 282 | 283 | if __name__ == '__main__': 284 | out1 = torch.rand(10,3,64,64) 285 | target1 = torch.randint(0,3,(10,64,64)) 286 | # print(target1) 287 | one_hot_target = to_one_hot(target1,3) 288 | print(one_hot_target.size()) 289 | # loss = DiceLoss() 290 | # loss = MulticlassDiceLoss() 291 | # loss = lovasz_softmax(out1,target1) 292 | # loss_value = loss(out1,target1) 293 | # loss = MUlticlassFocalLoss() 294 | # loss = nn.CrossEntropyLoss() 295 | loss = SoftIoULoss(3) 296 | loss_val = loss(out1,target1) 297 | print(loss_val) -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part contains functions related to the calculation of performance indicators 3 | """ 4 | from sklearn.metrics import roc_curve 5 | from sklearn.metrics import roc_auc_score 6 | from sklearn.metrics import confusion_matrix 7 | from sklearn.metrics import precision_recall_curve 8 | from sklearn.metrics import f1_score 9 | import os 10 | import torch 11 | from os.path import join 12 | import numpy as np 13 | from collections import OrderedDict 14 | import matplotlib.pylab as pylab 15 | import matplotlib.pyplot as plt 16 | params = {'legend.fontsize': 13, 17 | 'axes.labelsize': 15, 18 | 'axes.titlesize':15, 19 | 'xtick.labelsize':15, 20 | 'ytick.labelsize':15} # define pyplot parameters 21 | pylab.rcParams.update(params) 22 | #Area under the ROC curve 23 | 24 | class Evaluate(): 25 | def __init__(self,save_path=None): 26 | self.target = None 27 | self.output = None 28 | self.save_path = save_path 29 | if self.save_path is not None: 30 | if not os.path.exists(self.save_path): 31 | os.makedirs(self.save_path) 32 | self.threshold_confusion = 0.5 33 | 34 | # Add data pair (target and predicted value) 35 | def add_batch(self,batch_tar,batch_out): 36 | batch_tar = batch_tar.flatten() 37 | batch_out = batch_out.flatten() 38 | 39 | self.target = batch_tar if self.target is None else np.concatenate((self.target,batch_tar)) 40 | self.output = batch_out if self.output is None else np.concatenate((self.output,batch_out)) 41 | 42 | # Plot ROC and calculate AUC of ROC 43 | def auc_roc(self,plot=False): 44 | AUC_ROC = roc_auc_score(self.target, self.output) 45 | # print("\nAUC of ROC curve: " + str(AUC_ROC)) 46 | if plot and self.save_path is not None: 47 | fpr, tpr, thresholds = roc_curve(self.target, self.output) 48 | # print("\nArea under the ROC curve: " + str(AUC_ROC)) 49 | plt.figure() 50 | plt.plot(fpr, tpr, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_ROC) 51 | plt.title('ROC curve') 52 | plt.xlabel("FPR (False Positive Rate)") 53 | plt.ylabel("TPR (True Positive Rate)") 54 | plt.legend(loc="lower right") 55 | plt.savefig(join(self.save_path , "ROC.png")) 56 | return AUC_ROC 57 | 58 | # Plot PR curve and calculate AUC of PR curve 59 | def auc_pr(self,plot=False): 60 | precision, recall, thresholds = precision_recall_curve(self.target, self.output) 61 | precision = np.fliplr([precision])[0] 62 | recall = np.fliplr([recall])[0] 63 | AUC_pr = np.trapz(precision, recall) 64 | # print("\nAUC of P-R curve: " + str(AUC_pr)) 65 | if plot and self.save_path is not None: 66 | 67 | plt.figure() 68 | plt.plot(recall, precision, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_pr) 69 | plt.title('Precision - Recall curve') 70 | plt.xlabel("Recall") 71 | plt.ylabel("Precision") 72 | plt.legend(loc="lower right") 73 | plt.savefig(join(self.save_path ,"Precision_recall.png")) 74 | return AUC_pr 75 | 76 | # Accuracy, specificity, sensitivity, precision can be obtained by calculating the confusion matrix 77 | def confusion_matrix(self): 78 | #Confusion matrix 79 | y_pred = self.output>=self.threshold_confusion 80 | confusion = confusion_matrix(self.target, y_pred) 81 | # print(confusion) 82 | accuracy = 0 83 | if float(np.sum(confusion))!=0: 84 | accuracy = float(confusion[0,0]+confusion[1,1])/float(np.sum(confusion)) 85 | # print("Global Accuracy: " +str(accuracy)) 86 | specificity = 0 87 | if float(confusion[0,0]+confusion[0,1])!=0: 88 | specificity = float(confusion[0,0])/float(confusion[0,0]+confusion[0,1]) 89 | # print("Specificity: " +str(specificity)) 90 | sensitivity = 0 91 | if float(confusion[1,1]+confusion[1,0])!=0: 92 | sensitivity = float(confusion[1,1])/float(confusion[1,1]+confusion[1,0]) 93 | # print("Sensitivity: " +str(sensitivity)) 94 | precision = 0 95 | if float(confusion[1,1]+confusion[0,1])!=0: 96 | precision = float(confusion[1,1])/float(confusion[1,1]+confusion[0,1]) 97 | # print("Precision: " +str(precision)) 98 | return confusion,accuracy,specificity,sensitivity,precision 99 | 100 | # Jaccard similarity index 101 | def jaccard_index(self): 102 | pass 103 | # jaccard_index = jaccard_similarity_score(y_true, y_pred, normalize=True) 104 | # print("\nJaccard similarity score: " +str(jaccard_index)) 105 | 106 | # calculating f1_score 107 | def f1_score(self): 108 | pred = self.output>=self.threshold_confusion 109 | F1_score = f1_score(self.target, pred, labels=None, average='binary', sample_weight=None) 110 | # print("F1 score (F-measure): " +str(F1_score)) 111 | return F1_score 112 | 113 | # Save performance results to specified file 114 | def save_all_result(self,plot_curve=True,save_name=None): 115 | #Save the results 116 | AUC_ROC = self.auc_roc(plot=plot_curve) 117 | AUC_pr = self.auc_pr(plot=plot_curve) 118 | F1_score = self.f1_score() 119 | confusion,accuracy, specificity, sensitivity, precision = self.confusion_matrix() 120 | if save_name is not None: 121 | file_perf = open(join(self.save_path, save_name), 'w') 122 | file_perf.write("AUC ROC curve: "+str(AUC_ROC) 123 | + "\nAUC PR curve: " +str(AUC_pr) 124 | # + "\nJaccard similarity score: " +str(jaccard_index) 125 | + "\nF1 score: " +str(F1_score) 126 | +"\nAccuracy: " +str(accuracy) 127 | +"\nSensitivity(SE): " +str(sensitivity) 128 | +"\nSpecificity(SP): " +str(specificity) 129 | +"\nPrecision: " +str(precision) 130 | + "\n\nConfusion matrix:" 131 | + str(confusion) 132 | ) 133 | file_perf.close() 134 | return OrderedDict([("AUC_ROC",AUC_ROC),("AUC_PR",AUC_pr), 135 | ("f1-score",F1_score),("Acc",accuracy), 136 | ("SE",sensitivity),("SP",specificity), 137 | ("precision",precision) 138 | ]) -------------------------------------------------------------------------------- /lib/pre_processing.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # 3 | # Script to pre-process the original imgs 4 | # 5 | ################################################## 6 | 7 | import numpy as np 8 | import cv2 9 | 10 | #My pre processing (use for both training and testing!) 11 | def my_PreProc(data): 12 | assert(len(data.shape)==4) 13 | assert (data.shape[1]==3) #Use the original images 14 | #black-white conversion 15 | train_imgs = rgb2gray(data) 16 | #my preprocessing: 17 | train_imgs = dataset_normalized(train_imgs) 18 | train_imgs = clahe_equalized(train_imgs) 19 | train_imgs = adjust_gamma(train_imgs, 1.2) 20 | train_imgs = train_imgs/255. #reduce to 0-1 range 21 | return train_imgs 22 | 23 | #============================================================ 24 | #========= PRE PROCESSING FUNCTIONS ========================# 25 | #============================================================ 26 | 27 | #convert RGB image in black and white 28 | def rgb2gray(rgb): 29 | assert (len(rgb.shape)==4) #4D arrays 30 | assert (rgb.shape[1]==3) 31 | bn_imgs = rgb[:,0,:,:]*0.299 + rgb[:,1,:,:]*0.587 + rgb[:,2,:,:]*0.114 32 | bn_imgs = np.reshape(bn_imgs,(rgb.shape[0],1,rgb.shape[2],rgb.shape[3])) 33 | return bn_imgs 34 | 35 | #==== histogram equalization 36 | def histo_equalized(imgs): 37 | assert (len(imgs.shape)==4) #4D arrays 38 | assert (imgs.shape[1]==1) #check the channel is 1 39 | imgs_equalized = np.empty(imgs.shape) 40 | for i in range(imgs.shape[0]): 41 | imgs_equalized[i,0] = cv2.equalizeHist(np.array(imgs[i,0], dtype = np.uint8)) 42 | return imgs_equalized 43 | 44 | 45 | # CLAHE (Contrast Limited Adaptive Histogram Equalization) 46 | #adaptive histogram equalization is used. In this, image is divided into small blocks called "tiles" (tileSize is 8x8 by default in OpenCV). Then each of these blocks are histogram equalized as usual. So in a small area, histogram would confine to a small region (unless there is noise). If noise is there, it will be amplified. To avoid this, contrast limiting is applied. If any histogram bin is above the specified contrast limit (by default 40 in OpenCV), those pixels are clipped and distributed uniformly to other bins before applying histogram equalization. After equalization, to remove artifacts in tile borders, bilinear interpolation is applied 47 | def clahe_equalized(imgs): 48 | assert (len(imgs.shape)==4) #4D arrays 49 | assert (imgs.shape[1]==1) #check the channel is 1 50 | #create a CLAHE object (Arguments are optional). 51 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) 52 | imgs_equalized = np.empty(imgs.shape) 53 | for i in range(imgs.shape[0]): 54 | imgs_equalized[i,0] = clahe.apply(np.array(imgs[i,0], dtype = np.uint8)) 55 | return imgs_equalized 56 | 57 | 58 | # ===== normalize over the dataset 59 | def dataset_normalized(imgs): 60 | assert (len(imgs.shape)==4) #4D arrays 61 | assert (imgs.shape[1]==1) #check the channel is 1 62 | imgs_normalized = np.empty(imgs.shape) 63 | imgs_std = np.std(imgs) 64 | imgs_mean = np.mean(imgs) 65 | imgs_normalized = (imgs-imgs_mean)/imgs_std 66 | for i in range(imgs.shape[0]): 67 | imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255 68 | return imgs_normalized 69 | 70 | 71 | def adjust_gamma(imgs, gamma=1.0): 72 | assert (len(imgs.shape)==4) #4D arrays 73 | assert (imgs.shape[1]==1) #check the channel is 1 74 | # build a lookup table mapping the pixel values [0, 255] to 75 | # their adjusted gamma values 76 | invGamma = 1.0 / gamma 77 | table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8") 78 | # apply gamma correction using the lookup table 79 | new_imgs = np.empty(imgs.shape) 80 | for i in range(imgs.shape[0]): 81 | new_imgs[i,0] = cv2.LUT(np.array(imgs[i,0], dtype = np.uint8), table) 82 | return new_imgs 83 | 84 | -------------------------------------------------------------------------------- /lib/visualize.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from PIL import Image 4 | from matplotlib import pyplot as plt 5 | from copy import deepcopy 6 | 7 | #group a set of img patches 8 | def group_images(data,per_row): 9 | assert data.shape[0]%per_row==0 10 | assert (data.shape[1]==1 or data.shape[1]==3) 11 | data = np.transpose(data,(0,2,3,1)) 12 | all_stripe = [] 13 | for i in range(int(data.shape[0]/per_row)): 14 | stripe = data[i*per_row] 15 | for k in range(i*per_row+1, i*per_row+per_row): 16 | stripe = np.concatenate((stripe,data[k]),axis=1) 17 | all_stripe.append(stripe) 18 | totimg = all_stripe[0] 19 | for i in range(1,len(all_stripe)): 20 | totimg = np.concatenate((totimg,all_stripe[i]),axis=0) 21 | return totimg 22 | 23 | # Prediction result splicing (original img, predicted probability, binary img, groundtruth) 24 | def concat_result(ori_img,pred_res,gt): 25 | ori_img = data = np.transpose(ori_img,(1,2,0)) 26 | pred_res = data = np.transpose(pred_res,(1,2,0)) 27 | gt = data = np.transpose(gt,(1,2,0)) 28 | 29 | binary = deepcopy(pred_res) 30 | binary[binary>=0.5]=1 31 | binary[binary<0.5]=0 32 | 33 | if ori_img.shape[2]==3: 34 | pred_res = np.repeat((pred_res*255).astype(np.uint8),repeats=3,axis=2) 35 | binary = np.repeat((binary*255).astype(np.uint8),repeats=3,axis=2) 36 | gt = np.repeat((gt*255).astype(np.uint8),repeats=3,axis=2) 37 | total_img = np.concatenate((ori_img,pred_res,binary,gt),axis=1) 38 | return total_img 39 | 40 | #visualize image, save as PIL image 41 | def save_img(data,filename): 42 | assert (len(data.shape)==3) #height*width*channels 43 | if data.shape[2]==1: #in case it is black and white 44 | data = np.reshape(data,(data.shape[0],data.shape[1])) 45 | img = Image.fromarray(data.astype(np.uint8)) #the image is between 0-1 46 | img.save(filename) 47 | return img -------------------------------------------------------------------------------- /models/DenseUnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Single_level_densenet(nn.Module): 7 | def __init__(self, filters, num_conv=4): 8 | super(Single_level_densenet, self).__init__() 9 | self.num_conv = num_conv 10 | self.conv_list = nn.ModuleList() 11 | self.bn_list = nn.ModuleList() 12 | for i in range(self.num_conv): 13 | self.conv_list.append(nn.Conv2d(filters, filters, 3, padding=1)) 14 | self.bn_list.append(nn.BatchNorm2d(filters)) 15 | 16 | def forward(self, x): 17 | outs = [] 18 | outs.append(x) 19 | for i in range(self.num_conv): 20 | temp_out = self.conv_list[i](outs[i]) 21 | if i > 0: 22 | for j in range(i): 23 | temp_out += outs[j] 24 | outs.append(F.relu(self.bn_list[i](temp_out))) 25 | out_final = outs[-1] 26 | del outs 27 | return out_final 28 | 29 | 30 | class Down_sample(nn.Module): 31 | def __init__(self, kernel_size=2, stride=2): 32 | super(Down_sample, self).__init__() 33 | self.down_sample_layer = nn.MaxPool2d(kernel_size, stride) 34 | 35 | def forward(self, x): 36 | y = self.down_sample_layer(x) 37 | return y, x 38 | 39 | 40 | class Upsample_n_Concat(nn.Module): 41 | def __init__(self, filters): 42 | super(Upsample_n_Concat, self).__init__() 43 | self.upsample_layer = nn.ConvTranspose2d(filters, filters, 4, padding=1, stride=2) 44 | self.conv = nn.Conv2d(2 * filters, filters, 3, padding=1) 45 | self.bn = nn.BatchNorm2d(filters) 46 | 47 | def forward(self, x, y): 48 | x = self.upsample_layer(x) 49 | x = torch.cat([x, y], dim=1) 50 | x = F.relu(self.bn(self.conv(x))) 51 | return x 52 | 53 | 54 | class Dense_Unet(nn.Module): 55 | def __init__(self, in_chan=1,out_chan=2,filters=128, num_conv=4): 56 | 57 | super(Dense_Unet, self).__init__() 58 | self.conv1 = nn.Conv2d(in_chan, filters, 1) 59 | self.d1 = Single_level_densenet(filters, num_conv) 60 | self.down1 = Down_sample() 61 | self.d2 = Single_level_densenet(filters, num_conv) 62 | self.down2 = Down_sample() 63 | self.d3 = Single_level_densenet(filters, num_conv) 64 | self.down3 = Down_sample() 65 | self.d4 = Single_level_densenet(filters, num_conv) 66 | self.down4 = Down_sample() 67 | self.bottom = Single_level_densenet(filters, num_conv) 68 | self.up4 = Upsample_n_Concat(filters) 69 | self.u4 = Single_level_densenet(filters, num_conv) 70 | self.up3 = Upsample_n_Concat(filters) 71 | self.u3 = Single_level_densenet(filters, num_conv) 72 | self.up2 = Upsample_n_Concat(filters) 73 | self.u2 = Single_level_densenet(filters, num_conv) 74 | self.up1 = Upsample_n_Concat(filters) 75 | self.u1 = Single_level_densenet(filters, num_conv) 76 | self.outconv = nn.Conv2d(filters, out_chan, 1) 77 | 78 | # self.outconvp1 = nn.Conv2d(filters,out_chan, 1) 79 | # self.outconvm1 = nn.Conv2d(filters,out_chan, 1) 80 | 81 | def forward(self, x): 82 | x = self.conv1(x) 83 | x, y1 = self.down1(self.d1(x)) 84 | x, y2 = self.down1(self.d2(x)) 85 | x, y3 = self.down1(self.d3(x)) 86 | x, y4 = self.down1(self.d4(x)) 87 | x = self.bottom(x) 88 | x = self.u4(self.up4(x, y4)) 89 | x = self.u3(self.up3(x, y3)) 90 | x = self.u2(self.up2(x, y2)) 91 | x = self.u1(self.up1(x, y1)) 92 | x1 = self.outconv(x) 93 | # xm1 = self.outconvm1(x) 94 | # xp1 = self.outconvp1(x) 95 | x1 = F.softmax(x1,dim=1) 96 | return x1 97 | 98 | if __name__ == '__main__': 99 | net = Dense_Unet(3,21,128).cuda() 100 | print(net) 101 | in1 = torch.randn(4,3,224,224).cuda() 102 | out = net(in1) 103 | print(out.size()) -------------------------------------------------------------------------------- /models/LadderNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | drop = 0.25 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=True) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | if inplanes!= planes: 19 | self.conv0 = conv3x3(inplanes,planes) 20 | 21 | self.inplanes = inplanes 22 | self.planes = planes 23 | 24 | self.conv1 = conv3x3(planes, planes, stride) 25 | #self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | #self.conv2 = conv3x3(planes, planes) 28 | #self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | self.drop = nn.Dropout2d(p=drop) 32 | 33 | def forward(self, x): 34 | if self.inplanes != self.planes: 35 | x = self.conv0(x) 36 | x = F.relu(x) 37 | 38 | out = self.conv1(x) 39 | #out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.drop(out) 43 | 44 | out1 = self.conv1(out) 45 | #out1 = self.relu(out1) 46 | 47 | out2 = out1 + x 48 | 49 | return F.relu(out2) 50 | 51 | 52 | class Bottleneck(nn.Module): 53 | expansion = 4 54 | 55 | def __init__(self, inplanes, planes, stride=1, downsample=None): 56 | super(Bottleneck, self).__init__() 57 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 58 | self.bn1 = nn.BatchNorm2d(planes) 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 60 | padding=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 63 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x): 69 | residual = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv3(out) 80 | out = self.bn3(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | out += residual 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | class Initial_LadderBlock(nn.Module): 91 | 92 | def __init__(self,planes,layers,kernel=3,block=BasicBlock,inplanes = 3): 93 | super().__init__() 94 | self.planes = planes 95 | self.layers = layers 96 | self.kernel = kernel 97 | 98 | self.padding = int((kernel-1)/2) 99 | self.inconv = nn.Conv2d(in_channels=inplanes,out_channels=planes, 100 | kernel_size=3,stride=1,padding=1,bias=True) 101 | 102 | # create module list for down branch 103 | self.down_module_list = nn.ModuleList() 104 | for i in range(0,layers): 105 | self.down_module_list.append(block(planes*(2**i),planes*(2**i))) 106 | 107 | # use strided conv instead of pooling 108 | self.down_conv_list = nn.ModuleList() 109 | for i in range(0,layers): 110 | self.down_conv_list.append(nn.Conv2d(planes*2**i,planes*2**(i+1),stride=2,kernel_size=kernel,padding=self.padding)) 111 | 112 | # create module for bottom block 113 | self.bottom = block(planes*(2**layers),planes*(2**layers)) 114 | 115 | # create module list for up branch 116 | self.up_conv_list = nn.ModuleList() 117 | self.up_dense_list = nn.ModuleList() 118 | for i in range(0, layers): 119 | self.up_conv_list.append(nn.ConvTranspose2d(in_channels=planes*2**(layers-i), out_channels=planes*2**max(0,layers-i-1), kernel_size=3, 120 | stride=2,padding=1,output_padding=1,bias=True)) 121 | self.up_dense_list.append(block(planes*2**max(0,layers-i-1),planes*2**max(0,layers-i-1))) 122 | 123 | 124 | def forward(self, x): 125 | out = self.inconv(x) 126 | out = F.relu(out) 127 | 128 | down_out = [] 129 | # down branch 130 | for i in range(0,self.layers): 131 | out = self.down_module_list[i](out) 132 | down_out.append(out) 133 | out = self.down_conv_list[i](out) 134 | out = F.relu(out) 135 | 136 | # bottom branch 137 | out = self.bottom(out) 138 | bottom = out 139 | 140 | # up branch 141 | up_out = [] 142 | up_out.append(bottom) 143 | 144 | for j in range(0,self.layers): 145 | out = self.up_conv_list[j](out) + down_out[self.layers-j-1] 146 | #out = F.relu(out) 147 | out = self.up_dense_list[j](out) 148 | up_out.append(out) 149 | 150 | return up_out 151 | 152 | class LadderBlock(nn.Module): 153 | 154 | def __init__(self,planes,layers,kernel=3,block=BasicBlock,inplanes = 3): 155 | super().__init__() 156 | self.planes = planes 157 | self.layers = layers 158 | self.kernel = kernel 159 | 160 | self.padding = int((kernel-1)/2) 161 | self.inconv = block(planes,planes) 162 | 163 | # create module list for down branch 164 | self.down_module_list = nn.ModuleList() 165 | for i in range(0,layers): 166 | self.down_module_list.append(block(planes*(2**i),planes*(2**i))) 167 | 168 | # use strided conv instead of poooling 169 | self.down_conv_list = nn.ModuleList() 170 | for i in range(0,layers): 171 | self.down_conv_list.append(nn.Conv2d(planes*2**i,planes*2**(i+1),stride=2,kernel_size=kernel,padding=self.padding)) 172 | 173 | # create module for bottom block 174 | self.bottom = block(planes*(2**layers),planes*(2**layers)) 175 | 176 | # create module list for up branch 177 | self.up_conv_list = nn.ModuleList() 178 | self.up_dense_list = nn.ModuleList() 179 | for i in range(0, layers): 180 | self.up_conv_list.append(nn.ConvTranspose2d(planes*2**(layers-i), planes*2**max(0,layers-i-1), kernel_size=3, 181 | stride=2,padding=1,output_padding=1,bias=True)) 182 | self.up_dense_list.append(block(planes*2**max(0,layers-i-1),planes*2**max(0,layers-i-1))) 183 | 184 | 185 | def forward(self, x): 186 | out = self.inconv(x[-1]) 187 | 188 | down_out = [] 189 | # down branch 190 | for i in range(0,self.layers): 191 | out = out + x[-i-1] 192 | out = self.down_module_list[i](out) 193 | down_out.append(out) 194 | 195 | out = self.down_conv_list[i](out) 196 | out = F.relu(out) 197 | 198 | # bottom branch 199 | out = self.bottom(out) 200 | bottom = out 201 | 202 | # up branch 203 | up_out = [] 204 | up_out.append(bottom) 205 | 206 | for j in range(0,self.layers): 207 | out = self.up_conv_list[j](out) + down_out[self.layers-j-1] 208 | #out = F.relu(out) 209 | out = self.up_dense_list[j](out) 210 | up_out.append(out) 211 | 212 | return up_out 213 | 214 | class Final_LadderBlock(nn.Module): 215 | 216 | def __init__(self,planes,layers,kernel=3,block=BasicBlock,inplanes = 3): 217 | super().__init__() 218 | self.block = LadderBlock(planes,layers,kernel=kernel,block=block) 219 | 220 | def forward(self, x): 221 | out = self.block(x) 222 | return out[-1] 223 | 224 | class LadderNet(nn.Module): 225 | def __init__(self,inplanes=1,num_classes=2,layers=4,filters=10,): 226 | super().__init__() 227 | self.initial_block = Initial_LadderBlock(planes=filters,layers=layers,inplanes=inplanes) 228 | #self.middle_block = LadderBlock(planes=filters,layers=layers) 229 | self.final_block = Final_LadderBlock(planes=filters,layers=layers) 230 | self.final = nn.Conv2d(in_channels=filters,out_channels=num_classes,kernel_size=1) 231 | 232 | def forward(self,x): 233 | out = self.initial_block(x) 234 | #out = self.middle_block(out) 235 | out = self.final_block(out) 236 | out = self.final(out) 237 | #out = F.relu(out) 238 | out = F.softmax(out,dim=1) 239 | return out -------------------------------------------------------------------------------- /models/UNetFamily.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part contains UNet series models, 3 | including UNet, R2UNet, Attention UNet, R2Attention UNet, DenseUNet 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import init 10 | 11 | # ==========================Core Module================================ 12 | class conv_block(nn.Module): 13 | def __init__(self, ch_in, ch_out): 14 | super(conv_block, self).__init__() 15 | self.conv = nn.Sequential( 16 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 17 | nn.BatchNorm2d(ch_out), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 20 | nn.BatchNorm2d(ch_out), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | x = self.conv(x) 26 | return x 27 | 28 | 29 | class up_conv(nn.Module): 30 | def __init__(self, ch_in, ch_out): 31 | super(up_conv, self).__init__() 32 | self.up = nn.Sequential( 33 | nn.Upsample(scale_factor=2), 34 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 35 | nn.BatchNorm2d(ch_out), 36 | nn.ReLU(inplace=True) 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.up(x) 41 | return x 42 | 43 | 44 | class Recurrent_block(nn.Module): 45 | def __init__(self, ch_out, t=2): 46 | super(Recurrent_block, self).__init__() 47 | self.t = t 48 | self.ch_out = ch_out 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 51 | nn.BatchNorm2d(ch_out), 52 | nn.ReLU(inplace=True) 53 | ) 54 | 55 | def forward(self, x): 56 | for i in range(self.t): 57 | 58 | if i == 0: 59 | x1 = self.conv(x) 60 | 61 | x1 = self.conv(x + x1) 62 | return x1 63 | 64 | 65 | class RRCNN_block(nn.Module): 66 | def __init__(self, ch_in, ch_out, t=2): 67 | super(RRCNN_block, self).__init__() 68 | self.RCNN = nn.Sequential( 69 | Recurrent_block(ch_out, t=t), 70 | Recurrent_block(ch_out, t=t) 71 | ) 72 | self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) 73 | 74 | def forward(self, x): 75 | x = self.Conv_1x1(x) 76 | x1 = self.RCNN(x) 77 | return x + x1 78 | 79 | 80 | class single_conv(nn.Module): 81 | def __init__(self, ch_in, ch_out): 82 | super(single_conv, self).__init__() 83 | self.conv = nn.Sequential( 84 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 85 | nn.BatchNorm2d(ch_out), 86 | nn.ReLU(inplace=True) 87 | ) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class Attention_block(nn.Module): # attention Gate 95 | def __init__(self, F_g, F_l, F_int): 96 | super(Attention_block, self).__init__() 97 | self.W_g = nn.Sequential( 98 | nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), 99 | nn.BatchNorm2d(F_int) 100 | ) 101 | 102 | self.W_x = nn.Sequential( 103 | nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), 104 | nn.BatchNorm2d(F_int) 105 | ) 106 | 107 | self.psi = nn.Sequential( 108 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 109 | nn.BatchNorm2d(1), 110 | nn.Sigmoid() 111 | ) 112 | 113 | self.relu = nn.ReLU(inplace=True) 114 | 115 | def forward(self, g, x): 116 | g1 = self.W_g(g) 117 | x1 = self.W_x(x) 118 | psi = self.relu(g1 + x1) 119 | psi = self.psi(psi) 120 | 121 | return x * psi 122 | 123 | # ================================================================== 124 | class U_Net(nn.Module): 125 | def __init__(self, img_ch=3, output_ch=1): 126 | super(U_Net, self).__init__() 127 | 128 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 129 | 130 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=64) 131 | self.Conv2 = conv_block(ch_in=64, ch_out=128) 132 | self.Conv3 = conv_block(ch_in=128, ch_out=256) 133 | self.Conv4 = conv_block(ch_in=256, ch_out=512) 134 | self.Conv5 = conv_block(ch_in=512, ch_out=1024) 135 | 136 | self.Up5 = up_conv(ch_in=1024, ch_out=512) 137 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 138 | 139 | self.Up4 = up_conv(ch_in=512, ch_out=256) 140 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 141 | 142 | self.Up3 = up_conv(ch_in=256, ch_out=128) 143 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 144 | 145 | self.Up2 = up_conv(ch_in=128, ch_out=64) 146 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 147 | 148 | self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) 149 | 150 | def forward(self, x): 151 | # encoding path 152 | x1 = self.Conv1(x) 153 | 154 | x2 = self.Maxpool(x1) 155 | x2 = self.Conv2(x2) 156 | 157 | x3 = self.Maxpool(x2) 158 | x3 = self.Conv3(x3) 159 | 160 | x4 = self.Maxpool(x3) 161 | x4 = self.Conv4(x4) 162 | 163 | x5 = self.Maxpool(x4) 164 | x5 = self.Conv5(x5) 165 | 166 | # decoding + concat path 167 | d5 = self.Up5(x5) 168 | d5 = torch.cat((x4, d5), dim=1) 169 | 170 | d5 = self.Up_conv5(d5) 171 | 172 | d4 = self.Up4(d5) 173 | d4 = torch.cat((x3, d4), dim=1) 174 | d4 = self.Up_conv4(d4) 175 | 176 | d3 = self.Up3(d4) 177 | d3 = torch.cat((x2, d3), dim=1) 178 | d3 = self.Up_conv3(d3) 179 | 180 | d2 = self.Up2(d3) 181 | d2 = torch.cat((x1, d2), dim=1) 182 | d2 = self.Up_conv2(d2) 183 | 184 | d1 = self.Conv_1x1(d2) 185 | d1 = F.softmax(d1,dim=1) # mine 186 | 187 | return d1 188 | 189 | # ============================================================ 190 | class R2U_Net(nn.Module): 191 | def __init__(self, img_ch=3, output_ch=1, t=2): 192 | super(R2U_Net, self).__init__() 193 | 194 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 195 | self.Upsample = nn.Upsample(scale_factor=2) 196 | 197 | self.RRCNN1 = RRCNN_block(ch_in=img_ch, ch_out=64, t=t) 198 | 199 | self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t) 200 | 201 | self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t) 202 | 203 | self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t) 204 | 205 | self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t) 206 | 207 | self.Up5 = up_conv(ch_in=1024, ch_out=512) 208 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t) 209 | 210 | self.Up4 = up_conv(ch_in=512, ch_out=256) 211 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t) 212 | 213 | self.Up3 = up_conv(ch_in=256, ch_out=128) 214 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t) 215 | 216 | self.Up2 = up_conv(ch_in=128, ch_out=64) 217 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t) 218 | 219 | self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) 220 | 221 | def forward(self, x): 222 | # encoding path 223 | x1 = self.RRCNN1(x) 224 | 225 | x2 = self.Maxpool(x1) 226 | x2 = self.RRCNN2(x2) 227 | 228 | x3 = self.Maxpool(x2) 229 | x3 = self.RRCNN3(x3) 230 | 231 | x4 = self.Maxpool(x3) 232 | x4 = self.RRCNN4(x4) 233 | 234 | x5 = self.Maxpool(x4) 235 | x5 = self.RRCNN5(x5) 236 | 237 | # decoding + concat path 238 | d5 = self.Up5(x5) 239 | d5 = torch.cat((x4, d5), dim=1) 240 | d5 = self.Up_RRCNN5(d5) 241 | 242 | d4 = self.Up4(d5) 243 | d4 = torch.cat((x3, d4), dim=1) 244 | d4 = self.Up_RRCNN4(d4) 245 | 246 | d3 = self.Up3(d4) 247 | d3 = torch.cat((x2, d3), dim=1) 248 | d3 = self.Up_RRCNN3(d3) 249 | 250 | d2 = self.Up2(d3) 251 | d2 = torch.cat((x1, d2), dim=1) 252 | d2 = self.Up_RRCNN2(d2) 253 | 254 | d1 = self.Conv_1x1(d2) 255 | d1 = F.softmax(d1,dim=1) 256 | 257 | return d1 258 | 259 | # =========================================================== 260 | class AttU_Net(nn.Module): 261 | def __init__(self, img_ch=3, output_ch=1): 262 | super(AttU_Net, self).__init__() 263 | 264 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 265 | 266 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=64) 267 | self.Conv2 = conv_block(ch_in=64, ch_out=128) 268 | self.Conv3 = conv_block(ch_in=128, ch_out=256) 269 | self.Conv4 = conv_block(ch_in=256, ch_out=512) 270 | self.Conv5 = conv_block(ch_in=512, ch_out=1024) 271 | 272 | self.Up5 = up_conv(ch_in=1024, ch_out=512) 273 | self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) 274 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 275 | 276 | self.Up4 = up_conv(ch_in=512, ch_out=256) 277 | self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) 278 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 279 | 280 | self.Up3 = up_conv(ch_in=256, ch_out=128) 281 | self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) 282 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 283 | 284 | self.Up2 = up_conv(ch_in=128, ch_out=64) 285 | self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) 286 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 287 | 288 | self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) 289 | 290 | def forward(self, x): 291 | # encoding path 292 | x1 = self.Conv1(x) 293 | 294 | x2 = self.Maxpool(x1) 295 | x2 = self.Conv2(x2) 296 | 297 | x3 = self.Maxpool(x2) 298 | x3 = self.Conv3(x3) 299 | 300 | x4 = self.Maxpool(x3) 301 | x4 = self.Conv4(x4) 302 | 303 | x5 = self.Maxpool(x4) 304 | x5 = self.Conv5(x5) 305 | 306 | # decoding + concat path 307 | d5 = self.Up5(x5) 308 | x4 = self.Att5(g=d5, x=x4) 309 | d5 = torch.cat((x4, d5), dim=1) 310 | d5 = self.Up_conv5(d5) 311 | 312 | d4 = self.Up4(d5) 313 | x3 = self.Att4(g=d4, x=x3) 314 | d4 = torch.cat((x3, d4), dim=1) 315 | d4 = self.Up_conv4(d4) 316 | 317 | d3 = self.Up3(d4) 318 | x2 = self.Att3(g=d3, x=x2) 319 | d3 = torch.cat((x2, d3), dim=1) 320 | d3 = self.Up_conv3(d3) 321 | 322 | d2 = self.Up2(d3) 323 | x1 = self.Att2(g=d2, x=x1) 324 | d2 = torch.cat((x1, d2), dim=1) 325 | d2 = self.Up_conv2(d2) 326 | 327 | d1 = self.Conv_1x1(d2) 328 | d1 = F.softmax(d1,dim=1) 329 | return d1 330 | 331 | # ============================================================== 332 | class R2AttU_Net(nn.Module): 333 | def __init__(self, img_ch=3, output_ch=1, t=2): 334 | super(R2AttU_Net, self).__init__() 335 | 336 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 337 | self.Upsample = nn.Upsample(scale_factor=2) 338 | 339 | self.RRCNN1 = RRCNN_block(ch_in=img_ch, ch_out=64, t=t) 340 | 341 | self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t) 342 | 343 | self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t) 344 | 345 | self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t) 346 | 347 | self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t) 348 | 349 | self.Up5 = up_conv(ch_in=1024, ch_out=512) 350 | self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) 351 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t) 352 | 353 | self.Up4 = up_conv(ch_in=512, ch_out=256) 354 | self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) 355 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t) 356 | 357 | self.Up3 = up_conv(ch_in=256, ch_out=128) 358 | self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) 359 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t) 360 | 361 | self.Up2 = up_conv(ch_in=128, ch_out=64) 362 | self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) 363 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t) 364 | 365 | self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) 366 | 367 | def forward(self, x): 368 | # encoding path 369 | x1 = self.RRCNN1(x) 370 | 371 | x2 = self.Maxpool(x1) 372 | x2 = self.RRCNN2(x2) 373 | 374 | x3 = self.Maxpool(x2) 375 | x3 = self.RRCNN3(x3) 376 | 377 | x4 = self.Maxpool(x3) 378 | x4 = self.RRCNN4(x4) 379 | 380 | x5 = self.Maxpool(x4) 381 | x5 = self.RRCNN5(x5) 382 | 383 | # decoding + concat path 384 | d5 = self.Up5(x5) 385 | x4 = self.Att5(g=d5, x=x4) 386 | d5 = torch.cat((x4, d5), dim=1) 387 | d5 = self.Up_RRCNN5(d5) 388 | 389 | d4 = self.Up4(d5) 390 | x3 = self.Att4(g=d4, x=x3) 391 | d4 = torch.cat((x3, d4), dim=1) 392 | d4 = self.Up_RRCNN4(d4) 393 | 394 | d3 = self.Up3(d4) 395 | x2 = self.Att3(g=d3, x=x2) 396 | d3 = torch.cat((x2, d3), dim=1) 397 | d3 = self.Up_RRCNN3(d3) 398 | 399 | d2 = self.Up2(d3) 400 | x1 = self.Att2(g=d2, x=x1) 401 | d2 = torch.cat((x1, d2), dim=1) 402 | d2 = self.Up_RRCNN2(d2) 403 | 404 | d1 = self.Conv_1x1(d2) 405 | d1 = F.softmax(d1, dim=1) 406 | 407 | return d1 408 | 409 | #==================DenseUNet===================================== 410 | class Single_level_densenet(nn.Module): 411 | def __init__(self, filters, num_conv=4): 412 | super(Single_level_densenet, self).__init__() 413 | self.num_conv = num_conv 414 | self.conv_list = nn.ModuleList() 415 | self.bn_list = nn.ModuleList() 416 | for i in range(self.num_conv): 417 | self.conv_list.append(nn.Conv2d(filters, filters, 3, padding=1)) 418 | self.bn_list.append(nn.BatchNorm2d(filters)) 419 | 420 | def forward(self, x): 421 | outs = [] 422 | outs.append(x) 423 | for i in range(self.num_conv): 424 | temp_out = self.conv_list[i](outs[i]) 425 | if i > 0: 426 | for j in range(i): 427 | temp_out += outs[j] 428 | outs.append(F.relu(self.bn_list[i](temp_out))) 429 | out_final = outs[-1] 430 | del outs 431 | return out_final 432 | 433 | 434 | class Down_sample(nn.Module): 435 | def __init__(self, kernel_size=2, stride=2): 436 | super(Down_sample, self).__init__() 437 | self.down_sample_layer = nn.MaxPool2d(kernel_size, stride) 438 | 439 | def forward(self, x): 440 | y = self.down_sample_layer(x) 441 | return y, x 442 | 443 | 444 | class Upsample_n_Concat(nn.Module): 445 | def __init__(self, filters): 446 | super(Upsample_n_Concat, self).__init__() 447 | self.upsample_layer = nn.ConvTranspose2d(filters, filters, 4, padding=1, stride=2) 448 | self.conv = nn.Conv2d(2 * filters, filters, 3, padding=1) 449 | self.bn = nn.BatchNorm2d(filters) 450 | 451 | def forward(self, x, y): 452 | x = self.upsample_layer(x) 453 | x = torch.cat([x, y], dim=1) 454 | x = F.relu(self.bn(self.conv(x))) 455 | return x 456 | 457 | 458 | class Dense_Unet(nn.Module): 459 | def __init__(self, in_chan=3,out_chan=2,filters=128, num_conv=4): 460 | 461 | super(Dense_Unet, self).__init__() 462 | self.conv1 = nn.Conv2d(in_chan, filters, 1) 463 | self.d1 = Single_level_densenet(filters, num_conv) 464 | self.down1 = Down_sample() 465 | self.d2 = Single_level_densenet(filters, num_conv) 466 | self.down2 = Down_sample() 467 | self.d3 = Single_level_densenet(filters, num_conv) 468 | self.down3 = Down_sample() 469 | self.d4 = Single_level_densenet(filters, num_conv) 470 | self.down4 = Down_sample() 471 | self.bottom = Single_level_densenet(filters, num_conv) 472 | self.up4 = Upsample_n_Concat(filters) 473 | self.u4 = Single_level_densenet(filters, num_conv) 474 | self.up3 = Upsample_n_Concat(filters) 475 | self.u3 = Single_level_densenet(filters, num_conv) 476 | self.up2 = Upsample_n_Concat(filters) 477 | self.u2 = Single_level_densenet(filters, num_conv) 478 | self.up1 = Upsample_n_Concat(filters) 479 | self.u1 = Single_level_densenet(filters, num_conv) 480 | self.outconv = nn.Conv2d(filters, out_chan, 1) 481 | 482 | # self.outconvp1 = nn.Conv2d(filters,out_chan, 1) 483 | # self.outconvm1 = nn.Conv2d(filters,out_chan, 1) 484 | 485 | def forward(self, x): 486 | x = self.conv1(x) 487 | x, y1 = self.down1(self.d1(x)) 488 | x, y2 = self.down1(self.d2(x)) 489 | x, y3 = self.down1(self.d3(x)) 490 | x, y4 = self.down1(self.d4(x)) 491 | x = self.bottom(x) 492 | x = self.u4(self.up4(x, y4)) 493 | x = self.u3(self.up3(x, y3)) 494 | x = self.u2(self.up2(x, y2)) 495 | x = self.u1(self.up1(x, y1)) 496 | x1 = self.outconv(x) 497 | # xm1 = self.outconvm1(x) 498 | # xp1 = self.outconvp1(x) 499 | x1 = F.softmax(x1,dim=1) 500 | return x1 501 | # ========================================================= 502 | 503 | if __name__ == '__main__': 504 | net = Dense_Unet(3,21,128).cuda() 505 | print(net) 506 | in1 = torch.randn(4,3,224,224).cuda() 507 | out = net(in1) 508 | print(out.size()) 509 | 510 | if __name__ == '__main__': 511 | # test network forward 512 | net = AttU_Net(1,2).cuda() 513 | print(net) 514 | in1 = torch.randn((4,1,48,48)).cuda() 515 | out1 = net(in1) 516 | print(out1.size()) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .LadderNet import LadderNet 3 | from .UNetFamily import U_Net, R2U_Net, AttU_Net, R2AttU_Net, Dense_Unet -------------------------------------------------------------------------------- /models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * -------------------------------------------------------------------------------- /models/nn/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ChannelAttention(nn.Module): 5 | def __init__(self, in_planes, ratio=4): 6 | super(ChannelAttention, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.max_pool = nn.AdaptiveMaxPool2d(1) 9 | 10 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) 11 | self.relu1 = nn.ReLU() 12 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) 13 | 14 | self.sigmoid = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 18 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 19 | out = avg_out + max_out 20 | return self.sigmoid(out) * x 21 | 22 | class SpatialAttention(nn.Module): 23 | def __init__(self, kernel_size=3): 24 | super(SpatialAttention, self).__init__() 25 | 26 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 27 | padding = 3 if kernel_size == 7 else 1 28 | 29 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 30 | self.sigmoid = nn.Sigmoid() 31 | 32 | def forward(self, x): 33 | avg_out = torch.mean(x, dim=1, keepdim=True) 34 | max_out, _ = torch.max(x, dim=1, keepdim=True) 35 | y = torch.cat([avg_out, max_out], dim=1) 36 | y = self.conv1(y) 37 | return self.sigmoid(y) * x 38 | 39 | if __name__ == "__main__": 40 | bs, c, h, w = 1, 32, 16, 16 41 | in_tensor = torch.ones(bs, c, h, w) 42 | 43 | net = ChannelAttention(32) 44 | print("in shape:",in_tensor.size()) 45 | out_tensor = net(in_tensor) 46 | print("out shape:", out_tensor.size()) -------------------------------------------------------------------------------- /prepare_dataset/chasedb1.py: -------------------------------------------------------------------------------- 1 | #========================================================= 2 | # 3 | # 4 | # 5 | #========================================================= 6 | import os 7 | from os.path import join 8 | 9 | def get_path_list(root_path,img_path,label_path,fov_path): 10 | tmp_list = [img_path,label_path,fov_path] 11 | res = [] 12 | for i in range(len(tmp_list)): 13 | data_path = join(data_root_path,tmp_list[i]) 14 | filename_list = os.listdir(data_path) 15 | filename_list.sort() 16 | res.append([join(data_path,j) for j in filename_list]) 17 | return res 18 | 19 | def write_path_list(name_list, save_path, file_name): 20 | f = open(join(save_path, file_name), 'w') 21 | for i in range(len(name_list[0])): 22 | f.write(str(name_list[0][i]) + " " + str(name_list[1][i]) + " " + str(name_list[2][i]) + '\n') 23 | f.close() 24 | 25 | if __name__ == "__main__": 26 | #------------Path of the dataset -------------------------------- 27 | data_root_path = '/ssd/lzq/projects/vesselseg/data' 28 | # if not os.path.exists(data_root_path): raise ValueError("data path is not exist, Please make sure your data path is correct") 29 | #train 30 | img = "CHASEDB1/images" 31 | gt = "CHASEDB1/1st_label" 32 | fov = "CHASEDB1/mask" 33 | #---------------save path----------------------------------------- 34 | save_path = "./prepare_dataset/data_path_list/CHASEDB1" 35 | if not os.path.exists(save_path): 36 | os.mkdir(save_path) 37 | #----------------------------------------------------------------- 38 | data_list = get_path_list(data_root_path,img,gt,fov) 39 | print('Numbers of all imgs:',len(data_list[0])) 40 | test_range = (0,7) # 测试集索引范围,左闭右开 41 | train_list = [data_list[i][:test_range[0]] + data_list[i][test_range[1]:] for i in range(len(data_list))] 42 | test_list = [data_list[i][test_range[0]:test_range[1]] for i in range(len(data_list))] 43 | 44 | print('Number of train imgs:',len(train_list[0])) 45 | write_path_list(train_list, save_path, 'train.txt') 46 | 47 | print('Number of test imgs:',len(test_list[0])) 48 | write_path_list(test_list, save_path, 'test.txt') 49 | 50 | print("Finish!") 51 | 52 | -------------------------------------------------------------------------------- /prepare_dataset/data_path_list/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore path list files of datasets 2 | DRIVE/ 3 | CHASEDB1/ 4 | STARE/ 5 | -------------------------------------------------------------------------------- /prepare_dataset/drive.py: -------------------------------------------------------------------------------- 1 | #========================================================= 2 | # 3 | # 4 | # 5 | #========================================================= 6 | import os 7 | import h5py,cv2,imageio 8 | import numpy as np 9 | from PIL import Image 10 | from os.path import join 11 | 12 | def get_path_list(root_path,img_path,label_path,fov_path): 13 | tmp_list = [img_path,label_path,fov_path] 14 | res = [] 15 | for i in range(len(tmp_list)): 16 | data_path = join(data_root_path,tmp_list[i]) 17 | filename_list = os.listdir(data_path) 18 | filename_list.sort() 19 | res.append([join(data_path,j) for j in filename_list]) 20 | return res 21 | 22 | def write_path_list(name_list, save_path, file_name): 23 | f = open(join(save_path, file_name), 'w') 24 | for i in range(len(name_list[0])): 25 | f.write(str(name_list[0][i]) + " " + str(name_list[1][i]) + " " + str(name_list[2][i]) + '\n') 26 | f.close() 27 | 28 | if __name__ == "__main__": 29 | #------------Path of the dataset ------------------------- 30 | data_root_path = '/ssd/lzq/projects/vesselseg/data' 31 | # if not os.path.exists(data_root_path): raise ValueError("data path is not exist, Please make sure your data path is correct") 32 | #train 33 | img_train = "DRIVE/training/images/" 34 | gt_train = "DRIVE/training/1st_manual/" 35 | fov_train = "DRIVE/training/mask/" 36 | #test 37 | img_test = "DRIVE/test/images/" 38 | gt_test = "DRIVE/test/1st_manual/" 39 | fov_test = "DRIVE/test/mask/" 40 | #---------------------------------------------------------- 41 | save_path = "./prepare_dataset/data_path_list/DRIVE/" 42 | if not os.path.isdir(save_path): 43 | os.mkdir(save_path) 44 | 45 | train_list = get_path_list(data_root_path,img_train,gt_train,fov_train) 46 | print('Number of train imgs:',len(train_list[0])) 47 | write_path_list(train_list, save_path, 'train.txt') 48 | 49 | test_list = get_path_list(data_root_path,img_test,gt_test,fov_test) 50 | print('Number of test imgs:',len(test_list[0])) 51 | write_path_list(test_list, save_path, 'test.txt') 52 | 53 | print("Finish!") 54 | 55 | -------------------------------------------------------------------------------- /prepare_dataset/stare.py: -------------------------------------------------------------------------------- 1 | #========================================================= 2 | # 3 | # 4 | # 5 | #========================================================= 6 | import os 7 | from os.path import join 8 | 9 | def get_path_list(root_path,img_path,label_path,fov_path): 10 | tmp_list = [img_path,label_path,fov_path] 11 | res = [] 12 | for i in range(len(tmp_list)): 13 | data_path = join(data_root_path,tmp_list[i]) 14 | filename_list = os.listdir(data_path) 15 | filename_list.sort() 16 | res.append([join(data_path,j) for j in filename_list]) 17 | return res 18 | 19 | def write_path_list(name_list, save_path, file_name): 20 | f = open(join(save_path, file_name), 'w') 21 | for i in range(len(name_list[0])): 22 | f.write(str(name_list[0][i]) + " " + str(name_list[1][i]) + " " + str(name_list[2][i]) + '\n') 23 | f.close() 24 | 25 | if __name__ == "__main__": 26 | #------------Path of the dataset -------------------------------- 27 | data_root_path = '/ssd/lzq/projects/vesselseg/data' 28 | # if not os.path.exists(data_root_path): raise ValueError("data path is not exist, Please make sure your data path is correct") 29 | #train 30 | img = "STARE/images" 31 | gt = "STARE/1st_labels_ah" 32 | fov = "STARE/mask" 33 | #---------------save path----------------------------------------- 34 | save_path = "./prepare_dataset/data_path_list/STARE" 35 | if not os.path.exists(save_path): 36 | os.mkdir(save_path) 37 | #----------------------------------------------------------------- 38 | data_list = get_path_list(data_root_path,img,gt,fov) 39 | print('Numbers of all imgs:',len(data_list[0])) 40 | test_range = (0,5) # 测试集索引范围,左闭右开 41 | train_list = [data_list[i][:test_range[0]] + data_list[i][test_range[1]:] for i in range(len(data_list))] 42 | test_list = [data_list[i][test_range[0]:test_range[1]] for i in range(len(data_list))] 43 | 44 | print('Number of train imgs:',len(train_list[0])) 45 | write_path_list(train_list, save_path, 'train.txt') 46 | 47 | print('Number of test imgs:',len(test_list[0])) 48 | write_path_list(test_list, save_path, 'test.txt') 49 | 50 | print("Finish!") 51 | 52 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import joblib,copy 2 | import torch.backends.cudnn as cudnn 3 | from torch.utils.data import DataLoader 4 | import torch,sys 5 | from tqdm import tqdm 6 | 7 | from collections import OrderedDict 8 | from lib.visualize import save_img,group_images,concat_result 9 | import os 10 | import argparse 11 | from lib.logger import Logger, Print_Logger 12 | from lib.extract_patches import * 13 | from os.path import join 14 | from lib.dataset import TestDataset 15 | from lib.metrics import Evaluate 16 | import models 17 | from lib.common import setpu_seed,dict_round 18 | from config import parse_args 19 | from lib.pre_processing import my_PreProc 20 | 21 | setpu_seed(2021) 22 | 23 | class Test(): 24 | def __init__(self, args): 25 | self.args = args 26 | assert (args.stride_height <= args.test_patch_height and args.stride_width <= args.test_patch_width) 27 | # save path 28 | self.path_experiment = join(args.outf, args.save) 29 | 30 | self.patches_imgs_test, self.test_imgs, self.test_masks, self.test_FOVs, self.new_height, self.new_width = get_data_test_overlap( 31 | test_data_path_list=args.test_data_path_list, 32 | patch_height=args.test_patch_height, 33 | patch_width=args.test_patch_width, 34 | stride_height=args.stride_height, 35 | stride_width=args.stride_width 36 | ) 37 | self.img_height = self.test_imgs.shape[2] 38 | self.img_width = self.test_imgs.shape[3] 39 | 40 | test_set = TestDataset(self.patches_imgs_test) 41 | self.test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=3) 42 | 43 | # Inference prediction process 44 | def inference(self, net): 45 | net.eval() 46 | preds = [] 47 | with torch.no_grad(): 48 | for batch_idx, inputs in tqdm(enumerate(self.test_loader), total=len(self.test_loader)): 49 | inputs = inputs.cuda() 50 | outputs = net(inputs) 51 | outputs = outputs[:,1].data.cpu().numpy() 52 | preds.append(outputs) 53 | predictions = np.concatenate(preds, axis=0) 54 | self.pred_patches = np.expand_dims(predictions,axis=1) 55 | 56 | # Evaluate ate and visualize the predicted images 57 | def evaluate(self): 58 | self.pred_imgs = recompone_overlap( 59 | self.pred_patches, self.new_height, self.new_width, self.args.stride_height, self.args.stride_width) 60 | ## restore to original dimensions 61 | self.pred_imgs = self.pred_imgs[:, :, 0:self.img_height, 0:self.img_width] 62 | 63 | #predictions only inside the FOV 64 | y_scores, y_true = pred_only_in_FOV(self.pred_imgs, self.test_masks, self.test_FOVs) 65 | eval = Evaluate(save_path=self.path_experiment) 66 | eval.add_batch(y_true, y_scores) 67 | log = eval.save_all_result(plot_curve=True,save_name="performance.txt") 68 | # save labels and probs for plot ROC and PR curve when k-fold Cross-validation 69 | np.save('{}/result.npy'.format(self.path_experiment), np.asarray([y_true, y_scores])) 70 | return dict_round(log, 6) 71 | 72 | # save segmentation imgs 73 | def save_segmentation_result(self): 74 | img_path_list, _, _ = load_file_path_txt(self.args.test_data_path_list) 75 | img_name_list = [item.split('/')[-1].split('.')[0] for item in img_path_list] 76 | 77 | kill_border(self.pred_imgs, self.test_FOVs) # only for visualization 78 | self.save_img_path = join(self.path_experiment,'result_img') 79 | if not os.path.exists(join(self.save_img_path)): 80 | os.makedirs(self.save_img_path) 81 | # self.test_imgs = my_PreProc(self.test_imgs) # Uncomment to save the pre processed image 82 | for i in range(self.test_imgs.shape[0]): 83 | total_img = concat_result(self.test_imgs[i],self.pred_imgs[i],self.test_masks[i]) 84 | save_img(total_img,join(self.save_img_path, "Result_"+img_name_list[i]+'.png')) 85 | 86 | # Val on the test set at each epoch 87 | def val(self): 88 | self.pred_imgs = recompone_overlap( 89 | self.pred_patches, self.new_height, self.new_width, self.args.stride_height, self.args.stride_width) 90 | ## recover to original dimensions 91 | self.pred_imgs = self.pred_imgs[:, :, 0:self.img_height, 0:self.img_width] 92 | 93 | #predictions only inside the FOV 94 | y_scores, y_true = pred_only_in_FOV(self.pred_imgs, self.test_masks, self.test_FOVs) 95 | eval = Evaluate(save_path=self.path_experiment) 96 | eval.add_batch(y_true, y_scores) 97 | confusion,accuracy,specificity,sensitivity,precision = eval.confusion_matrix() 98 | log = OrderedDict([('val_auc_roc', eval.auc_roc()), 99 | ('val_f1', eval.f1_score()), 100 | ('val_acc', accuracy), 101 | ('SE', sensitivity), 102 | ('SP', specificity)]) 103 | return dict_round(log, 6) 104 | 105 | if __name__ == '__main__': 106 | args = parse_args() 107 | save_path = join(args.outf, args.save) 108 | sys.stdout = Print_Logger(os.path.join(save_path, 'test_log.txt')) 109 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 110 | 111 | 112 | # net = models.UNetFamily.Dense_Unet(1,2).to(device) 113 | net = models.LadderNet(inplanes=1, num_classes=2, layers=3, filters=16).to(device) 114 | cudnn.benchmark = True 115 | 116 | # Load checkpoint 117 | print('==> Loading checkpoint...') 118 | checkpoint = torch.load(join(save_path, 'best_model.pth')) 119 | net.load_state_dict(checkpoint['net']) 120 | 121 | eval = Test(args) 122 | eval.inference(net) 123 | print(eval.evaluate()) 124 | eval.save_segmentation_result() 125 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | `./Tools` is an experimental module that implements functions related to visualization and cross-validation 2 | ## This part of the function needs to be improved, please use it with caution. -------------------------------------------------------------------------------- /tools/ablation/ablation_plot.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_curve 2 | from sklearn.metrics import roc_auc_score 3 | from sklearn.metrics import confusion_matrix 4 | from sklearn.metrics import precision_recall_curve 5 | from sklearn.metrics import f1_score 6 | 7 | import matplotlib.pyplot as plt 8 | import matplotlib.pylab as pylab 9 | params = {'legend.fontsize': 13, 10 | 'axes.labelsize': 15, 11 | 'axes.titlesize':15, 12 | 'xtick.labelsize':15, 13 | 'ytick.labelsize':15} 14 | pylab.rcParams.update(params) 15 | import os 16 | import torch 17 | from os.path import join 18 | from collections import OrderedDict 19 | import numpy as np 20 | 21 | result_list = {"d_base":'./experiments/db1_new/result.npy', 22 | "d_up1":'./experiments/d_up1/result.npy', 23 | "d_total":'./experiments/d_total/result.npy'} 24 | save_path = './experiments/Drive_ablation' 25 | if not os.path.exists(save_path): os.makedirs(save_path) 26 | 27 | # ===============AUC ROC=============== 28 | plt.figure() 29 | for name,path in result_list.items(): 30 | data = np.load(path) 31 | target, output = data 32 | AUC_ROC= roc_auc_score(target, output) 33 | print(name + ": AUC of ROC curve: " + str(AUC_ROC)) 34 | 35 | fpr, tpr, thresholds = roc_curve(target, output) 36 | plt.plot(fpr, tpr, '-', label=name + ' (AUC = %0.4f)' % AUC_ROC) 37 | plt.legend(loc="lower right") 38 | 39 | plt.title('ROC curve') 40 | plt.xlabel("FPR (False Positive Rate)") 41 | plt.ylabel("TPR (True Positive Rate)") 42 | 43 | plt.savefig("{}/ROC.png".format(save_path)) 44 | 45 | # ===============AUC_PR================== 46 | plt.figure() 47 | for name,path in result_list.items(): 48 | data = np.load(path) 49 | target, output = data 50 | 51 | precision, recall, thresholds = precision_recall_curve(target, output) 52 | precision = np.fliplr([precision])[0] # so the array is increasing (you won't get negative AUC) 53 | recall = np.fliplr([recall])[0] # so the array is increasing (you won't get negative AUC) 54 | AUC_pr = np.trapz(precision, recall) 55 | print(name + ": AUC of P-R curve: " + str(AUC_pr)) 56 | 57 | plt.plot(recall, precision, '-', label=name+' (AUC = %0.4f)' % AUC_pr) 58 | plt.legend(loc="lower right") 59 | 60 | plt.title('Precision-Recall curve') 61 | plt.xlabel("Recall") 62 | plt.ylabel("Precision") 63 | plt.savefig("{}/PRC.png".format(save_path)) 64 | -------------------------------------------------------------------------------- /tools/ablation/ablation_plot_with_detail.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_curve 2 | from sklearn.metrics import roc_auc_score 3 | from sklearn.metrics import confusion_matrix 4 | from sklearn.metrics import precision_recall_curve 5 | from sklearn.metrics import f1_score 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.axes_grid1.inset_locator import mark_inset 8 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 9 | import matplotlib.pylab as pylab 10 | import os 11 | from os.path import join 12 | from collections import OrderedDict 13 | import numpy as np 14 | 15 | params = {'legend.fontsize': 12, 16 | 'axes.labelsize': 15, 17 | 'axes.titlesize':15, 18 | 'xtick.labelsize':13, 19 | 'ytick.labelsize':13} # define pyplot parameters 20 | pylab.rcParams.update(params) 21 | #Area under the ROC curve 22 | 23 | result_list = {"d_base":'./experiments/db1_new/result.npy', 24 | "d_up1":'./experiments/d_up1/result.npy', 25 | "d_total":'./experiments/d_total/result.npy'} 26 | save_path = './experiments/Drive_ablation' 27 | 28 | if not os.path.exists(save_path): os.makedirs(save_path) 29 | 30 | # ===============AUC ROC=============== 31 | fig, ax = plt.subplots(1, 1) 32 | axins = ax.inset_axes((0.45, 0.45, 0.45, 0.45)) # 左下角坐标(x0, y0),窗口大小( width, height) 33 | axins.set_xlim(0.07, 0.11) 34 | axins.set_ylim(0.93, 0.96) 35 | 36 | mark_inset(ax, axins, loc1=3, loc2=1, fc="none", ec='k', lw=1) # 建立连线 loc1和loc2取值为1,2,3,4,对应右上,左上,左下,右下 37 | for name,path in result_list.items(): 38 | data = np.load(path[0]) 39 | target, output = data 40 | AUC_ROC= roc_auc_score(target, output) 41 | print(name + ": AUC of ROC curve: " + str(AUC_ROC)) 42 | 43 | fpr, tpr, thresholds = roc_curve(target, output) 44 | ax.plot(fpr, tpr, linestyle=path[1], label=name + ' (AUC = %0.4f)' % AUC_ROC) 45 | axins.plot(fpr, tpr, linestyle=path[1], label=name + ' (AUC = %0.4f)' % AUC_ROC) 46 | ax.legend(loc="lower right") 47 | 48 | plt.title('ROC curve') 49 | plt.xlabel("FPR (False Positive Rate)") 50 | plt.ylabel("TPR (True Positive Rate)") 51 | plt.savefig("{}/ROC1.png".format(save_path)) 52 | 53 | # ===============AUC_PR================== 54 | fig, ax = plt.subplots(1, 1) 55 | axins = ax.inset_axes((0.15, 0.45, 0.4, 0.4)) # 左下角坐标(x0, y0),窗口大小( width, height) 56 | axins.set_xlim(0.83,0.89) 57 | axins.set_ylim(0.75,0.82) 58 | 59 | mark_inset(ax, axins, loc1=1, loc2=4, fc="none", ec='k', lw=1) # 建立连线 loc1和loc2取值为1,2,3,4,对应右上,左上,左下,右下 60 | for name,path in result_list.items(): 61 | data = np.load(path[0]) 62 | target, output = data 63 | 64 | precision, recall, thresholds = precision_recall_curve(target, output) 65 | precision = np.fliplr([precision])[0] # so the array is increasing (you won't get negative AUC) 66 | recall = np.fliplr([recall])[0] # so the array is increasing (you won't get negative AUC) 67 | AUC_pr = np.trapz(precision, recall) 68 | print(name + ": AUC of P-R curve: " + str(AUC_pr)) 69 | 70 | ax.plot(recall, precision, linestyle=path[1], label=name+' (AUC = %0.4f)' % AUC_pr) 71 | axins.plot(recall, precision, linestyle=path[1], label=name+' (AUC = %0.4f)' % AUC_pr) 72 | ax.legend(loc="lower left") 73 | 74 | plt.title('Precision-Recall curve') 75 | plt.xlabel("Recall") 76 | plt.ylabel("Precision") 77 | plt.savefig("{}/PRC1.png".format(save_path)) 78 | -------------------------------------------------------------------------------- /tools/merge_k-flod_plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Because the STARE and CHASE_DB1 dataset are not divided into trainset 3 | and testset, cross-validation should be used.The script combines the 4 | test results of k-fold cross-validation on a dataset to give the final performance. 5 | 6 | Attention: Run from the root directory of the project, otherwise the package may not be found 7 | """ 8 | import sys 9 | sys.path.append('./') # The root directory of the project 10 | from lib.metrics import Evaluate 11 | import numpy as np 12 | 13 | # second = ['stare1','stare2','stare3','stare4'] 14 | # second = ['chase1','chase2','chase3','chase4'] 15 | second = ['s_1','s_2','s_3','s_4'] 16 | # second = ['c_1','c_2','c_3','c_4'] 17 | save_path = './experiments/STARE' 18 | agent = Evaluate(save_path=save_path) 19 | for i in second: 20 | data = np.load('./experiments/{}/result.npy'.format(i)) 21 | y_true, y_prob = data[0],data[1] 22 | agent.add_batch(y_true,y_prob) 23 | np.save('{}/result.npy'.format(save_path),np.asarray([agent.target,agent.output])) 24 | agent.save_all_result(plot_curve=True) -------------------------------------------------------------------------------- /tools/visualization/detail_comparison.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os,sys 3 | sys.path.append('../') 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import imageio 8 | 9 | def readImg(im_fn): 10 | im = Image.open(im_fn) 11 | return im 12 | 13 | def split_result(result,count=4): 14 | res = [] 15 | h,w,c = result.shape 16 | w = w//4 17 | for i in range(count): 18 | img = result[:,w*i:w*(i+1)] 19 | res.append(img) # 变三通道 20 | return res 21 | 22 | def crop_and_resize(img,center,crop_length,target_shape): 23 | crop_img = img[center[1]-crop_length//2:center[1]+crop_length//2, 24 | center[0]-crop_length//2:center[0]+crop_length//2] 25 | return cv2.resize(crop_img, target_shape, interpolation=cv2.INTER_CUBIC) 26 | 27 | if __name__ == "__main__": 28 | result = readImg('/ssd/lzq/sf3/output/s_3/result_img/Original_GroundTruth_Prediction2.png') 29 | ori_img = readImg('/ssd/lzq/sf3/data/STARE/images/im0081.ppm') 30 | save_path = '/ssd/lzq/sf3/result_detail_visualization' 31 | w0 = 580 32 | h0 = 360 33 | block = 90 34 | 35 | result_list = split_result(result) 36 | _, prob_block, bin_block, gt_block= [crop_and_resize(img,(w0,h0),block,ori_img.shape[::-1][1:]) for img in result_list] 37 | 38 | ori_block = crop_and_resize(ori_img,(w0,h0),block,ori_img.shape[::-1][1:]) 39 | 40 | # 拼接 41 | res = np.concatenate((ori_img,ori_block,prob_block,bin_block,gt_block),axis=1) 42 | 43 | cv2.rectangle(res, (w0-block//2,h0-block//2),(w0+block//2,h0+block//2),255,2) 44 | cv2.line(res, (w0-block//2,h0-block//2),(ori_img.shape[1],0), 255, 2) 45 | cv2.line(res, (w0-block//2,h0+block//2),ori_img.shape[::-1][1:], 255, 2) 46 | cv2.imwrite(save_path+'/detail_comp_result_0081.png', res) 47 | 48 | -------------------------------------------------------------------------------- /tools/visualization/detail_comparison2.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os,sys 3 | import cv2 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import imageio 7 | 8 | def readImg(im_fn): 9 | im = Image.open(im_fn) 10 | return im 11 | 12 | def split_result(result,count=4): 13 | res = [] 14 | h,w,c = result.shape 15 | w = w//4 16 | for i in range(count): 17 | img = result[:,w*i:w*(i+1)] 18 | res.append(img) # 变三通道 19 | return res 20 | 21 | def crop_and_resize(img,center,crop_length,target_shape,inter = None): 22 | crop_img = img[center[1]-crop_length//2:center[1]+crop_length//2, 23 | center[0]-crop_length//2:center[0]+crop_length//2] 24 | if inter == 1:#最近邻插值 25 | return cv2.resize(crop_img, target_shape, interpolation=cv2.INTER_NEAREST) 26 | elif inter == 2:#双线性插值 27 | return cv2.resize(crop_img, target_shape, interpolation=cv2.INTER_LINEAR) 28 | elif inter == 3:#双三次插值 29 | return cv2.resize(crop_img, target_shape, interpolation=cv2.INTER_CUBIC) 30 | else: 31 | raise TypeError("please choose current interpolation!!!") 32 | 33 | if __name__ == "__main__": 34 | result = readImg('/ssd/lzq/sf3/output/s_3/result_img/Original_GroundTruth_Prediction2.png') 35 | img_path = '/ssd/lzq/sf3/data/STARE/images/im0081.ppm' 36 | save_path = '/ssd/lzq/sf3/visualization/result_detail_visualization' 37 | w0 = 590 38 | h0 = 360 39 | block = 110 40 | 41 | ori_img = readImg(img_path) 42 | idx = img_path.split('/')[-1].split('.')[0] 43 | 44 | result_list = split_result(result) 45 | prob_block = crop_and_resize(result_list[1],(w0,h0),block,(ori_img.shape[0],ori_img.shape[0]),inter=2) 46 | bin_block = crop_and_resize(result_list[2],(w0,h0),block,(ori_img.shape[0],ori_img.shape[0]),inter=1) 47 | gt_block = crop_and_resize(result_list[3],(w0,h0),block,(ori_img.shape[0],ori_img.shape[0]),inter=1) 48 | 49 | cv2.imwrite(save_path+'/{}_prob_block.png'.format(idx), prob_block) 50 | cv2.imwrite(save_path+'/{}_bin_block.png'.format(idx), bin_block) 51 | cv2.imwrite(save_path+'/{}_gt_block.png'.format(idx), gt_block) 52 | 53 | ori_block = crop_and_resize(ori_img,(w0,h0),block,(ori_img.shape[0],ori_img.shape[0]),inter=2) 54 | cv2.imwrite(save_path+'/{}_ori__block.png'.format(idx), ori_block) 55 | # 拼接 56 | # res = np.concatenate((ori_img,ori_block,prob_block,bin_block,gt_block),axis=1) 57 | 58 | cv2.rectangle(ori_img, (w0-block//2,h0-block//2),(w0+block//2,h0+block//2),255,5) 59 | cv2.line(ori_img, (w0-block//2,h0-block//2),(ori_img.shape[1],0), 255, 5) 60 | cv2.line(ori_img, (w0-block//2,h0+block//2),ori_img.shape[::-1][1:], 255, 5) 61 | cv2.imwrite(save_path+'/{}_ori_img.png'.format(idx), ori_img) 62 | 63 | -------------------------------------------------------------------------------- /tools/visualization/preprocess_visualization.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # 3 | # Script to pre-process the original imgs 4 | # 5 | ################################################## 6 | import numpy as np 7 | from PIL import Image 8 | import cv2,sys 9 | 10 | def readImg(im_fn): 11 | im = Image.open(im_fn) 12 | return im 13 | #============================================================ 14 | #convert RGB image in black and white 15 | def rgb2gray(rgb): 16 | assert (len(rgb.shape)==4) #4D arrays 17 | assert (rgb.shape[1]==3) 18 | bn_imgs = rgb[:,0,:,:]*0.299 + rgb[:,1,:,:]*0.587 + rgb[:,2,:,:]*0.114 19 | bn_imgs = np.reshape(bn_imgs,(rgb.shape[0],1,rgb.shape[2],rgb.shape[3])) 20 | return bn_imgs 21 | 22 | #==== histogram equalization 23 | def histo_equalized(imgs): 24 | assert (len(imgs.shape)==4) #4D arrays 25 | assert (imgs.shape[1]==1) #check the channel is 1 26 | imgs_equalized = np.empty(imgs.shape) 27 | for i in range(imgs.shape[0]): 28 | imgs_equalized[i,0] = cv2.equalizeHist(np.array(imgs[i,0], dtype = np.uint8)) 29 | return imgs_equalized 30 | 31 | 32 | # CLAHE (Contrast Limited Adaptive Histogram Equalization) 33 | #adaptive histogram equalization is used. In this, image is divided into small blocks called "tiles" (tileSize is 8x8 by default in OpenCV). Then each of these blocks are histogram equalized as usual. So in a small area, histogram would confine to a small region (unless there is noise). If noise is there, it will be amplified. To avoid this, contrast limiting is applied. If any histogram bin is above the specified contrast limit (by default 40 in OpenCV), those pixels are clipped and distributed uniformly to other bins before applying histogram equalization. After equalization, to remove artifacts in tile borders, bilinear interpolation is applied 34 | def clahe_equalized(imgs): 35 | assert (len(imgs.shape)==4) #4D arrays 36 | assert (imgs.shape[1]==1) #check the channel is 1 37 | #create a CLAHE object (Arguments are optional). 38 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) 39 | imgs_equalized = np.empty(imgs.shape) 40 | for i in range(imgs.shape[0]): 41 | imgs_equalized[i,0] = clahe.apply(np.array(imgs[i,0], dtype = np.uint8)) 42 | return imgs_equalized 43 | 44 | # ===== normalize over the dataset 45 | def dataset_normalized(imgs): 46 | assert (len(imgs.shape)==4) #4D arrays 47 | assert (imgs.shape[1]==1) #check the channel is 1 48 | imgs_normalized = np.empty(imgs.shape) 49 | imgs_std = np.std(imgs) 50 | imgs_mean = np.mean(imgs) 51 | imgs_normalized = (imgs-imgs_mean)/imgs_std 52 | for i in range(imgs.shape[0]): 53 | imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255 54 | return imgs_normalized 55 | 56 | 57 | def adjust_gamma(imgs, gamma=1.0): 58 | assert (len(imgs.shape)==4) #4D arrays 59 | assert (imgs.shape[1]==1) #check the channel is 1 60 | # build a lookup table mapping the pixel values [0, 255] to 61 | # their adjusted gamma values 62 | invGamma = 1.0 / gamma 63 | table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8") 64 | # apply gamma correction using the lookup table 65 | new_imgs = np.empty(imgs.shape) 66 | for i in range(imgs.shape[0]): 67 | new_imgs[i,0] = cv2.LUT(np.array(imgs[i,0], dtype = np.uint8), table) 68 | return new_imgs 69 | 70 | #My pre processing 71 | def my_PreProc(img_path,save_path,idx): 72 | data = np.asarray(Image.open(img_path)) 73 | data = np.flip(data,axis=2) # BGR to RGB 74 | data = np.expand_dims(data,axis=0) 75 | data = np.transpose(data,(0,3,1,2)) 76 | assert(len(data.shape)==4) 77 | assert (data.shape[1]==3) #Use the original images 78 | #black-white conversion 79 | train_imgs = rgb2gray(data) 80 | cv2.imwrite(save_path+'/img_gray_{}.png'.format(idx), train_imgs[0,0]) 81 | 82 | train_imgs = dataset_normalized(train_imgs) 83 | cv2.imwrite(save_path+'/img_normalized_{}.png'.format(idx), train_imgs[0,0]) 84 | 85 | train_imgs = clahe_equalized(train_imgs) 86 | cv2.imwrite(save_path+'/img_clahe_{}.png'.format(idx), train_imgs[0,0]) 87 | 88 | train_imgs = adjust_gamma(train_imgs, 1.2) 89 | cv2.imwrite(save_path+'/img_gamma_{}.png'.format(idx), train_imgs[0,0]) 90 | 91 | return train_imgs 92 | 93 | if __name__ == "__main__": 94 | img_path = '/ssd/lzq/sf3/data/STARE/images/im0005.ppm' 95 | save_path = '/ssd/lzq/sf3/visualization/preprocess_visual_result' 96 | my_PreProc(img_path,save_path,idx=5) 97 | 98 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | import torch.optim as optim 3 | import sys,time 4 | from os.path import join 5 | import torch 6 | from lib.losses.loss import * 7 | from lib.common import * 8 | from config import parse_args 9 | from lib.logger import Logger, Print_Logger 10 | import models 11 | from test import Test 12 | 13 | from function import get_dataloader, train, val, get_dataloaderV2 14 | 15 | 16 | def main(): 17 | setpu_seed(2021) 18 | args = parse_args() 19 | save_path = join(args.outf, args.save) 20 | save_args(args,save_path) 21 | 22 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 23 | cudnn.benchmark = True 24 | 25 | log = Logger(save_path) 26 | sys.stdout = Print_Logger(os.path.join(save_path,'train_log.txt')) 27 | print('The computing device used is: ','GPU' if device.type=='cuda' else 'CPU') 28 | 29 | # net = models.UNetFamily.U_Net(1,2).to(device) 30 | net = models.LadderNet(inplanes=args.in_channels, num_classes=args.classes, layers=3, filters=16).to(device) 31 | print("Total number of parameters: " + str(count_parameters(net))) 32 | 33 | log.save_graph(net,torch.randn((1,1,48,48)).to(device).to(device=device)) # Save the model structure to the tensorboard file 34 | # torch.nn.init.kaiming_normal(net, mode='fan_out') # Modify default initialization method 35 | # net.apply(weight_init) 36 | 37 | # The training speed of this task is fast, so pre training is not recommended 38 | if args.pre_trained is not None: 39 | # Load checkpoint. 40 | print('==> Resuming from checkpoint..') 41 | checkpoint = torch.load(args.outf + '%s/latest_model.pth' % args.pre_trained) 42 | net.load_state_dict(checkpoint['net']) 43 | optimizer.load_state_dict(checkpoint['optimizer']) 44 | args.start_epoch = checkpoint['epoch']+1 45 | 46 | # criterion = LossMulti(jaccard_weight=0,class_weights=np.array([0.5,0.5])) 47 | criterion = CrossEntropyLoss2d() # Initialize loss function 48 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 49 | # create a list of learning rate with epochs 50 | # lr_schedule = make_lr_schedule(np.array([50, args.N_epochs]),np.array([0.001, 0.0001])) 51 | # lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.5) 52 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.N_epochs, eta_min=0) 53 | 54 | train_loader, val_loader = get_dataloaderV2(args) # create dataloader 55 | # train_loader, val_loader = get_dataloader(args) 56 | 57 | if args.val_on_test: 58 | print('\033[0;32m===============Validation on Testset!!!===============\033[0m') 59 | val_tool = Test(args) 60 | 61 | best = {'epoch':0,'AUC_roc':0.5} # Initialize the best epoch and performance(AUC of ROC) 62 | trigger = 0 # Early stop Counter 63 | for epoch in range(args.start_epoch,args.N_epochs+1): 64 | print('\nEPOCH: %d/%d --(learn_rate:%.6f) | Time: %s' % \ 65 | (epoch, args.N_epochs,optimizer.state_dict()['param_groups'][0]['lr'], time.asctime())) 66 | 67 | # train stage 68 | train_log = train(train_loader,net,criterion, optimizer,device) 69 | # val stage 70 | if not args.val_on_test: 71 | val_log = val(val_loader,net,criterion,device) 72 | else: 73 | val_tool.inference(net) 74 | val_log = val_tool.val() 75 | 76 | log.update(epoch,train_log,val_log) # Add log information 77 | lr_scheduler.step() 78 | 79 | # Save checkpoint of latest and best model. 80 | state = {'net': net.state_dict(),'optimizer':optimizer.state_dict(),'epoch': epoch} 81 | torch.save(state, join(save_path, 'latest_model.pth')) 82 | trigger += 1 83 | if val_log['val_auc_roc'] > best['AUC_roc']: 84 | print('\033[0;33mSaving best model!\033[0m') 85 | torch.save(state, join(save_path, 'best_model.pth')) 86 | best['epoch'] = epoch 87 | best['AUC_roc'] = val_log['val_auc_roc'] 88 | trigger = 0 89 | print('Best performance at Epoch: {} | AUC_roc: {}'.format(best['epoch'],best['AUC_roc'])) 90 | # early stopping 91 | if not args.early_stop is None: 92 | if trigger >= args.early_stop: 93 | print("=> early stopping") 94 | break 95 | torch.cuda.empty_cache() 96 | if __name__ == '__main__': 97 | main() 98 | --------------------------------------------------------------------------------