├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── build ├── __init__.py ├── build_criterion.py ├── build_dataset.py ├── build_model.py ├── build_optimizer.py └── build_scheduler.py ├── configs ├── _base_ │ ├── datasets │ │ ├── irstd1k.py │ │ ├── nuaa.py │ │ ├── nudt.py │ │ └── sirstaug.py │ ├── default_runtime.py │ ├── models │ │ └── unetseries.py │ └── schedules │ │ └── schedule_500e.py └── unetseries │ ├── unetseries_serankdet_256x256_500e_nudt.py │ ├── unetseries_serankdet_256x256_500e_sirstaug.py │ ├── unetseries_serankdet_512x512_500e_irstd1k.py │ └── unetseries_serankdet_512x512_800e_nuaa.py ├── model ├── SeRankDet │ ├── SeRankDet.py │ ├── __init__.py │ └── tools.py ├── __init__.py └── build_segmentor.py ├── test.py ├── train.py └── utils ├── __init__.py ├── data.py ├── drawing.py ├── logs.py ├── loss.py ├── metric.py ├── save_model.py ├── scheduler.py ├── tools.py └── visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | # MacOS 2 | *.DS_Store 3 | .vscode/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - name: "DeepInfrared Contributors" 5 | title: "code_repo_template Toolbox and Benchmark" 6 | date-released: 2022-12-01 7 | url: "https://github.com/YimianDai/open-code_repo_template" 8 | license: Apache-2.0 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022-2023 GrokCV. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2022-2023 GrokCV. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔥🔥 Pick of the Bunch: Detecting Infrared Small Targets Beyond Hit-Miss Trade-Offs via Selective Rank-Aware Attention 🔥🔥 2 | 3 | Yimian Dai, Peiwen Pan, Yulei Qian, Yuxuan Li, Xiang Li, Jian Yang, Huan Wang 4 | 5 | This repository is the official site for "Pick of the Bunch: Detecting Infrared Small Targets Beyond Hit-Miss Trade-Offs via Selective Rank-Aware Attention". 6 | 7 | ## Abstract 8 | 9 | Infrared small target detection faces the inherent challenge of precisely localizing dim targets amidst complex background clutter. Traditional approaches struggle to balance detection precision and false alarm rates. To break this dilemma, we propose SeRankDet, a deep network that achieves high accuracy beyond the conventional hit-miss trade-off, by following the ``Pick of the Bunch'' principle. At its core lies our Selective Rank-Aware Attention (SeRank) module, employing a non-linear Top-K selection process that preserves the most salient responses, preventing target signal dilution while maintaining constant complexity. Furthermore, we replace the static concatenation typical in U-Net structures with our Large Selective Feature Fusion (LSFF) module, a dynamic fusion strategy that empowers SeRankDet with adaptive feature integration, enhancing its ability to discriminate true targets from false alarms. The network's discernment is further refined by our Dilated Difference Convolution (DDC) module, which merges differential convolution aimed at amplifying subtle target characteristics with dilated convolution to expand the receptive field, thereby substantially improving target-background separation. Despite its lightweight architecture, the proposed SeRankDet sets new benchmarks in state-of-the-art performance across multiple public datasets. The code is available at . 10 | 11 | - [Abstract](#abstract) 12 | - [Installation](#installation) 13 | - [Step 1: Create a conda environment](#step-1-create-a-conda-environment) 14 | - [Step 2: Install PyTorch](#step-2-install-pytorch) 15 | - [Step 3: Install OpenMMLab Codebases](#step-3-install-openmmlab-codebases) 16 | - [Dataset Preparation](#dataset-preparation) 17 | - [File Structure](#file-structure) 18 | - [Datasets Link](#datasets-link) 19 | - [Training](#training) 20 | - [Single GPU Training](#single-gpu-training) 21 | - [Multi GPU Training](#multi-gpu-training) 22 | - [Notes](#notes) 23 | - [Test](#test) 24 | - [Model Zoo and Benchmark](#model-zoo-and-benchmark) 25 | - [Leaderboard](#leaderboard) 26 | - [Model Zoo](#model-zoo) 27 | - [Citation](#citation) 28 | 29 | 30 | 31 | ## Installation 32 | 33 | ### Step 1: Create a conda environment 34 | 35 | ```shell 36 | $ conda create --name SeRankDet python=3.8 37 | $ source activate SeRankDet 38 | ``` 39 | 40 | ### Step 2: Install PyTorch 41 | 42 | ```shell 43 | conda install pytorch torchvision torchaudio pytorch-cuda=10.0 -c pytorch -c nvidia 44 | ``` 45 | 46 | ### Step 3: Install OpenMMLab Codebases 47 | 48 | ```shell 49 | # openmmlab codebases 50 | pip install -U openmim 51 | mim install mmcv-full==1.4.0 52 | mim install mmdet==2.25.0 53 | mim install mmsegmentation==0.28.0 54 | ``` 55 | 56 | **Note**: make sure you have `cd` to the root directory of `SeRankDet` 57 | 58 | ```shell 59 | $ git clone git@github.com:GrokCV/SeRankDet.git 60 | $ cd SeRankDet 61 | ``` 62 | 63 | ## Dataset Preparation 64 | ### File Structure 65 | ```angular2html 66 | |- datasets 67 | |- NUAA 68 | |-trainval 69 | |-images 70 | |-Misc_1.png 71 | ...... 72 | |-masks 73 | |-Misc_1.png 74 | ...... 75 | |-test 76 | |-images 77 | |-Misc_50.png 78 | ...... 79 | |-masks 80 | |-Misc_50.png 81 | ...... 82 | |-IRSTD1k 83 | |-NUDT 84 | |-SIRSTAUG 85 | 86 | ``` 87 | Please make sure that the path of your data set is consistent with the `data_root` in `configs/_base_/datasets/dataset_name.py` 88 | 89 | ### Datasets Link 90 | 91 | The datasets used in this project and the dataset split files can be downloaded from the following links: 92 | 93 | * NoisySIRST Dataset 94 | * [Baidu Netdisk](https://pan.baidu.com/s/15RUYw23RSC20Xk1c1dMKYA?pwd=grok) 95 | * [OneDrive](https://1drv.ms/f/s!AmElF7K4aY9pgYEae4JdbbMd--tzNQ?e=yKwxa3) 96 | * SIRST Dataset 97 | * [Baidu Netdisk](https://pan.baidu.com/s/1LgnBKcE8Cqlay5GnXfUaLA?pwd=grok) 98 | * [OneDrive](https://1drv.ms/f/s!AmElF7K4aY9pgYEgG0VEoH3nDbiWDA?e=gkUW2W) 99 | * SIRST-AUG Dataset 100 | * [Baidu Netdisk](https://pan.baidu.com/s/1_kAocokYSclQNf_ZLWPIhQ?pwd=grok) 101 | * [OneDrive](https://1drv.ms/f/s!AmElF7K4aY9pgYEfdtbrZhLsbd0ITg?e=thyA6h) 102 | * NUDT-SIRST Dataset 103 | * [Baidu Netdisk](https://pan.baidu.com/s/16BbL9H38cIcvaBh4tPNTCw?pwd=grok) 104 | * [OneDrive](https://1drv.ms/f/s!AmElF7K4aY9pgYEdBMrQDFM1Vi24DQ?e=vBNoN4) 105 | * IRSTD1K Dataset 106 | * [Baidu Netdisk](https://pan.baidu.com/s/1nRoZu1eI9BLnpmsxw0Kdwg?pwd=grok) 107 | * [OneDrive](https://1drv.ms/f/s!AmElF7K4aY9pgYEepi2ipymni0amNQ?e=XZILFh) 108 | 109 | 110 | 111 | 113 | 114 | ## Training 115 | ### Single GPU Training 116 | 117 | ``` 118 | python train.py 119 | ``` 120 | 121 | For example: 122 | 123 | ``` 124 | python train.py configs/unetseries/unetseries_serankdet_512x512_500e_irstd1k.py 125 | ``` 126 | 127 | ### Multi GPU Training 128 | 129 | ```nproc_per_node``` is the number of gpus you are using. 130 | 131 | ``` 132 | python -m torch.distributed.launch --nproc_per_node=[GPU_NUMS] train.py 133 | ``` 134 | 135 | For example: 136 | 137 | ``` 138 | python -m torch.distributed.launch --nproc_per_node=4 train.py configs/unetseries/unetseries_serankdet_512x512_500e_irstd1k.py 139 | ``` 140 | 141 | ### Notes 142 | * Be sure to set args.local_rank to 0 if using Multi-GPU training. 143 | 144 | ## Test 145 | 146 | ``` 147 | python test.py 148 | ``` 149 | 150 | For example: 151 | 152 | ``` 153 | python test.py configs/unetseries/unetseries_serankdet_512x512_500e_irstd1k.py work_dirs/unetseries_serankdet_512x512_500e_irstd1k/20221009_231431/best_mIoU.pth.tar 154 | ``` 155 | 156 | If you want to visualize the result, you only add ```--show``` at the end of the above command. 157 | 158 | The default image save path is under . You can use `--work-dir` to specify the test log path, and the image save path is under this path by default. Of course, you can also use `--show-dir` to specify the image save path. 159 | 160 | ## Model Zoo and Benchmark 161 | 162 | **Note: Both passwords for BaiduYun and OneDrive is `grok`**. 163 | 164 | ### Leaderboard 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 |
ModelSIRSTIRSTD1kSIRSTAUGNUDT-SIRST
IoUnIoUIoUnIoUIoUnIoUIoUnIoU
ACM72.8872.1763.3960.8173.8469.8368.4869.26
RDIAN72.8573.9664.3764.9074.1969.8081.0681.72
AGPCNet77.1375.1968.8166.1874.7171.4988.7187.48
DNANet75.5575.9068.8767.5374.8870.2392.6792.09
MTUNet78.7576.8267.5066.1574.7070.6687.4987.70
UIUNet80.0878.0969.1367.1974.2470.5790.7790.17
ABC81.0179.0072.0268.8176.1271.8392.8592.45
SeRankDet81.2779.6673.6669.1176.4971.9894.2893.69
276 | 277 | ### Model Zoo 278 | Checkpoint and Train log: [BaiduCloud](https://pan.baidu.com/s/1iyv6Q8N23ywy1g6jGm9SLQ?pwd=grok) 279 | 280 | ## Citation 281 | 282 | If you use this toolbox or benchmark in your research, please cite this project. 283 | 284 | ```bibtex 285 | @article{dai2024SeRankDet, 286 | title={Pick of the Bunch: Detecting Infrared Small Targets Beyond Hit-Miss Trade-Offs via Selective Rank-Aware Attention}, 287 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 288 | author={Dai, Yimian and Pan, Peiwen and Qian, Yulei and Li, Yuxuan and Li, Xiang and Yang, Jian and Wang, Huan}, 289 | year={2024}, 290 | volume={62}, 291 | number={}, 292 | pages={1-15} 293 | } 294 | ``` -------------------------------------------------------------------------------- /build/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/4/1 19:17 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /build/build_criterion.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/14 20:12 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : build_criterion.py 5 | # @Software: PyCharm 6 | from utils.loss import * 7 | 8 | __all__ = ['build_criterion', 'SoftIoULoss', 'BCEWithLogits', 'CrossEntropy'] 9 | 10 | 11 | # TODO Multiple loss functions 12 | def build_criterion(cfg): 13 | criterion_name = cfg.model['loss']['type'] 14 | criterion_class = globals()[criterion_name] 15 | criterion = criterion_class(**cfg.model['loss']) 16 | return criterion 17 | -------------------------------------------------------------------------------- /build/build_dataset.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/14 20:31 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : build_dataset.py 5 | # @Software: PyCharm 6 | from utils.data import * 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.distributed import DistributedSampler 9 | 10 | 11 | def build_dataset(args, cfg): 12 | trainset = DatasetLoad(mode='train', **cfg.data) 13 | testset = DatasetLoad(mode='test', **cfg.data) 14 | train_sample = None 15 | if args.local_rank != -1: 16 | train_sample = DistributedSampler(trainset) 17 | train_data = DataLoader(dataset=trainset, batch_size=cfg.data['train_batch'], sampler=train_sample, 18 | pin_memory=True, num_workers=cfg.data['num_workers'], drop_last=True) 19 | test_data = DataLoader(dataset=testset, batch_size=cfg.data['test_batch'], shuffle=False, 20 | num_workers=cfg.data['num_workers'], drop_last=False) 21 | else: 22 | train_data = DataLoader(dataset=trainset, batch_size=cfg.data['train_batch'], shuffle=True, 23 | num_workers=cfg.data['num_workers'], drop_last=True) 24 | test_data = DataLoader(dataset=testset, batch_size=cfg.data['test_batch'], shuffle=False, 25 | num_workers=cfg.data['num_workers'], drop_last=False) 26 | return [train_sample, train_data, test_data, trainset.__len__(), testset.__len__()] \ 27 | if args.local_rank != -1 else [train_data, test_data, trainset.__len__(), testset.__len__()] 28 | -------------------------------------------------------------------------------- /build/build_model.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/14 20:10 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : build_model.py 5 | # @Software: PyCharm 6 | from mmcv import Config 7 | from model.build_segmentor import Model 8 | 9 | 10 | def build_model(cfg): 11 | model = Model(cfg) 12 | return model 13 | 14 | 15 | if __name__ == '__main__': 16 | config = '/data1/ppw/works/All_ISTD/configs/segnext/segnext_tiny_512x512_800e_nuaa.py' 17 | cfg = Config.fromfile(config) 18 | model = build_model(cfg) 19 | total = sum([param.nelement() for param in model.parameters()]) 20 | print(total) 21 | import torch 22 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 23 | x = torch.rand(8, 3, 512, 512) 24 | model = model.to(device) 25 | x = x.to(device) 26 | out = model(x) 27 | print(out.size()) 28 | -------------------------------------------------------------------------------- /build/build_optimizer.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/14 20:17 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : build_optimizer.py 5 | # @Software: PyCharm 6 | from torch.optim import * 7 | 8 | __all__ = ['build_optimizer', 'Adagrad', 'Adadelta', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'LBFGS', 'NAdam', 'RAdam', 9 | 'RMSprop', 'Rprop', 'SGD', 'SparseAdam'] 10 | 11 | 12 | # TODO Solve the problem that **kwargs cannot be passed 13 | def build_optimizer(model, cfg): 14 | optimizer_name = cfg.optimizer['type'] 15 | optimizer_class = globals()[optimizer_name] 16 | return optimizer_class(model.parameters(), **cfg.optimizer['setting']) 17 | -------------------------------------------------------------------------------- /build/build_scheduler.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/15 20:39 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : build_scheduler.py 5 | # @Software: PyCharm 6 | from utils.scheduler import * 7 | 8 | __all__ = ['build_scheduler', 'PolyLR', 'CosineAnnealingLR', 'StepLR', 'linear'] 9 | 10 | 11 | # TODO Add scheduler and warmup 12 | def build_scheduler(optimizer, cfg): 13 | scheduler_name = cfg.lr_config['policy'] 14 | scheduler_class = globals()[scheduler_name] 15 | scheduler = scheduler_class(optimizer=optimizer, base_lr=cfg.optimizer['setting']['lr'], 16 | num_epochs=cfg.runner['max_epochs'], **cfg.lr_config) 17 | return scheduler 18 | -------------------------------------------------------------------------------- /configs/_base_/datasets/irstd1k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data = dict( 3 | dataset_type='IRSTD1k', 4 | data_root='/data1/ppw/works/All_ISTD/datasets/IRSTD-1k', 5 | base_size=512, 6 | crop_size=512, 7 | data_aug=True, 8 | suffix='png', 9 | num_workers=8, 10 | train_batch=8, 11 | test_batch=8, 12 | train_dir='trainval', 13 | test_dir='test' 14 | ) 15 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nuaa.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data = dict( 3 | dataset_type='NUAA', 4 | data_root='/data1/ppw/works/All_ISTD/datasets/NUAA', 5 | base_size=512, 6 | crop_size=512, 7 | data_aug=True, 8 | suffix='png', 9 | num_workers=8, 10 | train_batch=8, 11 | test_batch=8, 12 | train_dir='trainval', 13 | test_dir='test' 14 | ) 15 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nudt.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data = dict( 3 | dataset_type='NUDT', 4 | data_root='/data1/ppw/works/All_ISTD/datasets/NUDT', 5 | base_size=256, 6 | crop_size=256, 7 | data_aug=True, 8 | suffix='png', 9 | num_workers=8, 10 | train_batch=32, 11 | test_batch=32, 12 | train_dir='trainval', 13 | test_dir='test' 14 | ) 15 | -------------------------------------------------------------------------------- /configs/_base_/datasets/sirstaug.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data = dict( 3 | dataset_type='SIRSTAUG', 4 | data_root='/data1/ppw/works/All_ISTD/datasets/SIRST_AUG', 5 | base_size=256, 6 | crop_size=256, 7 | data_aug=True, 8 | suffix='png', 9 | num_workers=8, 10 | train_batch=32, 11 | test_batch=32, 12 | train_dir='trainval', 13 | test_dir='test' 14 | ) 15 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=10, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=True), 6 | dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | find_unused_parameters = False 16 | random_seed = 42 17 | gpus = 1 18 | -------------------------------------------------------------------------------- /configs/_base_/models/unetseries.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | name='UNet', 3 | type='EncoderDecoder', 4 | pretrained=None, 5 | backbone=dict( 6 | type=None 7 | ), 8 | decode_head=dict( 9 | type='U_Net', 10 | deep_supervision=True 11 | ), 12 | loss=dict(type='SoftIoULoss') 13 | ) 14 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_500e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', 4 | setting=dict(lr=0.01, weight_decay=0.0005) 5 | ) 6 | optimizer_config = dict() 7 | # learning policy 8 | # TODO warmup only 'linear' 9 | lr_config = dict(policy='PolyLR', warmup='linear', power=0.9, min_lr=1e-4, warmup_epochs=5) 10 | # runtime settings 11 | runner = dict(type='EpochBasedRunner', max_epochs=500) 12 | checkpoint_config = dict(by_epoch=False, interval=1) 13 | evaluation = dict(epochval=1, metric='mIoU', pre_eval=True) 14 | -------------------------------------------------------------------------------- /configs/unetseries/unetseries_serankdet_256x256_500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/unetseries.py' 6 | ] 7 | 8 | model = dict( 9 | decode_head=dict( 10 | type='SeRankDet' 11 | ) 12 | ) 13 | 14 | optimizer = dict( 15 | type='AdamW', 16 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 17 | ) 18 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 19 | data = dict( 20 | train_batch=8, 21 | test_batch=8) 22 | develop = dict(source_file_root='/data1/ppw/works/All_ISTD/model/UNetSeries/SeRankDet.py') 23 | -------------------------------------------------------------------------------- /configs/unetseries/unetseries_serankdet_256x256_500e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/unetseries.py' 6 | ] 7 | 8 | model = dict( 9 | decode_head=dict( 10 | type='SeRankDet' 11 | ) 12 | ) 13 | 14 | optimizer = dict( 15 | type='AdamW', 16 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 17 | ) 18 | runner = dict(type='EpochBasedRunner', max_epochs=300) 19 | data = dict( 20 | train_batch=8, 21 | test_batch=8) 22 | develop = dict(source_file_root='/data1/ppw/works/All_ISTD/model/UNetSeries/SeRankDet.py') 23 | -------------------------------------------------------------------------------- /configs/unetseries/unetseries_serankdet_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/unetseries.py' 6 | ] 7 | 8 | model = dict( 9 | decode_head=dict( 10 | type='SeRankDet' 11 | ) 12 | ) 13 | 14 | optimizer = dict( 15 | type='AdamW', 16 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 17 | ) 18 | 19 | runner = dict(type='EpochBasedRunner', max_epochs=500) 20 | data = dict( 21 | train_batch=4, 22 | test_batch=4) 23 | develop = dict(source_file_root='/data1/ppw/works/All_ISTD/model/UNetSeries/SeRankDet.py') 24 | # random_seed = 64 25 | find_unused_parameters = True 26 | 27 | # python rebuild_train.py configs/unetseries/unetseries_serankdet_512x512_500e_irstd1k.py 28 | -------------------------------------------------------------------------------- /configs/unetseries/unetseries_serankdet_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/unetseries.py' 6 | ] 7 | 8 | model = dict( 9 | decode_head=dict( 10 | type='SeRankDet', 11 | deep_supervision=True 12 | ) 13 | ) 14 | 15 | optimizer = dict( 16 | type='AdamW', 17 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 18 | ) 19 | 20 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 21 | data = dict( 22 | train_batch=4, 23 | test_batch=4) 24 | develop = dict(source_file_root='/data1/ppw/works/All_ISTD/model/UNetSeries/SeRankDet.py') 25 | find_unused_parameters = True 26 | random_seed = 42 27 | # python -m torch.distributed.launch --nproc_per_node=4 --master_port=25642 rebuild_train.py configs/unetseries/unetseries_serankdet_512x512_500e_irstd1k.py 28 | # python rebuild_train.py configs/unetseries/unetseries_serankdet_512x512_800e_nuaa.py 29 | -------------------------------------------------------------------------------- /model/SeRankDet/SeRankDet.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/12/26 13:14 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : SeRankDet.py 5 | # @Software: PyCharm 6 | from __future__ import print_function, division 7 | 8 | import einops 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.utils.data 14 | from torch import einsum 15 | 16 | from model.SeRankDet.tools import conv_relu_bn, CDC_conv 17 | 18 | 19 | class conv_block(nn.Module): 20 | """ 21 | Convolution Block 22 | """ 23 | 24 | def __init__(self, in_ch, out_ch): 25 | super(conv_block, self).__init__() 26 | 27 | self.conv = nn.Sequential( 28 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 29 | nn.BatchNorm2d(out_ch), 30 | nn.ReLU(inplace=True), 31 | CDC_conv(out_ch, out_ch), 32 | nn.BatchNorm2d(out_ch), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 35 | nn.BatchNorm2d(out_ch), 36 | nn.ReLU(inplace=True), 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.conv(x) 41 | return x 42 | 43 | 44 | # class new_conv_block(nn.Module): 45 | # """ 46 | # Convolution Block 47 | # """ 48 | # 49 | # def __init__(self, in_ch, out_ch): 50 | # super(new_conv_block, self).__init__() 51 | # self.conv_layer = nn.Sequential( 52 | # conv_relu_bn(in_ch, in_ch, 1), 53 | # conv_relu_bn(in_ch, out_ch, 1), 54 | # conv_relu_bn(out_ch, out_ch, 1) 55 | # ) 56 | # self.cdc_layer = nn.Sequential( 57 | # CDC_conv(in_ch, out_ch // 2), 58 | # nn.BatchNorm2d(out_ch // 2), 59 | # nn.ReLU(inplace=True) 60 | # ) 61 | # self.dconv_layer = nn.Sequential( 62 | # conv_relu_bn(in_ch, out_ch, 2), 63 | # conv_relu_bn(out_ch, out_ch // 2, 4), 64 | # conv_relu_bn(out_ch // 2, out_ch // 2, 2) 65 | # ) 66 | # self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3) 67 | # self.conv = nn.Conv2d(out_ch // 2, out_ch, 1) 68 | # 69 | # self.gamma = nn.Parameter(torch.zeros(1)) 70 | # self.cdcd_conv = conv_relu_bn(out_ch, out_ch, 1) 71 | # self.final_conv = conv_relu_bn(out_ch, out_ch, 1) 72 | # 73 | # 74 | # def forward(self, x): 75 | # conv_out = self.conv_layer(x) 76 | # cdc_out = self.cdc_layer(x) 77 | # dconv_out = self.dconv_layer(x) 78 | # 79 | # cdcd_out = torch.concat([cdc_out, dconv_out], dim=1) 80 | # cdcd_out = self.cdcd_conv(cdcd_out) 81 | # 82 | # attn = torch.concat([cdc_out, dconv_out], dim=1) 83 | # avg_attn = torch.mean(attn, dim=1, keepdim=True) 84 | # max_attn, _ = torch.max(attn, dim=1, keepdim=True) 85 | # agg = torch.concat([avg_attn, max_attn], dim=1) 86 | # sig = self.conv_squeeze(agg).sigmoid() 87 | # attn = cdc_out * sig[:, 0, :, :].unsqueeze(1) + dconv_out * sig[:, 1, :, :].unsqueeze(1) 88 | # attn = self.conv(attn) 89 | # 90 | # out = conv_out * attn * self.gamma + conv_out + cdcd_out 91 | # # out = self.final_conv(out) 92 | # return out 93 | 94 | 95 | class Neck(nn.Module): 96 | def __init__(self, in_ch, out_ch): 97 | super(Neck, self).__init__() 98 | self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3) 99 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 100 | self.gamma = nn.Parameter(torch.zeros(1)) 101 | 102 | def forward(self, enc, dec): 103 | x = torch.cat([enc, dec], dim=1) 104 | out_x = torch.cat([enc, dec], dim=1) 105 | attn = torch.cat([enc, dec], dim=1) 106 | avg_attn = torch.mean(attn, dim=1, keepdim=True) 107 | max_attn, _ = torch.max(attn, dim=1, keepdim=True) 108 | agg = torch.concat([avg_attn, max_attn], dim=1) 109 | sig = self.conv_squeeze(agg).sigmoid() 110 | attn = enc * sig[:, 0, :, :].unsqueeze(1) + dec * sig[:, 1, :, :].unsqueeze(1) 111 | attn = self.conv(attn) 112 | out = out_x * attn 113 | return self.gamma * out + x 114 | 115 | 116 | class new_conv_block(nn.Module): 117 | """ 118 | Convolution Block 119 | """ 120 | 121 | def __init__(self, in_ch, out_ch): 122 | super(new_conv_block, self).__init__() 123 | self.conv_layer = nn.Sequential( 124 | conv_relu_bn(in_ch, in_ch, 1), 125 | conv_relu_bn(in_ch, out_ch, 1), 126 | conv_relu_bn(out_ch, out_ch, 1), 127 | ) 128 | self.cdc_layer = nn.Sequential( 129 | CDC_conv(in_ch, out_ch), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) 130 | ) 131 | self.dconv_layer = nn.Sequential( 132 | conv_relu_bn(in_ch, in_ch, 2), 133 | conv_relu_bn(in_ch, out_ch, 4), 134 | conv_relu_bn(out_ch, out_ch, 2), 135 | ) 136 | self.final_layer = conv_relu_bn(out_ch * 3, out_ch, 1) 137 | 138 | def forward(self, x): 139 | conv_out = self.conv_layer(x) 140 | cdc_out = self.cdc_layer(x) 141 | dconv_out = self.dconv_layer(x) 142 | out = torch.concat([conv_out, cdc_out, dconv_out], dim=1) 143 | out = self.final_layer(out) 144 | return out 145 | 146 | 147 | class up_conv(nn.Module): 148 | """ 149 | Up Convolution Block 150 | """ 151 | 152 | def __init__(self, in_ch, out_ch): 153 | super(up_conv, self).__init__() 154 | self.up = nn.Sequential( 155 | nn.Upsample(scale_factor=2, mode="bilinear"), 156 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 157 | nn.BatchNorm2d(out_ch), 158 | nn.ReLU(inplace=True), 159 | ) 160 | 161 | def forward(self, x): 162 | x = self.up(x) 163 | return x 164 | 165 | 166 | # self.active = torch.nn.Sigmoid() 167 | def _upsample_like(src, tar): 168 | src = F.upsample(src, size=tar.shape[2:], mode="bilinear") 169 | return src 170 | 171 | 172 | class PositionalEncoding(nn.Module): 173 | def __init__(self, image_size, embedding_size): 174 | super(PositionalEncoding, self).__init__() 175 | pos_enc = torch.zeros(image_size, embedding_size) 176 | position = torch.arange(0, image_size).unsqueeze(1) 177 | div_term = torch.exp( 178 | torch.arange(0, embedding_size, 2) * (-math.log(10000.0) / embedding_size) 179 | ) 180 | pos_enc[:, 0::2].copy_(torch.sin(position * div_term)) 181 | pos_enc[:, 1::2].copy_(torch.cos(position * div_term)) 182 | self.pos_enc = nn.Parameter(pos_enc, requires_grad=False) 183 | self.mp = {64: 512, 128: 256, 256: 128, 512: 64, 1024: 32} 184 | 185 | def forward(self, x): 186 | b, c, h, w = x.size() 187 | max_feature = x.view(b, c, -1) 188 | _, topk_indices = torch.topk(max_feature, k=self.mp[c], dim=2) 189 | topk_indices, _ = torch.sort(topk_indices, dim=2) 190 | max_feature = torch.gather(max_feature, 2, topk_indices) 191 | 192 | indices_x = topk_indices // w 193 | indices_y = topk_indices % w 194 | pos_embed = self.pos_enc[indices_x, indices_y] 195 | max_feature = pos_embed + max_feature 196 | 197 | return max_feature 198 | 199 | 200 | def compute_index(x, y, max_x): 201 | return x + y * max_x 202 | 203 | 204 | class MaxChannel(nn.Module): 205 | def __init__(self, in_ch, num_embeddings): 206 | super(MaxChannel, self).__init__() 207 | self.fc1 = nn.Linear(in_features=in_ch, out_features=in_ch * 2, bias=False) 208 | self.fc2 = nn.Linear(in_features=in_ch, out_features=in_ch * 2, bias=False) 209 | self.attend = nn.Softmax(dim=-1) 210 | self.pos_enc = PositionalEncoding(num_embeddings, num_embeddings) 211 | # self.embedding_layer = nn.Embedding(num_embeddings, 1) 212 | # self.num_emb = num_embeddings 213 | self.mp = {64: 512, 128: 256, 256: 128, 512: 64, 1024: 32} 214 | 215 | def forward(self, x): 216 | b, c, h, w = x.size() 217 | # max_feature = x.view(b, c, -1) 218 | # _, topk_indices = torch.topk(max_feature, k=self.mp[c], dim=2) 219 | # topk_indices, _ = torch.sort(topk_indices, dim=2) 220 | # max_feature = torch.gather(max_feature, 2, topk_indices) 221 | max_feature = self.pos_enc(x) 222 | 223 | # indices_x = topk_indices // w 224 | # indices_y = topk_indices % w 225 | # indices = compute_index(indices_x, indices_y, 2) 226 | # output = self.embedding_layer(topk_indices // self.num_emb) 227 | # output = output.squeeze(dim=3) 228 | # max_feature = output + max_feature 229 | 230 | q = self.fc1(max_feature) 231 | k = self.fc2(max_feature) 232 | k = einops.rearrange(k, "b c m -> b m c") 233 | attend = torch.matmul(q, k) 234 | attend = (attend - torch.mean(attend)) / (torch.std(attend) + 1e-5) 235 | attention = self.attend(attend) 236 | new_x = einops.rearrange(x, "b c h w -> b c (h w)") 237 | res = torch.matmul(attention, new_x) 238 | return einops.rearrange(res, "b c (h w) -> b c h w", h=h) + x 239 | 240 | 241 | class SeRankDet(nn.Module): 242 | """ 243 | UNet - Basic Implementation 244 | Paper : https://arxiv.org/abs/1505.04597 245 | """ 246 | 247 | def __init__(self, in_ch=3, out_ch=1, deep_supervision=True, **kwargs): 248 | super(SeRankDet, self).__init__() 249 | self.deep_supervision = deep_supervision 250 | n1 = 64 251 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 252 | 253 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 254 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 255 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) 256 | self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) 257 | 258 | self.Conv1 = new_conv_block(in_ch, filters[0]) 259 | self.Conv2 = new_conv_block(filters[0], filters[1]) 260 | self.Conv3 = new_conv_block(filters[1], filters[2]) 261 | self.Conv4 = new_conv_block(filters[2], filters[3]) 262 | self.Conv5 = new_conv_block(filters[3], filters[4]) 263 | 264 | self.max_channel1 = MaxChannel(512, 512) 265 | self.max_channel2 = MaxChannel(256, 256) 266 | self.max_channel3 = MaxChannel(128, 128) 267 | self.max_channel4 = MaxChannel(64, 64) 268 | self.max_channel5 = MaxChannel(32, 32) 269 | 270 | self.neck5 = Neck(filters[3], filters[4]) 271 | self.neck4 = Neck(filters[2], filters[3]) 272 | self.neck3 = Neck(filters[1], filters[2]) 273 | self.neck2 = Neck(filters[0], filters[1]) 274 | 275 | self.Up5 = up_conv(filters[4], filters[3]) 276 | self.Up_conv5 = conv_block(filters[4], filters[3]) 277 | 278 | self.Up4 = up_conv(filters[3], filters[2]) 279 | self.Up_conv4 = conv_block(filters[3], filters[2]) 280 | 281 | self.Up3 = up_conv(filters[2], filters[1]) 282 | self.Up_conv3 = conv_block(filters[2], filters[1]) 283 | 284 | self.Up2 = up_conv(filters[1], filters[0]) 285 | self.Up_conv2 = conv_block(filters[1], filters[0]) 286 | 287 | self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) 288 | 289 | # -------------------------------------------------------------------------------------------------------------- 290 | self.conv5 = nn.Conv2d(filters[4], out_ch, kernel_size=3, stride=1, padding=1) 291 | self.conv4 = nn.Conv2d(filters[3], out_ch, kernel_size=3, stride=1, padding=1) 292 | self.conv3 = nn.Conv2d(filters[2], out_ch, kernel_size=3, stride=1, padding=1) 293 | self.conv2 = nn.Conv2d(filters[1], out_ch, kernel_size=3, stride=1, padding=1) 294 | self.conv1 = nn.Conv2d(filters[0], out_ch, kernel_size=3, stride=1, padding=1) 295 | # -------------------------------------------------------------------------------------------------------------- 296 | 297 | def forward(self, x): 298 | e1 = self.Conv1(x) 299 | e1 = self.max_channel1(e1) 300 | 301 | e2 = self.Maxpool1(e1) 302 | e2 = self.Conv2(e2) 303 | e2 = self.max_channel2(e2) 304 | 305 | e3 = self.Maxpool2(e2) 306 | e3 = self.Conv3(e3) 307 | e3 = self.max_channel3(e3) 308 | 309 | e4 = self.Maxpool3(e3) 310 | e4 = self.Conv4(e4) 311 | e4 = self.max_channel4(e4) 312 | 313 | e5 = self.Maxpool4(e4) 314 | e5 = self.Conv5(e5) 315 | e5 = self.max_channel5(e5) 316 | 317 | d5 = self.Up5(e5) 318 | # d5 = torch.cat((e4, d5), dim=1) 319 | d5 = self.neck5(e4, d5) 320 | d5 = self.Up_conv5(d5) 321 | 322 | d4 = self.Up4(d5) 323 | # d4 = torch.cat((e3, d4), dim=1) 324 | d4 = self.neck4(e3, d4) 325 | d4 = self.Up_conv4(d4) 326 | 327 | d3 = self.Up3(d4) 328 | # d3 = torch.cat((e2, d3), dim=1) 329 | d3 = self.neck3(e2, d3) 330 | d3 = self.Up_conv3(d3) 331 | 332 | d2 = self.Up2(d3) 333 | # d2 = torch.cat((e1, d2), dim=1) 334 | d2 = self.neck2(e1, d2) 335 | d2 = self.Up_conv2(d2) 336 | 337 | out = self.Conv(d2) 338 | 339 | d_s1 = self.conv1(d2) 340 | d_s2 = self.conv2(d3) 341 | d_s2 = _upsample_like(d_s2, d_s1) 342 | d_s3 = self.conv3(d4) 343 | d_s3 = _upsample_like(d_s3, d_s1) 344 | d_s4 = self.conv4(d5) 345 | d_s4 = _upsample_like(d_s4, d_s1) 346 | d_s5 = self.conv5(e5) 347 | d_s5 = _upsample_like(d_s5, d_s1) 348 | if self.deep_supervision: 349 | outs = [d_s1, d_s2, d_s3, d_s4, d_s5, out] 350 | else: 351 | outs = out 352 | # d1 = self.active(out) 353 | 354 | return outs 355 | 356 | 357 | if __name__ == "__main__": 358 | x = torch.rand(8, 3, 512, 512) 359 | model = SeRankDet() 360 | outs = model(x) 361 | for out in outs: 362 | print(out.size()) 363 | -------------------------------------------------------------------------------- /model/SeRankDet/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/10/5 11:03 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | from model.SeRankDet.SeRankDet import SeRankDet 7 | -------------------------------------------------------------------------------- /model/SeRankDet/tools.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/6/15 19:02 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : tools.py 5 | # @Software: PyCharm 6 | from __future__ import print_function, division 7 | from einops import rearrange 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data 11 | import torch 12 | 13 | 14 | def conv_relu_bn(in_channel, out_channel, dirate=1): 15 | return nn.Sequential( 16 | nn.Conv2d( 17 | in_channels=in_channel, 18 | out_channels=out_channel, 19 | kernel_size=3, 20 | stride=1, 21 | padding=dirate, 22 | dilation=dirate, 23 | ), 24 | nn.BatchNorm2d(out_channel), 25 | nn.ReLU(inplace=True), 26 | ) 27 | 28 | 29 | class dconv_block(nn.Module): 30 | """ 31 | Convolution Block 32 | """ 33 | 34 | def __init__(self, in_ch, out_ch): 35 | super(dconv_block, self).__init__() 36 | self.conv1 = conv_relu_bn(in_ch, out_ch, 1) 37 | self.dconv1 = conv_relu_bn(out_ch, out_ch // 2, 2) 38 | self.dconv2 = conv_relu_bn(out_ch // 2, out_ch // 2, 4) 39 | self.dconv3 = conv_relu_bn(out_ch, out_ch, 2) 40 | self.conv2 = conv_relu_bn(out_ch * 2, out_ch, 1) 41 | 42 | def forward(self, x): 43 | x1 = self.conv1(x) 44 | dx1 = self.dconv1(x1) 45 | dx2 = self.dconv2(dx1) 46 | dx3 = self.dconv3(torch.cat((dx1, dx2), dim=1)) 47 | out = self.conv2(torch.cat((x1, dx3), dim=1)) 48 | return out 49 | 50 | 51 | class CDC_conv(nn.Module): 52 | def __init__( 53 | self, 54 | in_channels, 55 | out_channels, 56 | bias=True, 57 | kernel_size=3, 58 | padding=1, 59 | dilation=1, 60 | theta=0.7, 61 | ): 62 | super().__init__() 63 | self.conv = nn.Conv2d( 64 | in_channels, 65 | out_channels, 66 | kernel_size=kernel_size, 67 | padding=padding, 68 | dilation=dilation, 69 | bias=bias, 70 | ) 71 | self.theta = theta 72 | 73 | def forward(self, x): 74 | norm_out = self.conv(x) 75 | [c_out, c_in, kernel_size, kernel_size] = self.conv.weight.shape 76 | kernel_diff = self.conv.weight.sum(2).sum(2) 77 | kernel_diff = kernel_diff[:, :, None, None] 78 | diff_out = F.conv2d( 79 | input=x, 80 | weight=kernel_diff, 81 | bias=self.conv.bias, 82 | stride=self.conv.stride, 83 | padding=0, 84 | ) 85 | out = norm_out - self.theta * diff_out 86 | return out 87 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/4/1 19:17 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /model/build_segmentor.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/22 17:02 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : build_segmentor.py 5 | # @Software: PyCharm 6 | import torch.nn as nn 7 | 8 | from model.SeRankDet import SeRankDet 9 | 10 | __all__ = ["Model", "SeRankDet"] 11 | 12 | 13 | class Model(nn.Module): 14 | def __init__(self, cfg): 15 | super(Model, self).__init__() 16 | backbone_name = ( 17 | cfg.model["backbone"]["type"] if cfg.model["backbone"]["type"] else None 18 | ) 19 | decode_name = cfg.model["decode_head"]["type"] 20 | backbone_class = globals()[backbone_name] if backbone_name else None 21 | decode_class = globals()[decode_name] 22 | self.backbone = ( 23 | backbone_class(**cfg.model["backbone"]) if backbone_name else None 24 | ) 25 | self.decode_head = decode_class(**cfg.model["decode_head"]) 26 | 27 | def forward(self, x): 28 | if self.backbone: 29 | x = self.backbone(x) 30 | out = self.decode_head(x) 31 | return out 32 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/10/1 20:01 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : rebuild_test.py 5 | # @Software: PyCharm 6 | import argparse 7 | import os 8 | from mmcv import Config 9 | from tqdm import tqdm 10 | from build.build_model import build_model 11 | from build.build_criterion import build_criterion 12 | from build.build_dataset import build_dataset 13 | 14 | from utils.visual import * 15 | from utils.tools import * 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser( 20 | description='mmseg test (and eval) a model') 21 | parser.add_argument('config', help='test config file path') 22 | parser.add_argument('checkpoint', help='checkpoint file') 23 | parser.add_argument( 24 | '--work-dir', 25 | help=('if specified, the evaluation metric results will be dumped' 26 | 'into the directory as txt')) 27 | parser.add_argument('--show', action='store_true', help='show results') 28 | parser.add_argument( 29 | '--show-dir', help='directory where painted images will be saved') 30 | parser.add_argument( 31 | '--gpu-id', 32 | type=int, 33 | default=4, 34 | help='id of gpu to use ' 35 | '(only applicable to non-distributed testing)') 36 | parser.add_argument('--local_rank', type=int, default=-1) 37 | args = parser.parse_args() 38 | if 'LOCAL_RANK' not in os.environ: 39 | os.environ['LOCAL_RANK'] = str(args.local_rank) 40 | return args 41 | 42 | 43 | class Test(object): 44 | def __init__(self, args, cfg): 45 | super(Test, self).__init__() 46 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s', datefmt='%F %T') 47 | self.cfg = cfg 48 | self.deep_supervision = 'deep_supervision' in self.cfg.model['decode_head'] 49 | self.save_dir = args.work_dir if args.work_dir else os.path.dirname(os.path.abspath(args.checkpoint)) 50 | self.show_dir = args.show_dir if args.show_dir else os.path.join(self.save_dir, 'show') 51 | make_show_dir(self.show_dir) if args.show else empty_function() 52 | _, self.test_data, _, self.img_num = build_dataset(args, self.cfg) 53 | self.criterion = build_criterion(self.cfg) 54 | self.model = build_model(self.cfg) 55 | self.mIoU_metric = SigmoidMetric() 56 | self.nIoU_metric = SamplewiseSigmoidMetric(1, score_thresh=0.5) 57 | self.ROC = ROCMetric(1, 10) 58 | self.PD_FA = PD_FA(1, 10, cfg) 59 | self.best_recall = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 60 | self.best_precision = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 61 | self.mIoU_metric.reset() 62 | self.nIoU_metric.reset() 63 | self.PD_FA.reset() 64 | 65 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 66 | checkpoint = torch.load(args.checkpoint) 67 | self.model.load_state_dict(checkpoint['state_dict']) 68 | logging.info("Model Initializing") 69 | self.model = self.model.to(self.device) 70 | self.model.eval() 71 | tbar = tqdm(self.test_data) 72 | losses = [] 73 | 74 | with torch.no_grad(): 75 | for i, (img, mask) in enumerate(tbar): 76 | img, mask = data2device(args, (img, mask), self.device) 77 | preds = self.model(img) 78 | loss, preds = compute_loss(preds, mask, self.deep_supervision, cfg, self.criterion) 79 | losses.append(loss.item()) 80 | 81 | self.ROC.update(preds, mask) 82 | self.mIoU_metric.update(preds, mask) 83 | self.nIoU_metric.update(preds, mask) 84 | self.PD_FA.update(preds, mask) 85 | _, mIoU = self.mIoU_metric.get() 86 | _, nIoU = self.nIoU_metric.get() 87 | ture_positive_rate, false_positive_rate, recall, precision, F1_score = self.ROC.get() 88 | tbar.set_description( 89 | 'Loss %.4f, mIoU %.4f, nIoU %.4f, F1-score %.4f' % (np.mean(losses), mIoU, nIoU, F1_score)) 90 | if args.show: 91 | save_Pred_GT(preds, mask, self.show_dir, cfg.data['test_batch'] * i, cfg) 92 | FA, PD = self.PD_FA.get(self.img_num) 93 | save_test_config(cfg, self.save_dir) 94 | save_result_for_test(self.save_dir, mIoU, nIoU, recall, precision, FA, PD, F1_score, ture_positive_rate, 95 | false_positive_rate) 96 | if args.show: 97 | total_show_generation(self.show_dir, cfg) 98 | logging.info('Finishing') 99 | logging.info('mIoU: %.4f, nIoU: %.4f, F1-score: %.4f' % (mIoU, nIoU, F1_score)) 100 | 101 | 102 | def main(args): 103 | cfg = Config.fromfile(args.config) 104 | tester = Test(args, cfg) 105 | 106 | 107 | if __name__ == "__main__": 108 | args = parse_args() 109 | main(args) 110 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/6/16 16:36 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : rebuild_train.py 5 | # @Software: PyCharm 6 | import argparse 7 | import os 8 | import time 9 | 10 | import torch.distributed 11 | import torch.nn 12 | from mmcv import Config, DictAction 13 | 14 | from utils.tools import * 15 | 16 | from build.build_model import build_model 17 | from build.build_criterion import build_criterion 18 | from build.build_optimizer import build_optimizer 19 | from build.build_dataset import build_dataset 20 | from build.build_scheduler import build_scheduler 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description='Train a segmentor') 25 | parser.add_argument('config', help='train config file path') 26 | parser.add_argument( 27 | '--load-from', help='the checkpoint file to load weights from') 28 | parser.add_argument( 29 | '--resume-from', help='the checkpoint file to resume from') 30 | group_gpus = parser.add_mutually_exclusive_group() 31 | group_gpus.add_argument( 32 | '--gpu-id', 33 | type=int, 34 | default=0, 35 | help='id of gpu to use ' 36 | '(only applicable to non-distributed training)') 37 | parser.add_argument( 38 | '--cfg-options', 39 | nargs='+', 40 | action=DictAction, 41 | help='override some settings in the used config, the key-value pair ' 42 | 'in xxx=yyy format will be merged into config file. If the value to ' 43 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 44 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 45 | 'Note that the quotation marks are necessary and that no white space ' 46 | 'is allowed.') 47 | parser.add_argument('--local_rank', type=int, default=-1) 48 | args = parser.parse_args() 49 | if 'LOCAL_RANK' not in os.environ: 50 | os.environ['LOCAL_RANK'] = str(args.local_rank) 51 | return args 52 | 53 | 54 | class Train(object): 55 | def __init__(self, args, cfg): 56 | super(Train, self).__init__() 57 | self.cfg = cfg 58 | self.cfg.gpus = torch.cuda.device_count() if args.local_rank != -1 else 1 59 | self.resume = args.resume_from 60 | self.deep_supervision = 'deep_supervision' in self.cfg.model['decode_head'] 61 | 62 | self.device = init_devices(args, self.cfg) 63 | 64 | data = build_dataset(args, self.cfg) 65 | self.data = init_data(args, data) 66 | 67 | # TODO Initialized inside each model 68 | model = build_model(self.cfg) 69 | # TODO Add model struct to tensorboard 70 | self.model, checkpoint = init_model(args, self.cfg, model, self.device) 71 | self.criterion = build_criterion(self.cfg) 72 | optimizer = build_optimizer(self.model, self.cfg) 73 | if self.cfg.lr_config['policy']: 74 | self.scheduler = build_scheduler(optimizer, self.cfg) 75 | 76 | self.optimizer, self.metrics = init_metrics(args, optimizer, checkpoint if args.resume_from else None) 77 | self.save_dir, self.train_log_file, self.write = save_log(args, self.cfg, self.model) 78 | 79 | def training(self, epoch): 80 | self.model.train() 81 | losses = [] 82 | if args.local_rank != -1: 83 | self.data['train_sample'].set_epoch(epoch) 84 | # FIXME 85 | if not self.resume and self.cfg.lr_config['policy']: 86 | self.scheduler.step(epoch - 1) 87 | 88 | for i, data in enumerate(self.data['train_data']): 89 | since = time.time() 90 | img, mask = data2device(args, data, self.device) 91 | preds = self.model(img) 92 | loss, _ = compute_loss(preds, mask, self.deep_supervision, self.cfg, self.criterion) 93 | self.optimizer.zero_grad() 94 | loss.backward() 95 | self.optimizer.step() 96 | losses.append(loss.item()) 97 | time_elapsed = time.time() - since 98 | show_log('train', args, self.cfg, epoch, losses, self.save_dir, self.train_log_file, i=i, data=self.data, 99 | time_elapsed=time_elapsed, optimizer=self.optimizer) 100 | save_model('train', args, self.cfg, epoch, self.model, losses, self.optimizer, self.metrics, self.save_dir, 101 | self.train_log_file) 102 | update_log('train', args, self.metrics, self.write, losses, epoch, optimizer=self.optimizer) 103 | 104 | def testing(self, epoch): 105 | self.model.eval() 106 | reset_metrics(self.metrics) 107 | eval_losses = [] 108 | with torch.no_grad(): 109 | for i, data in enumerate(self.data['test_data']): 110 | img, mask = data2device(args, data, self.device) 111 | preds = self.model(img) 112 | loss, preds = compute_loss(preds, mask, self.deep_supervision, self.cfg, self.criterion) 113 | eval_losses.append(loss.item()) 114 | IoU, nIoU, F1_score = update_metrics(preds, mask, self.metrics) 115 | show_log('test', args, self.cfg, epoch, eval_losses, self.save_dir, self.train_log_file, IoU=IoU, nIoU=nIoU, 116 | F1_score=F1_score, metrics=self.metrics) 117 | append_metrics(args, self.metrics, eval_losses, IoU, nIoU, F1_score) 118 | save_model('test', args, self.cfg, epoch, self.model, eval_losses, self.optimizer, self.metrics, 119 | self.save_dir, self.train_log_file, IoU=IoU, nIoU=nIoU) 120 | draw(args, self.metrics, self.save_dir, self.train_log_file) 121 | update_log('test', args, self.metrics, self.write, eval_losses, epoch, IoU=IoU, nIoU=nIoU, 122 | F1_score=F1_score) 123 | 124 | 125 | def main(args): 126 | cfg = Config.fromfile(args.config) 127 | trainer = Train(args, cfg) 128 | if args.local_rank != -1: 129 | torch.distributed.barrier() 130 | start = torch.load(args.resume_from)['epoch'] + 1 if args.resume_from else 1 131 | for i in range(start, cfg.runner['max_epochs'] + 1): 132 | trainer.training(i) 133 | trainer.testing(i) 134 | 135 | 136 | if __name__ == '__main__': 137 | args = parse_args() 138 | main(args) 139 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 14:41 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 14:41 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : data.py 5 | # @Software: PyCharm 6 | import random 7 | import sys 8 | import os.path as osp 9 | import os 10 | from PIL import Image, ImageOps, ImageFilter 11 | import torchvision.transforms as transforms 12 | import torch.utils.data as Data 13 | import torch 14 | import numpy as np 15 | import math 16 | 17 | 18 | class DatasetLoad(Data.Dataset): 19 | def __init__(self, data_root, base_size, crop_size, mode, train_dir, test_dir, data_aug=True, suffix='png', 20 | rgb=True, **kwargs): 21 | self.base_size = base_size 22 | self.crop_size = crop_size 23 | self.mode = mode 24 | self.data_aug = data_aug 25 | self.rgb = rgb 26 | assert mode in ['train', 'test'], 'The mode should be train or test' 27 | if mode == 'train': 28 | self.data_dir = osp.join(data_root, train_dir) 29 | else: 30 | self.data_dir = osp.join(data_root, test_dir) 31 | 32 | self.img_names = [] 33 | for img in os.listdir(osp.join(self.data_dir, 'images')): 34 | if img.endswith(suffix): 35 | self.img_names.append(img) 36 | 37 | self.rgb_transform = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 40 | ]) 41 | 42 | self.gray_transform = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize([-0.1246], [1.0923]) 45 | ]) 46 | 47 | def _sync_transform(self, img, mask): 48 | if self.mode == 'train' and self.data_aug: 49 | if random.random() < 0.5: 50 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 51 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 52 | crop_size = self.crop_size 53 | long_size = random.randint( 54 | int(self.base_size * 0.5), int(self.base_size * 2.0)) 55 | # int(self.base_size * 0.8), int(self.base_size * 1.2)) 56 | w, h = img.size 57 | if h > w: 58 | oh = long_size 59 | ow = int(1.0 * w * long_size / h + 0.5) 60 | short_size = ow 61 | else: 62 | ow = long_size 63 | oh = int(1.0 * h * long_size / w + 0.5) 64 | short_size = oh 65 | img = img.resize((ow, oh), Image.BILINEAR) 66 | mask = mask.resize((ow, oh), Image.NEAREST) 67 | if short_size < crop_size: 68 | padh = crop_size - oh if oh < crop_size else 0 69 | padw = crop_size - ow if ow < crop_size else 0 70 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 71 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 72 | w, h = img.size 73 | x1 = random.randint(0, w - crop_size) 74 | y1 = random.randint(0, h - crop_size) 75 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 76 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 77 | if random.random() < 0.5: 78 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) 79 | img, mask = np.array(img), np.array(mask) 80 | img = self.rgb_transform(img) if self.rgb else self.gray_transform(img) 81 | mask = transforms.ToTensor()(mask) 82 | else: 83 | img = img.resize((self.base_size, self.base_size), Image.BILINEAR) 84 | mask = mask.resize((self.base_size, self.base_size), Image.NEAREST) 85 | img, mask = np.array(img), np.array(mask) 86 | img = self.rgb_transform(img) if self.rgb else self.gray_transform(img) 87 | mask = transforms.ToTensor()(mask) 88 | return img, mask 89 | 90 | def __getitem__(self, item): 91 | img_name = self.img_names[item] 92 | img_path = osp.join(self.data_dir, 'images', img_name) 93 | label_path = osp.join(self.data_dir, 'masks', img_name) 94 | img = Image.open(img_path).convert('RGB') if self.rgb else Image.open(img_path).convert('L') 95 | mask = Image.open(label_path).convert('L') 96 | img, mask = self._sync_transform(img, mask) 97 | return img, mask 98 | 99 | def __len__(self): 100 | return len(self.img_names) 101 | -------------------------------------------------------------------------------- /utils/drawing.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/14 19:57 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : drawing.py 5 | # @Software: PyCharm 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def drawing_loss(num_epoch, train_loss, test_loss, save_dir, curve_file): 10 | plt.figure() 11 | plt.plot(num_epoch, train_loss, label='train_loss') 12 | plt.plot(num_epoch, test_loss, label='test_loss') 13 | plt.legend() 14 | plt.ylabel('Loss') 15 | plt.xlabel('Epoch') 16 | plt.savefig("work_dirs/" + save_dir + '/' + curve_file + "/fig_loss.png") 17 | 18 | 19 | def drawing_iou(num_epoch, mIoU, nIoU, save_dir, curve_file): 20 | plt.figure() 21 | plt.plot(num_epoch, mIoU, label='mIoU') 22 | plt.plot(num_epoch, nIoU, label='nIoU') 23 | plt.legend() 24 | plt.ylabel('IoU') 25 | plt.xlabel('Epoch') 26 | plt.savefig("work_dirs/" + save_dir + '/' + curve_file + "/fig_IoU.png") 27 | 28 | 29 | def drawing_f1(num_epoch, f1, save_dir, curve_file): 30 | plt.figure() 31 | plt.plot(num_epoch, f1, label='F1-score') 32 | plt.legend() 33 | plt.ylabel('F1-score') 34 | plt.xlabel('Epoch') 35 | plt.savefig("work_dirs/" + save_dir + '/' + curve_file + "/fig_F1-score.png") 36 | -------------------------------------------------------------------------------- /utils/logs.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 19:04 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : logs.py 5 | # @Software: PyCharm 6 | from datetime import datetime 7 | 8 | import os 9 | 10 | 11 | def save_config_log(cfg, save_dir, file_name): 12 | with open('work_dirs/%s/%s/train_log.txt' % (save_dir, file_name), 'a') as f: 13 | print(cfg.pretty_text) 14 | f.write(cfg.pretty_text) 15 | f.write('\n') 16 | return 17 | 18 | def save_test_config(cfg, save_dir): 19 | with open('%s/test_log.txt' % save_dir, 'a') as f: 20 | print(cfg.pretty_text) 21 | f.write('config_file = ' + cfg.filename) 22 | f.write('\n') 23 | f.write(cfg.pretty_text) 24 | f.write('\n') 25 | return 26 | 27 | 28 | def save_train_args_log(args, save_dir): 29 | dict_args = vars(args) 30 | args_key = list(dict_args.keys()) 31 | args_value = list(dict_args.values()) 32 | with open('work_dirs/%s/train_log.txt' % save_dir, 'a') as f: 33 | now = datetime.now() 34 | f.write("time:--") 35 | dt_string = now.strftime("%Y/%m/%d %H:%M:%S ") 36 | f.write(dt_string) 37 | f.write('\n') 38 | for i in range(len(args_key)): 39 | f.write(args_key[i]) 40 | f.write(':--') 41 | f.write(str(args_value[i])) 42 | f.write('\n') 43 | f.write('\n') 44 | return 45 | 46 | 47 | def save_model_struct(save_dir, file_name, model): 48 | with open('work_dirs/%s/%s/model.txt' % (save_dir, file_name), 'a') as f: 49 | f.write(str(model)) 50 | return 51 | 52 | def save_train_log(save_dir, file_name, epoch, epochs, iter, iters, loss, lr, time): 53 | with open('work_dirs/%s/%s/train_log.txt' % (save_dir, file_name), 'a') as f: 54 | now = datetime.now() 55 | dt_string = now.strftime("%Y/%m/%d %H:%M:%S ") 56 | f.write(dt_string) 57 | f.write('Epoch: [%d/%d] Iter[%d/%d] Loss: %.4f Lr: %.5f Time: %.5f' 58 | % (epoch, epochs, iter, iters, loss, lr, time)) 59 | f.write('\n') 60 | return 61 | 62 | 63 | def save_test_log(save_dir, file_name, epoch, epochs, loss, mIoU, nIoU, f1, best_miou, best_niou, best_f1): 64 | with open('work_dirs/%s/%s/train_log.txt' % (save_dir, file_name), 'a') as f: 65 | now = datetime.now() 66 | dt_string = now.strftime("%Y/%m/%d %H:%M:%S ") 67 | f.write(dt_string) 68 | f.write('Epoch: [%d/%d] Loss: %.4f mIoU: %.4f nIoU: %.4f F1-score: %.4f ' 69 | 'Best_mIoU: %.4f Best_nIoU: %.4f Best_F1-score: %.4f' % ( 70 | epoch, epochs, loss, mIoU, nIoU, f1, best_miou, best_niou, best_f1)) 71 | f.write('\n') 72 | return 73 | 74 | 75 | def save_result_for_test(save_dir, mIoU, nIoU, recall, precision, FA, PD, f1, tp, fp): 76 | with open('%s/test_log.txt' % save_dir, 'a') as f: 77 | now = datetime.now() 78 | dt_string = now.strftime("%Y/%m/%d %H:%M:%S") 79 | f.write(dt_string) 80 | f.write('\n') 81 | f.write('mIoU: %.4f nIoU: %.4f F1-score: %.4f' % (mIoU, nIoU, f1)) 82 | f.write('\n') 83 | f.write('Recall-----:') 84 | for i in range(len(recall)): 85 | f.write(' ') 86 | f.write(str(round(recall[i], 8))) 87 | f.write(' ') 88 | f.write('\n') 89 | f.write('Precision--:') 90 | for i in range(len(precision)): 91 | f.write(' ') 92 | f.write(str(round(precision[i], 8))) 93 | f.write(' ') 94 | f.write('\n') 95 | f.write('TP---------:') 96 | for i in range(len(tp)): 97 | f.write(' ') 98 | f.write(str(round(tp[i], 8))) 99 | f.write(' ') 100 | f.write('\n') 101 | f.write('FP---------:') 102 | for i in range(len(fp)): 103 | f.write(' ') 104 | f.write(str(round(fp[i], 8))) 105 | f.write(' ') 106 | f.write('\n') 107 | f.write('PD---------:') 108 | for i in range(len(PD)): 109 | f.write(' ') 110 | f.write(str(round(PD[i], 8))) 111 | f.write(' ') 112 | f.write('\n') 113 | f.write('FA---------:') 114 | for i in range(len(FA)): 115 | f.write(' ') 116 | f.write(str(round(FA[i], 8))) 117 | f.write(' ') 118 | f.write('\n') 119 | f.write( 120 | '---------------------------------------------------------------------------------------------------------' 121 | '---------------------------------------------------------------------------------------------------\n') 122 | f.write( 123 | '---------------------------------------------------------------------------------------------------------' 124 | '---------------------------------------------------------------------------------------------------\n') 125 | f.write( 126 | '---------------------------------------------------------------------------------------------------------' 127 | '---------------------------------------------------------------------------------------------------\n') 128 | return 129 | 130 | 131 | def make_dir(config): 132 | save_dir = config 133 | os.makedirs('work_dirs/%s' % save_dir, exist_ok=True) 134 | return save_dir 135 | 136 | 137 | def make_log_dir(config, log_file): 138 | os.makedirs('work_dirs/%s' % config, exist_ok=True) 139 | os.makedirs('work_dirs/%s/%s' % (config, log_file), exist_ok=True) 140 | 141 | 142 | def train_log_file(): 143 | now = datetime.now() 144 | dt_string = now.strftime("%Y%m%d_%H%M%S") 145 | file_name = dt_string 146 | return file_name 147 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 14:58 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : loss.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class SoftIoULoss(nn.Module): 11 | def __init__(self, **kwargs): 12 | super(SoftIoULoss, self).__init__() 13 | 14 | def forward(self, pred, target): 15 | # Old One 16 | pred = torch.sigmoid(pred) 17 | smooth = 1 18 | 19 | # print("pred.shape: ", pred.shape) 20 | # print("target.shape: ", target.shape) 21 | 22 | intersection = pred * target 23 | loss = (intersection.sum() + smooth) / (pred.sum() + target.sum() - intersection.sum() + smooth) 24 | 25 | # loss = (intersection.sum(axis=(1, 2, 3)) + smooth) / \ 26 | # (pred.sum(axis=(1, 2, 3)) + target.sum(axis=(1, 2, 3)) 27 | # - intersection.sum(axis=(1, 2, 3)) + smooth) 28 | 29 | loss = 1 - loss.mean() 30 | # loss = (1 - loss).mean() 31 | 32 | return loss 33 | 34 | 35 | class CrossEntropy(nn.Module): 36 | def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', 37 | label_smoothing=0.0, **kwargs): 38 | super(CrossEntropy, self).__init__() 39 | self.crit = nn.CrossEntropyLoss(weight, size_average, ignore_index, reduce, reduction, label_smoothing) 40 | 41 | def forward(self, pred, target): 42 | target.squeeze(dim=1) 43 | loss = self.crit(pred, target) 44 | return loss 45 | 46 | 47 | class BCEWithLogits(nn.Module): 48 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None, **kwargs): 49 | super(BCEWithLogits, self).__init__() 50 | self.crit = nn.BCEWithLogitsLoss(weight, size_average, reduce, reduction, pos_weight) 51 | 52 | def forward(self, pred, target): 53 | loss = self.crit(pred, target) 54 | return loss 55 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 14:54 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : metric.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from skimage import measure 11 | 12 | 13 | class SigmoidMetric(): 14 | def __init__(self, score_thresh=0): 15 | self.score_thresh = score_thresh 16 | self.reset() 17 | 18 | def update(self, pred, labels): 19 | correct, labeled = self.batch_pix_accuracy(pred, labels) 20 | inter, union = self.batch_intersection_union(pred, labels) 21 | 22 | self.total_correct += correct 23 | self.total_label += labeled 24 | self.total_inter += inter 25 | self.total_union += union 26 | 27 | def get(self): 28 | """Gets the current evaluation result.""" 29 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 30 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 31 | mIoU = IoU.mean() 32 | return pixAcc, mIoU 33 | 34 | def reset(self): 35 | """Resets the internal evaluation result to initial state.""" 36 | self.total_inter = 0 37 | self.total_union = 0 38 | self.total_correct = 0 39 | self.total_label = 0 40 | 41 | def batch_pix_accuracy(self, output, target): 42 | assert output.shape == target.shape 43 | output = output.cpu().detach().numpy() 44 | target = target.cpu().detach().numpy() 45 | 46 | predict = (output > self.score_thresh).astype('int64') # P 47 | pixel_labeled = np.sum(target > 0) # T 48 | pixel_correct = np.sum((predict == target) * (target > 0)) # TP 49 | assert pixel_correct <= pixel_labeled 50 | return pixel_correct, pixel_labeled 51 | 52 | def batch_intersection_union(self, output, target): 53 | mini = 1 54 | maxi = 1 # nclass 55 | nbins = 1 # nclass 56 | predict = (output.cpu().detach().numpy() > self.score_thresh).astype('int64') # P 57 | target = target.cpu().numpy().astype('int64') # T 58 | intersection = predict * (predict == target) # TP 59 | 60 | 61 | # areas of intersection and union 62 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) 63 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 64 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 65 | area_union = area_pred + area_lab - area_inter 66 | assert (area_inter <= area_union).all() 67 | return area_inter, area_union 68 | 69 | 70 | class SamplewiseSigmoidMetric(): 71 | def __init__(self, nclass, score_thresh=0.5): 72 | self.nclass = nclass 73 | self.score_thresh = score_thresh 74 | self.reset() 75 | 76 | def update(self, preds, labels): 77 | """Updates the internal evaluation result.""" 78 | inter_arr, union_arr = self.batch_intersection_union(preds, labels, 79 | self.nclass, self.score_thresh) 80 | self.total_inter = np.append(self.total_inter, inter_arr) 81 | self.total_union = np.append(self.total_union, union_arr) 82 | 83 | def get(self): 84 | """Gets the current evaluation result.""" 85 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 86 | mIoU = IoU.mean() 87 | return IoU, mIoU 88 | 89 | def reset(self): 90 | """Resets the internal evaluation result to initial state.""" 91 | self.total_inter = np.array([]) 92 | self.total_union = np.array([]) 93 | self.total_correct = np.array([]) 94 | self.total_label = np.array([]) 95 | 96 | def batch_intersection_union(self, output, target, nclass, score_thresh): 97 | """mIoU""" 98 | # inputs are tensor 99 | # the category 0 is ignored class, typically for background / boundary 100 | mini = 1 101 | maxi = 1 # nclass 102 | nbins = 1 # nclass 103 | 104 | predict = (F.sigmoid(output).cpu().detach().numpy() > score_thresh).astype('int64') # P 105 | target = target.cpu().detach().numpy().astype('int64') # T 106 | intersection = predict * (predict == target) # TP 107 | 108 | num_sample = intersection.shape[0] 109 | area_inter_arr = np.zeros(num_sample) 110 | area_pred_arr = np.zeros(num_sample) 111 | area_lab_arr = np.zeros(num_sample) 112 | area_union_arr = np.zeros(num_sample) 113 | 114 | for b in range(num_sample): 115 | # areas of intersection and union 116 | area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi)) 117 | area_inter_arr[b] = area_inter 118 | 119 | area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi)) 120 | area_pred_arr[b] = area_pred 121 | 122 | area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi)) 123 | area_lab_arr[b] = area_lab 124 | 125 | area_union = area_pred + area_lab - area_inter 126 | area_union_arr[b] = area_union 127 | 128 | assert (area_inter <= area_union).all() 129 | 130 | return area_inter_arr, area_union_arr 131 | 132 | 133 | class ROCMetric(): 134 | """Computes pixAcc and mIoU metric scores 135 | """ 136 | 137 | def __init__(self, nclass, bins): # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值 138 | super(ROCMetric, self).__init__() 139 | self.nclass = nclass 140 | self.bins = bins 141 | self.tp_arr = np.zeros(self.bins + 1) 142 | self.pos_arr = np.zeros(self.bins + 1) 143 | self.fp_arr = np.zeros(self.bins + 1) 144 | self.neg_arr = np.zeros(self.bins + 1) 145 | self.class_pos = np.zeros(self.bins + 1) 146 | # self.reset() 147 | 148 | def update(self, preds, labels): 149 | for iBin in range(self.bins + 1): 150 | score_thresh = (iBin + 0.0) / self.bins 151 | # print(iBin, "-th, score_thresh: ", score_thresh) 152 | i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh) 153 | self.tp_arr[iBin] += i_tp 154 | self.pos_arr[iBin] += i_pos 155 | self.fp_arr[iBin] += i_fp 156 | self.neg_arr[iBin] += i_neg 157 | self.class_pos[iBin] += i_class_pos 158 | 159 | def get(self): 160 | tp_rates = self.tp_arr / (self.pos_arr + 0.001) 161 | fp_rates = self.fp_arr / (self.neg_arr + 0.001) 162 | 163 | recall = self.tp_arr / (self.pos_arr + 0.001) 164 | precision = self.tp_arr / (self.class_pos + 0.001) 165 | f1_score = (2.0 * recall[5] * precision[5]) / (recall[5] + precision[5] + 0.00001) 166 | 167 | return tp_rates, fp_rates, recall, precision, f1_score 168 | 169 | def reset(self): 170 | self.tp_arr = np.zeros([11]) 171 | self.pos_arr = np.zeros([11]) 172 | self.fp_arr = np.zeros([11]) 173 | self.neg_arr = np.zeros([11]) 174 | self.class_pos = np.zeros([11]) 175 | 176 | 177 | class PD_FA(): 178 | def __init__(self, nclass, bins, cfg): 179 | super(PD_FA, self).__init__() 180 | self.nclass = nclass 181 | self.bins = bins 182 | self.image_area_total = [] 183 | self.image_area_match = [] 184 | self.FA = np.zeros(self.bins + 1) 185 | self.PD = np.zeros(self.bins + 1) 186 | self.target = np.zeros(self.bins + 1) 187 | self.cfg = cfg 188 | 189 | def update(self, preds, labels): 190 | 191 | for iBin in range(self.bins + 1): 192 | score_thresh = iBin * (255 / self.bins) 193 | batch = preds.size()[0] 194 | for b in range(batch): 195 | predits = np.array((preds[b, :, :, :] > score_thresh).cpu()).astype('int64') 196 | predits = np.reshape(predits, (self.cfg.data['crop_size'], self.cfg.data['crop_size'])) 197 | labelss = np.array((labels[b, :, :, :]).cpu()).astype('int64') # P 198 | labelss = np.reshape(labelss, (self.cfg.data['crop_size'], self.cfg.data['crop_size'])) 199 | 200 | image = measure.label(predits, connectivity=2) 201 | coord_image = measure.regionprops(image) 202 | label = measure.label(labelss, connectivity=2) 203 | coord_label = measure.regionprops(label) 204 | 205 | self.target[iBin] += len(coord_label) 206 | self.image_area_total = [] 207 | self.image_area_match = [] 208 | self.distance_match = [] 209 | self.dismatch = [] 210 | 211 | for K in range(len(coord_image)): 212 | area_image = np.array(coord_image[K].area) 213 | self.image_area_total.append(area_image) 214 | 215 | for i in range(len(coord_label)): 216 | centroid_label = np.array(list(coord_label[i].centroid)) 217 | for m in range(len(coord_image)): 218 | centroid_image = np.array(list(coord_image[m].centroid)) 219 | distance = np.linalg.norm(centroid_image - centroid_label) 220 | area_image = np.array(coord_image[m].area) 221 | if distance < 3: 222 | self.distance_match.append(distance) 223 | self.image_area_match.append(area_image) 224 | 225 | del coord_image[m] 226 | break 227 | 228 | self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] 229 | self.FA[iBin] += np.sum(self.dismatch) 230 | self.PD[iBin] += len(self.distance_match) 231 | 232 | def get(self, img_num): 233 | 234 | Final_FA = self.FA / ((self.cfg.data['crop_size'] * self.cfg.data['crop_size']) * img_num) 235 | Final_PD = self.PD / self.target 236 | 237 | return Final_FA, Final_PD 238 | 239 | def reset(self): 240 | self.FA = np.zeros([self.bins + 1]) 241 | self.PD = np.zeros([self.bins + 1]) 242 | 243 | 244 | def cal_tp_pos_fp_neg(output, target, nclass, score_thresh): 245 | predict = (torch.sigmoid(output) > score_thresh).float() 246 | if len(target.shape) == 3: 247 | target = np.expand_dims(target.float(), axis=1) 248 | elif len(target.shape) == 4: 249 | target = target.float() 250 | else: 251 | raise ValueError("Unknown target dimension") 252 | intersection = predict * ((predict == target).float()) 253 | tp = intersection.sum() 254 | fp = (predict * ((predict != target).float())).sum() 255 | tn = ((1 - predict) * ((predict == target).float())).sum() 256 | fn = (((predict != target).float()) * (1 - predict)).sum() 257 | pos = tp + fn 258 | neg = fp + tn 259 | class_pos = tp + fp 260 | return tp, pos, fp, neg, class_pos 261 | 262 | 263 | if __name__ == '__main__': 264 | pred = torch.rand(8, 1, 512, 512) 265 | target = torch.rand(8, 1, 512, 512) 266 | m1 = SigmoidMetric() 267 | m2 = SamplewiseSigmoidMetric(nclass=1, score_thresh=0.5) 268 | m1.update(pred, target) 269 | m2.update(pred, target) 270 | pixAcc, mIoU = m1.get() 271 | _, nIoU = m2.get() 272 | -------------------------------------------------------------------------------- /utils/save_model.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 20:22 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : save_model.py 5 | # @Software: PyCharm 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import os 10 | import torch.nn as nn 11 | import torch 12 | from skimage import measure 13 | import numpy 14 | 15 | 16 | def make_dir(dataset, model): 17 | now = datetime.now() 18 | dt_string = now.strftime("%Y_%m_%d_%H_%M_%S") 19 | save_dir = "%s_%s_%s" % (dataset, model, dt_string) 20 | os.makedirs('work_dirs/%s' % save_dir, exist_ok=True) 21 | return save_dir 22 | 23 | 24 | def save_ckpt(state, save_path, filename): 25 | torch.save(state, os.path.join(save_path, filename)) 26 | 27 | 28 | def save_model_and_result(dt_string, epoch, train_loss, test_loss, best_iou, recall, precision, save_mIoU_dir, 29 | save_other_metric_dir): 30 | with open(save_mIoU_dir, 'a') as f: 31 | f.write('{} - {:04d}:\t - train_loss: {:04f}:\t - test_loss: {:04f}:\t mIoU {:.4f}\n'.format(dt_string, epoch, 32 | train_loss, 33 | test_loss, 34 | best_iou)) 35 | with open(save_other_metric_dir, 'a') as f: 36 | f.write(dt_string) 37 | f.write('-') 38 | f.write(str(epoch)) 39 | f.write('\n') 40 | f.write('Recall-----:') 41 | for i in range(len(recall)): 42 | f.write(' ') 43 | f.write(str(round(recall[i], 8))) 44 | f.write(' ') 45 | f.write('\n') 46 | f.write('Precision--:') 47 | for i in range(len(precision)): 48 | f.write(' ') 49 | f.write(str(round(precision[i], 8))) 50 | f.write(' ') 51 | f.write('\n') 52 | 53 | 54 | # def save_model(mean_IOU, best_iou, save_dir, save_prefix, train_loss, test_loss, recall, precision, epoch, net): 55 | # if mean_IOU > best_iou: 56 | # save_mIoU_dir = 'work_dirs/' + save_dir + '/' + save_prefix + '_best_IoU_IoU.log' 57 | # save_other_metric_dir = 'work_dirs/' + save_dir + \ 58 | # '/' + save_prefix + '_best_IoU_other_metric.log' 59 | # now = datetime.now() 60 | # dt_string = now.strftime("%Y/%m/%d %H:%M:%S") 61 | # best_iou = mean_IOU 62 | # save_model_and_result(dt_string, epoch, train_loss, test_loss, best_iou, recall, precision, save_mIoU_dir, 63 | # save_other_metric_dir) 64 | # save_ckpt({'epoch': epoch, 65 | # 'state_dict': net, 66 | # 'loss': test_loss, 67 | # 'mean_IOU': mean_IOU, }, save_path='work_dirs/' + save_dir, 68 | # filename='mIoU_' + save_prefix + '_epoch' + '.pth.tar') 69 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/5/31 17:19 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : scheduler.py 5 | # @Software: PyCharm 6 | import math 7 | 8 | 9 | def linear(optimizer, epoch, base_lr, warmup_epoch=5): 10 | if epoch == 0: 11 | lr = base_lr / warmup_epoch 12 | else: 13 | lr = epoch * (base_lr / warmup_epoch) 14 | for param_group in optimizer.param_groups: 15 | param_group['lr'] = lr 16 | 17 | 18 | class PolyLR(object): 19 | def __init__(self, optimizer, num_epochs, base_lr, warmup, power=0.9, warmup_epochs=5, **kwargs): 20 | super(PolyLR, self).__init__() 21 | self.optimizer = optimizer 22 | self.num_epochs = num_epochs 23 | self.base_lr = base_lr 24 | self.warmup = warmup 25 | self.warmup_epoch = warmup_epochs if self.warmup else 0 26 | self.power = power 27 | 28 | def step(self, epoch): 29 | if self.warmup and epoch <= self.warmup_epoch: 30 | globals()[self.warmup](self.optimizer, epoch, self.base_lr, self.warmup_epoch) 31 | else: 32 | lr = self.base_lr * (1 - (epoch - self.warmup_epoch) / self.num_epochs) ** self.power 33 | for param_group in self.optimizer.param_groups: 34 | param_group['lr'] = lr 35 | 36 | 37 | class CosineAnnealingLR(object): 38 | def __init__(self, optimizer, num_epochs, base_lr, warmup, min_lr=1e-4, warmup_epochs=5, **kwargs): 39 | super(CosineAnnealingLR, self).__init__() 40 | self.optimizer = optimizer 41 | self.num_epochs = num_epochs 42 | self.base_lr = base_lr 43 | self.warmup = warmup 44 | self.warmup_epoch = warmup_epochs if self.warmup else 0 45 | self.min_lr = min_lr 46 | 47 | def step(self, epoch): 48 | if self.warmup and epoch <= self.warmup_epoch: 49 | globals()[self.warmup](self.optimizer, epoch, self.base_lr, self.warmup_epoch) 50 | else: 51 | lr = self.min_lr + ((self.base_lr - self.min_lr) / 2) * ( 52 | 1 + math.cos((epoch - self.warmup_epoch) / self.num_epochs * math.pi)) 53 | for param_group in self.optimizer.param_groups: 54 | param_group['lr'] = lr 55 | 56 | 57 | class StepLR(object): 58 | def __init__(self, optimizer, step, base_lr, warmup, gamma=0.1, warmup_epochs=5, **kwargs): 59 | super(StepLR, self).__init__() 60 | self.optimizer = optimizer 61 | self.step = step 62 | self.gamma = gamma 63 | self.base_lr = base_lr 64 | self.warmup = warmup 65 | self.warmup_epoch = warmup_epochs if self.warmup else 0 66 | 67 | def step(self, epoch): 68 | if self.warmup and epoch <= self.warmup_epoch: 69 | globals()[self.warmup](self.optimizer, epoch, self.base_lr, self.warmup_epoch) 70 | else: 71 | if epoch in self.step: 72 | for param_group in self.optimizer.param_groups: 73 | param_group['lr'] *= self.gamma 74 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/14 22:11 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : tools.py 5 | # @Software: PyCharm 6 | 7 | import random 8 | import torch.distributed 9 | import torch.nn 10 | from utils.metric import * 11 | from torch.utils.tensorboard import SummaryWriter 12 | from utils.logs import * 13 | import shutil 14 | from utils.save_model import * 15 | from utils.drawing import * 16 | import logging 17 | 18 | 19 | def random_seed(n): 20 | random.seed(n) 21 | np.random.seed(n) 22 | torch.manual_seed(n) 23 | torch.cuda.manual_seed_all(n) 24 | 25 | 26 | def empty_function(): 27 | pass 28 | 29 | 30 | def model_wrapper(model_dict): 31 | new_dict = {} 32 | for k, v in model_dict.items(): 33 | new_dict['decode_head.' + k] = v 34 | return new_dict 35 | 36 | 37 | def init_metrics(args, optimizer, checkpoint=None): 38 | best_mIoU, best_nIoU, best_f1 = 0.0, 0.0, 0.0 39 | train_loss, test_loss, mIoU, nIoU, f1, num_epoch = [], [], [], [], [], [] 40 | if args.resume_from: 41 | best_mIoU = checkpoint['best_mIoU'] 42 | best_nIoU = checkpoint['best_nIoU'] 43 | best_f1 = checkpoint['best_f1'] 44 | train_loss = checkpoint['train_loss'] 45 | test_loss = checkpoint['test_loss'] 46 | mIoU = checkpoint['mIoU'] 47 | nIoU = checkpoint['nIoU'] 48 | f1 = checkpoint['f1'] 49 | num_epoch = checkpoint['num_epoch'] 50 | optimizer.load_state_dict(checkpoint['optimizer']) 51 | iou_metric = SigmoidMetric() 52 | nIoU_metric = SamplewiseSigmoidMetric(1, score_thresh=0.5) 53 | ROC = ROCMetric(1, 10) 54 | 55 | return optimizer, {'best_mIoU': best_mIoU, 'best_nIoU': best_nIoU, 'best_f1': best_f1, 'train_loss': train_loss, 56 | 'test_loss': test_loss, 'mIoU': mIoU, 'nIoU': nIoU, 'f1': f1, 'num_epoch': num_epoch, 57 | 'iou_metric': iou_metric, 'nIoU_metric': nIoU_metric, 'ROC': ROC} 58 | 59 | 60 | def init_data(args, data): 61 | train_sample = None 62 | if args.local_rank != -1: 63 | train_sample, train_data, test_data, train_data_len, test_data_len = data 64 | else: 65 | train_data, test_data, train_data_len, test_data_len = data 66 | return {'train_sample': train_sample, 'train_data': train_data, 'test_data': test_data, 67 | 'train_data_len': train_data_len, 'test_data_len': test_data_len} 68 | 69 | 70 | def init_model(args, cfg, model, device): 71 | checkpoint = None 72 | if args.load_from: 73 | cfg.load_from = args.load_from 74 | checkpoint = torch.load(args.load_from) 75 | model.load_state_dict(checkpoint) 76 | 77 | # FIXME Loss Accuracy Decreases When Use resume_from 78 | if args.resume_from: 79 | cfg.resume_from = args.resume_from 80 | checkpoint = torch.load(args.resume_from) 81 | model.load_state_dict(checkpoint['state_dict']) 82 | print("Model Initializing") 83 | 84 | if args.local_rank != -1: 85 | model.to(device) 86 | model = torch.nn.parallel.DistributedDataParallel( 87 | model, device_ids=[args.local_rank], output_device=args.local_rank, 88 | find_unused_parameters=cfg.find_unused_parameters) 89 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 90 | else: 91 | model = model.to(device) 92 | 93 | return model, checkpoint 94 | 95 | 96 | def init_devices(args, cfg): 97 | if args.local_rank != -1: 98 | device = torch.device('cuda', args.local_rank) 99 | torch.cuda.set_device(args.local_rank) 100 | torch.distributed.init_process_group(backend=cfg.dist_params['backend']) 101 | else: 102 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 103 | random_seed(cfg.random_seed) 104 | return device 105 | 106 | 107 | def save_log(args, cfg, model): 108 | save_dir, train_log_file_name, write = None, None, None 109 | if args.local_rank <= 0: 110 | save_dir = args.config.split('/')[-1][:-3] 111 | train_log_file_name = train_log_file() 112 | make_log_dir(save_dir, train_log_file_name) 113 | save_config_log(cfg, save_dir, train_log_file_name) 114 | save_model_struct(save_dir, train_log_file_name, model) 115 | if 'develop' in cfg: 116 | shutil.copy(cfg.develop['source_file_root'], 117 | os.path.join('work_dirs', save_dir, train_log_file_name, 'model.py')) 118 | write = SummaryWriter(log_dir='work_dirs/' + save_dir + '/' + train_log_file_name + '/tf_logs') 119 | return save_dir, train_log_file_name, write 120 | 121 | 122 | def data2device(args, data, device): 123 | img, mask = data 124 | if args.local_rank != -1: 125 | img = img.cuda() 126 | mask = mask.cuda() 127 | else: 128 | img = img.to(device) 129 | mask = mask.to(device) 130 | return img, mask 131 | 132 | 133 | def compute_loss(preds, mask, deep_supervision, cfg, criterion): 134 | # TODO when use deep supervision, should log pred loss, not all loss sum 135 | if deep_supervision and cfg.model['decode_head']['deep_supervision']: 136 | loss = [] 137 | for pre in preds: 138 | loss.append(criterion(pre, mask)) 139 | loss = sum(loss) 140 | preds = preds[-1] 141 | else: 142 | loss = criterion(preds, mask) 143 | return loss, preds 144 | 145 | 146 | def show_log(mode, args, cfg, epoch, losses, save_dir, train_log_file, **kwargs): 147 | if mode not in ['train', 'test']: 148 | raise ValueError('The parameter "mode" input should be "train" or "test"') 149 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s', datefmt='%F %T') 150 | if args.local_rank <= 0: 151 | if mode == 'train': 152 | msg = 'Epoch %d/%d, Iter %d/%d, train loss %.4f, lr %.5f, time %.5f' % ( 153 | epoch, cfg.runner['max_epochs'], kwargs['i'] + 1, 154 | kwargs['data']['train_data_len'] / cfg.data['train_batch'] / cfg.gpus, 155 | np.mean(losses), kwargs['optimizer'].state_dict()['param_groups'][0]['lr'], kwargs['time_elapsed']) 156 | logging.info(msg) 157 | if (kwargs['i'] + 1) % cfg.log_config['interval'] == 0: 158 | save_train_log(save_dir, train_log_file, epoch, cfg.runner['max_epochs'], kwargs['i'] + 1, 159 | kwargs['data']['train_data_len'] / cfg.data['train_batch'] / cfg.gpus, 160 | np.mean(losses), kwargs['optimizer'].state_dict()['param_groups'][0]['lr'], 161 | kwargs['time_elapsed']) 162 | else: 163 | msg = 'Epoch %d/%d, test loss %.4f, mIoU %.4f, nIoU %.4f, F1-score %.4f, best_mIoU %.4f, ' \ 164 | 'best_nIoU %.4f, best_F1-score %.4f' % ( 165 | epoch, cfg.runner['max_epochs'], np.mean(losses), kwargs['IoU'], kwargs['nIoU'], 166 | kwargs['F1_score'], kwargs['metrics']['best_mIoU'], kwargs['metrics']['best_nIoU'], 167 | kwargs['metrics']['best_f1']) 168 | logging.info(msg) 169 | save_test_log(save_dir, train_log_file, epoch, cfg.runner['max_epochs'], 170 | np.mean(losses), kwargs['IoU'], kwargs['nIoU'], kwargs['F1_score'], 171 | kwargs['metrics']['best_mIoU'], kwargs['metrics']['best_nIoU'], kwargs['metrics']['best_f1']) 172 | 173 | 174 | def save_model(mode, args, cfg, epoch, model, losses, optimizer, metrics, save_dir, train_log_file, **kwargs): 175 | if mode not in ['train', 'test']: 176 | raise ValueError('The parameter "mode" input should be "train" or "test"') 177 | if args.local_rank <= 0: 178 | ckpt_info = { 179 | 'epoch': epoch, 180 | 'state_dict': model.module.state_dict() if args.local_rank != -1 else model.state_dict(), 181 | 'loss': np.mean(losses), 182 | 'optimizer': optimizer.state_dict(), 183 | 'train_loss': metrics['train_loss'], 184 | 'test_loss': metrics['test_loss'], 185 | 'num_epoch': metrics['num_epoch'], 186 | 'best_mIoU': metrics['best_mIoU'], 187 | 'best_nIoU': metrics['best_nIoU'], 188 | 'best_f1': metrics['best_f1'], 189 | 'mIoU': metrics['mIoU'], 190 | 'nIoU': metrics['nIoU'], 191 | 'f1': metrics['f1'] 192 | } 193 | if mode == 'train': 194 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, filename='last.pth.tar') 195 | if cfg.checkpoint_config['by_epoch'] and epoch % cfg.checkpoint_config['interval'] == 0: 196 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, 197 | filename='epoch_%d' % epoch + '.pth.tar') 198 | else: 199 | if kwargs['IoU'] > metrics['best_mIoU'] or kwargs['nIoU'] > metrics['best_nIoU']: 200 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, filename='best.pth.tar') 201 | if kwargs['IoU'] > metrics['best_mIoU']: 202 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, 203 | filename='best_mIoU.pth.tar') 204 | if kwargs['nIoU'] > metrics['best_nIoU']: 205 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, 206 | filename='best_nIoU.pth.tar') 207 | 208 | 209 | def update_log(mode, args, metrics, write, losses, epoch, **kwargs): 210 | if mode not in ['train', 'test']: 211 | raise ValueError('The parameter "mode" input should be "train" or "test"') 212 | if args.local_rank <= 0: 213 | if mode == 'train': 214 | metrics['train_loss'].append(np.mean(losses)) 215 | metrics['num_epoch'].append(epoch) 216 | write.add_scalar('train/train_loss', np.mean(losses), epoch) 217 | write.add_scalar('train/lr', kwargs['optimizer'].state_dict()['param_groups'][0]['lr'], epoch) 218 | else: 219 | metrics['best_mIoU'] = max(kwargs['IoU'], metrics['best_mIoU']) 220 | metrics['best_nIoU'] = max(kwargs['nIoU'], metrics['best_nIoU']) 221 | metrics['best_f1'] = max(kwargs['F1_score'], metrics['best_f1']) 222 | write.add_scalar('train/test_loss', np.mean(losses), epoch) 223 | write.add_scalar('test/mIoU', kwargs['IoU'], epoch) 224 | write.add_scalar('test/nIoU', kwargs['nIoU'], epoch) 225 | write.add_scalar('test/F1-score', kwargs['F1_score'], epoch) 226 | 227 | 228 | def reset_metrics(metrics): 229 | metrics['iou_metric'].reset() 230 | metrics['nIoU_metric'].reset() 231 | metrics['ROC'].reset() 232 | 233 | 234 | def update_metrics(preds, mask, metrics): 235 | metrics['iou_metric'].update(preds, mask) 236 | metrics['nIoU_metric'].update(preds, mask) 237 | metrics['ROC'].update(preds, mask) 238 | _, IoU = metrics['iou_metric'].get() 239 | _, nIoU = metrics['nIoU_metric'].get() 240 | _, _, _, _, F1_score = metrics['ROC'].get() 241 | return IoU, nIoU, F1_score 242 | 243 | 244 | def append_metrics(args, metrics, losses, IoU, nIoU, F1_score): 245 | if args.local_rank <= 0: 246 | metrics['test_loss'].append(np.mean(losses)) 247 | metrics['mIoU'].append(IoU) 248 | metrics['nIoU'].append(nIoU) 249 | metrics['f1'].append(F1_score) 250 | 251 | 252 | def draw(args, metrics, save_dir, train_log_file): 253 | if args.local_rank <= 0: 254 | drawing_loss(metrics['num_epoch'], metrics['train_loss'], metrics['test_loss'], save_dir, train_log_file) 255 | drawing_iou(metrics['num_epoch'], metrics['mIoU'], metrics['nIoU'], save_dir, train_log_file) 256 | drawing_f1(metrics['num_epoch'], metrics['f1'], save_dir, train_log_file) 257 | -------------------------------------------------------------------------------- /utils/visual.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/7 17:01 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : visual.py 5 | # @Software: PyCharm 6 | import os 7 | import shutil 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from matplotlib import pyplot as plt 13 | 14 | 15 | def make_show_dir(show_dir): 16 | if not os.path.exists(show_dir): 17 | os.mkdir(show_dir) 18 | 19 | if os.path.exists(os.path.join(show_dir, 'result')): 20 | shutil.rmtree(os.path.join(show_dir, 'result')) # 删除目录,包括目录下的所有文件 21 | os.mkdir(os.path.join(show_dir, 'result')) 22 | 23 | if os.path.exists(os.path.join(show_dir, 'fuse')): 24 | shutil.rmtree(os.path.join(show_dir, 'fuse')) # 删除目录,包括目录下的所有文件 25 | os.mkdir(os.path.join(show_dir, 'fuse')) 26 | 27 | 28 | def save_Pred_GT(preds, labels, show_dir, num, cfg): 29 | img_name = os.listdir(os.path.join(cfg.data['data_root'], cfg.data['test_dir'], 'images')) 30 | val_img_ids = [] 31 | for img in img_name: 32 | val_img_ids.append(img.split('.')[0]) 33 | # predsss = ((torch.sigmoid((pred)).cpu().numpy()) * 255).astype('int64') 34 | batch = preds.size()[0] 35 | for b in range(batch): 36 | predsss = np.array((preds[b, :, :, :] > 0).cpu()).astype('int64') * 255 37 | predsss = np.uint8(predsss) 38 | labelsss = labels[b, :, :, :] * 255 39 | labelsss = np.uint8(labelsss.cpu()) 40 | 41 | img = Image.fromarray(predsss.reshape(cfg.data['crop_size'], cfg.data['crop_size'])) 42 | img.save(show_dir + '/result/' + '%s_Pred' % (val_img_ids[num + b]) + '.' + cfg.data['suffix']) 43 | img = Image.fromarray(labelsss.reshape(cfg.data['crop_size'], cfg.data['crop_size'])) 44 | img.save(show_dir + '/result/' + '%s_GT' % (val_img_ids[num + b]) + '.' + cfg.data['suffix']) 45 | 46 | 47 | def save_Pred_GT_visulize(pred, img_demo_dir, img_demo_index, suffix, cfg): 48 | predsss = np.array((pred > 0).cpu()).astype('int64') * 255 49 | predsss = np.uint8(predsss) 50 | 51 | img = Image.fromarray(predsss.reshape(cfg.data['crop_size'], cfg.data['crop_size'])) 52 | img.save(img_demo_dir + '/' + '%s_Pred' % (img_demo_index) + suffix) 53 | 54 | plt.figure(figsize=(10, 6)) 55 | plt.subplot(1, 2, 1) 56 | img = plt.imread(img_demo_dir + '/' + img_demo_index + suffix) 57 | plt.imshow(img, cmap='gray') 58 | plt.xlabel("Raw Imamge", size=11) 59 | 60 | plt.subplot(1, 2, 2) 61 | img = plt.imread(img_demo_dir + '/' + '%s_Pred' % (img_demo_index) + suffix) 62 | plt.imshow(img, cmap='gray') 63 | plt.xlabel("Predicts", size=11) 64 | 65 | plt.savefig(img_demo_dir + '/' + img_demo_index + "_fuse" + suffix, facecolor='w', edgecolor='red') 66 | plt.show() 67 | 68 | 69 | def total_show_generation(show_dir, cfg): 70 | source_image_path = os.path.join(cfg.data['data_root'], cfg.data['test_dir'], 'images') 71 | ids = [] 72 | img_name = os.listdir(source_image_path) 73 | for img in img_name: 74 | ids.append(img.split('.')[0]) 75 | for i in range(len(ids)): 76 | source_image = source_image_path + '/' + ids[i] + '.' + cfg.data['suffix'] 77 | target_image = show_dir + '/result/' + ids[i] + '.' + cfg.data['suffix'] 78 | shutil.copy(source_image, target_image) 79 | for i in range(len(ids)): 80 | source_image = show_dir + '/result/' + ids[i] + '.' + cfg.data['suffix'] 81 | img = Image.open(source_image) 82 | img = img.resize((cfg.data['crop_size'], cfg.data['crop_size']), Image.ANTIALIAS) 83 | img.save(source_image) 84 | for m in range(len(ids)): 85 | print('Processing the %d image' % (m + 1)) 86 | plt.figure(figsize=(10, 6)) 87 | plt.subplot(1, 3, 1) 88 | img = plt.imread(show_dir + '/result/' + ids[m] + '.' + cfg.data['suffix']) 89 | plt.imshow(img, cmap='gray') 90 | plt.xlabel("Raw Image", size=11) 91 | 92 | plt.subplot(1, 3, 2) 93 | img = plt.imread(show_dir + '/result/' + ids[m] + '_GT' + '.' + cfg.data['suffix']) 94 | plt.imshow(img, cmap='gray') 95 | plt.xlabel("Ground Truth", size=11) 96 | 97 | plt.subplot(1, 3, 3) 98 | img = plt.imread(show_dir + '/result/' + ids[m] + '_Pred' + '.' + cfg.data['suffix']) 99 | plt.imshow(img, cmap='gray') 100 | plt.xlabel("Predicts", size=11) 101 | plt.savefig(show_dir + '/fuse/' + ids[m].split('.')[0] + "_fuse" + '.' + cfg.data['suffix'], 102 | facecolor='w', edgecolor='red') 103 | --------------------------------------------------------------------------------