├── .editorconfig ├── .github └── workflows │ └── stale.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── eval.py ├── examples ├── alias_for_plotting.json ├── config_dataset_json_example.json ├── config_method_json_example.json ├── converter_config.yaml ├── rgbd_aliases.yaml ├── single_row_style.yml └── two_row_style.yml ├── metrics ├── __init__.py ├── draw_curves.py ├── image_metrics.py └── video_metrics.py ├── plot.py ├── pyproject.toml ├── readme.md ├── readme_zh.md ├── requirements.txt ├── tools ├── append_results.py ├── check_path.py ├── converter.py ├── info_py_to_json.py ├── readme.md └── rename.py └── utils ├── __init__.py ├── generate_info.py ├── misc.py ├── print_formatter.py └── recorders ├── __init__.py ├── curve_drawer.py ├── excel_recorder.py ├── metric_recorder.py └── txt_recorder.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # https://editorconfig.org/ 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | insert_final_newline = true 9 | trim_trailing_whitespace = true 10 | end_of_line = lf 11 | charset = utf-8 12 | 13 | # Docstrings and comments use max_line_length = 79 14 | [*.py] 15 | max_line_length = 99 16 | 17 | # Use 2 spaces for the HTML files 18 | [*.html] 19 | indent_size = 2 20 | 21 | # The JSON files contain newlines inconsistently 22 | [*.json] 23 | indent_size = 2 24 | insert_final_newline = ignore 25 | 26 | [**/admin/js/vendor/**] 27 | indent_style = ignore 28 | indent_size = ignore 29 | 30 | # Minified JavaScript files shouldn't be changed 31 | [**.min.js] 32 | indent_style = ignore 33 | insert_final_newline = ignore 34 | 35 | # Makefiles always use tabs for indentation 36 | [Makefile] 37 | indent_style = space 38 | 39 | # Batch files use tabs for indentation 40 | [*.bat] 41 | indent_style = space 42 | 43 | [docs/**.txt] 44 | max_line_length = 119 45 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: 'Close stale issues and PR' 7 | on: 8 | schedule: 9 | - cron: '0 14 * * *' 10 | 11 | jobs: 12 | stale: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | issues: write 16 | pull-requests: write 17 | steps: 18 | - uses: actions/stale@v9 19 | with: 20 | repo-token: ${{ secrets.GITHUB_TOKEN }} 21 | stale-issue-label: 'no-issue-activity' 22 | stale-issue-message: 'This issue is stale because it has been open 7 days with no activity. Remove stale label or comment or this will be closed in 5 days.' 23 | close-issue-message: 'This issue was closed because it has been stalled for 5 days with no activity.' 24 | days-before-stale: 7 25 | days-before-close: 5 26 | 27 | stale-pr-label: 'no-pr-activity' 28 | stale-pr-message: 'This PR is stale because it has been open 14 days with no activity. Remove stale label or comment or this will be closed in 10 days.' 29 | days-before-pr-stale: 14 30 | days-before-pr-close: 10 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Big files 2 | **/*.png 3 | **/*.pdf 4 | **/*.jpg 5 | **/*.bmp 6 | **/*.zip 7 | **/*.7z 8 | **/*.rar 9 | **/*.tar* 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | *.npy 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | .idea/ 24 | .vscode/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | ### Python template 128 | # Byte-compiled / optimized / DLL files 129 | __pycache__/ 130 | *.py[cod] 131 | *$py.class 132 | 133 | # C extensions 134 | *.so 135 | 136 | # Distribution / packaging 137 | .Python 138 | build/ 139 | develop-eggs/ 140 | dist/ 141 | downloads/ 142 | eggs/ 143 | .eggs/ 144 | lib/ 145 | lib64/ 146 | parts/ 147 | sdist/ 148 | var/ 149 | wheels/ 150 | pip-wheel-metadata/ 151 | share/python-wheels/ 152 | *.egg-info/ 153 | .installed.cfg 154 | *.egg 155 | MANIFEST 156 | 157 | # PyInstaller 158 | # Usually these files are written by a python script from a template 159 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 160 | *.manifest 161 | *.spec 162 | 163 | # Installer logs 164 | pip-log.txt 165 | pip-delete-this-directory.txt 166 | 167 | # Unit test / coverage reports 168 | htmlcov/ 169 | .tox/ 170 | .nox/ 171 | .coverage 172 | .coverage.* 173 | .cache 174 | nosetests.xml 175 | coverage.xml 176 | *.cover 177 | *.py,cover 178 | .hypothesis/ 179 | .pytest_cache/ 180 | cover/ 181 | 182 | # Translations 183 | *.mo 184 | *.pot 185 | 186 | # Django stuff: 187 | *.log 188 | local_settings.py 189 | db.sqlite3 190 | db.sqlite3-journal 191 | 192 | # Flask stuff: 193 | instance/ 194 | .webassets-cache 195 | 196 | # Scrapy stuff: 197 | .scrapy 198 | 199 | # Sphinx documentation 200 | docs/_build/ 201 | 202 | # PyBuilder 203 | .pybuilder/ 204 | target/ 205 | 206 | # Jupyter Notebook 207 | .ipynb_checkpoints 208 | 209 | # IPython 210 | profile_default/ 211 | ipython_config.py 212 | 213 | # pyenv 214 | # For a library or package, you might want to ignore these files since the code is 215 | # intended to run in multiple environments; otherwise, check them in: 216 | # .python-version 217 | 218 | # pipenv 219 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 220 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 221 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 222 | # install all needed dependencies. 223 | #Pipfile.lock 224 | 225 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 226 | __pypackages__/ 227 | 228 | # Celery stuff 229 | celerybeat-schedule 230 | celerybeat.pid 231 | 232 | # SageMath parsed files 233 | *.sage.py 234 | 235 | # Environments 236 | .env 237 | .venv 238 | env/ 239 | venv/ 240 | ENV/ 241 | env.bak/ 242 | venv.bak/ 243 | 244 | # Spyder project settings 245 | .spyderproject 246 | .spyproject 247 | 248 | # Rope project settings 249 | .ropeproject 250 | 251 | # mkdocs documentation 252 | /site 253 | 254 | # mypy 255 | .mypy_cache/ 256 | .dmypy.json 257 | dmypy.json 258 | 259 | # Pyre type checker 260 | .pyre/ 261 | 262 | # pytype static type analyzer 263 | .pytype/ 264 | 265 | # Cython debug symbols 266 | cython_debug/ 267 | 268 | ### Example user template template 269 | ### Example user template 270 | 271 | # IntelliJ project files 272 | .idea 273 | *.iml 274 | out 275 | gen 276 | 277 | # private files 278 | /output/ 279 | /untracked/ 280 | /configs/ 281 | # /*.py 282 | /*.sh 283 | /*.ps1 284 | /*.bat 285 | /results/rgb_sod.md 286 | /results/htmls/*.html 287 | !/.github/assets/*.jpg 288 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.2.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-toml 11 | - id: check-added-large-files 12 | - id: fix-encoding-pragma 13 | - id: mixed-line-ending 14 | - repo: https://github.com/pycqa/isort 15 | rev: 5.6.4 16 | hooks: 17 | - id: isort 18 | - repo: https://github.com/psf/black 19 | rev: 20.8b1 20 | # Replace by any tag/version: https://github.com/psf/black/tags 21 | hooks: 22 | - id: black 23 | language_version: python3 24 | # Should be a command that runs python3.6+ 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 MY_ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import os 4 | import textwrap 5 | import warnings 6 | 7 | from metrics import image_metrics, video_metrics 8 | from utils.generate_info import get_datasets_info, get_methods_info 9 | from utils.recorders import SUPPORTED_METRICS 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser( 14 | description=textwrap.dedent( 15 | r""" 16 | A Powerful Evaluation Toolkit based on PySODMetrics. 17 | 18 | INCLUDE: More metrics can be set in `utils/recorders/metric_recorder.py` 19 | 20 | - F-measure-Threshold Curve 21 | - Precision-Recall Curve 22 | - MAE 23 | - weighted F-measure 24 | - S-measure 25 | - max/average/adaptive/binary F-measure 26 | - max/average/adaptive/binary E-measure 27 | - max/average/adaptive/binary Precision 28 | - max/average/adaptive/binary Recall 29 | - max/average/adaptive/binary Sensitivity 30 | - max/average/adaptive/binary Specificity 31 | - max/average/adaptive/binary F-measure 32 | - max/average/adaptive/binary Dice 33 | - max/average/adaptive/binary IoU 34 | 35 | NOTE: 36 | 37 | - Our method automatically calculates the intersection of `pre` and `gt`. 38 | - Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext` 39 | 40 | EXAMPLES: 41 | 42 | python eval.py \ 43 | --dataset-json configs/datasets/rgbd_sod.json \ 44 | --method-json \ 45 | configs/methods/json/rgbd_other_methods.json \ 46 | configs/methods/json/rgbd_our_method.json \ 47 | --metric-names sm wfm mae fmeasure em \ 48 | --num-bits 4 \ 49 | --num-workers 4 \ 50 | --metric-npy output/rgbd_metrics.npy \ 51 | --curves-npy output/rgbd_curves.npy \ 52 | --record-txt output/rgbd_results.txt 53 | --to-overwrite \ 54 | --record-xlsx output/test-metric.xlsx \ 55 | --include-dataset \ 56 | dataset-name1-from-dataset-json \ 57 | dataset-name2-from-dataset-json \ 58 | dataset-name3-from-dataset-json 59 | --include-methods \ 60 | method-name1-from-method-json \ 61 | method-name2-from-method-json \ 62 | method-name3-from-method-json 63 | """ 64 | ), 65 | formatter_class=argparse.RawTextHelpFormatter, 66 | ) 67 | # fmt: off 68 | parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.") 69 | parser.add_argument("--method-json", required=True, nargs="+", type=str, help="Json file for methods.") 70 | parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.") 71 | parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.") 72 | parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.") 73 | parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.") 74 | parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.") 75 | parser.add_argument("--include-methods", type=str, nargs="+", help="Names of only specific methods you want to evaluate.") 76 | parser.add_argument("--exclude-methods", type=str, nargs="+", help="Names of some specific methods you do not want to evaluate.") 77 | parser.add_argument("--include-datasets", type=str, nargs="+", help="Names of only specific datasets you want to evaluate.") 78 | parser.add_argument("--exclude-datasets", type=str, nargs="+", help="Names of some specific datasets you do not want to evaluate.") 79 | parser.add_argument("--num-workers", type=int, default=4, help="Number of workers for multi-threading or multi-processing. Default: 4") 80 | parser.add_argument("--num-bits", type=int, default=3, help="Number of decimal places for showing results. Default: 3") 81 | parser.add_argument("--metric-names", type=str, nargs="+", default=["sm", "wfm", "mae", "fmeasure", "em", "precision", "recall", "msiou"], choices=SUPPORTED_METRICS, help="Names of metrics") 82 | parser.add_argument("--data-type", type=str, default="image", choices=["image", "video"], help="Type of data.") 83 | 84 | known_args = parser.parse_known_args()[0] 85 | if known_args.data_type == "video": 86 | parser.add_argument("--valid-frame-start", type=int, default=0, help="Valid start index of the frame in each gt video. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames.") 87 | parser.add_argument("--valid-frame-end", type=int, default=0, help="Valid end index of the frame in each gt video. Defaults to -1, it will skip the last frame. If it is set to 0, the code will not skip frames.") 88 | # fmt: on 89 | args = parser.parse_args() 90 | 91 | if args.data_type == "video": 92 | args.valid_frame_start = max(args.valid_frame_start, 0) 93 | args.valid_frame_end = min(args.valid_frame_end, 0) 94 | if args.valid_frame_end == 0: 95 | args.valid_frame_end = None 96 | 97 | if args.metric_npy: 98 | os.makedirs(os.path.dirname(args.metric_npy), exist_ok=True) 99 | if args.curves_npy: 100 | os.makedirs(os.path.dirname(args.curves_npy), exist_ok=True) 101 | if args.record_txt: 102 | os.makedirs(os.path.dirname(args.record_txt), exist_ok=True) 103 | if args.record_xlsx: 104 | os.makedirs(os.path.dirname(args.record_xlsx), exist_ok=True) 105 | if args.to_overwrite and not args.record_txt: 106 | warnings.warn("--to-overwrite only works with a valid --record-txt") 107 | return args 108 | 109 | 110 | def main(): 111 | args = get_args() 112 | 113 | # 包含所有数据集信息的字典 114 | datasets_info = get_datasets_info( 115 | datastes_info_json=args.dataset_json, 116 | include_datasets=args.include_datasets, 117 | exclude_datasets=args.exclude_datasets, 118 | ) 119 | # 包含所有待比较模型结果的信息的字典 120 | methods_info = get_methods_info( 121 | methods_info_jsons=args.method_json, 122 | for_drawing=True, 123 | include_methods=args.include_methods, 124 | exclude_methods=args.exclude_methods, 125 | ) 126 | 127 | if args.data_type == "image": 128 | image_metrics.cal_metrics( 129 | sheet_name="Results", 130 | to_append=not args.to_overwrite, 131 | txt_path=args.record_txt, 132 | xlsx_path=args.record_xlsx, 133 | methods_info=methods_info, 134 | datasets_info=datasets_info, 135 | curves_npy_path=args.curves_npy, 136 | metrics_npy_path=args.metric_npy, 137 | num_bits=args.num_bits, 138 | num_workers=args.num_workers, 139 | metric_names=args.metric_names, 140 | ) 141 | else: 142 | video_metrics.cal_metrics( 143 | sheet_name="Results", 144 | to_append=not args.to_overwrite, 145 | txt_path=args.record_txt, 146 | xlsx_path=args.record_xlsx, 147 | methods_info=methods_info, 148 | datasets_info=datasets_info, 149 | curves_npy_path=args.curves_npy, 150 | metrics_npy_path=args.metric_npy, 151 | num_bits=args.num_bits, 152 | num_workers=args.num_workers, 153 | metric_names=args.metric_names, 154 | return_group=False, 155 | start_idx=args.valid_frame_start, 156 | end_idx=args.valid_frame_end, 157 | ) 158 | 159 | 160 | # 确保多进程在windows上也可以正常使用 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /examples/alias_for_plotting.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "Name_In_Json": "Name_In_SubFigure", 4 | "NJUD": "NJUD", 5 | "NLPR": "NLPR", 6 | "DUTRGBD": "DUTRGBD", 7 | "STEREO1000": "SETERE", 8 | "RGBD135": "RGBD135", 9 | "SSD": "SSD", 10 | "SIP": "SIP" 11 | }, 12 | "method": { 13 | "Name_In_Json": "Name_In_Legend", 14 | "GateNet_2020": "GateNet", 15 | "MINet_R50_2020": "MINet" 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /examples/config_dataset_json_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "LFSD": { 3 | "image": { 4 | "path": "Path_Of_RGBDSOD_Datasets/LFSD/Image", 5 | "prefix": "some_gt_prefix", 6 | "suffix": ".jpg" 7 | }, 8 | "mask": { 9 | "path": "Path_Of_RGBDSOD_Datasets/LFSD/Mask", 10 | "prefix": "some_gt_prefix", 11 | "suffix": ".png" 12 | } 13 | }, 14 | "NJUD": { 15 | "image": { 16 | "path": "Path_Of_RGBDSOD_Datasets/NJUD_FULL/Image", 17 | "suffix": ".jpg" 18 | }, 19 | "mask": { 20 | "path": "Path_Of_RGBDSOD_Datasets/NJUD_FULL/Mask", 21 | "suffix": ".png" 22 | } 23 | }, 24 | "NLPR": { 25 | "image": { 26 | "path": "Path_Of_RGBDSOD_Datasets/NLPR_FULL/Image", 27 | "suffix": ".jpg" 28 | }, 29 | "mask": { 30 | "path": "Path_Of_RGBDSOD_Datasets/NLPR_FULL/Mask", 31 | "suffix": ".png" 32 | } 33 | }, 34 | "RGBD135": { 35 | "image": { 36 | "path": "Path_Of_RGBDSOD_Datasets/RGBD135/Image", 37 | "suffix": ".jpg" 38 | }, 39 | "mask": { 40 | "path": "Path_Of_RGBDSOD_Datasets/RGBD135/Mask", 41 | "suffix": ".png" 42 | } 43 | }, 44 | "SIP": { 45 | "image": { 46 | "path": "Path_Of_RGBDSOD_Datasets/SIP/Image", 47 | "suffix": ".jpg" 48 | }, 49 | "mask": { 50 | "path": "Path_Of_RGBDSOD_Datasets/SIP/Mask", 51 | "suffix": ".png" 52 | } 53 | }, 54 | "SSD": { 55 | "image": { 56 | "path": "Path_Of_RGBDSOD_Datasets/SSD/Image", 57 | "suffix": ".jpg" 58 | }, 59 | "mask": { 60 | "path": "Path_Of_RGBDSOD_Datasets/SSD/Mask", 61 | "suffix": ".png" 62 | } 63 | }, 64 | "STEREO797": { 65 | "image": { 66 | "path": "Path_Of_RGBDSOD_Datasets/STEREO797/Image", 67 | "suffix": ".jpg" 68 | }, 69 | "mask": { 70 | "path": "Path_Of_RGBDSOD_Datasets/STEREO797/Mask", 71 | "suffix": ".png" 72 | } 73 | }, 74 | "STEREO1000": { 75 | "image": { 76 | "path": "Path_Of_RGBDSOD_Datasets/STEREO1000/Image", 77 | "suffix": ".jpg" 78 | }, 79 | "mask": { 80 | "path": "Path_Of_RGBDSOD_Datasets/STEREO1000/Mask", 81 | "suffix": ".png" 82 | } 83 | }, 84 | "DUTRGBD": { 85 | "image": { 86 | "path": "Path_Of_RGBDSOD_Datasets/DUT-RGBD/Test/Image", 87 | "suffix": ".jpg" 88 | }, 89 | "mask": { 90 | "path": "Path_Of_RGBDSOD_Datasets/DUT-RGBD/Test/Mask", 91 | "suffix": ".png" 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /examples/config_method_json_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "Method1": { 3 | "PASCAL-S": { 4 | "path": "Path_Of_Method1/PASCAL-S/DGRL", 5 | "prefix": "some_method_prefix", 6 | "suffix": ".png" 7 | }, 8 | "ECSSD": { 9 | "path": "Path_Of_Method1/ECSSD/DGRL", 10 | "prefix": "some_method_prefix", 11 | "suffix": ".png" 12 | }, 13 | "HKU-IS": { 14 | "path": "Path_Of_Method1/HKU-IS/DGRL", 15 | "prefix": "some_method_prefix", 16 | "suffix": ".png" 17 | }, 18 | "DUT-OMRON": { 19 | "path": "Path_Of_Method1/DUT-OMRON/DGRL", 20 | "prefix": "some_method_prefix", 21 | "suffix": ".png" 22 | }, 23 | "DUTS-TE": { 24 | "path": "Path_Of_Method1/DUTS-TE/DGRL", 25 | "suffix": ".png" 26 | } 27 | }, 28 | "Method2": { 29 | "PASCAL-S": { 30 | "path": "Path_Of_Method2/pascal", 31 | "prefix": "pascal_", 32 | "suffix": ".png" 33 | }, 34 | "ECSSD": { 35 | "path": "Path_Of_Method2/ecssd", 36 | "prefix": "ecssd_", 37 | "suffix": ".png" 38 | }, 39 | "HKU-IS": { 40 | "path": "Path_Of_Method2/hku", 41 | "prefix": "hku_", 42 | "suffix": ".png" 43 | }, 44 | "DUT-OMRON": { 45 | "path": "Path_Of_Method2/duto", 46 | "prefix": "duto_", 47 | "suffix": ".png" 48 | }, 49 | "DUTS-TE": { 50 | "path": "Path_Of_Method2/dut_te", 51 | "prefix": "dut_te_", 52 | "suffix": ".png" 53 | } 54 | }, 55 | "Method3": { 56 | "PASCAL-S": { 57 | "path": "Path_Of_Method3/pascal", 58 | "prefix": "pascal_", 59 | "suffix": "_fused_sod.png" 60 | }, 61 | "ECSSD": { 62 | "path": "Path_Of_Method3/ecssd", 63 | "prefix": "ecssd_", 64 | "suffix": "_fused_sod.png" 65 | }, 66 | "HKU-IS": { 67 | "path": "Path_Of_Method3/hku", 68 | "prefix": "hku_", 69 | "suffix": "_fused_sod.png" 70 | }, 71 | "DUT-OMRON": { 72 | "path": "Path_Of_Method3/duto", 73 | "prefix": "duto_", 74 | "suffix": "_fused_sod.png" 75 | }, 76 | "DUTS-TE": { 77 | "path": "Path_Of_Method3/dut_te", 78 | "prefix": "dut_te_", 79 | "suffix": "_fused_sod.png" 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /examples/converter_config.yaml: -------------------------------------------------------------------------------- 1 | dataset_names: [ 2 | 'NJUD', 3 | 'NLPR', 4 | 'SIP', 5 | 'STEREO1000', 6 | 'SSD', 7 | 'LFSD', 8 | 'RGBD135', 9 | 'DUTRGBD' 10 | ] 11 | 12 | # 使用单引号保证不被转义 13 | method_names: { 14 | '2020-ECCV-DANetV19': 'DANet$_{20}$', 15 | '2020-ECCV-HDFNetR50': 'HDFNet$_{20}$', 16 | '2022-AAAI-SSLSOD-ImageNet': 'SSLSOD$_{22}$', 17 | } 18 | 19 | # 使用单引号保证不被转义 20 | metric_names: { 21 | 'sm': '$S_{m}~\uparrow$', 22 | 'wfm': '$F^{\omega}_{\beta}~\uparrow$', 23 | 'mae': '$MAE~\downarrow$', 24 | 'adpf': '$F^{adp}_{\beta}~\uparrow$', 25 | 'avgf': '$F^{avg}_{\beta}~\uparrow$', 26 | 'maxf': '$F^{max}_{\beta}~\uparrow$', 27 | 'adpe': '$E^{adp}_{m}~\uparrow$', 28 | 'avge': '$E^{avg}_{m}~\uparrow$', 29 | 'maxe': '$E^{max}_{m}~\uparrow$', 30 | } 31 | -------------------------------------------------------------------------------- /examples/rgbd_aliases.yaml: -------------------------------------------------------------------------------- 1 | dataset: { 2 | "NJUD": "NJUD", 3 | "NLPR": "NLPR", 4 | "SIP": "SIP", 5 | "STEREO1000": "STEREO1000", 6 | "RGBD135": "RGBD135", 7 | "SSD": "SSD", 8 | "LFSD": "LFSD", 9 | "DUTRGBD": "DUTRGBD", 10 | } 11 | 12 | method: { 13 | '2020-ECCV-DANetV19': 'DANet$_{20}$', 14 | '2020-ECCV-HDFNetR50': 'HDFNet$_{20}$', 15 | '2022-AAAI-SSLSOD-ImageNet': 'SSLSOD$_{22}$', 16 | } 17 | -------------------------------------------------------------------------------- /examples/single_row_style.yml: -------------------------------------------------------------------------------- 1 | # Based: 2 | # - https://matplotlib.org/stable/tutorials/introductory/customizing.html#the-default-matplotlibrc-file 3 | # - https://github.com/rougier/scientific-visualization-book/blob/master/code/defaults/mystyle.txt 4 | 5 | 6 | ## *************************************************************************** 7 | ## * LINES * 8 | ## *************************************************************************** 9 | ## See https://matplotlib.org/api/artist_api.html#module-matplotlib.lines 10 | ## for more information on line properties. 11 | lines.linewidth: 2 12 | lines.markersize: 5 13 | 14 | 15 | ## *************************************************************************** 16 | ## * FONT * 17 | ## *************************************************************************** 18 | ## The font properties used by `text.Text`. 19 | ## See https://matplotlib.org/api/font_manager_api.html for more information 20 | ## on font properties. The 6 font properties used for font matching are 21 | ## given below with their default values. 22 | ## 23 | ## The font.family property can take either a concrete font name (not supported 24 | ## when rendering text with usetex), or one of the following five generic 25 | ## values: 26 | ## - 'serif' (e.g., Times), 27 | ## - 'sans-serif' (e.g., Helvetica), 28 | ## - 'cursive' (e.g., Zapf-Chancery), 29 | ## - 'fantasy' (e.g., Western), and 30 | ## - 'monospace' (e.g., Courier). 31 | ## Each of these values has a corresponding default list of font names 32 | ## (font.serif, etc.); the first available font in the list is used. Note that 33 | ## for font.serif, font.sans-serif, and font.monospace, the first element of 34 | ## the list (a DejaVu font) will always be used because DejaVu is shipped with 35 | ## Matplotlib and is thus guaranteed to be available; the other entries are 36 | ## left as examples of other possible values. 37 | ## 38 | ## The font.style property has three values: normal (or roman), italic 39 | ## or oblique. The oblique style will be used for italic, if it is not 40 | ## present. 41 | ## 42 | ## The font.variant property has two values: normal or small-caps. For 43 | ## TrueType fonts, which are scalable fonts, small-caps is equivalent 44 | ## to using a font size of 'smaller', or about 83%% of the current font 45 | ## size. 46 | ## 47 | ## The font.weight property has effectively 13 values: normal, bold, 48 | ## bolder, lighter, 100, 200, 300, ..., 900. Normal is the same as 49 | ## 400, and bold is 700. bolder and lighter are relative values with 50 | ## respect to the current weight. 51 | ## 52 | ## The font.stretch property has 11 values: ultra-condensed, 53 | ## extra-condensed, condensed, semi-condensed, normal, semi-expanded, 54 | ## expanded, extra-expanded, ultra-expanded, wider, and narrower. This 55 | ## property is not currently implemented. 56 | ## 57 | ## The font.size property is the default font size for text, given in points. 58 | ## 10 pt is the standard value. 59 | ## 60 | ## Note that font.size controls default text sizes. To configure 61 | ## special text sizes tick labels, axes, labels, title, etc., see the rc 62 | ## settings for axes and ticks. Special text sizes can be defined 63 | ## relative to font.size, using the following values: xx-small, x-small, 64 | ## small, medium, large, x-large, xx-large, larger, or smaller 65 | font.family: sans-serif 66 | font.style: normal 67 | font.variant: normal 68 | font.weight: normal 69 | # font.stretch: normal 70 | font.size: 12.0 71 | 72 | #font.serif: DejaVu Serif, Bitstream Vera Serif, Computer Modern Roman, New Century Schoolbook, Century Schoolbook L, Utopia, ITC Bookman, Bookman, Nimbus Roman No9 L, Times New Roman, Times, Palatino, Charter, serif 73 | font.sans-serif: Trebuchet MS, DejaVu Sans, Bitstream Vera Sans, Computer Modern Sans Serif, Lucida Grande, Verdana, Geneva, Lucid, Arial, Helvetica, Avant Garde, sans-serif 74 | #font.cursive: Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, Comic Neue, Comic Sans MS, cursive 75 | #font.fantasy: Chicago, Charcoal, Impact, Western, Humor Sans, xkcd, fantasy 76 | #font.monospace: DejaVu Sans Mono, Bitstream Vera Sans Mono, Computer Modern Typewriter, Andale Mono, Nimbus Mono L, Courier New, Courier, Fixed, Terminal, monospace 77 | 78 | 79 | ## *************************************************************************** 80 | ## * AXES * 81 | ## *************************************************************************** 82 | ## Following are default face and edge colors, default tick sizes, 83 | ## default font sizes for tick labels, and so on. See 84 | ## https://matplotlib.org/api/axes_api.html#module-matplotlib.axes 85 | axes.linewidth: 1 86 | axes.grid: True 87 | axes.ymargin: 0.1 88 | 89 | axes.titlelocation: center # alignment of the title: {left, right, center} 90 | axes.titlesize: x-large # font size of the axes title 91 | axes.titleweight: bold # font weight of title 92 | axes.titlecolor: black # color of the axes title, auto falls back to text.color as default value 93 | 94 | axes.spines.left: True 95 | axes.spines.bottom: True 96 | axes.spines.right: False 97 | axes.spines.top: False 98 | 99 | axes.labelsize: medium # font size of the x and y labels 100 | axes.labelpad: 2.0 # space between label and axis 101 | axes.labelweight: normal # weight of the x and y labels 102 | axes.labelcolor: black 103 | axes.axisbelow: True # draw axis gridlines and ticks: 104 | # - below patches (True) 105 | # - above patches but below lines ('line') 106 | # - above all (False) 107 | 108 | 109 | ## *************************************************************************** 110 | ## * TICKS * 111 | ## *************************************************************************** 112 | ## See https://matplotlib.org/api/axis_api.html#matplotlib.axis.Tick 113 | xtick.bottom: True 114 | xtick.top: False 115 | xtick.direction: out 116 | xtick.major.size: 5 117 | xtick.major.width: 1 118 | xtick.minor.size: 3 119 | xtick.minor.width: 0.5 120 | xtick.minor.visible: False 121 | xtick.alignment: center # alignment of xticks 122 | 123 | ytick.left: True 124 | ytick.right: False 125 | ytick.direction: out 126 | ytick.major.size: 5 127 | ytick.major.width: 1 128 | ytick.minor.size: 3 129 | ytick.minor.width: 0.5 130 | ytick.minor.visible: False 131 | ytick.alignment: center_baseline # alignment of yticks 132 | 133 | 134 | ## *************************************************************************** 135 | ## * GRIDS * 136 | ## *************************************************************************** 137 | grid.color: black 138 | grid.linewidth: 0.1 139 | grid.alpha: 0.4 # transparency, between 0.0 and 1.0 140 | 141 | 142 | ## *************************************************************************** 143 | ## * LEGEND * 144 | ## *************************************************************************** 145 | legend.fancybox: True # if True, use a rounded box for the legend background, else a rectangle 146 | legend.shadow: False # if True, give background a shadow effect 147 | legend.numpoints: 1 # the number of marker points in the legend line 148 | legend.scatterpoints: 1 # number of scatter points 149 | legend.markerscale: 1.0 # the relative size of legend markers vs. original 150 | legend.fontsize: large 151 | legend.framealpha: 0.9 152 | 153 | # Dimensions as fraction of font size: 154 | legend.borderpad: 0.4 # border whitespace 155 | legend.labelspacing: 0.5 # the vertical space between the legend entries 156 | legend.handlelength: 2.0 # the length of the legend lines 157 | legend.handleheight: 0.7 # the height of the legend handle 158 | legend.handletextpad: 0.5 # the space between the legend line and legend text 159 | legend.borderaxespad: 0.5 # the border between the axes and legend edge 160 | legend.columnspacing: 0.5 # column separation 161 | 162 | 163 | ## *************************************************************************** 164 | ## * FIGURE * 165 | ## *************************************************************************** 166 | ## See https://matplotlib.org/api/figure_api.html#matplotlib.figure.Figure 167 | figure.titlesize: large # size of the figure title (``Figure.suptitle()``) 168 | figure.titleweight: normal # weight of the figure title 169 | figure.figsize: 16,4 # figure size in inches 170 | figure.dpi: 600 # figure dots per inch 171 | figure.facecolor: white # figure face color 172 | figure.edgecolor: white # figure edge color 173 | 174 | # The figure subplot parameters. All dimensions are a fraction of the figure width and height. 175 | figure.subplot.left: 0.00 # the left side of the subplots of the figure 176 | figure.subplot.right: 1.00 # the right side of the subplots of the figure 177 | figure.subplot.bottom: 0.00 # the bottom of the subplots of the figure 178 | figure.subplot.top: 1.00 # the top of the subplots of the figure 179 | figure.subplot.wspace: 0.10 # the amount of width reserved for space between subplots, expressed as a fraction of the average axis width 180 | figure.subplot.hspace: 0.10 # the amount of height reserved for space between subplots, expressed as a fraction of the average axis height 181 | 182 | ## Figure layout 183 | figure.autolayout: False # When True, automatically adjust subplot parameters to make the plot fit the figure using `tight_layout` 184 | 185 | 186 | ## *************************************************************************** 187 | ## * IMAGES * 188 | ## *************************************************************************** 189 | image.interpolation: antialiased # see help(imshow) for options 190 | image.cmap: gray # A colormap name, gray etc... 191 | image.lut: 256 # the size of the colormap lookup table 192 | 193 | 194 | ## *************************************************************************** 195 | ## * SAVING FIGURES * 196 | ## *************************************************************************** 197 | ## The default savefig parameters can be different from the display parameters 198 | ## e.g., you may want a higher resolution, or to make the figure 199 | ## background white 200 | savefig.dpi: figure # figure dots per inch or 'figure' 201 | savefig.format: pdf # {png, ps, pdf, svg} 202 | 203 | ## PDF backend params 204 | pdf.compression: 6 # integer from 0 to 9 0 disables compression (good for debugging) 205 | pdf.fonttype: 3 # Output Type 3 (Type3) or Type 42 (TrueType) 206 | -------------------------------------------------------------------------------- /examples/two_row_style.yml: -------------------------------------------------------------------------------- 1 | # Based: 2 | # - https://matplotlib.org/stable/tutorials/introductory/customizing.html#the-default-matplotlibrc-file 3 | # - https://github.com/rougier/scientific-visualization-book/blob/master/code/defaults/mystyle.txt 4 | 5 | 6 | ## *************************************************************************** 7 | ## * LINES * 8 | ## *************************************************************************** 9 | ## See https://matplotlib.org/api/artist_api.html#module-matplotlib.lines 10 | ## for more information on line properties. 11 | lines.linewidth: 2 12 | lines.markersize: 5 13 | 14 | 15 | ## *************************************************************************** 16 | ## * FONT * 17 | ## *************************************************************************** 18 | ## The font properties used by `text.Text`. 19 | ## See https://matplotlib.org/api/font_manager_api.html for more information 20 | ## on font properties. The 6 font properties used for font matching are 21 | ## given below with their default values. 22 | ## 23 | ## The font.family property can take either a concrete font name (not supported 24 | ## when rendering text with usetex), or one of the following five generic 25 | ## values: 26 | ## - 'serif' (e.g., Times), 27 | ## - 'sans-serif' (e.g., Helvetica), 28 | ## - 'cursive' (e.g., Zapf-Chancery), 29 | ## - 'fantasy' (e.g., Western), and 30 | ## - 'monospace' (e.g., Courier). 31 | ## Each of these values has a corresponding default list of font names 32 | ## (font.serif, etc.); the first available font in the list is used. Note that 33 | ## for font.serif, font.sans-serif, and font.monospace, the first element of 34 | ## the list (a DejaVu font) will always be used because DejaVu is shipped with 35 | ## Matplotlib and is thus guaranteed to be available; the other entries are 36 | ## left as examples of other possible values. 37 | ## 38 | ## The font.style property has three values: normal (or roman), italic 39 | ## or oblique. The oblique style will be used for italic, if it is not 40 | ## present. 41 | ## 42 | ## The font.variant property has two values: normal or small-caps. For 43 | ## TrueType fonts, which are scalable fonts, small-caps is equivalent 44 | ## to using a font size of 'smaller', or about 83%% of the current font 45 | ## size. 46 | ## 47 | ## The font.weight property has effectively 13 values: normal, bold, 48 | ## bolder, lighter, 100, 200, 300, ..., 900. Normal is the same as 49 | ## 400, and bold is 700. bolder and lighter are relative values with 50 | ## respect to the current weight. 51 | ## 52 | ## The font.stretch property has 11 values: ultra-condensed, 53 | ## extra-condensed, condensed, semi-condensed, normal, semi-expanded, 54 | ## expanded, extra-expanded, ultra-expanded, wider, and narrower. This 55 | ## property is not currently implemented. 56 | ## 57 | ## The font.size property is the default font size for text, given in points. 58 | ## 10 pt is the standard value. 59 | ## 60 | ## Note that font.size controls default text sizes. To configure 61 | ## special text sizes tick labels, axes, labels, title, etc., see the rc 62 | ## settings for axes and ticks. Special text sizes can be defined 63 | ## relative to font.size, using the following values: xx-small, x-small, 64 | ## small, medium, large, x-large, xx-large, larger, or smaller 65 | font.family: sans-serif 66 | font.style: normal 67 | font.variant: normal 68 | font.weight: normal 69 | # font.stretch: normal 70 | font.size: 12.0 71 | 72 | #font.serif: DejaVu Serif, Bitstream Vera Serif, Computer Modern Roman, New Century Schoolbook, Century Schoolbook L, Utopia, ITC Bookman, Bookman, Nimbus Roman No9 L, Times New Roman, Times, Palatino, Charter, serif 73 | font.sans-serif: Trebuchet MS, DejaVu Sans, Bitstream Vera Sans, Computer Modern Sans Serif, Lucida Grande, Verdana, Geneva, Lucid, Arial, Helvetica, Avant Garde, sans-serif 74 | #font.cursive: Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, Comic Neue, Comic Sans MS, cursive 75 | #font.fantasy: Chicago, Charcoal, Impact, Western, Humor Sans, xkcd, fantasy 76 | #font.monospace: DejaVu Sans Mono, Bitstream Vera Sans Mono, Computer Modern Typewriter, Andale Mono, Nimbus Mono L, Courier New, Courier, Fixed, Terminal, monospace 77 | 78 | 79 | ## *************************************************************************** 80 | ## * AXES * 81 | ## *************************************************************************** 82 | ## Following are default face and edge colors, default tick sizes, 83 | ## default font sizes for tick labels, and so on. See 84 | ## https://matplotlib.org/api/axes_api.html#module-matplotlib.axes 85 | axes.linewidth: 1 86 | axes.grid: True 87 | axes.ymargin: 0.1 88 | 89 | axes.titlelocation: center # alignment of the title: {left, right, center} 90 | axes.titlesize: x-large # font size of the axes title 91 | axes.titleweight: bold # font weight of title 92 | axes.titlecolor: black # color of the axes title, auto falls back to text.color as default value 93 | 94 | axes.spines.left: True 95 | axes.spines.bottom: True 96 | axes.spines.right: False 97 | axes.spines.top: False 98 | 99 | axes.labelsize: medium # font size of the x and y labels 100 | axes.labelpad: 2.0 # space between label and axis 101 | axes.labelweight: normal # weight of the x and y labels 102 | axes.labelcolor: black 103 | axes.axisbelow: True # draw axis gridlines and ticks: 104 | # - below patches (True) 105 | # - above patches but below lines ('line') 106 | # - above all (False) 107 | 108 | 109 | ## *************************************************************************** 110 | ## * TICKS * 111 | ## *************************************************************************** 112 | ## See https://matplotlib.org/api/axis_api.html#matplotlib.axis.Tick 113 | xtick.bottom: True 114 | xtick.top: False 115 | xtick.direction: out 116 | xtick.major.size: 5 117 | xtick.major.width: 1 118 | xtick.minor.size: 3 119 | xtick.minor.width: 0.5 120 | xtick.minor.visible: False 121 | xtick.alignment: center # alignment of xticks 122 | 123 | ytick.left: True 124 | ytick.right: False 125 | ytick.direction: out 126 | ytick.major.size: 5 127 | ytick.major.width: 1 128 | ytick.minor.size: 3 129 | ytick.minor.width: 0.5 130 | ytick.minor.visible: False 131 | ytick.alignment: center_baseline # alignment of yticks 132 | 133 | 134 | ## *************************************************************************** 135 | ## * GRIDS * 136 | ## *************************************************************************** 137 | grid.color: black 138 | grid.linewidth: 0.1 139 | grid.alpha: 0.4 # transparency, between 0.0 and 1.0 140 | 141 | 142 | ## *************************************************************************** 143 | ## * LEGEND * 144 | ## *************************************************************************** 145 | legend.fancybox: True # if True, use a rounded box for the legend background, else a rectangle 146 | legend.shadow: False # if True, give background a shadow effect 147 | legend.numpoints: 1 # the number of marker points in the legend line 148 | legend.scatterpoints: 1 # number of scatter points 149 | legend.markerscale: 1.0 # the relative size of legend markers vs. original 150 | legend.fontsize: large 151 | legend.framealpha: 1 152 | 153 | # Dimensions as fraction of font size: 154 | legend.borderpad: 0.4 # border whitespace 155 | legend.labelspacing: 0.5 # the vertical space between the legend entries 156 | legend.handlelength: 2.0 # the length of the legend lines 157 | legend.handleheight: 0.7 # the height of the legend handle 158 | legend.handletextpad: 0.5 # the space between the legend line and legend text 159 | legend.borderaxespad: 0.5 # the border between the axes and legend edge 160 | legend.columnspacing: 0.5 # column separation 161 | 162 | 163 | ## *************************************************************************** 164 | ## * FIGURE * 165 | ## *************************************************************************** 166 | ## See https://matplotlib.org/api/figure_api.html#matplotlib.figure.Figure 167 | figure.titlesize: large # size of the figure title (``Figure.suptitle()``) 168 | figure.titleweight: normal # weight of the figure title 169 | figure.figsize: 16,8 # figure size in inches 170 | figure.dpi: 600 # figure dots per inch 171 | figure.facecolor: white # figure face color 172 | figure.edgecolor: white # figure edge color 173 | 174 | # The figure subplot parameters. All dimensions are a fraction of the figure width and height. 175 | figure.subplot.left: 0.00 # the left side of the subplots of the figure 176 | figure.subplot.right: 1.00 # the right side of the subplots of the figure 177 | figure.subplot.bottom: 0.00 # the bottom of the subplots of the figure 178 | figure.subplot.top: 1.00 # the top of the subplots of the figure 179 | figure.subplot.wspace: 0.10 # the amount of width reserved for space between subplots, expressed as a fraction of the average axis width 180 | figure.subplot.hspace: 0.10 # the amount of height reserved for space between subplots, expressed as a fraction of the average axis height 181 | 182 | ## Figure layout 183 | figure.autolayout: False # When True, automatically adjust subplot parameters to make the plot fit the figure using `tight_layout` 184 | 185 | 186 | ## *************************************************************************** 187 | ## * IMAGES * 188 | ## *************************************************************************** 189 | image.interpolation: antialiased # see help(imshow) for options 190 | image.cmap: gray # A colormap name, gray etc... 191 | image.lut: 256 # the size of the colormap lookup table 192 | 193 | 194 | ## *************************************************************************** 195 | ## * SAVING FIGURES * 196 | ## *************************************************************************** 197 | ## The default savefig parameters can be different from the display parameters 198 | ## e.g., you may want a higher resolution, or to make the figure 199 | ## background white 200 | savefig.dpi: figure # figure dots per inch or 'figure' 201 | savefig.format: pdf # {png, ps, pdf, svg} 202 | 203 | ## PDF backend params 204 | pdf.compression: 6 # integer from 0 to 9 0 disables compression (good for debugging) 205 | pdf.fonttype: 3 # Output Type 3 (Type3) or Type 42 (TrueType) 206 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lartpang/PySODEvalToolkit/a52dbccb90ec3da8f52a6e0b97b055e1d27b6319/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/draw_curves.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | from matplotlib import colors 7 | 8 | from utils.recorders import CurveDrawer 9 | 10 | # Align the mode with those in GRAYSCALE_METRICS 11 | _YX_AXIS_NAMES = { 12 | "pr": ("precision", "recall"), 13 | "fm": ("fmeasure", None), 14 | "fmeasure": ("fmeasure", None), 15 | "em": ("em", None), 16 | "iou": ("iou", None), 17 | "dice": ("dice", None), 18 | } 19 | 20 | 21 | def draw_curves( 22 | mode: str, 23 | axes_setting: dict = None, 24 | curves_npy_path: list = None, 25 | row_num: int = 1, 26 | our_methods: list = None, 27 | method_aliases: OrderedDict = None, 28 | dataset_aliases: OrderedDict = None, 29 | style_cfg: dict = None, 30 | ncol_of_legend: int = 1, 31 | separated_legend: bool = False, 32 | sharey: bool = False, 33 | line_width=3, 34 | save_name=None, 35 | ): 36 | """A better curve painter! 37 | 38 | Args: 39 | mode (str): `pr` for PR curves, `fm` for F-measure curves, and `em' for E-measure curves. 40 | axes_setting (dict, optional): Setting for axes. Defaults to None. 41 | curves_npy_path (list, optional): Paths of curve npy files. Defaults to None. 42 | row_num (int, optional): Number of rows. Defaults to 1. 43 | our_methods (list, optional): Names of our methods. Defaults to None. 44 | method_aliases (OrderedDict, optional): Aliases of methods. Defaults to None. 45 | dataset_aliases (OrderedDict, optional): Aliases of datasets. Defaults to None. 46 | style_cfg (dict, optional): Config file for the style of matplotlib. Defaults to None. 47 | ncol_of_legend (int, optional): Number of columns for the legend. Defaults to 1. 48 | separated_legend (bool, optional): Use the separated legend. Defaults to False. 49 | sharey (bool, optional): Use a shared y-axis. Defaults to False. 50 | line_width (int, optional): Width of lines. Defaults to 3. 51 | save_name (str, optional): Name or path (without the extension format). Defaults to None. 52 | """ 53 | save_name = save_name or mode 54 | y_axis_name, x_axis_name = _YX_AXIS_NAMES[mode] 55 | 56 | assert curves_npy_path 57 | if not isinstance(curves_npy_path, (list, tuple)): 58 | curves_npy_path = [curves_npy_path] 59 | 60 | curves = {} 61 | unique_method_names_from_npy = [] 62 | for p in curves_npy_path: 63 | single_curves = np.load(p, allow_pickle=True).item() 64 | for dataset_name, method_infos in single_curves.items(): 65 | curves.setdefault(dataset_name, {}) 66 | for method_name, method_info in method_infos.items(): 67 | curves[dataset_name][method_name] = method_info 68 | if method_name not in unique_method_names_from_npy: 69 | unique_method_names_from_npy.append(method_name) 70 | dataset_names_from_npy = list(curves.keys()) 71 | 72 | if dataset_aliases is None: 73 | dataset_aliases = OrderedDict({k: k for k in dataset_names_from_npy}) 74 | else: 75 | for x in dataset_aliases.keys(): 76 | if x not in dataset_names_from_npy: 77 | raise ValueError(f"{x} must be contained in\n{dataset_names_from_npy}") 78 | 79 | if method_aliases is not None: 80 | target_unique_method_names = [] 81 | for x in method_aliases: 82 | if x in unique_method_names_from_npy: 83 | target_unique_method_names.append(x) 84 | # Only consider the name in npy is also in alias config. 85 | # if x not in unique_method_names_from_npy: 86 | # raise ValueError( 87 | # f"{x} must be contained in\n{sorted(unique_method_names_from_npy)}" 88 | # ) 89 | else: 90 | method_aliases = {} 91 | target_unique_method_names = unique_method_names_from_npy 92 | 93 | if our_methods is not None: 94 | our_methods.reverse() 95 | for x in our_methods: 96 | if x not in target_unique_method_names: 97 | raise ValueError(f"{x} must be contained in\n{target_unique_method_names}") 98 | # Put our methods into the head of the name list 99 | target_unique_method_names.pop(target_unique_method_names.index(x)) 100 | target_unique_method_names.insert(0, x) 101 | # assert len(our_methods) <= len(line_styles) 102 | else: 103 | our_methods = [] 104 | num_our_methods = len(our_methods) 105 | 106 | # Give each method a unique color and style. 107 | color_table = sorted( 108 | [ 109 | color 110 | for name, color in colors.cnames.items() 111 | if name not in ["red", "white"] or not name.startswith("light") or "gray" in name 112 | ] 113 | ) 114 | style_table = ["-", "--", "-.", ":", "."] 115 | 116 | unique_method_settings = OrderedDict() 117 | for i, method_name in enumerate(target_unique_method_names): 118 | if i < num_our_methods: 119 | line_color = "red" 120 | line_style = style_table[i % len(style_table)] 121 | else: 122 | other_idx = i - num_our_methods 123 | line_color = color_table[other_idx] 124 | line_style = style_table[other_idx % 2] 125 | 126 | unique_method_settings[method_name] = { 127 | "line_color": line_color, 128 | "line_label": method_aliases.get(method_name, method_name), 129 | "line_style": line_style, 130 | "line_width": line_width, 131 | } 132 | # ensure that our methods are drawn last to avoid being overwritten by other methods 133 | target_unique_method_names.reverse() 134 | 135 | curve_drawer = CurveDrawer( 136 | row_num=row_num, 137 | num_subplots=len(dataset_aliases), 138 | style_cfg=style_cfg, 139 | ncol_of_legend=ncol_of_legend, 140 | separated_legend=separated_legend, 141 | sharey=sharey, 142 | ) 143 | 144 | for idx, (dataset_name, dataset_alias) in enumerate(dataset_aliases.items()): 145 | dataset_results = curves[dataset_name] 146 | 147 | for method_name in target_unique_method_names: 148 | method_setting = unique_method_settings[method_name] 149 | 150 | if method_name not in dataset_results: 151 | print(f"{method_name} will be skipped for {dataset_name}!") 152 | continue 153 | 154 | method_results = dataset_results[method_name] 155 | 156 | if y_axis_name is None: 157 | y_data = np.linspace(0, 1, 256) 158 | else: 159 | y_data = method_results[y_axis_name] 160 | assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys()) 161 | 162 | if x_axis_name is None: 163 | x_data = np.linspace(0, 1, 256) 164 | else: 165 | x_data = method_results[x_axis_name] 166 | assert isinstance(x_data, (list, tuple)), (method_name, method_results.keys()) 167 | 168 | curve_drawer.plot_at_axis(idx, method_setting, x_data=x_data, y_data=y_data) 169 | curve_drawer.set_axis_property(idx, dataset_alias, **axes_setting[mode]) 170 | curve_drawer.save(path=save_name) 171 | -------------------------------------------------------------------------------- /metrics/image_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from collections import defaultdict 5 | from functools import partial 6 | from multiprocessing import pool 7 | from threading import RLock as TRLock 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from utils.misc import get_gt_pre_with_name, get_name_list, make_dir 13 | from utils.print_formatter import formatter_for_tabulate 14 | from utils.recorders import ( 15 | BINARY_METRIC_MAPPING, 16 | GRAYSCALE_METRICS, 17 | BinaryMetricRecorder, 18 | GrayscaleMetricRecorder, 19 | MetricExcelRecorder, 20 | TxtRecorder, 21 | ) 22 | 23 | 24 | class Recorder: 25 | def __init__( 26 | self, 27 | method_names, 28 | dataset_names, 29 | metric_names, 30 | *, 31 | txt_path, 32 | to_append, 33 | xlsx_path, 34 | sheet_name, 35 | ): 36 | self.curves = defaultdict(dict) # Two curve metrics 37 | self.metrics = defaultdict(dict) # Six numerical metrics 38 | self.method_names = method_names 39 | self.dataset_names = dataset_names 40 | 41 | self.txt_recorder = None 42 | if txt_path: 43 | self.txt_recorder = TxtRecorder( 44 | txt_path=txt_path, 45 | to_append=to_append, 46 | max_method_name_width=max([len(x) for x in method_names]), # 显示完整名字 47 | ) 48 | 49 | self.excel_recorder = None 50 | if xlsx_path: 51 | excel_metric_names = [] 52 | for x in metric_names: 53 | if x in GRAYSCALE_METRICS: 54 | if x == "em": 55 | excel_metric_names.append(f"max{x}") 56 | excel_metric_names.append(f"avg{x}") 57 | excel_metric_names.append(f"adp{x}") 58 | else: 59 | config = BINARY_METRIC_MAPPING[x] 60 | if config["kwargs"]["with_dynamic"]: 61 | excel_metric_names.append(f"max{x}") 62 | excel_metric_names.append(f"avg{x}") 63 | if config["kwargs"]["with_adaptive"]: 64 | excel_metric_names.append(f"adp{x}") 65 | else: 66 | excel_metric_names.append(x) 67 | 68 | self.excel_recorder = MetricExcelRecorder( 69 | xlsx_path=xlsx_path, 70 | sheet_name=sheet_name, 71 | row_header=["methods"], 72 | dataset_names=dataset_names, 73 | metric_names=excel_metric_names, 74 | ) 75 | 76 | def record(self, method_results, dataset_name, method_name): 77 | """Record results""" 78 | method_curves = method_results.get("sequential") 79 | method_metrics = method_results["numerical"] 80 | self.curves[dataset_name][method_name] = method_curves 81 | self.metrics[dataset_name][method_name] = method_metrics 82 | 83 | def export(self): 84 | """After evaluating all methods, export results to ensure the order of names.""" 85 | for dataset_name in self.dataset_names: 86 | if dataset_name not in self.metrics: 87 | continue 88 | 89 | for method_name in self.method_names: 90 | dataset_results = self.metrics[dataset_name] 91 | method_results = dataset_results.get(method_name) 92 | if method_results is None: 93 | continue 94 | 95 | if self.txt_recorder: 96 | self.txt_recorder.add_row(row_name="Dataset", row_data=dataset_name) 97 | self.txt_recorder(method_results=method_results, method_name=method_name) 98 | if self.excel_recorder: 99 | self.excel_recorder( 100 | row_data=method_results, dataset_name=dataset_name, method_name=method_name 101 | ) 102 | 103 | 104 | def cal_metrics( 105 | sheet_name: str = "results", 106 | txt_path: str = "", 107 | to_append: bool = True, 108 | xlsx_path: str = "", 109 | methods_info: dict = None, 110 | datasets_info: dict = None, 111 | curves_npy_path: str = "./curves.npy", 112 | metrics_npy_path: str = "./metrics.npy", 113 | num_bits: int = 3, 114 | num_workers: int = 2, 115 | metric_names: tuple = ("sm", "wfm", "mae", "fmeasure", "em"), 116 | ): 117 | """Save the results of all models on different datasets in a `npy` file in the form of a 118 | dictionary. 119 | 120 | Args: 121 | sheet_name (str, optional): The type of the sheet in xlsx file. Defaults to "results". 122 | txt_path (str, optional): The path of the txt for saving results. Defaults to "". 123 | to_append (bool, optional): Whether to append results to the original record. Defaults to True. 124 | xlsx_path (str, optional): The path of the xlsx file for saving results. Defaults to "". 125 | methods_info (dict, optional): The method information. Defaults to None. 126 | datasets_info (dict, optional): The dataset information. Defaults to None. 127 | curves_npy_path (str, optional): The npy file path for saving curve data. Defaults to "./curves.npy". 128 | metrics_npy_path (str, optional): The npy file path for saving metric values. Defaults to "./metrics.npy". 129 | num_bits (int, optional): The number of bits used to format results. Defaults to 3. 130 | num_workers (int, optional): The number of workers of multiprocessing or multithreading. Defaults to 2. 131 | metric_names (tuple, optional): Names of metrics. Defaults to ("sm", "wfm", "mae", "fmeasure", "em"). 132 | 133 | Returns: 134 | { 135 | dataset1:{ 136 | method1:[fm, em, p, r], 137 | method2:[fm, em, p, r], 138 | ..... 139 | }, 140 | dataset2:{ 141 | method1:[fm, em, p, r], 142 | method2:[fm, em, p, r], 143 | ..... 144 | }, 145 | .... 146 | } 147 | 148 | """ 149 | if all([x in BinaryMetricRecorder.suppoted_metrics for x in metric_names]): 150 | metric_class = BinaryMetricRecorder 151 | elif all([x in GrayscaleMetricRecorder.suppoted_metrics for x in metric_names]): 152 | metric_class = GrayscaleMetricRecorder 153 | else: 154 | raise ValueError(metric_names) 155 | 156 | method_names = tuple(methods_info.keys()) 157 | dataset_names = tuple(datasets_info.keys()) 158 | recorder = Recorder( 159 | method_names=method_names, 160 | dataset_names=dataset_names, 161 | metric_names=metric_names, 162 | txt_path=txt_path, 163 | to_append=to_append, 164 | xlsx_path=xlsx_path, 165 | sheet_name=sheet_name, 166 | ) 167 | 168 | tqdm.set_lock(TRLock()) 169 | procs = pool.ThreadPool( 170 | processes=num_workers, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),) 171 | ) 172 | print(f"Create a {procs}).") 173 | 174 | for dataset_name, dataset_path in datasets_info.items(): 175 | # 获取真值图片信息 176 | gt_info = dataset_path["mask"] 177 | gt_root = gt_info["path"] 178 | gt_prefix = gt_info.get("prefix", "") 179 | gt_suffix = gt_info["suffix"] 180 | # 真值名字列表 181 | gt_index_file = dataset_path.get("index_file") 182 | if gt_index_file: 183 | gt_name_list = get_name_list( 184 | data_path=gt_index_file, name_prefix=gt_prefix, name_suffix=gt_suffix 185 | ) 186 | else: 187 | gt_name_list = get_name_list( 188 | data_path=gt_root, name_prefix=gt_prefix, name_suffix=gt_suffix 189 | ) 190 | gt_info_pair = (gt_root, gt_prefix, gt_suffix) 191 | assert len(gt_name_list) > 0, "there is not ground truth." 192 | 193 | # ==>> test the intersection between pre and gt for each method <<== 194 | for method_name, method_info in methods_info.items(): 195 | method_root = method_info["path_dict"] 196 | method_dataset_info = method_root.get(dataset_name, None) 197 | if method_dataset_info is None: 198 | tqdm.write(f"{method_name} does not have results on {dataset_name}") 199 | continue 200 | 201 | # 预测结果存放路径下的图片文件名字列表和扩展名称 202 | pre_prefix = method_dataset_info.get("prefix", "") 203 | pre_suffix = method_dataset_info["suffix"] 204 | pre_root = method_dataset_info["path"] 205 | pre_name_list = get_name_list( 206 | data_path=pre_root, name_prefix=pre_prefix, name_suffix=pre_suffix 207 | ) 208 | pre_info_pair = (pre_root, pre_prefix, pre_suffix) 209 | 210 | # get the intersection 211 | eval_name_list = sorted(set(gt_name_list).intersection(pre_name_list)) 212 | if len(eval_name_list) == 0: 213 | tqdm.write(f"{method_name} does not have results on {dataset_name}") 214 | continue 215 | 216 | desc = f"[{dataset_name}({len(gt_name_list)}):{method_name}({len(pre_name_list)})]" 217 | kwargs = dict( 218 | names=eval_name_list, 219 | num_bits=num_bits, 220 | pre_info_pair=pre_info_pair, 221 | gt_info_pair=gt_info_pair, 222 | metric_names=metric_names, 223 | metric_class=metric_class, 224 | desc=desc, 225 | ) 226 | callback = partial(recorder.record, dataset_name=dataset_name, method_name=method_name) 227 | procs.apply_async(func=evaluate, kwds=kwargs, callback=callback) 228 | # print(" -------------------- [DEBUG] -------------------- ") 229 | # callback(evaluate(**kwargs), dataset_name=dataset_name, method_name=method_name) 230 | procs.close() 231 | procs.join() 232 | 233 | recorder.export() 234 | if curves_npy_path: 235 | make_dir(os.path.dirname(curves_npy_path)) 236 | np.save(curves_npy_path, recorder.curves) 237 | tqdm.write(f"All curves has been saved in {curves_npy_path}") 238 | if metrics_npy_path: 239 | make_dir(os.path.dirname(metrics_npy_path)) 240 | np.save(metrics_npy_path, recorder.metrics) 241 | tqdm.write(f"All metrics has been saved in {metrics_npy_path}") 242 | formatted_string = formatter_for_tabulate(recorder.metrics, method_names, dataset_names) 243 | tqdm.write(f"All methods have been evaluated:\n{formatted_string}") 244 | 245 | 246 | def evaluate(names, num_bits, pre_info_pair, gt_info_pair, metric_class, metric_names, desc=""): 247 | metric_recoder = metric_class(metric_names=metric_names) 248 | # https://github.com/tqdm/tqdm#parameters 249 | # https://github.com/tqdm/tqdm/blob/master/examples/parallel_bars.py 250 | for name in tqdm(names, total=len(names), desc=desc, ncols=79, lock_args=(False,)): 251 | gt, pre = get_gt_pre_with_name( 252 | img_name=name, 253 | pre_root=pre_info_pair[0], 254 | pre_prefix=pre_info_pair[1], 255 | pre_suffix=pre_info_pair[2], 256 | gt_root=gt_info_pair[0], 257 | gt_prefix=gt_info_pair[1], 258 | gt_suffix=gt_info_pair[2], 259 | to_normalize=False, 260 | ) 261 | metric_recoder.step(pre=pre, gt=gt, gt_path=os.path.join(gt_info_pair[0], name)) 262 | 263 | method_results = metric_recoder.show(num_bits=num_bits, return_ndarray=False) 264 | return method_results 265 | -------------------------------------------------------------------------------- /metrics/video_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from collections import defaultdict 5 | from functools import partial 6 | from multiprocessing import pool 7 | from threading import RLock as TRLock 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from utils.misc import ( 13 | get_gt_pre_with_name_and_group, 14 | get_name_with_group_list, 15 | make_dir, 16 | ) 17 | from utils.print_formatter import formatter_for_tabulate 18 | from utils.recorders import ( 19 | BINARY_METRIC_MAPPING, 20 | GRAYSCALE_METRICS, 21 | GroupedMetricRecorder, 22 | MetricExcelRecorder, 23 | TxtRecorder, 24 | ) 25 | 26 | 27 | class Recorder: 28 | def __init__( 29 | self, 30 | method_names, 31 | dataset_names, 32 | metric_names, 33 | *, 34 | txt_path, 35 | to_append, 36 | xlsx_path, 37 | sheet_name, 38 | ): 39 | self.curves = defaultdict(dict) # Two curve metrics 40 | self.metrics = defaultdict(dict) # Six numerical metrics 41 | self.method_names = method_names 42 | self.dataset_names = dataset_names 43 | 44 | self.txt_recorder = None 45 | if txt_path: 46 | self.txt_recorder = TxtRecorder( 47 | txt_path=txt_path, 48 | to_append=to_append, 49 | max_method_name_width=max([len(x) for x in method_names]), # 显示完整名字 50 | ) 51 | 52 | self.excel_recorder = None 53 | if xlsx_path: 54 | excel_metric_names = [] 55 | for x in metric_names: 56 | if x in GRAYSCALE_METRICS: 57 | if x == "em": 58 | excel_metric_names.append(f"max{x}") 59 | excel_metric_names.append(f"avg{x}") 60 | excel_metric_names.append(f"adp{x}") 61 | else: 62 | config = BINARY_METRIC_MAPPING[x] 63 | if config["kwargs"]["with_dynamic"]: 64 | excel_metric_names.append(f"max{x}") 65 | excel_metric_names.append(f"avg{x}") 66 | if config["kwargs"]["with_adaptive"]: 67 | excel_metric_names.append(f"adp{x}") 68 | else: 69 | excel_metric_names.append(x) 70 | 71 | self.excel_recorder = MetricExcelRecorder( 72 | xlsx_path=xlsx_path, 73 | sheet_name=sheet_name, 74 | row_header=["methods"], 75 | dataset_names=dataset_names, 76 | metric_names=excel_metric_names, 77 | ) 78 | 79 | def record(self, method_results, dataset_name, method_name): 80 | """Record results""" 81 | method_curves = method_results.get("sequential") 82 | method_metrics = method_results["numerical"] 83 | self.curves[dataset_name][method_name] = method_curves 84 | self.metrics[dataset_name][method_name] = method_metrics 85 | 86 | def export(self): 87 | """After evaluating all methods, export results to ensure the order of names.""" 88 | for dataset_name in self.dataset_names: 89 | if dataset_name not in self.metrics: 90 | continue 91 | 92 | for method_name in self.method_names: 93 | dataset_results = self.metrics[dataset_name] 94 | method_results = dataset_results.get(method_name) 95 | if method_results is None: 96 | continue 97 | 98 | if self.txt_recorder: 99 | self.txt_recorder.add_row(row_name="Dataset", row_data=dataset_name) 100 | self.txt_recorder(method_results=method_results, method_name=method_name) 101 | if self.excel_recorder: 102 | self.excel_recorder( 103 | row_data=method_results, dataset_name=dataset_name, method_name=method_name 104 | ) 105 | 106 | 107 | def cal_metrics( 108 | sheet_name: str = "results", 109 | txt_path: str = "", 110 | to_append: bool = True, 111 | xlsx_path: str = "", 112 | methods_info: dict = None, 113 | datasets_info: dict = None, 114 | curves_npy_path: str = "./curves.npy", 115 | metrics_npy_path: str = "./metrics.npy", 116 | num_bits: int = 3, 117 | num_workers: int = 2, 118 | metric_names: tuple = ("sm", "wfm", "mae", "avgdice", "avgiou", "adpe", "avge", "maxe"), 119 | return_group: bool = False, 120 | start_idx: int = 1, 121 | end_idx: int = -1, 122 | ): 123 | """Save the results of all models on different datasets in a `npy` file in the form of a 124 | dictionary. 125 | 126 | Args: 127 | sheet_name (str, optional): The type of the sheet in xlsx file. Defaults to "results". 128 | txt_path (str, optional): The path of the txt for saving results. Defaults to "". 129 | to_append (bool, optional): Whether to append results to the original record. Defaults to True. 130 | xlsx_path (str, optional): The path of the xlsx file for saving results. Defaults to "". 131 | methods_info (dict, optional): The method information. Defaults to None. 132 | datasets_info (dict, optional): The dataset information. Defaults to None. 133 | curves_npy_path (str, optional): The npy file path for saving curve data. Defaults to "./curves.npy". 134 | metrics_npy_path (str, optional): The npy file path for saving metric values. Defaults to "./metrics.npy". 135 | num_bits (int, optional): The number of bits used to format results. Defaults to 3. 136 | num_workers (int, optional): The number of workers of multiprocessing or multithreading. Defaults to 2. 137 | metric_names (tuple, optional): Names of metrics. Defaults to ("sm", "wfm", "em", "mae", "dice", "iou"). 138 | return_group (bool, optional): Whether to return the grouped results. Defaults to False. 139 | start_idx (int, optional): The index of the first frame in each gt sequence. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames. 140 | end_idx (int, optional): The index of the last frame in each gt sequence. Defaults to -1, it will skip the last frame. If it is set to None, the code will not skip frames. 141 | 142 | Returns: 143 | { 144 | dataset1:{ 145 | method1:[fm, em, p, r], 146 | method2:[fm, em, p, r], 147 | ..... 148 | }, 149 | dataset2:{ 150 | method1:[fm, em, p, r], 151 | method2:[fm, em, p, r], 152 | ..... 153 | }, 154 | .... 155 | } 156 | 157 | """ 158 | metric_class = GroupedMetricRecorder 159 | 160 | method_names = tuple(methods_info.keys()) 161 | dataset_names = tuple(datasets_info.keys()) 162 | recorder = Recorder( 163 | method_names=method_names, 164 | dataset_names=dataset_names, 165 | metric_names=metric_names, 166 | txt_path=txt_path, 167 | to_append=to_append, 168 | xlsx_path=xlsx_path, 169 | sheet_name=sheet_name, 170 | ) 171 | 172 | tqdm.set_lock(TRLock()) 173 | procs = pool.ThreadPool(processes=num_workers, initializer=tqdm.set_lock, 174 | initargs=(tqdm.get_lock(),)) 175 | print(f"Create a {procs}).") 176 | name_sep = "" 177 | 178 | for dataset_name, dataset_path in datasets_info.items(): 179 | # 获取真值图片信息 180 | gt_info = dataset_path["mask"] 181 | gt_root = gt_info["path"] 182 | gt_prefix = gt_info.get("prefix", "") 183 | gt_suffix = gt_info["suffix"] 184 | # 真值名字列表 185 | gt_index_file = dataset_path.get("index_file") 186 | if gt_index_file: 187 | gt_name_list = get_name_with_group_list( 188 | data_path=gt_index_file, name_prefix=gt_prefix, name_suffix=gt_suffix, sep=name_sep 189 | ) 190 | else: 191 | gt_name_list = get_name_with_group_list( 192 | data_path=gt_root, 193 | name_prefix=gt_prefix, 194 | name_suffix=gt_suffix, 195 | start_idx=start_idx, 196 | end_idx=end_idx, 197 | sep=name_sep, 198 | ) 199 | gt_info_pair = (gt_root, gt_prefix, gt_suffix) 200 | assert len(gt_name_list) > 0, f"there is not ground truth in {dataset_path}." 201 | 202 | # ==>> test the intersection between pre and gt for each method <<== 203 | for method_name, method_info in methods_info.items(): 204 | method_root = method_info["path_dict"] 205 | method_dataset_info = method_root.get(dataset_name, None) 206 | if method_dataset_info is None: 207 | tqdm.write(f"{method_name} does not have results on {dataset_name}") 208 | continue 209 | 210 | # 预测结果存放路径下的图片文件名字列表和扩展名称 211 | pre_prefix = method_dataset_info.get("prefix", "") 212 | pre_suffix = method_dataset_info["suffix"] 213 | pre_root = method_dataset_info["path"] 214 | pre_name_list = get_name_with_group_list( 215 | data_path=pre_root, name_prefix=pre_prefix, name_suffix=pre_suffix, sep=name_sep 216 | ) 217 | pre_info_pair = (pre_root, pre_prefix, pre_suffix) 218 | 219 | # get the intersection 220 | eval_name_list = sorted(set(gt_name_list).intersection(pre_name_list)) 221 | if len(eval_name_list) == 0: 222 | tqdm.write(f"{method_name} does not have results on {dataset_name}") 223 | continue 224 | 225 | desc = f"[{dataset_name}({len(gt_name_list)}):{method_name}({len(pre_name_list)}->{len(eval_name_list)})]" 226 | kwargs = dict( 227 | names=eval_name_list, 228 | num_bits=num_bits, 229 | pre_info_pair=pre_info_pair, 230 | gt_info_pair=gt_info_pair, 231 | metric_names=metric_names, 232 | metric_class=metric_class, 233 | return_group=return_group, 234 | sep=name_sep, 235 | desc=desc, 236 | ) 237 | callback = partial(recorder.record, dataset_name=dataset_name, method_name=method_name) 238 | procs.apply_async(func=evaluate, kwds=kwargs, callback=callback) 239 | # print(" -------------------- [DEBUG] -------------------- ") 240 | # callback(evaluate(**kwargs), dataset_name=dataset_name, method_name=method_name) 241 | procs.close() 242 | procs.join() 243 | 244 | recorder.export() 245 | if curves_npy_path: 246 | make_dir(os.path.dirname(curves_npy_path)) 247 | np.save(curves_npy_path, recorder.curves) 248 | tqdm.write(f"All curves has been saved in {curves_npy_path}") 249 | if metrics_npy_path: 250 | make_dir(os.path.dirname(metrics_npy_path)) 251 | np.save(metrics_npy_path, recorder.metrics) 252 | tqdm.write(f"All metrics has been saved in {metrics_npy_path}") 253 | formatted_string = formatter_for_tabulate(recorder.metrics, method_names, dataset_names) 254 | tqdm.write(f"All methods have been evaluated:\n{formatted_string}") 255 | 256 | 257 | def evaluate( 258 | names, 259 | num_bits, 260 | pre_info_pair, 261 | gt_info_pair, 262 | metric_class, 263 | metric_names, 264 | return_group=False, 265 | sep="", 266 | desc="", 267 | ): 268 | group_names = sorted(set([n.split(sep)[0] for n in names])) 269 | metric_recoder = metric_class(group_names=group_names, metric_names=metric_names) 270 | # https://github.com/tqdm/tqdm#parameters 271 | # https://github.com/tqdm/tqdm/blob/master/examples/parallel_bars.py 272 | tqdm_bar = tqdm(names, total=len(names), desc=desc, ncols=78, lock_args=(False,)) 273 | for name in tqdm_bar: 274 | group_name = name.split(sep)[0] 275 | gt, pre = get_gt_pre_with_name_and_group( 276 | img_name=name, 277 | pre_root=pre_info_pair[0], 278 | pre_prefix=pre_info_pair[1], 279 | pre_suffix=pre_info_pair[2], 280 | gt_root=gt_info_pair[0], 281 | gt_prefix=gt_info_pair[1], 282 | gt_suffix=gt_info_pair[2], 283 | to_normalize=False, 284 | sep=sep, 285 | ) 286 | metric_recoder.step( 287 | group_name=group_name, pre=pre, gt=gt, gt_path=os.path.join(gt_info_pair[0], name) 288 | ) 289 | 290 | # TODO: 打印的形式有待进一步完善 291 | method_results = metric_recoder.show(num_bits=num_bits, return_group=return_group) 292 | return method_results 293 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import textwrap 4 | 5 | import numpy as np 6 | import yaml 7 | 8 | from metrics import draw_curves 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser( 13 | description=textwrap.dedent( 14 | r""" 15 | INCLUDE: 16 | 17 | - Fm Curve 18 | - PR Curves 19 | 20 | NOTE: 21 | 22 | - Our method automatically calculates the intersection of `pre` and `gt`. 23 | - Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext` 24 | 25 | EXAMPLES: 26 | 27 | python plot.py \ 28 | --curves-npy output/rgbd-othermethods-curves.npy output/rgbd-ours-curves.npy \ # use the information from these npy files to draw curves 29 | --num-rows 1 \ # set the number of rows of the figure to 2 30 | --style-cfg configs/single_row_style.yml \ # specific the configuration file for the style of matplotlib 31 | --num-col-legend 1 \ # set the number of the columns of the legend in the figure to 1 32 | --mode pr \ # draw `pr` curves 33 | --our-methods Ours \ # specific the names of our own methods, they must be contained in the npy file 34 | --save-name ./output/rgbd-pr-curves # save the figure into `./output/rgbd-pr-curves.`, where `ext_name` will be specificed to the item `savefig.format` in the `--style-cfg` 35 | 36 | python plot.py \ 37 | --curves-npy output/rgbd-othermethods-curves.npy output/rgbd-ours-curves.npy \ 38 | --num-rows 2 \ 39 | --style-cfg configs/two_row_style.yml \ 40 | --num-col-legend 2 \ 41 | --separated-legend \ # use a separated legend 42 | --mode fm \ # draw `fm` curves 43 | --our-methods OursV0 OursV1 \ # specific the names of our own methods, they must be contained in the npy file 44 | --save-name output/rgbd-fm \ 45 | --alias-yaml configs/rgbd_aliases.yaml # aliases corresponding to methods and datasets you want to use 46 | """ 47 | ), 48 | formatter_class=argparse.RawTextHelpFormatter, 49 | ) 50 | # fmt: off 51 | parser.add_argument("--alias-yaml", type=str, help="Yaml file for datasets and methods alias.") 52 | parser.add_argument("--style-cfg", type=str, required=True, help="Yaml file for plotting curves.") 53 | parser.add_argument("--curves-npys", required=True, type=str, nargs="+", help="Npy file for saving curve results.") 54 | parser.add_argument("--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it.") 55 | parser.add_argument("--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1") 56 | parser.add_argument("--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1") 57 | parser.add_argument("--mode", type=str, choices=["pr", "fm", "em", "iou", "dice"], default="pr", help="Mode for plotting. Default: pr") 58 | parser.add_argument("--separated-legend", action="store_true", help="Use the separated legend.") 59 | parser.add_argument("--sharey", action="store_true", help="Use the shared y-axis.") 60 | parser.add_argument("--save-name", type=str, help="the exported file path") 61 | # fmt: on 62 | args = parser.parse_args() 63 | 64 | return args 65 | 66 | 67 | def main(args): 68 | method_aliases = dataset_aliases = None 69 | if args.alias_yaml: 70 | with open(args.alias_yaml, mode="r", encoding="utf-8") as f: 71 | aliases = yaml.safe_load(f) 72 | method_aliases = aliases.get("method") 73 | dataset_aliases = aliases.get("dataset") 74 | 75 | # TODO: Better method to set axes_setting 76 | axes_setting = { 77 | # pr curve 78 | "pr": { 79 | "x_label": "Recall", 80 | "y_label": "Precision", 81 | "x_ticks": np.linspace(0.5, 1, 6), 82 | "y_ticks": np.linspace(0.7, 1, 6), 83 | }, 84 | # fm curve 85 | "fm": { 86 | "x_label": "Threshold", 87 | "y_label": r"F$_{\beta}$", 88 | "x_ticks": np.linspace(0, 1, 6), 89 | "y_ticks": np.linspace(0.6, 1, 6), 90 | }, 91 | # em curve 92 | "em": { 93 | "x_label": "Threshold", 94 | "y_label": r"E$_{m}$", 95 | "x_ticks": np.linspace(0, 1, 6), 96 | "y_ticks": np.linspace(0.7, 1, 6), 97 | }, 98 | # iou curve 99 | "iou": { 100 | "x_label": "Threshold", 101 | "y_label": "IoU", 102 | "x_ticks": np.linspace(0, 1, 6), 103 | "y_ticks": np.linspace(0.4, 1, 6), 104 | }, 105 | # dice curve 106 | "dice": { 107 | "x_label": "Threshold", 108 | "y_label": "Dice", 109 | "x_ticks": np.linspace(0, 1, 6), 110 | "y_ticks": np.linspace(0.4, 1, 6), 111 | }, 112 | } 113 | 114 | draw_curves.draw_curves( 115 | mode=args.mode, 116 | axes_setting=axes_setting, 117 | curves_npy_path=args.curves_npys, 118 | row_num=args.num_rows, 119 | method_aliases=method_aliases, 120 | dataset_aliases=dataset_aliases, 121 | style_cfg=args.style_cfg, 122 | ncol_of_legend=args.num_col_legend, 123 | separated_legend=args.separated_legend, 124 | sharey=args.sharey, 125 | our_methods=args.our_methods, 126 | save_name=args.save_name, 127 | ) 128 | 129 | 130 | if __name__ == "__main__": 131 | args = get_args() 132 | main(args) 133 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # Exclude a variety of commonly ignored directories. 3 | exclude = [ 4 | ".bzr", 5 | ".direnv", 6 | ".eggs", 7 | ".git", 8 | ".git-rewrite", 9 | ".hg", 10 | ".ipynb_checkpoints", 11 | ".mypy_cache", 12 | ".nox", 13 | ".pants.d", 14 | ".pyenv", 15 | ".pytest_cache", 16 | ".pytype", 17 | ".ruff_cache", 18 | ".svn", 19 | ".tox", 20 | ".venv", 21 | ".vscode", 22 | "__pypackages__", 23 | "_build", 24 | "buck-out", 25 | "build", 26 | "dist", 27 | "node_modules", 28 | "site-packages", 29 | "venv", 30 | ] 31 | 32 | # Same as Black. 33 | line-length = 99 34 | indent-width = 4 35 | 36 | [tool.ruff.lint] 37 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 38 | select = ["E4", "E7", "E9", "F"] 39 | ignore = [] 40 | 41 | # Allow fix for all enabled rules (when `--fix`) is provided. 42 | fixable = ["ALL"] 43 | unfixable = [] 44 | 45 | # Allow unused variables when underscore-prefixed. 46 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 47 | 48 | [tool.ruff.format] 49 | # Like Black, use double quotes for strings. 50 | quote-style = "double" 51 | 52 | # Like Black, indent with spaces, rather than tabs. 53 | indent-style = "space" 54 | 55 | # Like Black, respect magic trailing commas. 56 | skip-magic-trailing-comma = false 57 | 58 | # Like Black, automatically detect the appropriate line ending. 59 | line-ending = "auto" 60 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # A Python-based image grayscale/binary segmentation evaluation toolbox. 3 | 4 | [中文文档](./readme_zh.md) 5 | 6 | ## TODO 7 | 8 | - More flexible configuration script. 9 | - [ ] Use the yaml file that [meets matplotlib requirements]((https://matplotlib.org/stable/tutorials/introductory/customizing.html#the-default-matplotlibrc-file)) to control the drawing format. 10 | - [ ] Replace the json with a more flexible configuration format, such as yaml or toml. 11 | - [ ] Add test scripts. 12 | - [ ] Add more detailed comments. 13 | - Optimize the code for exporting evaluation results. 14 | - [x] Implement code to export results to XLSX files. 15 | - [ ] Optimize the code for exporting to XLSX files. 16 | - [ ] Consider if using a text format like CSV would be better? It can be opened as a text file and also organized using Excel. 17 | - [ ] Replace `os.path` with `pathlib.Path`. 18 | - [x] Improve the code for grouping data, supporting tasks like CoSOD, Video Binary Segmentation, etc. 19 | - [x] Support concurrency strategy to speed up computation. Retained support for multi-threading, removed the previous multi-process code. 20 | - [ ] Currently, due to the use of multi-threading, there is an issue with extra log information being written, which needs more optimization. 21 | - [X] Separate USVOS code into another repository [PyDavis16EvalToolbox](https://github.com/lartpang/PyDavis16EvalToolbox). 22 | - [X] Use more rapid and accurate metric code [PySODMetrics](https://github.com/lartpang/PySODMetrics) as the evaluation benchmark. 23 | 24 | > [!tip] 25 | > - Some methods provide result names that do not match the original dataset's ground truth names. 26 | > - [Note] (2021-11-18) Currently, support is provided for both `prefix` and `suffix` names, so users generally do not need to change the names themselves. 27 | > - [Optional] The provided script tools/rename.py can be used to rename files in bulk. *Please use it carefully to avoid data overwriting.* 28 | > - [Optional] Other tools, such as `rename` on Linux, and [`Microsoft PowerToys`](https://github.com/microsoft/PowerToys) on Windows. 29 | 30 | ## Features 31 | 32 | - Benefiting from PySODMetrics, it supports a richer set of metrics. For more details, see `utils/recorders/metric_recorder.py`. 33 | - Supports evaluating *grayscale images*, such as predictions from saliency object detection (SOD) and camouflaged object detection (COD) tasks. 34 | - MAE 35 | - Emeasure 36 | - Smeasure 37 | - Weighted Fmeasure 38 | - Maximum/Average/Adaptive Fmeasure 39 | - Maximum/Average/Adaptive Precision 40 | - Maximum/Average/Adaptive Recall 41 | - Maximum/Average/Adaptive IoU 42 | - Maximum/Average/Adaptive Dice 43 | - Maximum/Average/Adaptive Specificity 44 | - Maximum/Average/Adaptive BER 45 | - Fmeasure-Threshold Curve (run `eval.py` with the metric `fmeasure`) 46 | - Emeasure-Threshold Curve (run `eval.py` with the metric `em`) 47 | - Precision-Recall Curve (run `eval.py` with the metrics `precision` and `recall`, this is different from previous versions as the calculation of `precision` and `recall` has been separated from `fmeasure`) 48 | - Supports evaluating *binary images*, such as common binary segmentation tasks. 49 | - Binary Fmeasure 50 | - Binary Precision 51 | - Binary Recall 52 | - Binary IoU 53 | - Binary Dice 54 | - Binary Specificity 55 | - Binary BER 56 | - Richer functions. 57 | - Supports evaluating models according to the configuration. 58 | - Supports drawing `PR curves`, `F-measure curves` and `E-measure curves` based on configuration and evaluation results. 59 | - Supports exporting results to TXT files. 60 | - Supports exporting results to XLSX files (re-supported on January 4, 2021). 61 | - Supports exporting LaTeX table code from generated `.npy` files, and marks the top three methods with different colors. 62 | - … :>. 63 | 64 | ## How to Use 65 | 66 | ### Installing Dependencies 67 | 68 | Install the required libraries: `pip install -r requirements.txt` 69 | 70 | The metric evaluation is based on my another project: [PySODMetrics](https://github.com/lartpang/PySODMetrics). Bug reports are welcome! 71 | 72 | ### Configuring Paths for Datasets and Method Predictions 73 | 74 | This project relies on json files to store data. Examples for dataset and method configurations are provided in `./examples`: `config_dataset_json_example.json` and `config_method_json_example.json`. You can directly modify them for subsequent steps. 75 | 76 | > [!note] 77 | > - Please note that since this project relies on OpenCV to read images, ensure that the path strings do not contain non-ASCII characters. 78 | > - Make sure that *the name of the dataset in the dataset configuration file* matches *the name of the dataset in the method configuration file*. After preparing the json files, it is recommended to use the provided `tools/check_path.py` to check if the path information in the json files is correct. 79 | 80 |
81 | 82 | More Details on Configuration 83 | 84 | 85 | Example 1: Dataset Configuration 86 | 87 | Note, "image" is not necessary here. The actual evaluation only reads "mask". 88 | 89 | ```json 90 | { 91 | "LFSD": { 92 | "image": { 93 | "path": "Path_Of_RGBDSOD_Datasets/LFSD/Image", 94 | "prefix": "some_gt_prefix", 95 | "suffix": ".jpg" 96 | }, 97 | "mask": { 98 | "path": "Path_Of_RGBDSOD_Datasets/LFSD/Mask", 99 | "prefix": "some_gt_prefix", 100 | "suffix": ".png" 101 | } 102 | } 103 | } 104 | ``` 105 | 106 | Example 2: Method Configuration 107 | 108 | ```json 109 | { 110 | "Method1": { 111 | "PASCAL-S": { 112 | "path": "Path_Of_Method1/PASCAL-S", 113 | "prefix": "some_method_prefix", 114 | "suffix": ".png" 115 | }, 116 | "ECSSD": { 117 | "path": "Path_Of_Method1/ECSSD", 118 | "prefix": "some_method_prefix", 119 | "suffix": ".png" 120 | }, 121 | "HKU-IS": { 122 | "path": "Path_Of_Method1/HKU-IS", 123 | "prefix": "some_method_prefix", 124 | "suffix": ".png" 125 | }, 126 | "DUT-OMRON": { 127 | "path": "Path_Of_Method1/DUT-OMRON", 128 | "prefix": "some_method_prefix", 129 | "suffix": ".png" 130 | }, 131 | "DUTS-TE": { 132 | "path": "Path_Of_Method1/DUTS-TE", 133 | "suffix": ".png" 134 | } 135 | } 136 | } 137 | ``` 138 | 139 | Here, `path` represents the directory where image data is stored. `prefix` and `suffix` refer to the prefix and suffix *outside the common part in the names* of the predicted images and the actual ground truth images. 140 | 141 | During the evaluation process, the matching of method predictions and dataset ground truths is based on the shared part of the file names. Their naming patterns are preset as `[prefix]+[shared-string]+[suffix]`. For example, if there are predicted images like `method1_00001.jpg`, `method1_00002.jpg`, `method1_00003.jpg` and ground truth images `gt_00001.png`, `gt_00002.png`, `gt_00003.png`, then we can configure it as follows: 142 | 143 | Example 3: Dataset Configuration 144 | 145 | ```json 146 | { 147 | "dataset1": { 148 | "mask": { 149 | "path": "path/Mask", 150 | "prefix": "gt_", 151 | "suffix": ".png" 152 | } 153 | } 154 | } 155 | ``` 156 | 157 | Example 4: Method Configuration 158 | 159 | ```json 160 | { 161 | "method1": { 162 | "dataset1": { 163 | "path": "path/dataset1", 164 | "prefix": "method1_", 165 | "suffix": ".jpg" 166 | } 167 | } 168 | } 169 | ``` 170 | 171 |
172 | 173 | ### Running the Evaluation 174 | 175 | - Once all the previous steps are correctly completed, you can begin the evaluation. For usage of the evaluation script, refer to the output of the command `python eval.py --help`. 176 | - Add configuration options according to your needs and execute the command. If there are no exceptions, it will generate result files with the specified filename. 177 | - If not all files are specified, it will directly output the results, as detailed in the help information of `eval.py`. 178 | - If `--curves-npy` is specified, the metrics information related to drawing will be saved in the corresponding `.npy` file. 179 | - [Optional] You can use `tools/converter.py` to directly export the LaTeX table code from the generated npy files. 180 | 181 | ### Plotting Curves for Grayscale Image Evaluation 182 | 183 | You can use `plot.py` to read the `.npy` file to organize and draw `PR`, `F-measure`, and `E-measure` curves for specified methods and datasets as needed. The usage of this script can be seen in the output of `python plot.py --help`. Add configuration items as per your requirement and execute the command. 184 | 185 | The most basic instruction is to specify the values in the `figure.figsize` item in the configuration file according to the number of subplots reasonably. 186 | 187 | ### A Basic Execution Process 188 | 189 | Here I'll use the RGB SOD configuration in my local configs folder as an example (necessary modifications should be made according to the actual situation). 190 | 191 | ```shell 192 | # Check Configuration Files 193 | python tools/check_path.py --method-jsons configs/methods/rgb-sod/rgb_sod_methods.json --dataset-jsons configs/datasets/rgb_sod.json 194 | 195 | # After ensuring there's nothing unreasonable in the output information, you can begin the evaluation with the following commands: 196 | # --dataset-json: Set `configs/datasets/rgb_sod.json` as dataset configuration file 197 | # --method-json: Set `configs/methods/rgb-sod/rgb_sod_methods.json` as method configuration file 198 | # --metric-npy: Set `output/rgb_sod/metrics.npy` to store the metrics information in npy format 199 | # --curves-npy: Set `output/rgb_sod/curves.npy` to store the curves information in npy format 200 | # --record-txt: Set `output/rgb_sod/results.txt` to store the results information in text format 201 | # --record-xlsx: Set `output/rgb_sod/results.xlsx` to store the results information in Excel format 202 | # --metric-names: Specify `fmeasure em precision recall` as the metrics to be calculated 203 | # --include-methods: Specify the methods from `configs/methods/rgb-sod/rgb_sod_methods.json` to be evaluated 204 | # --include-datasets: Specify the datasets from `configs/datasets/rgb_sod.json` to be evaluated 205 | python eval.py --dataset-json configs/datasets/rgb_sod.json --method-json configs/methods/rgb-sod/rgb_sod_methods.json --metric-npy output/rgb_sod/metrics.npy --curves-npy output/rgb_sod/curves.npy --record-txt output/rgb_sod/results.txt --record-xlsx output/rgb_sod/results.xlsx --metric-names sm wfm mae fmeasure em precision recall --include-methods MINet_R50_2020 GateNet_2020 --include-datasets PASCAL-S ECSSD 206 | 207 | # Once you've obtained the curve data file, which in this case is the 'output/rgb_sod/curves.npy' file, you can start drawing the plot. 208 | 209 | # For a simple example, after executing the command below, the result will be saved as 'output/rgb_sod/simple_curve_pr.pdf': 210 | # --style-cfg: Specify the style configuration file `examples/single_row_style.yml`,Since there are only a few subplots, you can directly use a single-row configuration. 211 | # --num-rows: The number of subplots in the figure. 212 | # --curves-npys: Use the curve data file `output/rgb_sod/curves.npy` to draw the plot. 213 | # --mode: Use `pr` to draw the `pr` curve, `em` to draw the `E-measure` curve, and `fm` to draw the `F-measure` curve. 214 | # --save-name: Just provide the image save path without the file extension; the code will append the file extension as specified by the `savefig.format` in the `--style-cfg` you designated earlier. 215 | # --alias-yaml: A yaml file that specifies the method and dataset aliases to be used in the plot. 216 | python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-npys output/rgb_sod/curves.npy --mode pr --save-name output/rgb_sod/simple_curve_pr --alias-yaml configs/rgb_aliases.yaml 217 | 218 | # More complex examples, after executing the command below, the result will be saved as 'output/rgb_sod/complex_curve_pr.pdf'. 219 | 220 | # --style-cfg: Specify the style configuration file `examples/single_row_style.yml`,Since there are only a few subplots, you can directly use a single-row configuration. 221 | # --num-rows: The number of subplots in the figure. 222 | # --curves-npys: Use the curve data file `output/rgb_sod/curves.npy` to draw the plot. 223 | # --our-methods: The specified method, `MINet_R50_2020`, is highlighted with a bold red solid line in the plot. 224 | # --num-col-legend: The number of columns in the legend. 225 | # --mode: Use `pr` to draw the `pr` curve, `em` to draw the `E-measure` curve, and `fm` to draw the `F-measure` curve. 226 | # --separated-legend: Draw a shared single legend. 227 | # --sharey: Share the y-axis, which will only display the scale value on the first graph in each row. 228 | # --save-name: Just provide the image save path without the file extension; the code will append the file extension as specified by the `savefig.format` in the `--style-cfg` you designated earlier. 229 | python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-npys output/rgb_sod/curves.npy --our-methods MINet_R50_2020 --num-col-legend 1 --mode pr --separated-legend --sharey --save-name output/rgb_sod/complex_curve_pr 230 | ``` 231 | 232 | ## Corresponding Results 233 | 234 | **Precision-Recall Curve**: 235 | 236 | ![PRCurves](https://user-images.githubusercontent.com/26847524/227249768-a41ef076-6355-4b96-a291-fc0e071d9d35.jpg) 237 | 238 | **F-measure Curve**: 239 | 240 | ![fm-curves](https://user-images.githubusercontent.com/26847524/227249746-f61d7540-bb73-464d-bccf-9a36323dec47.jpg) 241 | 242 | **E-measure Curve**: 243 | 244 | ![em-curves](https://user-images.githubusercontent.com/26847524/227249727-8323d5cf-ddd7-427b-8152-b8f47781c4e3.jpg) 245 | 246 | ## Programming Reference 247 | 248 | * `openpyxl` library: 249 | * `re` module: 250 | 251 | ## Relevant Literature 252 | 253 | ```text 254 | @inproceedings{Fmeasure, 255 | title={Frequency-tuned salient region detection}, 256 | author={Achanta, Radhakrishna and Hemami, Sheila and Estrada, Francisco and S{\"u}sstrunk, Sabine}, 257 | booktitle=CVPR, 258 | number={CONF}, 259 | pages={1597--1604}, 260 | year={2009} 261 | } 262 | 263 | @inproceedings{MAE, 264 | title={Saliency filters: Contrast based filtering for salient region detection}, 265 | author={Perazzi, Federico and Kr{\"a}henb{\"u}hl, Philipp and Pritch, Yael and Hornung, Alexander}, 266 | booktitle=CVPR, 267 | pages={733--740}, 268 | year={2012} 269 | } 270 | 271 | @inproceedings{Smeasure, 272 | title={Structure-measure: A new way to eval foreground maps}, 273 | author={Fan, Deng-Ping and Cheng, Ming-Ming and Liu, Yun and Li, Tao and Borji, Ali}, 274 | booktitle=ICCV, 275 | pages={4548--4557}, 276 | year={2017} 277 | } 278 | 279 | @inproceedings{Emeasure, 280 | title="Enhanced-alignment Measure for Binary Foreground Map Evaluation", 281 | author="Deng-Ping {Fan} and Cheng {Gong} and Yang {Cao} and Bo {Ren} and Ming-Ming {Cheng} and Ali {Borji}", 282 | booktitle=IJCAI, 283 | pages="698--704", 284 | year={2018} 285 | } 286 | 287 | @inproceedings{wFmeasure, 288 | title={How to eval foreground maps?}, 289 | author={Margolin, Ran and Zelnik-Manor, Lihi and Tal, Ayellet}, 290 | booktitle=CVPR, 291 | pages={248--255}, 292 | year={2014} 293 | } 294 | ``` 295 | -------------------------------------------------------------------------------- /readme_zh.md: -------------------------------------------------------------------------------- 1 | 2 | # 基于 Python 的图像灰度/二值分割测评工具箱 3 | 4 | ## 一些规划 5 | 6 | - 更灵活的配置脚本. 7 | - [ ] 使用 [符合matplotlib要求的](https://matplotlib.org/stable/tutorials/introductory/customizing.html#the-default-matplotlibrc-file) 的 yaml 文件来控制绘图格式. 8 | - [ ] 使用更加灵活的配置文件格式, 例如 yaml 或者 toml 替换 json. 9 | - [ ] 添加测试脚本. 10 | - [ ] 添加更详细的注释. 11 | - 优化导出评估结果的代码. 12 | - [x] 实现导出结果到 XLSX 文件的代码. 13 | - [ ] 优化导出到 XLSX 文件的代码. 14 | - [ ] 是否应该使用 CSV 这样的文本格式更好些? 既可以当做文本文件打开, 亦可使用 Excel 来进行整理. 15 | - [ ] 使用 `pathlib.Path` 替换 `os.path`. 16 | - [x] 完善关于分组数据的代码, 即 CoSOD、Video Binary Segmentation 等任务的支持. 17 | - [x] 支持并发策略加速计算. 目前保留了多线程支持, 剔除了之前的多进程代码. 18 | - [ ] 目前由于多线程的使用, 存在提示信息额外写入的问题, 有待优化. 19 | - [X] 剥离 USVOS 代码到另一个仓库 [PyDavis16EvalToolbox](https://github.com/lartpang/PyDavis16EvalToolbox). 20 | - [X] 使用更加快速和准确的指标代码 [PySODMetrics](https://github.com/lartpang/PySODMetrics) 作为评估基准. 21 | 22 | > [!tip] 23 | > - 一些方法提供的结果名字预原始数据集真值的名字不一致 24 | > - [注意] (2021-11-18) 当前同时提供了对名称前缀与后缀的支持, 所以基本不用用户自己改名字了. 25 | > - [可选] 可以使用提供的脚本 `tools/rename.py` 来批量修改文件名.**请小心使用, 以避免数据被覆盖.** 26 | > - [可选] 其他的工具: 例如 Linux 上的 `rename`, Windows 上的 [`Microsoft PowerToys`](https://github.com/microsoft/PowerToys) 27 | 28 | ## 特性 29 | 30 | - 受益于 PySODMetrics, 从而获得了更加丰富的指标的支持. 更多细节可见 `utils/recorders/metric_recorder.py`. 31 | - 支持评估*灰度图像*, 例如来自显著性目标检测任务的预测. 32 | - MAE 33 | - Emeasure 34 | - Smeasure 35 | - Weighted Fmeasure 36 | - Maximum/Average/Adaptive Fmeasure 37 | - Maximum/Average/Adaptive Precision 38 | - Maximum/Average/Adaptive Recall 39 | - Maximum/Average/Adaptive IoU 40 | - Maximum/Average/Adaptive Dice 41 | - Maximum/Average/Adaptive Specificity 42 | - Maximum/Average/Adaptive BER 43 | - Fmeasure-Threshold Curve (执行 `eval.py` 请指定指标 `fmeasure`) 44 | - Emeasure-Threshold Curve (执行 `eval.py` 请指定指标 `em`) 45 | - Precision-Recall Curve (执行 `eval.py` 请指定指标 `precision` 和 `recall`,这一点不同于以前的版本,因为 `precision` 和 `recall` 的计算被从 `fmeasure` 中独立出来了) 46 | - 支持评估*二值图像*, 例如常见的二值分割任务. 47 | - Binary Fmeasure 48 | - Binary Precision 49 | - Binary Recall 50 | - Binary IoU 51 | - Binary Dice 52 | - Binary Specificity 53 | - Binary BER 54 | - 更丰富的功能. 55 | - 支持根据配置评估模型. 56 | - 支持根据配置和评估结果绘制 PR 曲线和 F-measure 曲线. 57 | - 支持导出结果到 TXT 文件中. 58 | - 支持导出结果到 XLSX 文件 (2021 年 01 月 04 日重新提供支持). 59 | - 支持从生成的 `.npy` 文件导出 LaTeX 表格代码, 同时支持对最优的前三个方法用不同颜色进行标记. 60 | - … :>. 61 | 62 | ## 使用方法 63 | 64 | ### 安装依赖 65 | 66 | 安装相关依赖库: `pip install -r requirements.txt` . 67 | 68 | 其中指标库是我的另一个项目: [PySODMetrics](https://github.com/lartpang/PySODMetrics), 欢迎捉 BUG! 69 | 70 | ### 配置数据集与方法预测的路径信息 71 | 72 | 本项目依赖于 json 文件存放数据, `./examples` 中已经提供了数据集和方法配置的例子: `config_dataset_json_example.json` 和 `config_method_json_example.json` , 可以至直接修改他们用于后续步骤. 73 | 74 | > [!note] 75 | > - 请注意, 由于本项目依赖于 OpenCV 读取图片, 所以请确保路径字符串不包含非 ASCII 字符. 76 | > - 请务必确保*数据集配置文件中数据集的名字*和*方法配置文件中数据集的名字*一致. 准备好 json 文件后, 建议使用提供的 `tools/check_path.py` 来检查下 json 文件中的路径信息是否正常. 77 | 78 |
79 | 80 | 关于配置的更多细节 81 | 82 | 83 | 例子 1: 数据集配置 84 | 85 | 注意, 这里的 "image" 并不是必要的. 实际评估仅仅读取 "mask". 86 | 87 | ```json 88 | { 89 | "LFSD": { 90 | "image": { 91 | "path": "Path_Of_RGBDSOD_Datasets/LFSD/Image", 92 | "prefix": "some_gt_prefix", 93 | "suffix": ".jpg" 94 | }, 95 | "mask": { 96 | "path": "Path_Of_RGBDSOD_Datasets/LFSD/Mask", 97 | "prefix": "some_gt_prefix", 98 | "suffix": ".png" 99 | } 100 | } 101 | } 102 | ``` 103 | 104 | 例子 2: 方法配置 105 | 106 | ```json 107 | { 108 | "Method1": { 109 | "PASCAL-S": { 110 | "path": "Path_Of_Method1/PASCAL-S", 111 | "prefix": "some_method_prefix", 112 | "suffix": ".png" 113 | }, 114 | "ECSSD": { 115 | "path": "Path_Of_Method1/ECSSD", 116 | "prefix": "some_method_prefix", 117 | "suffix": ".png" 118 | }, 119 | "HKU-IS": { 120 | "path": "Path_Of_Method1/HKU-IS", 121 | "prefix": "some_method_prefix", 122 | "suffix": ".png" 123 | }, 124 | "DUT-OMRON": { 125 | "path": "Path_Of_Method1/DUT-OMRON", 126 | "prefix": "some_method_prefix", 127 | "suffix": ".png" 128 | }, 129 | "DUTS-TE": { 130 | "path": "Path_Of_Method1/DUTS-TE", 131 | "suffix": ".png" 132 | } 133 | } 134 | } 135 | ``` 136 | 137 | 这里 `path` 表示存放图像数据的目录. 而 `prefix` 和 `suffix` 表示实际预测图像和真值图像*名字中除去共有部分外*的前缀预后缀内容. 138 | 139 | 评估过程中, 方法预测和数据集真值匹配的方式是基于文件名字的共有部分. 二者的名字模式预设为 `[prefix]+[shared-string]+[suffix]` . 例如假如有这样的预测图像 `method1_00001.jpg` , `method1_00002.jpg` , `method1_00003.jpg` 和真值图像 `gt_00001.png` , `gt_00002.png` , `gt_00003.png` . 则我们可以配置如下: 140 | 141 | 例子 3: 数据集配置 142 | 143 | ```json 144 | { 145 | "dataset1": { 146 | "mask": { 147 | "path": "path/Mask", 148 | "prefix": "gt_", 149 | "suffix": ".png" 150 | } 151 | } 152 | } 153 | ``` 154 | 155 | 例子 4: 方法配置 156 | 157 | ```json 158 | { 159 | "method1": { 160 | "dataset1": { 161 | "path": "path/dataset1", 162 | "prefix": "method1_", 163 | "suffix": ".jpg" 164 | } 165 | } 166 | } 167 | ``` 168 | 169 |
170 | 171 | ### 执行评估过程 172 | 173 | - 前述步骤一切正常后, 可以开始评估了. 评估脚本用法可参考命令 `python eval.py --help` 的输出. 174 | - 根据自己需求添加配置项并执行即可. 如无异常, 会生成指定文件名的结果文件. 175 | - 如果不指定所有的文件, 那么就直接输出结果, 具体可见 `eval.py` 的帮助信息. 176 | - 如指定 `--curves-npy`, 绘图相关的指标信息将会保存到对应的 `.npy` 文件中. 177 | - [可选] 可以使用 `tools/converter.py` 直接从生成的 npy 文件中导出 latex 表格代码. 178 | 179 | ### 为灰度图像的评估绘制曲线 180 | 181 | 可以使用 `plot.py` 来读取 `.npy` 文件按需对指定方法和数据集的结果整理并绘制 `PR` , `F-measure` 和 `E-measure` 曲线. 该脚本用法可见 `python plot.py --help` 的输出. 按照自己需求添加配置项并执行即可. 182 | 183 | 最基本的一条是请按照子图数量, 合理地指定配置文件中的 `figure.figsize` 项的数值. 184 | 185 | ### 一个基本的执行流程 186 | 187 | 这里以我自己本地的 configs 文件夹中的 RGB SOD 的配置 (需要根据实际情况进行必要的修改) 为例. 188 | 189 | ```shell 190 | # 检查配置文件 191 | python tools/check_path.py --method-jsons configs/methods/rgb-sod/rgb_sod_methods.json --dataset-jsons configs/datasets/rgb_sod.json 192 | 193 | # 在输出信息中没有不合理的地方后,开始进行评估 194 | # --dataset-json 数据集配置文件 configs/datasets/rgb_sod.json 195 | # --method-json 方法配置文件 configs/methods/rgb-sod/rgb_sod_methods.json 196 | # --metric-npy 输出评估结果数据到 output/rgb_sod/metrics.npy 197 | # --curves-npy 输出曲线数据到 output/rgb_sod/curves.npy 198 | # --record-txt 输出评估结果文本到 output/rgb_sod/results.txt 199 | # --record-xlsx 输出评估结果到excel文档 output/rgb_sod/results.xlsx 200 | # --metric-names 所有结果仅包含给定指标的信息, 涉及到曲线的四个指标分别为 fmeasure em precision recall 201 | # --include-methods 评估过程仅包含 configs/methods/rgb-sod/rgb_sod_methods.json 中的给定方法 202 | # --include-datasets 评估过程仅包含 configs/datasets/rgb_sod.json 中的给定数据集 203 | python eval.py --dataset-json configs/datasets/rgb_sod.json --method-json configs/methods/rgb-sod/rgb_sod_methods.json --metric-npy output/rgb_sod/metrics.npy --curves-npy output/rgb_sod/curves.npy --record-txt output/rgb_sod/results.txt --record-xlsx output/rgb_sod/results.xlsx --metric-names sm wfm mae fmeasure em precision recall --include-methods MINet_R50_2020 GateNet_2020 --include-datasets PASCAL-S ECSSD 204 | 205 | # 得到曲线数据文件,即这里的 output/rgb_sod/curves.npy 文件后,就可以开始绘制图像了 206 | 207 | # 简单的例子,下面指令执行后,结果保存为 output/rgb_sod/simple_curve_pr.pdf 208 | # --style-cfg 使用图像风格配置文件 examples/single_row_style.yml,这里子图较少,直接使用单行的配置 209 | # --num-rows 图像子图都位于一行 210 | # --curves-npys 将使用曲线数据文件 output/rgb_sod/curves.npy 来绘图 211 | # --mode pr: 绘制是pr曲线;fm: 绘制的是fm曲线 212 | # --save-name 图像保存路径,只需写出名字,代码会加上由前面指定的 --style-cfg 中的 `savefig.format` 项指定的格式后缀名 213 | # --alias-yaml: 使用 yaml 文件指定绘图中使用的方法别名和数据集别名 214 | python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-npys output/rgb_sod/curves.npy --mode pr --save-name output/rgb_sod/simple_curve_pr --alias-yaml configs/rgb_aliases.yaml 215 | 216 | # 复杂的例子,下面指令执行后,结果保存为 output/rgb_sod/complex_curve_pr.pdf 217 | 218 | # --style-cfg 使用图像风格配置文件 examples/single_row_style.yml,这里子图较少,直接使用单行的配置 219 | # --num-rows 图像子图都位于一行 220 | # --curves-npys 将使用曲线数据文件 output/rgb_sod/curves.npy 来绘图 221 | # --our-methods 在图中使用红色实线加粗标注指定的方法 MINet_R50_2020 222 | # --num-col-legend 图像子图图示中信息的列数 223 | # --mode pr: 绘制是pr曲线;fm: 绘制的是fm曲线 224 | # --separated-legend 使用共享的单个图示 225 | # --sharey 使用共享的 y 轴刻度,这将仅在每行的第一个图上显示刻度值 226 | # --save-name 图像保存路径,只需写出名字,代码会加上由前面指定的 --style-cfg 中的 `savefig.format` 项指定的格式后缀名 227 | python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-npys output/rgb_sod/curves.npy --our-methods MINet_R50_2020 --num-col-legend 1 --mode pr --separated-legend --sharey --save-name output/rgb_sod/complex_curve_pr 228 | ``` 229 | 230 | ## 绘图示例 231 | 232 | **Precision-Recall Curve**: 233 | 234 | ![PRCurves](https://user-images.githubusercontent.com/26847524/227249768-a41ef076-6355-4b96-a291-fc0e071d9d35.jpg) 235 | 236 | **F-measure Curve**: 237 | 238 | ![fm-curves](https://user-images.githubusercontent.com/26847524/227249746-f61d7540-bb73-464d-bccf-9a36323dec47.jpg) 239 | 240 | **E-measure Curve**: 241 | 242 | ![em-curves](https://user-images.githubusercontent.com/26847524/227249727-8323d5cf-ddd7-427b-8152-b8f47781c4e3.jpg) 243 | 244 | ## 编程参考 245 | 246 | - `openpyxl` 库: 247 | - `re` 模块: 248 | 249 | ## 相关文献 250 | 251 | ```text 252 | @inproceedings{Fmeasure, 253 | title={Frequency-tuned salient region detection}, 254 | author={Achanta, Radhakrishna and Hemami, Sheila and Estrada, Francisco and S{\"u}sstrunk, Sabine}, 255 | booktitle=CVPR, 256 | number={CONF}, 257 | pages={1597--1604}, 258 | year={2009} 259 | } 260 | 261 | @inproceedings{MAE, 262 | title={Saliency filters: Contrast based filtering for salient region detection}, 263 | author={Perazzi, Federico and Kr{\"a}henb{\"u}hl, Philipp and Pritch, Yael and Hornung, Alexander}, 264 | booktitle=CVPR, 265 | pages={733--740}, 266 | year={2012} 267 | } 268 | 269 | @inproceedings{Smeasure, 270 | title={Structure-measure: A new way to eval foreground maps}, 271 | author={Fan, Deng-Ping and Cheng, Ming-Ming and Liu, Yun and Li, Tao and Borji, Ali}, 272 | booktitle=ICCV, 273 | pages={4548--4557}, 274 | year={2017} 275 | } 276 | 277 | @inproceedings{Emeasure, 278 | title="Enhanced-alignment Measure for Binary Foreground Map Evaluation", 279 | author="Deng-Ping {Fan} and Cheng {Gong} and Yang {Cao} and Bo {Ren} and Ming-Ming {Cheng} and Ali {Borji}", 280 | booktitle=IJCAI, 281 | pages="698--704", 282 | year={2018} 283 | } 284 | 285 | @inproceedings{wFmeasure, 286 | title={How to eval foreground maps?}, 287 | author={Margolin, Ran and Zelnik-Manor, Lihi and Tal, Ayellet}, 288 | booktitle=CVPR, 289 | pages={248--255}, 290 | year={2014} 291 | } 292 | ``` 293 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | 3 | matplotlib 4 | numpy 5 | opencv-python 6 | openpyxl 7 | pysodmetrics==1.4.2 # Our Metric Libirary 8 | PyYAML==6.0 9 | tabulate 10 | tqdm 11 | -------------------------------------------------------------------------------- /tools/append_results.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | 4 | import numpy as np 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser( 9 | description="""A simple tool for merging two npy file. 10 | Patch the method items corresponding to the `--method-names` and `--dataset-names` of `--new-npy` into `--old-npy`, 11 | and output the whole container to `--out-npy`. 12 | """ 13 | ) 14 | parser.add_argument("--old-npy", type=str, required=True) 15 | parser.add_argument("--new-npy", type=str, required=True) 16 | parser.add_argument("--method-names", type=str, nargs="+") 17 | parser.add_argument("--dataset-names", type=str, nargs="+") 18 | parser.add_argument("--out-npy", type=str, required=True) 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | args = get_args() 25 | new_npy: dict = np.load(args.new_npy, allow_pickle=True).item() 26 | old_npy: dict = np.load(args.old_npy, allow_pickle=True).item() 27 | 28 | for dataset_name, methods_info in new_npy.items(): 29 | if args.dataset_names and dataset_name not in args.dataset_names: 30 | continue 31 | 32 | print(f"[PROCESSING INFORMATION ABOUT DATASET {dataset_name}...]") 33 | old_methods_info = old_npy.get(dataset_name) 34 | if not old_methods_info: 35 | raise KeyError(f"{old_npy} doesn't contain the information about {dataset_name}.") 36 | 37 | print(f"OLD_NPY: {list(old_methods_info.keys())}") 38 | print(f"NEW_NPY: {list(methods_info.keys())}") 39 | 40 | for method_name, method_info in methods_info.items(): 41 | if args.method_names and method_name not in args.method_names: 42 | continue 43 | 44 | if method_name not in old_npy[dataset_name]: 45 | old_methods_info[method_name] = method_info 46 | print(f"MERGED_NPY: {list(old_methods_info.keys())}") 47 | 48 | np.save(file=args.out_npy, arr=old_npy) 49 | print(f"THE MERGED_NPY IS SAVED INTO {args.out_npy}") 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /tools/check_path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import json 5 | import os 6 | from collections import OrderedDict 7 | 8 | parser = argparse.ArgumentParser(description="A simple tool for checking your json config file.") 9 | parser.add_argument( 10 | "-m", "--method-jsons", nargs="+", required=True, help="The json file about all methods." 11 | ) 12 | parser.add_argument( 13 | "-d", "--dataset-jsons", nargs="+", required=True, help="The json file about all datasets." 14 | ) 15 | args = parser.parse_args() 16 | 17 | for method_json, dataset_json in zip(args.method_jsons, args.dataset_jsons): 18 | with open(method_json, encoding="utf-8", mode="r") as f: 19 | methods_info = json.load(f, object_hook=OrderedDict) # 有序载入 20 | with open(dataset_json, encoding="utf-8", mode="r") as f: 21 | datasets_info = json.load(f, object_hook=OrderedDict) # 有序载入 22 | 23 | total_msgs = [] 24 | for method_name, method_info in methods_info.items(): 25 | print(f"Checking for {method_name} ...") 26 | for dataset_name, results_info in method_info.items(): 27 | if results_info is None: 28 | continue 29 | 30 | dataset_mask_info = datasets_info[dataset_name]["mask"] 31 | mask_path = dataset_mask_info["path"] 32 | mask_suffix = dataset_mask_info["suffix"] 33 | 34 | dir_path = results_info["path"] 35 | file_prefix = results_info.get("prefix", "") 36 | file_suffix = results_info["suffix"] 37 | 38 | if not os.path.exists(dir_path): 39 | total_msgs.append(f"{dir_path} 不存在") 40 | continue 41 | elif not os.path.isdir(dir_path): 42 | total_msgs.append(f"{dir_path} 不是正常的文件夹路径") 43 | continue 44 | else: 45 | pred_names = [ 46 | name[len(file_prefix) : -len(file_suffix)] 47 | for name in os.listdir(dir_path) 48 | if name.startswith(file_prefix) and name.endswith(file_suffix) 49 | ] 50 | if len(pred_names) == 0: 51 | total_msgs.append(f"{dir_path} 中不包含前缀为{file_prefix}且后缀为{file_suffix}的文件") 52 | continue 53 | 54 | mask_names = [ 55 | name[: -len(mask_suffix)] 56 | for name in os.listdir(mask_path) 57 | if name.endswith(mask_suffix) 58 | ] 59 | intersection_names = set(mask_names).intersection(set(pred_names)) 60 | if len(intersection_names) == 0: 61 | total_msgs.append(f"{dir_path} 中数据名字与真值 {mask_path} 不匹配") 62 | elif len(intersection_names) != len(mask_names): 63 | difference_names = set(mask_names).difference(pred_names) 64 | total_msgs.append( 65 | f"{dir_path} 中数据({len(list(pred_names))})与真值({len(list(mask_names))})不一致" 66 | ) 67 | 68 | if total_msgs: 69 | print(*total_msgs, sep="\n") 70 | else: 71 | print(f"{method_json} & {dataset_json} 基本正常") 72 | -------------------------------------------------------------------------------- /tools/converter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/8/25 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | import argparse 7 | import re 8 | from itertools import chain 9 | 10 | import numpy as np 11 | import yaml 12 | 13 | # fmt: off 14 | parser = argparse.ArgumentParser(description="A useful and convenient tool to convert your .npy results into the table code in latex.") 15 | parser.add_argument("-i", "--result-file", required=True, nargs="+", action="extend", help="The path of the *_metrics.npy file.") 16 | parser.add_argument("-o", "--tex-file", required=True, type=str, help="The path of the exported tex file.") 17 | parser.add_argument("-c", "--config-file", type=str, help="The path of the customized config yaml file.") 18 | parser.add_argument("--contain-table-env", action="store_true", help="Whether to containe the table env in the exported code.") 19 | parser.add_argument("--num-bits", type=int, default=3, help="Number of valid digits.") 20 | parser.add_argument("--transpose", action="store_true", help="Whether to transpose the table.") 21 | # fmt: on 22 | args = parser.parse_args() 23 | 24 | arg_head = f"%% Generated by: {vars(args)}" 25 | 26 | 27 | def update_dict(parent_dict, sub_dict): 28 | for sub_k, sub_v in sub_dict.items(): 29 | if sub_k in parent_dict: 30 | if sub_v is not None and isinstance(sub_v, dict): 31 | update_dict(parent_dict=parent_dict[sub_k], sub_dict=sub_v) 32 | continue 33 | parent_dict.update(sub_dict) 34 | 35 | 36 | results = {} 37 | for result_file in args.result_file: 38 | result = np.load(file=result_file, allow_pickle=True).item() 39 | for dataset_name, method_infos in result.items(): 40 | results.setdefault(dataset_name, {}) 41 | for method_name, method_info in method_infos.items(): 42 | new_method_info = {} 43 | for metric_name, metric_result in method_info.items(): 44 | if "fmeasure" in metric_name: 45 | metric_name = metric_name.replace("fmeasure", "f") 46 | new_method_info[metric_name] = metric_result 47 | results[dataset_name][method_name] = new_method_info 48 | 49 | IMPOSSIBLE_UP_BOUND = 1 50 | IMPOSSIBLE_DOWN_BOUND = 0 51 | 52 | # 读取数据 53 | dataset_names = sorted(list(results.keys())) 54 | metric_names = ["SM", "wFm", "MAE", "adpE", "avgE", "maxE", "adpF", "avgF", "maxF"] 55 | method_names = sorted(list(set(chain(*[list(results[n].keys()) for n in dataset_names])))) 56 | 57 | if args.config_file is not None: 58 | assert args.config_file.endswith(".yaml") or args.config_file.endswith("yml") 59 | with open(args.config_file, mode="r", encoding="utf-8") as f: 60 | cfg = yaml.safe_load(f) 61 | 62 | if "dataset_names" not in cfg: 63 | print("`dataset_names` not in the config file, use the default config.") 64 | else: 65 | dataset_names = cfg["dataset_names"] 66 | if "metric_names" not in cfg: 67 | print("`metric_names` not in the config file, use the default config.") 68 | else: 69 | metric_names = cfg["metric_names"] 70 | if "method_names" not in cfg: 71 | print("`method_names` not in the config file, use the default config.") 72 | else: 73 | method_names = cfg["method_names"] 74 | 75 | print( 76 | f"CONFIG INFORMATION:" 77 | f"\n- DATASETS ({len(dataset_names)}): {dataset_names}]" 78 | f"\n- METRICS ({len(metric_names)}): {metric_names}" 79 | f"\n- METHODS ({len(method_names)}): {method_names}" 80 | ) 81 | 82 | if isinstance(metric_names, (list, tuple)): 83 | ori_metric_names = metric_names 84 | elif isinstance(metric_names, dict): 85 | ori_metric_names, metric_names = list(zip(*list(metric_names.items()))) 86 | else: 87 | raise NotImplementedError 88 | 89 | if isinstance(method_names, (list, tuple)): 90 | ori_method_names = method_names 91 | elif isinstance(method_names, dict): 92 | ori_method_names, method_names = list(zip(*list(method_names.items()))) 93 | else: 94 | raise NotImplementedError 95 | 96 | # 整理表格 97 | ori_columns = [] 98 | column_for_index = [] 99 | for dataset_idx, dataset_name in enumerate(dataset_names): 100 | for metric_idx, ori_metric_name in enumerate(ori_metric_names): 101 | filled_value = ( 102 | IMPOSSIBLE_UP_BOUND if ori_metric_name.lower() == "mae" else IMPOSSIBLE_DOWN_BOUND 103 | ) 104 | filled_dict = {k: filled_value for k in ori_metric_names} 105 | ori_column = [] 106 | for method_name in ori_method_names: 107 | method_result = results[dataset_name].get(method_name, filled_dict) 108 | if ori_metric_name not in method_result: 109 | raise KeyError( 110 | f"{ori_metric_name} must be contained in {list(method_result.keys())}" 111 | ) 112 | ori_column.append(method_result[ori_metric_name]) 113 | 114 | column_for_index.append([x * round(1 - filled_value * 2) for x in ori_column]) 115 | ori_columns.append(ori_column) 116 | 117 | style_templates = dict( 118 | method_row_body="& {method_name}", 119 | method_column_body=" {method_name}", 120 | dataset_row_body="& \multicolumn{{{num_metrics}}}{{c}}{{\\textbf{{{dataset_name}}}}}", 121 | dataset_column_body="\multirow{{-{num_metrics}}}{{*}}{{\\rotatebox{{90}}{{\\textbf{{{dataset_name}}}}}}}", 122 | dataset_head=" ", 123 | metric_body="& {metric_name}", 124 | metric_row_head=" ", 125 | metric_column_head="& ", 126 | body=[ 127 | "& \\first{{{txt:.03f}}}", # style for top1 128 | "& \\second{{{txt:.03f}}}", # style for top2 129 | "& \\third{{{txt:.03f}}}", # style for top3 130 | "& {txt:.03f}", # style for other 131 | ], 132 | ) 133 | 134 | 135 | # 排序并添加样式 136 | def replace_cell(ori_value, k): 137 | if ori_value == IMPOSSIBLE_UP_BOUND or ori_value == IMPOSSIBLE_DOWN_BOUND: 138 | new_value = "& " 139 | else: 140 | new_value = style_templates["body"][k].format(txt=ori_value) 141 | return new_value 142 | 143 | 144 | for col, ori_col in zip(column_for_index, ori_columns): 145 | col_array = np.array(col).reshape(-1).round(args.num_bits) 146 | sorted_col_array = np.sort(np.unique(col_array), axis=-1)[-3:][::-1] 147 | # [top1_idxes, top2_idxes, top3_idxes] 148 | top_k_idxes = [np.argwhere(col_array == x).tolist() for x in sorted_col_array] 149 | for k, idxes in enumerate(top_k_idxes): 150 | for row_idx in idxes: 151 | ori_col[row_idx[0]] = replace_cell(ori_col[row_idx[0]], k) 152 | 153 | for idx, x in enumerate(ori_col): 154 | if not isinstance(x, str): 155 | ori_col[idx] = replace_cell(x, -1) 156 | 157 | # 构建表头 158 | num_datasets = len(dataset_names) 159 | num_metrics = len(metric_names) 160 | num_methods = len(method_names) 161 | 162 | # 先构开头的列,再整体构造开头的行 163 | latex_table_head = [] 164 | latex_table_tail = [] 165 | 166 | 167 | def remove_latex_chars_out_of_mathenv(string: str): 168 | string_splits = string.split("$") # 'abcd$efg$hij$' -> ['abcd', 'efg', 'hij', ''] 169 | for i, s in enumerate(string_splits): 170 | if i % 2 == 0: 171 | string_splits[i] = re.sub(pattern=r"_", repl=r"-", string=s) 172 | string = "$".join(string_splits) 173 | return string 174 | 175 | 176 | method_names = [remove_latex_chars_out_of_mathenv(x) for x in method_names] 177 | dataset_names = [remove_latex_chars_out_of_mathenv(x) for x in dataset_names] 178 | if not args.transpose: 179 | dataset_row = ( 180 | [style_templates["dataset_head"]] 181 | + [ 182 | style_templates["dataset_row_body"].format(num_metrics=num_metrics, dataset_name=x) 183 | for x in dataset_names 184 | ] 185 | + [r"\\"] 186 | ) 187 | metric_row = ( 188 | [style_templates["metric_row_head"]] 189 | + [style_templates["metric_body"].format(metric_name=x) for x in metric_names] 190 | * num_datasets 191 | + [r"\\"] 192 | ) 193 | additional_rows = [dataset_row, metric_row] 194 | 195 | # 构建第一列 196 | method_column = [ 197 | style_templates["method_column_body"].format(method_name=x) for x in method_names 198 | ] 199 | additional_columns = [method_column] 200 | 201 | columns = additional_columns + ori_columns 202 | rows = [list(row) + [r"\\"] for row in zip(*columns)] 203 | rows = additional_rows + rows 204 | 205 | if args.contain_table_env: 206 | column_style = "|".join([f"*{num_metrics}{{c}}"] * len(dataset_names)) 207 | latex_table_head = [ 208 | f"\\begin{{tabular}}{{l|{column_style}}}\n", 209 | "\\toprule[2pt]", 210 | ] 211 | else: 212 | dataset_column = [] 213 | for x in dataset_names: 214 | blank_cells = [" "] * (num_metrics - 1) 215 | dataset_cell = [ 216 | style_templates["dataset_column_body"].format(num_metrics=num_metrics, dataset_name=x) 217 | ] 218 | dataset_column.extend(blank_cells + dataset_cell) 219 | metric_column = [ 220 | style_templates["metric_body"].format(metric_name=x) for x in metric_names 221 | ] * num_datasets 222 | additional_columns = [dataset_column, metric_column] 223 | 224 | method_row = ( 225 | [style_templates["dataset_head"], style_templates["metric_column_head"]] 226 | + [style_templates["method_row_body"].format(method_name=x) for x in method_names] 227 | + [r"\\"] 228 | ) 229 | additional_rows = [method_row] 230 | 231 | additional_columns = [list(x) for x in zip(*additional_columns)] 232 | rows = [cells + row + [r"\\"] for cells, row in zip(additional_columns, ori_columns)] 233 | rows = additional_rows + rows 234 | 235 | if args.contain_table_env: 236 | column_style = "".join([f"*{{{num_methods}}}{{c}}"]) 237 | latex_table_head = [ 238 | f"\\begin{{tabular}}{{cc|{column_style}}}\n", 239 | "\\toprule[2pt]", 240 | ] 241 | 242 | if args.contain_table_env: 243 | latex_table_tail = [ 244 | "\\bottomrule[2pt]\n", 245 | "\\end{tabular}", 246 | ] 247 | 248 | rows = [arg_head, latex_table_head] + rows + [latex_table_tail] 249 | 250 | with open(args.tex_file, mode="w", encoding="utf-8") as f: 251 | for row in rows: 252 | f.write("".join(row) + "\n") 253 | -------------------------------------------------------------------------------- /tools/info_py_to_json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/3/14 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | import argparse 6 | import ast 7 | import json 8 | import os 9 | import sys 10 | from importlib import import_module 11 | 12 | 13 | def validate_py_syntax(filename): 14 | with open(filename, "r") as f: 15 | content = f.read() 16 | try: 17 | ast.parse(content) 18 | except SyntaxError as e: 19 | raise SyntaxError("There are syntax errors in config " f"file {filename}: {e}") 20 | 21 | 22 | def convert_py_to_json(source_config_root, target_config_root): 23 | if not os.path.isdir(source_config_root): 24 | raise NotADirectoryError(source_config_root) 25 | if not os.path.exists(target_config_root): 26 | os.makedirs(target_config_root) 27 | else: 28 | if not os.path.isdir(target_config_root): 29 | raise NotADirectoryError(target_config_root) 30 | 31 | sys.path.insert(0, source_config_root) 32 | source_config_files = os.listdir(source_config_root) 33 | for source_config_file in source_config_files: 34 | source_config_path = os.path.join(source_config_root, source_config_file) 35 | if not (os.path.isfile(source_config_path) and source_config_path.endswith(".py")): 36 | continue 37 | validate_py_syntax(source_config_path) 38 | print(source_config_path) 39 | 40 | temp_module_name = os.path.splitext(source_config_file)[0] 41 | mod = import_module(temp_module_name) 42 | 43 | total_dict = {} 44 | for name, value in mod.__dict__.items(): 45 | if not name.startswith("_") and isinstance(value, dict): 46 | total_dict[name] = value 47 | 48 | # delete imported module 49 | del sys.modules[temp_module_name] 50 | 51 | with open( 52 | os.path.join(target_config_root, os.path.basename(temp_module_name) + ".json"), 53 | encoding="utf-8", 54 | mode="w", 55 | ) as f: 56 | json.dump(total_dict, f, indent=2) 57 | 58 | 59 | def get_args(): 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("-i", "--source-py-root", required=True, type=str) 62 | parser.add_argument("-o", "--target-json-root", required=True, type=str) 63 | args = parser.parse_args() 64 | return args 65 | 66 | 67 | if __name__ == "__main__": 68 | args = get_args() 69 | convert_py_to_json( 70 | source_config_root=args.source_py_root, target_config_root=args.target_json_root 71 | ) 72 | -------------------------------------------------------------------------------- /tools/readme.md: -------------------------------------------------------------------------------- 1 | # Useful tools 2 | 3 | ## `append_results.py` 4 | 5 | 将新生成的npy文件与旧的npy文件合并到一个新npy文件中。 6 | 7 | ```shell 8 | $ python append_results.py --help 9 | usage: append_results.py [-h] --old-npy OLD_NPY --new-npy NEW_NPY [--method-names METHOD_NAMES [METHOD_NAMES ...]] 10 | [--dataset-names DATASET_NAMES [DATASET_NAMES ...]] --out-npy OUT_NPY 11 | 12 | A simple tool for merging two npy file. Patch the method items corresponding to the `--method-names` and `--dataset-names` of `--new-npy` 13 | into `--old-npy`, and output the whole container to `--out-npy`. 14 | 15 | optional arguments: 16 | -h, --help show this help message and exit 17 | --old-npy OLD_NPY 18 | --new-npy NEW_NPY 19 | --method-names METHOD_NAMES [METHOD_NAMES ...] 20 | --dataset-names DATASET_NAMES [DATASET_NAMES ...] 21 | --out-npy OUT_NPY 22 | ``` 23 | 24 | 使用情形: 25 | 26 | 对于rgb sod数据,我已经生成了包含一批方法的结果的npy文件:`old_rgb_sod_curves.npy`。 27 | 对于某个方法重新评估后又获得了一个新的npy文件(文件不重名,所以不会覆盖):`new_rgb_sod_curves.npy`。 28 | 现在我想要将这两个结果整合到一个文件中:`finalnew_rgb_sod_curves.npy`。 29 | 可以通过如下指令实现。 30 | 31 | ```shell 32 | python tools/append_results.py --old-npy output/old_rgb_sod_curves.npy \ 33 | --new-npy output/new_rgb_sod_curves.npy \ 34 | --out-npy output/finalnew_rgb_sod_curves.npy 35 | ``` 36 | 37 | 38 | ## `converter.py` 39 | 40 | 将生成的 `*_metrics.npy` 文件中的信息导出成latex表格的形式. 41 | 42 | 可以按照例子文件夹中的 `examples/converter_config.py` 进行手动配置, 从而针对性的生成latex表格代码. 43 | 44 | ```shell 45 | $ python converter.py --help 46 | usage: converter.py [-h] -i RESULT_FILE [RESULT_FILE ...] -o TEX_FILE [-c CONFIG_FILE] [--contain-table-env] [--transpose] 47 | 48 | A useful and convenient tool to convert your .npy results into the table code in latex. 49 | 50 | optional arguments: 51 | -h, --help show this help message and exit 52 | -i RESULT_FILE [RESULT_FILE ...], --result-file RESULT_FILE [RESULT_FILE ...] 53 | The path of the *_metrics.npy file. 54 | -o TEX_FILE, --tex-file TEX_FILE 55 | The path of the exported tex file. 56 | -c CONFIG_FILE, --config-file CONFIG_FILE 57 | The path of the customized config yaml file. 58 | --contain-table-env Whether to containe the table env in the exported code. 59 | --transpose Whether to transpose the table. 60 | ``` 61 | 62 | 使用案例如下. 63 | 64 | ```shell 65 | $ python tools/converter.py -i output/your_metrics_1.npy output/your_metrics_2.npy -o output/your_metrics.tex -c ./examples/converter_config.yaml --transpose --contain-table-env 66 | ``` 67 | 68 | 该指令从多个npy文件中(如果你仅有一个, 可以紧跟一个npy文件)读取数据, 处理后导出到指定的tex文件中. 并且使用指定的config文件设置了相关的数据集、指标以及模型方法. 69 | 70 | 另外, 对于输出的表格代码, 使用转置后的竖表, 并且包含table的 `tabular` 环境代码. 71 | 72 | ## `check_path.py` 73 | 74 | 通过将json中的信息与实际系统中的路径进行匹配, 检验是否存在异常. 75 | 76 | ```shell 77 | $ python check_path.py --help 78 | usage: check_path.py [-h] -m METHOD_JSONS [METHOD_JSONS ...] -d DATASET_JSONS [DATASET_JSONS ...] 79 | 80 | A simple tool for checking your json config file. 81 | 82 | optional arguments: 83 | -h, --help show this help message and exit 84 | -m METHOD_JSONS [METHOD_JSONS ...], --method-jsons METHOD_JSONS [METHOD_JSONS ...] 85 | The json file about all methods. 86 | -d DATASET_JSONS [DATASET_JSONS ...], --dataset-jsons DATASET_JSONS [DATASET_JSONS ...] 87 | ``` 88 | 89 | 使用案例如下, 请务必确保, `-m` 和 `-d` 后的文件路径是一一对应的, 在代码中会使用 `zip` 来同步迭代两个列表来确认文件名称的匹配关系. 90 | 91 | ```shell 92 | $ python check_path.py -m ../configs/methods/json/rgb_sod_methods.json ../configs/methods/json/rgbd_sod_methods.json \ 93 | -d ../configs/datasets/json/rgb_sod.json ../configs/datasets/json/rgbd_sod.json 94 | ``` 95 | 96 | ## `info_py_to_json.py` 97 | 98 | 将基于python格式的配置文件转化为更便携的json文件. 99 | 100 | ```shell 101 | $ python info_py_to_json.py --help 102 | usage: info_py_to_json.py [-h] -i SOURCE_PY_ROOT -o TARGET_JSON_ROOT 103 | 104 | optional arguments: 105 | -h, --help show this help message and exit 106 | -i SOURCE_PY_ROOT, --source-py-root SOURCE_PY_ROOT 107 | -o TARGET_JSON_ROOT, --target-json-root TARGET_JSON_ROOT 108 | ``` 109 | 110 | 即提供了两个必需提供的配置项, 存放python配置文件的输入目录 `-i` 和将要存放生成的json文件输出目录 `-o` . 111 | 112 | 通过载入输入目录中各个python文件, 我们默认从中获取内部包含的不使用 `_` 开头的字典对象对应于各个数据集或者方法的配置信息. 113 | 114 | 最后将这些信息汇总到一个完整的字典中, 直接导出到json文件中, 保存到输出目录下. 115 | 116 | ## `rename.py` 117 | 118 | 批量重命名. 119 | 120 | 使用前建议通读代码, 请小心使用, 防止文件覆盖造成不必要的损失. 121 | -------------------------------------------------------------------------------- /tools/rename.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/4/24 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | import glob 7 | import os 8 | import re 9 | import shutil 10 | 11 | 12 | def path_join(base_path, sub_path): 13 | if sub_path.startswith(os.sep): 14 | sub_path = sub_path[len(os.sep) :] 15 | return os.path.join(base_path, sub_path) 16 | 17 | 18 | def rename_all_files(src_pattern, dst_pattern, src_name, src_dir, dst_dir=None): 19 | """ 20 | :param src_pattern: 匹配原始数据名字的正则表达式 21 | :param dst_pattern: 对应的修改后的字符式 22 | :param src_dir: 存放原始数据的文件夹路径,可以组合src_name来构造路径模式,使用glob进行数据搜索 23 | :param src_name: glob类型的模式 24 | :param dst_dir: 存放修改后数据的文件夹路径,默认为None,表示直接修改原始数据 25 | """ 26 | assert os.path.isdir(src_dir) 27 | 28 | if dst_dir is None: 29 | dst_dir = src_dir 30 | rename_func = os.replace 31 | else: 32 | assert os.path.isdir(dst_dir) 33 | if dst_dir == src_dir: 34 | rename_func = os.replace 35 | else: 36 | rename_func = shutil.copy 37 | print(f"将会使用 {rename_func.__name__} 来修改数据") 38 | 39 | src_dir = os.path.abspath(src_dir) 40 | dst_dir = os.path.abspath(dst_dir) 41 | src_data_paths = glob.glob(path_join(src_dir, src_name)) 42 | 43 | print(f"开始替换 {src_dir} 中的数据") 44 | for idx, src_data_path in enumerate(src_data_paths, start=1): 45 | src_name_w_dir_name = src_data_path[len(src_dir) + 1 :] 46 | dst_name_w_dir_name = re.sub(src_pattern, repl=dst_pattern, string=src_name_w_dir_name) 47 | if dst_name_w_dir_name == src_name_w_dir_name: 48 | continue 49 | dst_data_path = path_join(dst_dir, dst_name_w_dir_name) 50 | 51 | dst_data_dir = os.path.dirname(dst_data_path) 52 | if not os.path.exists(dst_data_dir): 53 | print(f"{idx}: {dst_data_dir} 不存在,新建一下") 54 | os.makedirs(dst_data_dir) 55 | rename_func(src=src_data_path, dst=dst_data_path) 56 | print(f"{src_data_path} -> {dst_data_path}") 57 | 58 | print("OK...") 59 | 60 | 61 | if __name__ == "__main__": 62 | rename_all_files( 63 | src_pattern=r"", 64 | dst_pattern="", 65 | src_name="*/*.png", 66 | src_dir="", 67 | dst_dir="", 68 | ) 69 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lartpang/PySODEvalToolkit/a52dbccb90ec3da8f52a6e0b97b055e1d27b6319/utils/__init__.py -------------------------------------------------------------------------------- /utils/generate_info.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import os 5 | from collections import OrderedDict 6 | 7 | from matplotlib import colors 8 | 9 | # max = 148 10 | _COLOR_Genarator = iter( 11 | sorted( 12 | [ 13 | color 14 | for name, color in colors.cnames.items() 15 | if name not in ["red", "white"] or not name.startswith("light") or "gray" in name 16 | ] 17 | ) 18 | ) 19 | 20 | 21 | def curve_info_generator(): 22 | # TODO 当前尽是对方法依次赋予`-`和`--`,但是对于某一方法仅包含特定数据集的结果时, 23 | # 会导致确实结果的数据集中的该方法相邻的两个方法style一样 24 | line_style_flag = True 25 | 26 | def _template_generator( 27 | method_info: dict, method_name: str, line_color: str = None, line_width: int = 2 28 | ) -> dict: 29 | nonlocal line_style_flag 30 | template_info = dict( 31 | path_dict=method_info, 32 | curve_setting=dict( 33 | line_style="-" if line_style_flag else "--", 34 | line_label=method_name, 35 | line_width=line_width, 36 | ), 37 | ) 38 | if line_color is not None: 39 | template_info["curve_setting"]["line_color"] = line_color 40 | else: 41 | template_info["curve_setting"]["line_color"] = next(_COLOR_Genarator) 42 | 43 | line_style_flag = not line_style_flag 44 | return template_info 45 | 46 | return _template_generator 47 | 48 | 49 | def simple_info_generator(): 50 | def _template_generator(method_info: dict, method_name: str) -> dict: 51 | template_info = dict(path_dict=method_info, label=method_name) 52 | return template_info 53 | 54 | return _template_generator 55 | 56 | 57 | def get_valid_elements(source: list, include_elements: list, exclude_elements: list) -> list: 58 | targeted_elements = [] 59 | if include_elements and not exclude_elements: # only include_elements is not [] and not None 60 | for element in include_elements: 61 | assert element in source, element 62 | targeted_elements.append(element) 63 | 64 | elif not include_elements and exclude_elements: # only exclude_elements is not [] and not None 65 | for element in exclude_elements: 66 | assert element in source, element 67 | for element in source: 68 | if element not in exclude_elements: 69 | targeted_elements.append(element) 70 | 71 | elif not include_elements and not exclude_elements: 72 | targeted_elements = source 73 | 74 | else: 75 | raise ValueError( 76 | f"include_elements: {include_elements}\nexclude_elements: {exclude_elements}" 77 | ) 78 | 79 | if not targeted_elements: 80 | print(source, include_elements, exclude_elements) 81 | raise ValueError("targeted_elements must be a valid and non-empty list.") 82 | return targeted_elements 83 | 84 | 85 | def get_methods_info( 86 | methods_info_jsons: list, 87 | include_methods: list, 88 | exclude_methods: list, 89 | *, 90 | for_drawing: bool = False, 91 | our_name: str = "", 92 | ) -> OrderedDict: 93 | """ 94 | 在json文件中存储的对应方法的字典的键值会被直接用于绘图 95 | 96 | :param methods_info_jsons: 保存方法信息的json文件,支持多个文件组合使用,按照输入的顺序依此读取 97 | :param for_drawing: 是否用于绘制曲线图,True会补充一些绘图信息 98 | :param our_name: 在绘图时,可以通过指定our_name来使用红色加粗实线强调特定方法的曲线 99 | :param include_methods: 仅返回列表中指定的方法的信息,为None时,返回所有 100 | :param exclude_methods: 仅返回列表中指定的方法的信息,为None时,返回所有,与include_datasets必须仅有一个非None 101 | :return: methods_full_info 102 | """ 103 | if not isinstance(methods_info_jsons, (list, tuple)): 104 | methods_info_jsons = [methods_info_jsons] 105 | 106 | methods_info = {} 107 | for f in methods_info_jsons: 108 | if not os.path.isfile(f): 109 | raise FileNotFoundError(f"{f} is not be found!!!") 110 | 111 | with open(f, encoding="utf-8", mode="r") as f: 112 | methods_info.update(json.load(f, object_hook=OrderedDict)) # 有序载入 113 | 114 | if our_name: 115 | assert our_name in methods_info, f"{our_name} is not in json file." 116 | 117 | targeted_methods = get_valid_elements( 118 | source=list(methods_info.keys()), 119 | include_elements=include_methods, 120 | exclude_elements=exclude_methods, 121 | ) 122 | if our_name and our_name in targeted_methods: 123 | targeted_methods.pop(targeted_methods.index(our_name)) 124 | targeted_methods.insert(0, our_name) 125 | 126 | if for_drawing: 127 | info_generator = curve_info_generator() 128 | else: 129 | info_generator = simple_info_generator() 130 | 131 | methods_full_info = [] 132 | for method_name in targeted_methods: 133 | method_path = methods_info[method_name] 134 | 135 | if for_drawing and our_name and our_name == method_name: 136 | method_info = info_generator(method_path, method_name, line_color="red", line_width=3) 137 | else: 138 | method_info = info_generator(method_path, method_name) 139 | methods_full_info.append((method_name, method_info)) 140 | return OrderedDict(methods_full_info) 141 | 142 | 143 | def get_datasets_info( 144 | datastes_info_json: str, include_datasets: list, exclude_datasets: list 145 | ) -> OrderedDict: 146 | """ 147 | 在json文件中存储的所有数据集的信息会被直接导出到一个字典中 148 | 149 | :param datastes_info_json: 保存方法信息的json文件 150 | :param include_datasets: 指定读取信息的数据集名字,为None时,读取所有 151 | :param exclude_datasets: 排除读取信息的数据集名字,为None时,读取所有,与include_datasets必须仅有一个非None 152 | :return: datastes_full_info 153 | """ 154 | 155 | assert os.path.isfile(datastes_info_json), datastes_info_json 156 | with open(datastes_info_json, encoding="utf-8", mode="r") as f: 157 | datasets_info = json.load(f, object_hook=OrderedDict) # 有序载入 158 | 159 | targeted_datasets = get_valid_elements( 160 | source=list(datasets_info.keys()), 161 | include_elements=include_datasets, 162 | exclude_elements=exclude_datasets, 163 | ) 164 | 165 | datasets_full_info = [] 166 | for dataset_name in targeted_datasets: 167 | data_path = datasets_info[dataset_name] 168 | 169 | datasets_full_info.append((dataset_name, data_path)) 170 | return OrderedDict(datasets_full_info) 171 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import glob 3 | import os 4 | import re 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | def get_ext(path_list): 11 | ext_list = list(set([os.path.splitext(p)[1] for p in path_list])) 12 | if len(ext_list) != 1: 13 | if ".png" in ext_list: 14 | ext = ".png" 15 | elif ".jpg" in ext_list: 16 | ext = ".jpg" 17 | elif ".bmp" in ext_list: 18 | ext = ".bmp" 19 | else: 20 | raise NotImplementedError 21 | print(f"预测文件夹中包含多种扩展名,这里仅使用{ext}") 22 | else: 23 | ext = ext_list[0] 24 | return ext 25 | 26 | 27 | def get_name_list_and_suffix(data_path: str) -> tuple: 28 | name_list = [] 29 | if os.path.isfile(data_path): 30 | print(f" ++>> {data_path} is a file. <<++ ") 31 | with open(data_path, mode="r", encoding="utf-8") as file: 32 | line = file.readline() 33 | while line: 34 | img_name = os.path.basename(line.split()[0]) 35 | file_ext = os.path.splitext(img_name)[1] 36 | name_list.append(os.path.splitext(img_name)[0]) 37 | line = file.readline() 38 | if file_ext == "": 39 | # 默认为png 40 | file_ext = ".png" 41 | else: 42 | print(f" ++>> {data_path} is a folder. <<++ ") 43 | data_list = os.listdir(data_path) 44 | file_ext = get_ext(data_list) 45 | name_list = [os.path.splitext(f)[0] for f in data_list if f.endswith(file_ext)] 46 | name_list = list(set(name_list)) 47 | return name_list, file_ext 48 | 49 | 50 | def get_name_list(data_path: str, name_prefix: str = "", name_suffix: str = "") -> list: 51 | if os.path.isfile(data_path): 52 | assert data_path.endswith((".txt", ".lst")) 53 | data_list = [] 54 | with open(data_path, encoding="utf-8", mode="r") as f: 55 | line = f.readline().strip() 56 | while line: 57 | data_list.append(line) 58 | line = f.readline().strip() 59 | else: 60 | data_list = os.listdir(data_path) 61 | 62 | name_list = data_list 63 | if not name_prefix and not name_suffix: 64 | name_list = [os.path.splitext(f)[0] for f in name_list] 65 | else: 66 | name_list = [ 67 | f[len(name_prefix) : -len(name_suffix)] 68 | for f in name_list 69 | if f.startswith(name_prefix) and f.endswith(name_suffix) 70 | ] 71 | 72 | name_list = list(set(name_list)) 73 | return name_list 74 | 75 | 76 | def get_number_from_tail(string): 77 | tail_number = re.findall(pattern="\d+$", string=string)[0] 78 | return int(tail_number) 79 | 80 | 81 | def get_name_with_group_list( 82 | data_path: str, 83 | name_prefix: str = "", 84 | name_suffix: str = "", 85 | start_idx: int = 0, 86 | end_idx: int = None, 87 | sep: str = "", 88 | ): 89 | """get file names with the group name 90 | 91 | Args: 92 | data_path (str): The path of data. 93 | name_prefix (str, optional): The prefix of the file name. Defaults to "". 94 | name_suffix (str, optional): The suffix of the file name. Defaults to "". 95 | start_idx (int, optional): The index of the first frame in each group. Defaults to 0, it will not skip any frames. 96 | end_idx (int, optional): The index of the last frame in each group. Defaults to None, it will not skip any frames. 97 | sep (str, optional): The returned name is a string containing group_name and file_name separated by `sep`. 98 | 99 | Raises: 100 | NotImplementedError: Undefined. 101 | 102 | Returns: 103 | list: The name (with the group name) list and the original number of image in the dataset. 104 | """ 105 | name_list = [] 106 | if os.path.isfile(data_path): 107 | # 暂未遇到类似的设定 108 | raise NotImplementedError 109 | else: 110 | if "*" in data_path: # for VCOD 111 | group_paths = glob.glob(data_path, recursive=False) 112 | group_name_start_idx = data_path.find("*") 113 | for group_path in group_paths: 114 | group_name = group_path[group_name_start_idx:].split(os.sep)[0] 115 | 116 | file_names = sorted( 117 | [ 118 | n[len(name_prefix) : -len(name_suffix)] 119 | for n in os.listdir(group_path) 120 | if n.startswith(name_prefix) and n.endswith(name_suffix) 121 | ], 122 | key=lambda item: get_number_from_tail(item), 123 | ) 124 | 125 | for file_name in file_names[start_idx:end_idx]: 126 | name_list.append(f"{group_name}{sep}{file_name}") 127 | 128 | else: # for CoSOD 129 | group_names = os.listdir(data_path) 130 | group_paths = [os.path.join(data_path, n) for n in group_names] 131 | for group_path in group_paths: 132 | group_name = os.path.basename(group_path) 133 | 134 | file_names = sorted( 135 | [ 136 | n[len(name_prefix) : -len(name_suffix)] 137 | for n in os.listdir(group_path) 138 | if n.startswith(name_prefix) and n.endswith(name_suffix) 139 | ], 140 | key=lambda item: get_number_from_tail(item), 141 | ) 142 | 143 | for file_name in file_names[start_idx:end_idx]: 144 | name_list.append(f"{group_name}{sep}{file_name}") 145 | name_list = sorted(set(name_list)) # 去重 146 | return name_list 147 | 148 | 149 | def get_list_with_suffix(dataset_path: str, suffix: str): 150 | name_list = [] 151 | if os.path.isfile(dataset_path): 152 | print(f" ++>> {dataset_path} is a file. <<++ ") 153 | with open(dataset_path, mode="r", encoding="utf-8") as file: 154 | line = file.readline() 155 | while line: 156 | img_name = os.path.basename(line.split()[0]) 157 | name_list.append(os.path.splitext(img_name)[0]) 158 | line = file.readline() 159 | else: 160 | print(f" ++>> {dataset_path} is a folder. <<++ ") 161 | name_list = [ 162 | os.path.splitext(f)[0] for f in os.listdir(dataset_path) if f.endswith(suffix) 163 | ] 164 | name_list = list(set(name_list)) 165 | return name_list 166 | 167 | 168 | def make_dir(path): 169 | if not os.path.exists(path): 170 | print(f"`{path}` does not exist,we will create it.") 171 | os.makedirs(path) 172 | else: 173 | assert os.path.isdir(path), f"`{path}` should be a folder" 174 | print(f"`{path}`已存在") 175 | 176 | 177 | def imread_with_checking(path, for_color: bool = True) -> np.ndarray: 178 | assert os.path.exists(path=path) and os.path.isfile(path=path), path 179 | if for_color: 180 | data = cv2.imread(path, flags=cv2.IMREAD_COLOR) 181 | data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) 182 | else: 183 | data = cv2.imread(path, flags=cv2.IMREAD_GRAYSCALE) 184 | return data 185 | 186 | 187 | def get_gt_pre_with_name( 188 | img_name: str, 189 | gt_root: str, 190 | pre_root: str, 191 | *, 192 | gt_prefix: str = "", 193 | pre_prefix: str = "", 194 | gt_suffix: str = ".png", 195 | pre_suffix: str = "", 196 | to_normalize: bool = False, 197 | ): 198 | img_path = os.path.join(pre_root, pre_prefix + img_name + pre_suffix) 199 | gt_path = os.path.join(gt_root, gt_prefix + img_name + gt_suffix) 200 | 201 | pre = imread_with_checking(img_path, for_color=False) 202 | gt = imread_with_checking(gt_path, for_color=False) 203 | 204 | if pre.shape != gt.shape: 205 | pre = cv2.resize(pre, dsize=gt.shape[::-1], interpolation=cv2.INTER_LINEAR).astype( 206 | np.uint8 207 | ) 208 | 209 | if to_normalize: 210 | gt = normalize_array(gt, to_binary=True, max_eq_255=True) 211 | pre = normalize_array(pre, to_binary=False, max_eq_255=True) 212 | return gt, pre 213 | 214 | 215 | def get_gt_pre_with_name_and_group( 216 | img_name: str, 217 | gt_root: str, 218 | pre_root: str, 219 | *, 220 | gt_prefix: str = "", 221 | pre_prefix: str = "", 222 | gt_suffix: str = ".png", 223 | pre_suffix: str = "", 224 | to_normalize: bool = False, 225 | interpolation: int = cv2.INTER_CUBIC, 226 | sep: str = "", 227 | ): 228 | group_name, file_name = img_name.split(sep) 229 | if "*" in gt_root: 230 | gt_root = gt_root.replace("*", group_name) 231 | else: 232 | gt_root = os.path.join(gt_root, group_name) 233 | if "*" in pre_root: 234 | pre_root = pre_root.replace("*", group_name) 235 | else: 236 | pre_root = os.path.join(pre_root, group_name) 237 | img_path = os.path.join(pre_root, pre_prefix + file_name + pre_suffix) 238 | gt_path = os.path.join(gt_root, gt_prefix + file_name + gt_suffix) 239 | 240 | pre = imread_with_checking(img_path, for_color=False) 241 | gt = imread_with_checking(gt_path, for_color=False) 242 | 243 | if pre.shape != gt.shape: 244 | pre = cv2.resize(pre, dsize=gt.shape[::-1], interpolation=interpolation).astype(np.uint8) 245 | 246 | if to_normalize: 247 | gt = normalize_array(gt, to_binary=True, max_eq_255=True) 248 | pre = normalize_array(pre, to_binary=False, max_eq_255=True) 249 | return gt, pre 250 | 251 | 252 | def normalize_array( 253 | data: np.ndarray, to_binary: bool = False, max_eq_255: bool = True 254 | ) -> np.ndarray: 255 | if max_eq_255: 256 | data = data / 255 257 | # else: data is in [0, 1] 258 | if to_binary: 259 | data = (data > 0.5).astype(np.uint8) 260 | else: 261 | if data.max() != data.min(): 262 | data = (data - data.min()) / (data.max() - data.min()) 263 | data = data.astype(np.float32) 264 | return data 265 | 266 | 267 | def get_valid_key_name(data_dict: dict, key_name: str) -> str: 268 | if data_dict.get(key_name.lower(), "keyerror") == "keyerror": 269 | key_name = key_name.upper() 270 | else: 271 | key_name = key_name.lower() 272 | return key_name 273 | 274 | 275 | def get_target_key(target_dict: dict, key: str) -> str: 276 | """ 277 | from the keys of the target_dict, get the valid key name corresponding to the `key` 278 | if there is not a valid name, return None 279 | """ 280 | target_keys = {k.lower(): k for k in target_dict.keys()} 281 | return target_keys.get(key.lower(), None) 282 | 283 | 284 | def colored_print(msg: str, mode: str = "general"): 285 | """ 286 | 为不同类型的字符串消息的打印提供一些显示格式的定制 287 | 288 | :param msg: 要输出的字符串消息 289 | :param mode: 对应的字符串打印模式,目前支持 general/warning/error 290 | :return: 291 | """ 292 | if mode == "general": 293 | msg = msg 294 | elif mode == "warning": 295 | msg = f"\033[5;31m{msg}\033[0m" 296 | elif mode == "error": 297 | msg = f"\033[1;31m{msg}\033[0m" 298 | else: 299 | raise ValueError(f"{mode} is invalid mode.") 300 | print(msg) 301 | 302 | 303 | class ColoredPrinter: 304 | """ 305 | 为不同类型的字符串消息的打印提供一些显示格式的定制 306 | """ 307 | 308 | @staticmethod 309 | def info(msg): 310 | print(msg) 311 | 312 | @staticmethod 313 | def warn(msg): 314 | msg = f"\033[5;31m{msg}\033[0m" 315 | print(msg) 316 | 317 | @staticmethod 318 | def error(msg): 319 | msg = f"\033[1;31m{msg}\033[0m" 320 | print(msg) 321 | 322 | 323 | def update_info(source_info: dict, new_info: dict): 324 | for name, info in source_info.items(): 325 | if name in new_info: 326 | if isinstance(info, dict): 327 | update_info(source_info=info, new_info=new_info[name]) 328 | else: # int, float, list, tuple 329 | info = new_info[name] 330 | source_info[name] = info 331 | return source_info 332 | -------------------------------------------------------------------------------- /utils/print_formatter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from tabulate import tabulate 4 | 5 | 6 | def print_formatter( 7 | results: dict, method_name_length=10, metric_name_length=5, metric_value_length=5 8 | ): 9 | dataset_regions = [] 10 | for dataset_name, dataset_metrics in results.items(): 11 | dataset_head_row = f" Dataset: {dataset_name} " 12 | dataset_region = [dataset_head_row] 13 | for method_name, metric_info in dataset_metrics.items(): 14 | showed_method_name = clip_string( 15 | method_name, max_length=method_name_length, mode="left" 16 | ) 17 | method_row_head = f"{showed_method_name} " 18 | method_row_body = [] 19 | for metric_name, metric_value in metric_info.items(): 20 | showed_metric_name = clip_string( 21 | metric_name, max_length=metric_name_length, mode="right" 22 | ) 23 | showed_value_string = clip_string( 24 | str(metric_value), max_length=metric_value_length, mode="left" 25 | ) 26 | method_row_body.append(f"{showed_metric_name}: {showed_value_string}") 27 | method_row = method_row_head + ", ".join(method_row_body) 28 | dataset_region.append(method_row) 29 | dataset_region_string = "\n".join(dataset_region) 30 | dataset_regions.append(dataset_region_string) 31 | dividing_line = "\n" + "-" * len(dataset_region[-1]) + "\n" # 直接使用最后一个数据集区域的最后一行数据的长度作为分割线的长度 32 | formatted_string = dividing_line.join(dataset_regions) 33 | return formatted_string 34 | 35 | 36 | def clip_string(string: str, max_length: int, padding_char: str = " ", mode: str = "left"): 37 | assert isinstance(max_length, int), f"{max_length} must be `int` type" 38 | 39 | real_length = len(string) 40 | if real_length <= max_length: 41 | padding_length = max_length - real_length 42 | if mode == "left": 43 | clipped_string = string + f"{padding_char}" * padding_length 44 | elif mode == "center": 45 | left_padding_str = f"{padding_char}" * (padding_length // 2) 46 | right_padding_str = f"{padding_char}" * (padding_length - padding_length // 2) 47 | clipped_string = left_padding_str + string + right_padding_str 48 | elif mode == "right": 49 | clipped_string = f"{padding_char}" * padding_length + string 50 | else: 51 | raise NotImplementedError 52 | else: 53 | clipped_string = string[:max_length] 54 | 55 | return clipped_string 56 | 57 | 58 | def formatter_for_tabulate( 59 | results: dict, 60 | method_names: tuple, 61 | dataset_names: tuple, 62 | dataset_titlefmt: str = "Dataset: {}", 63 | method_name_length=None, 64 | metric_value_length=None, 65 | tablefmt="github", 66 | ): 67 | """ 68 | tabulate format: 69 | 70 | :: 71 | 72 | table = [["spam",42],["eggs",451],["bacon",0]] 73 | headers = ["item", "qty"] 74 | print(tabulate(table, headers, tablefmt="github")) 75 | 76 | | item | qty | 77 | |--------|-------| 78 | | spam | 42 | 79 | | eggs | 451 | 80 | | bacon | 0 | 81 | 82 | 本函数的作用: 83 | 针对不同的数据集各自构造符合tabulate格式的列表并使用换行符间隔串联起来返回 84 | """ 85 | all_tables = [] 86 | for dataset_name in dataset_names: 87 | dataset_metrics = results[dataset_name] 88 | all_tables.append(dataset_titlefmt.format(dataset_name)) 89 | 90 | table = [] 91 | headers = ["methods"] 92 | for method_name in method_names: 93 | metric_info = dataset_metrics.get(method_name) 94 | if metric_info is None: 95 | continue 96 | 97 | if method_name_length: 98 | method_name = clip_string(method_name, max_length=method_name_length, mode="left") 99 | method_row = [method_name] 100 | 101 | for metric_name, metric_value in metric_info.items(): 102 | if metric_value_length: 103 | metric_value = clip_string( 104 | str(metric_value), max_length=metric_value_length, mode="center" 105 | ) 106 | if metric_name not in headers: 107 | headers.append(metric_name) 108 | method_row.append(metric_value) 109 | table.append(method_row) 110 | all_tables.append(tabulate(table, headers, tablefmt=tablefmt)) 111 | 112 | formatted_string = "\n".join(all_tables) 113 | return formatted_string 114 | -------------------------------------------------------------------------------- /utils/recorders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .curve_drawer import CurveDrawer 3 | from .excel_recorder import MetricExcelRecorder 4 | from .metric_recorder import ( 5 | BINARY_METRIC_MAPPING, 6 | GRAYSCALE_METRICS, 7 | SUPPORTED_METRICS, 8 | BinaryMetricRecorder, 9 | GrayscaleMetricRecorder, 10 | GroupedMetricRecorder, 11 | ) 12 | from .txt_recorder import TxtRecorder 13 | -------------------------------------------------------------------------------- /utils/recorders/curve_drawer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/1/4 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | import math 6 | import os 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | 12 | class CurveDrawer(object): 13 | def __init__( 14 | self, 15 | row_num, 16 | num_subplots, 17 | style_cfg=None, 18 | ncol_of_legend=1, 19 | separated_legend=False, 20 | sharey=False, 21 | ): 22 | """A better wrapper of matplotlib for me. 23 | 24 | Args: 25 | row_num (int): Number of rows. 26 | num_subplots (int): Number of subplots. 27 | style_cfg (str, optional): Style yaml file path for matplotlib. Defaults to None. 28 | ncol_of_legend (int, optional): Number of columns of the legend. Defaults to 1. 29 | separated_legend (bool, optional): Use the separated legend. Defaults to False. 30 | sharey (bool, optional): Use a shared y-axis. Defaults to False. 31 | """ 32 | if style_cfg is not None: 33 | assert os.path.isfile(style_cfg) 34 | plt.style.use(style_cfg) 35 | 36 | self.ncol_of_legend = ncol_of_legend 37 | self.separated_legend = separated_legend 38 | if self.separated_legend: 39 | num_subplots += 1 40 | self.num_subplots = num_subplots 41 | self.sharey = sharey 42 | 43 | fig, axes = plt.subplots( 44 | nrows=row_num, ncols=math.ceil(self.num_subplots / row_num), sharey=self.sharey 45 | ) 46 | self.fig = fig 47 | self.axes = axes 48 | if isinstance(self.axes, np.ndarray): 49 | self.axes = self.axes.flatten() 50 | else: 51 | self.axes = [self.axes] 52 | 53 | self.init_subplots() 54 | self.dummy_data = {} 55 | 56 | def init_subplots(self): 57 | for ax in self.axes: 58 | ax.set_axis_off() 59 | 60 | def plot_at_axis(self, idx, method_curve_setting, x_data, y_data): 61 | """ 62 | :param method_curve_setting: { 63 | "line_color": "color"(str), 64 | "line_style": "style"(str), 65 | "line_label": "label"(str), 66 | "line_width": width(int), 67 | } 68 | """ 69 | assert isinstance(idx, int) and 0 <= idx < self.num_subplots 70 | self.axes[idx].plot( 71 | x_data, 72 | y_data, 73 | linewidth=method_curve_setting["line_width"], 74 | label=method_curve_setting["line_label"], 75 | color=method_curve_setting["line_color"], 76 | linestyle=method_curve_setting["line_style"], 77 | ) 78 | 79 | if self.separated_legend: 80 | self.dummy_data[method_curve_setting["line_label"]] = method_curve_setting 81 | 82 | def set_axis_property( 83 | self, idx, title=None, x_label=None, y_label=None, x_ticks=None, y_ticks=None 84 | ): 85 | ax = self.axes[idx] 86 | 87 | ax.set_axis_on() 88 | 89 | # give plot a title 90 | ax.set_title(title) 91 | 92 | # make axis labels 93 | ax.set_xlabel(x_label) 94 | ax.set_ylabel(y_label) 95 | 96 | # 对坐标刻度的设置 97 | x_ticks = [] if x_ticks is None else x_ticks 98 | y_ticks = [] if y_ticks is None else y_ticks 99 | ax.set_xlim((min(x_ticks), max(x_ticks))) 100 | ax.set_ylim((min(y_ticks), max(x_ticks))) 101 | ax.set_xticks(x_ticks) 102 | ax.set_yticks(y_ticks) 103 | ax.set_xticklabels(labels=[f"{x:.2f}" for x in x_ticks]) 104 | ax.set_yticklabels(labels=[f"{y:.2f}" for y in y_ticks]) 105 | 106 | def _plot(self): 107 | if self.sharey: 108 | for ax in self.axes[1:]: 109 | ax.set_ylabel(None) 110 | ax.tick_params(bottom=True, top=False, left=False, right=False) 111 | 112 | if self.separated_legend: 113 | # settings for the legend axis 114 | for method_label, method_info in self.dummy_data.items(): 115 | self.plot_at_axis( 116 | idx=self.num_subplots - 1, method_curve_setting=method_info, x_data=0, y_data=0 117 | ) 118 | ax = self.axes[self.num_subplots - 1] 119 | ax.set_axis_off() 120 | ax.legend(loc=10, ncol=self.ncol_of_legend, facecolor="white", edgecolor="white") 121 | else: 122 | # settings for the legneds of all common subplots. 123 | for ax in self.axes: 124 | # loc=0,自动将位置放在最合适的 125 | ax.legend(loc=3, ncol=self.ncol_of_legend, facecolor="white", edgecolor="white") 126 | 127 | def show(self): 128 | self._plot() 129 | plt.tight_layout() 130 | plt.show() 131 | 132 | def save(self, path): 133 | self._plot() 134 | plt.tight_layout() 135 | plt.savefig(path) 136 | -------------------------------------------------------------------------------- /utils/recorders/excel_recorder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/1/3 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | import os 7 | import re 8 | 9 | from openpyxl import Workbook, load_workbook 10 | from openpyxl.utils import get_column_letter 11 | from openpyxl.worksheet.worksheet import Worksheet 12 | 13 | 14 | # Thanks: 15 | # - Python_Openpyxl: https://www.cnblogs.com/programmer-tlh/p/10461353.html 16 | # - Python之re模块: https://www.cnblogs.com/shenjianping/p/11647473.html 17 | class _BaseExcelRecorder(object): 18 | def __init__(self, xlsx_path: str): 19 | """ 20 | 提供写xlsx文档功能的基础类。主要基于openpyxl实现了一层更方便的封装。 21 | 22 | :param xlsx_path: xlsx文档的路径。 23 | """ 24 | self.xlsx_path = xlsx_path 25 | if not os.path.exists(self.xlsx_path): 26 | print("We have created a new excel file!!!") 27 | self._initial_xlsx() 28 | else: 29 | print("Excel file has existed!") 30 | 31 | def _initial_xlsx(self): 32 | Workbook().save(self.xlsx_path) 33 | 34 | def load_sheet(self, sheet_name: str): 35 | wb = load_workbook(self.xlsx_path) 36 | if sheet_name not in wb.sheetnames: 37 | wb.create_sheet(title=sheet_name, index=0) 38 | sheet = wb[sheet_name] 39 | 40 | return wb, sheet 41 | 42 | def append_row(self, sheet: Worksheet, row_data): 43 | assert isinstance(row_data, (tuple, list)) 44 | sheet.append(row_data) 45 | 46 | def insert_row(self, sheet: Worksheet, row_data, row_id, min_col=1, interval=0): 47 | assert isinstance(row_id, int) and isinstance(min_col, int) and row_id > 0 and min_col > 0 48 | assert isinstance(row_data, (tuple, list)), row_data 49 | 50 | num_elements = len(row_data) 51 | row_data = iter(row_data) 52 | for row in sheet.iter_rows( 53 | min_row=row_id, 54 | max_row=row_id, 55 | min_col=min_col, 56 | max_col=min_col + (interval + 1) * (num_elements - 1), 57 | ): 58 | for i, cell in enumerate(row): 59 | if i % (interval + 1) == 0: 60 | sheet.cell(row=row_id, column=cell.column, value=next(row_data)) 61 | 62 | @staticmethod 63 | def merge_region(sheet: Worksheet, min_row, max_row, min_col, max_col): 64 | assert max_row >= min_row > 0 and max_col >= min_col > 0 65 | 66 | merged_region = ( 67 | f"{get_column_letter(min_col)}{min_row}:{get_column_letter(max_col)}{max_row}" 68 | ) 69 | sheet.merge_cells(merged_region) 70 | 71 | @staticmethod 72 | def get_col_id_with_row_id(sheet: Worksheet, col_name: str, row_id): 73 | """ 74 | 从指定行中寻找特定的列名,并返回对应的列序号 75 | """ 76 | assert isinstance(row_id, int) and row_id > 0 77 | 78 | for cell in sheet[row_id]: 79 | if cell.value == col_name: 80 | return cell.column 81 | raise ValueError(f"In row {row_id}, there is not the column {col_name}!") 82 | 83 | def get_row_id_with_col_name(self, sheet: Worksheet, row_name: str, col_name: str): 84 | """ 85 | 从指定列名字的一列中寻找指定行,返回对应的row_id, col_id, is_new_row 86 | """ 87 | is_new_row = True 88 | col_id = self.get_col_id_with_row_id(sheet=sheet, col_name=col_name, row_id=1) 89 | 90 | row_id = 0 91 | for cell in sheet[get_column_letter(col_id)]: 92 | row_id = cell.row 93 | if cell.value == row_name: 94 | return (row_id, col_id), not is_new_row 95 | return (row_id + 1, col_id), is_new_row 96 | 97 | @staticmethod 98 | def get_row_id_with_col_id(sheet: Worksheet, row_name: str, col_id: int): 99 | """ 100 | 从指定序号的一列中寻找指定行 101 | """ 102 | assert isinstance(col_id, int) and col_id > 0 103 | 104 | is_new_row = True 105 | row_id = 0 106 | for cell in sheet[get_column_letter(col_id)]: 107 | row_id = cell.row 108 | if cell.value == row_name: 109 | return row_id, not is_new_row 110 | return row_id + 1, is_new_row 111 | 112 | @staticmethod 113 | def format_string_with_config(string: str, repalce_config: dict = None): 114 | assert repalce_config is not None 115 | if repalce_config.get("lower"): 116 | string = string.lower() 117 | elif repalce_config.get("upper"): 118 | string = string.upper() 119 | elif repalce_config.get("title"): 120 | string = string.title() 121 | 122 | sub_rule = repalce_config.get("replace") 123 | if sub_rule: 124 | string = re.sub(pattern=sub_rule[0], repl=sub_rule[1], string=string) 125 | return string 126 | 127 | 128 | class MetricExcelRecorder(_BaseExcelRecorder): 129 | def __init__( 130 | self, 131 | xlsx_path: str, 132 | sheet_name: str = None, 133 | repalce_config=None, 134 | row_header=None, 135 | dataset_names=None, 136 | metric_names=None, 137 | ): 138 | """ 139 | 向xlsx文档写数据的类 140 | 141 | :param xlsx_path: 对应的xlsx文档路径 142 | :param sheet_name: 要写入数据对应的sheet名字 143 | 默认为 `results` 144 | :param repalce_config: 用于替换对应数据字典的键的模式,会被用于re.sub来进行替换 145 | 默认为 dict(lower=True, replace=(r"[_-]", "")) 146 | :param row_header: 用于指定表格工作表左上角的内容,这里默认为 `["methods", "num_data"]` 147 | :param dataset_names: 对应的数据集名称列表 148 | 默认为rgb sod的数据集合 ["pascals", "ecssd", "hkuis", "dutste", "dutomron"] 149 | :param metric_names: 对应指标名称列表 150 | 默认为 ["smeasure","wfmeasure","mae","adpfm","meanfm","maxfm","adpem","meanem","maxem"] 151 | """ 152 | super().__init__(xlsx_path=xlsx_path) 153 | if sheet_name is None: 154 | sheet_name = "results" 155 | 156 | if repalce_config is None: 157 | self.repalce_config = dict(lower=True, replace=(r"[_-]", "")) 158 | else: 159 | self.repalce_config = repalce_config 160 | 161 | if row_header is None: 162 | row_header = ["methods", "num_data"] 163 | self.row_header = [ 164 | self.format_string_with_config(s, self.repalce_config) for s in row_header 165 | ] 166 | if dataset_names is None: 167 | dataset_names = ["pascals", "ecssd", "hkuis", "dutste", "dutomron"] 168 | self.dataset_names = [ 169 | self.format_string_with_config(s, self.repalce_config) for s in dataset_names 170 | ] 171 | if metric_names is None: 172 | metric_names = [ 173 | "smeasure", 174 | "wfmeasure", 175 | "mae", 176 | "adpfm", 177 | "meanfm", 178 | "maxfm", 179 | "adpem", 180 | "meanem", 181 | "maxem", 182 | ] 183 | self.metric_names = [ 184 | self.format_string_with_config(s, self.repalce_config) for s in metric_names 185 | ] 186 | 187 | self.sheet_name = sheet_name 188 | self._initial_table() 189 | 190 | def _initial_table(self): 191 | wb, sheet = self.load_sheet(sheet_name=self.sheet_name) 192 | 193 | # 插入row_header 194 | self.insert_row(sheet=sheet, row_data=self.row_header, row_id=1, min_col=1) 195 | # 合并row_header的单元格 196 | for col_id in range(len(self.row_header)): 197 | self.merge_region( 198 | sheet=sheet, min_row=1, max_row=2, min_col=col_id + 1, max_col=col_id + 1 199 | ) 200 | 201 | # 插入数据集信息 202 | self.insert_row( 203 | sheet=sheet, 204 | row_data=self.dataset_names, 205 | row_id=1, 206 | min_col=len(self.row_header) + 1, 207 | interval=len(self.metric_names) - 1, 208 | ) 209 | 210 | # 插入指标信息 211 | start_col = len(self.row_header) + 1 212 | for i in range(len(self.dataset_names)): 213 | self.insert_row( 214 | sheet=sheet, 215 | row_data=self.metric_names, 216 | row_id=2, 217 | min_col=start_col + i * len(self.metric_names), 218 | ) 219 | wb.save(self.xlsx_path) 220 | 221 | def _format_row_data(self, row_data: dict) -> list: 222 | row_data = { 223 | self.format_string_with_config(k, self.repalce_config): v for k, v in row_data.items() 224 | } 225 | return [row_data[n] for n in self.metric_names] 226 | 227 | def __call__(self, row_data: dict, dataset_name: str, method_name: str): 228 | dataset_name = self.format_string_with_config(dataset_name, self.repalce_config) 229 | assert ( 230 | dataset_name in self.dataset_names 231 | ), f"{dataset_name} is not contained in {self.dataset_names}" 232 | 233 | # 1 载入数据表 234 | wb, sheet = self.load_sheet(sheet_name=self.sheet_name) 235 | # 2 搜索method_name是否存在,如果存在则直接寻找对应的行列坐标,如果不存在则直接使用新行 236 | dataset_col_start_id = self.get_col_id_with_row_id( 237 | sheet=sheet, col_name=dataset_name, row_id=1 238 | ) 239 | (method_row_id, method_col_id), is_new_row = self.get_row_id_with_col_name( 240 | sheet=sheet, row_name=method_name, col_name="methods" 241 | ) 242 | # 3 插入方法名字到对应的位置 243 | if is_new_row: 244 | sheet.cell(row=method_row_id, column=method_col_id, value=method_name) 245 | # 4 格式化指标数据部分为合理的格式,并插入表中 246 | row_data = self._format_row_data(row_data=row_data) 247 | self.insert_row( 248 | sheet=sheet, row_data=row_data, row_id=method_row_id, min_col=dataset_col_start_id 249 | ) 250 | # 4 写入新表 251 | wb.save(self.xlsx_path) 252 | -------------------------------------------------------------------------------- /utils/recorders/metric_recorder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/1/4 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | from collections import OrderedDict 6 | 7 | import numpy as np 8 | import py_sod_metrics 9 | 10 | 11 | def ndarray_to_basetype(data): 12 | """ 13 | 将单独的ndarray,或者tuple,list或者dict中的ndarray转化为基本数据类型, 14 | 即列表(.tolist())和python标量 15 | """ 16 | 17 | def _to_list_or_scalar(item): 18 | listed_item = item.tolist() 19 | if isinstance(listed_item, list) and len(listed_item) == 1: 20 | listed_item = listed_item[0] 21 | return listed_item 22 | 23 | if isinstance(data, (tuple, list)): 24 | results = [_to_list_or_scalar(item) for item in data] 25 | elif isinstance(data, dict): 26 | results = {k: _to_list_or_scalar(item) for k, item in data.items()} 27 | else: 28 | assert isinstance(data, np.ndarray) 29 | results = _to_list_or_scalar(data) 30 | return results 31 | 32 | 33 | def round_w_zero_padding(x, bit_width): 34 | x = str(round(x, bit_width)) 35 | x += "0" * (bit_width - len(x.split(".")[-1])) 36 | return x 37 | 38 | 39 | INDIVADUAL_METRIC_MAPPING = { 40 | "mae": py_sod_metrics.MAE, 41 | # "fm": py_sod_metrics.Fmeasure, 42 | "em": py_sod_metrics.Emeasure, 43 | "sm": py_sod_metrics.Smeasure, 44 | "wfm": py_sod_metrics.WeightedFmeasure, 45 | "msiou": py_sod_metrics.MSIoU, 46 | } 47 | 48 | # fmt: off 49 | gray_metric_kwargs = dict(with_dynamic=True, with_adaptive=True, with_binary=False, sample_based=True) 50 | binary_metric_kwargs = dict(with_dynamic=False, with_adaptive=False, with_binary=True, sample_based=False) 51 | BINARY_METRIC_MAPPING = { 52 | "fmeasure": {"handler": py_sod_metrics.FmeasureHandler, "kwargs": dict(**gray_metric_kwargs, beta=0.3)}, 53 | "f1": {"handler": py_sod_metrics.FmeasureHandler, "kwargs": dict(**gray_metric_kwargs, beta=1)}, 54 | "precision": {"handler": py_sod_metrics.PrecisionHandler, "kwargs": gray_metric_kwargs}, 55 | "recall": {"handler": py_sod_metrics.RecallHandler, "kwargs": gray_metric_kwargs}, 56 | "iou": {"handler": py_sod_metrics.IOUHandler, "kwargs": gray_metric_kwargs}, 57 | "dice": {"handler": py_sod_metrics.DICEHandler, "kwargs": gray_metric_kwargs}, 58 | "specificity": {"handler": py_sod_metrics.SpecificityHandler, "kwargs": gray_metric_kwargs}, 59 | # 60 | "bif1": {"handler": py_sod_metrics.FmeasureHandler, "kwargs": dict(**binary_metric_kwargs, beta=1)}, 61 | "biprecision": {"handler": py_sod_metrics.PrecisionHandler, "kwargs": binary_metric_kwargs}, 62 | "birecall": {"handler": py_sod_metrics.RecallHandler, "kwargs": binary_metric_kwargs}, 63 | "biiou": {"handler": py_sod_metrics.IOUHandler, "kwargs": binary_metric_kwargs}, 64 | "bioa": {"handler": py_sod_metrics.OverallAccuracyHandler, "kwargs": binary_metric_kwargs}, 65 | "bikappa": {"handler": py_sod_metrics.KappaHandler, "kwargs": binary_metric_kwargs}, 66 | } 67 | GRAYSCALE_METRICS = ["em"] + [k for k in BINARY_METRIC_MAPPING.keys() if not k.startswith("bi")] 68 | SUPPORTED_METRICS = ["mae", "em", "sm", "wfm", "msiou"] + sorted(BINARY_METRIC_MAPPING.keys()) 69 | # fmt: on 70 | 71 | 72 | class GrayscaleMetricRecorder: 73 | # 'fm' is replaced by 'fmeasure' in BINARY_METRIC_MAPPING 74 | suppoted_metrics = ["mae", "em", "sm", "wfm", "msiou"] + sorted( 75 | [k for k in BINARY_METRIC_MAPPING.keys() if not k.startswith("bi")] 76 | ) 77 | 78 | def __init__(self, metric_names=None): 79 | """ 80 | 用于统计各种指标的类 81 | """ 82 | if not metric_names: 83 | metric_names = self.suppoted_metrics 84 | assert all( 85 | [m in self.suppoted_metrics for m in metric_names] 86 | ), f"Only support: {self.suppoted_metrics}" 87 | 88 | self.metric_objs = {} 89 | has_existed = False 90 | for metric_name in metric_names: 91 | if metric_name in INDIVADUAL_METRIC_MAPPING: 92 | self.metric_objs[metric_name] = INDIVADUAL_METRIC_MAPPING[metric_name]() 93 | else: # metric_name in BINARY_METRIC_MAPPING 94 | if not has_existed: # only init once 95 | self.metric_objs["fmeasurev2"] = py_sod_metrics.FmeasureV2() 96 | has_existed = True 97 | metric_handler = BINARY_METRIC_MAPPING[metric_name] 98 | self.metric_objs["fmeasurev2"].add_handler( 99 | handler_name=metric_name, 100 | metric_handler=metric_handler["handler"](**metric_handler["kwargs"]), 101 | ) 102 | 103 | def step(self, pre: np.ndarray, gt: np.ndarray, gt_path: str): 104 | assert pre.shape == gt.shape, (pre.shape, gt.shape, gt_path) 105 | assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype, gt_path) 106 | 107 | for m_obj in self.metric_objs.values(): 108 | m_obj.step(pre, gt) 109 | 110 | def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: 111 | """ 112 | 返回指标计算结果: 113 | 114 | - 曲线数据(sequential) 115 | - 数值指标(numerical) 116 | """ 117 | sequential_results = {} 118 | numerical_results = {} 119 | for m_name, m_obj in self.metric_objs.items(): 120 | info = m_obj.get_results() 121 | if m_name == "fmeasurev2": 122 | for _name, results in info.items(): 123 | dynamic_results = results.get("dynamic") 124 | adaptive_results = results.get("adaptive") 125 | if dynamic_results is not None: 126 | sequential_results[_name] = np.flip(dynamic_results) 127 | numerical_results[f"max{_name}"] = dynamic_results.max() 128 | numerical_results[f"avg{_name}"] = dynamic_results.mean() 129 | if adaptive_results is not None: 130 | numerical_results[f"adp{_name}"] = adaptive_results 131 | else: 132 | results = info[m_name] 133 | if m_name in ("wfm", "sm", "mae", "msiou"): 134 | numerical_results[m_name] = results 135 | elif m_name == "em": 136 | sequential_results[m_name] = np.flip(results["curve"]) 137 | numerical_results.update( 138 | { 139 | "maxem": results["curve"].max(), 140 | "avgem": results["curve"].mean(), 141 | "adpem": results["adp"], 142 | } 143 | ) 144 | else: 145 | raise NotImplementedError(m_name) 146 | 147 | if num_bits is not None and isinstance(num_bits, int): 148 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()} 149 | if not return_ndarray: 150 | sequential_results = ndarray_to_basetype(sequential_results) 151 | numerical_results = ndarray_to_basetype(numerical_results) 152 | return {"sequential": sequential_results, "numerical": numerical_results} 153 | 154 | 155 | class BinaryMetricRecorder: 156 | suppoted_metrics = sorted([k for k in BINARY_METRIC_MAPPING.keys() if k.startswith("bi")]) 157 | 158 | def __init__(self, metric_names=("bif1", "biprecision", "birecall", "biiou", "bioa")): 159 | """ 160 | 用于统计各种指标的类 161 | """ 162 | if not metric_names: 163 | metric_names = self.suppoted_metrics 164 | assert all( 165 | [m in self.suppoted_metrics for m in metric_names] 166 | ), f"Only support: {self.suppoted_metrics}" 167 | 168 | self.metric_objs = {"fmeasurev2": py_sod_metrics.FmeasureV2()} 169 | for metric_name in metric_names: 170 | # metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING 171 | metric_handler = BINARY_METRIC_MAPPING[metric_name] 172 | self.metric_objs["fmeasurev2"].add_handler( 173 | handler_name=metric_name, 174 | metric_handler=metric_handler["handler"](**metric_handler["kwargs"]), 175 | ) 176 | 177 | def step(self, pre: np.ndarray, gt: np.ndarray, gt_path: str): 178 | assert pre.shape == gt.shape, (pre.shape, gt.shape, gt_path) 179 | assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype, gt_path) 180 | 181 | for m_name, m_obj in self.metric_objs.items(): 182 | m_obj.step(pre, gt, normalize=True) 183 | 184 | def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: 185 | numerical_results = {} 186 | for m_name, m_obj in self.metric_objs.items(): 187 | info = m_obj.get_results() 188 | assert m_name == "fmeasurev2" 189 | for _name, results in info.items(): 190 | binary_results = results.get("binary") 191 | if binary_results is not None: 192 | numerical_results[_name] = binary_results 193 | 194 | if num_bits is not None and isinstance(num_bits, int): 195 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()} 196 | if not return_ndarray: 197 | numerical_results = ndarray_to_basetype(numerical_results) 198 | return {"numerical": numerical_results} 199 | 200 | 201 | class GroupedMetricRecorder: 202 | def __init__( 203 | self, group_names=None, metric_names=("sm", "wfm", "mae", "fmeasure", "em", "iou", "dice") 204 | ): 205 | self.group_names = group_names 206 | self.metric_names = metric_names 207 | self.zero() 208 | 209 | def zero(self): 210 | self.metric_recorders = {} 211 | if self.group_names is not None: 212 | self.metric_recorders.update( 213 | { 214 | n: GrayscaleMetricRecorder(metric_names=self.metric_names) 215 | for n in self.group_names 216 | } 217 | ) 218 | 219 | def step(self, group_name: str, pre: np.ndarray, gt: np.ndarray, gt_path: str): 220 | if group_name not in self.metric_recorders: 221 | self.metric_recorders[group_name] = GrayscaleMetricRecorder( 222 | metric_names=self.metric_names 223 | ) 224 | self.metric_recorders[group_name].step(pre, gt, gt_path) 225 | 226 | def show(self, num_bits: int = 3, return_group: bool = False): 227 | groups_metrics = { 228 | n: r.show(num_bits=None, return_ndarray=True) for n, r in self.metric_recorders.items() 229 | } 230 | 231 | results = {} # collect all group metrics into a list 232 | for group_name, group_metrics in groups_metrics.items(): 233 | for metric_type, metric_group in group_metrics.items(): 234 | # metric_type: sequential and numerical 235 | results.setdefault(metric_type, {}) 236 | for metric_name, metric_array in metric_group.items(): 237 | results[metric_type].setdefault(metric_name, []).append(metric_array) 238 | 239 | numerical_results = {} 240 | sequential_results = {} 241 | for metric_type, metric_group in results.items(): 242 | for metric_name, metric_arrays in metric_group.items(): 243 | metric_array = np.mean(np.vstack(metric_arrays), axis=0) # average over all groups 244 | 245 | if metric_name in BINARY_METRIC_MAPPING or metric_name == "em": 246 | if metric_type == "sequential": 247 | numerical_results[f"max{metric_name}"] = metric_array.max() 248 | numerical_results[f"avg{metric_name}"] = metric_array.mean() 249 | sequential_results[metric_name] = metric_array 250 | else: 251 | if metric_type == "numerical": 252 | if metric_name.startswith(("max", "avg")): 253 | # metrics (maxfm, avgfm, maxem, avgem) will be recomputed within the group 254 | continue 255 | numerical_results[metric_name] = metric_array 256 | 257 | sequential_results = ndarray_to_basetype(sequential_results) 258 | if not return_group: 259 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()} 260 | numerical_results = ndarray_to_basetype(numerical_results) 261 | numerical_results = self.sort_results(numerical_results) 262 | return {"sequential": sequential_results, "numerical": numerical_results} 263 | else: 264 | group_numerical_results = {} 265 | for group_name, group_metric in groups_metrics.items(): 266 | group_metric = {k: v.round(num_bits) for k, v in group_metric["numerical"].items()} 267 | group_metric = ndarray_to_basetype(group_metric) 268 | group_numerical_results[group_name] = self.sort_results(group_metric) 269 | return {"sequential": sequential_results, "numerical": group_numerical_results} 270 | 271 | def sort_results(self, results: dict) -> OrderedDict: 272 | """for a single group of metrics""" 273 | sorted_results = OrderedDict() 274 | all_keys = sorted(results.keys(), key=lambda item: item[::-1]) 275 | for name in self.metric_names: 276 | for key in all_keys: 277 | if key.endswith(name): 278 | sorted_results[key] = results[key] 279 | return sorted_results 280 | -------------------------------------------------------------------------------- /utils/recorders/txt_recorder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/1/4 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | from datetime import datetime 7 | 8 | 9 | class TxtRecorder: 10 | def __init__(self, txt_path, to_append=True, max_method_name_width=10): 11 | """ 12 | 用于向txt文档写数据的类。 13 | 14 | :param txt_path: txt文档路径 15 | :param to_append: 是否要继续使用之前的文档,如果没有就重新创建 16 | :param max_method_name_width: 方法字符串的最大长度 17 | """ 18 | self.txt_path = txt_path 19 | self.max_method_name_width = max_method_name_width 20 | mode = "a" if to_append else "w" 21 | with open(txt_path, mode=mode, encoding="utf-8") as f: 22 | f.write(f"\n ========>> Date: {datetime.now()} <<======== \n") 23 | self.row_names = [] 24 | 25 | def add_row(self, row_name, row_data, row_start_str="", row_end_str="\n"): 26 | self.row_names.append(row_name) 27 | with open(self.txt_path, mode="a", encoding="utf-8") as f: 28 | f.write(f"{row_start_str} ========>> {row_name}: {row_data} <<======== {row_end_str}") 29 | 30 | def __call__( 31 | self, 32 | method_results: dict, 33 | method_name: str = "", 34 | row_start_str="", 35 | row_end_str="\n", 36 | value_width=6, 37 | ): 38 | msg = row_start_str 39 | if len(method_name) > self.max_method_name_width: 40 | method_name = method_name[: self.max_method_name_width - 3] + "..." 41 | else: 42 | method_name += " " * (self.max_method_name_width - len(method_name)) 43 | msg += f"[{method_name}] " 44 | for metric_name, metric_value in method_results.items(): 45 | assert isinstance(metric_value, float) 46 | msg += f"{metric_name}: " 47 | real_width = len(str(metric_value)) 48 | if value_width > real_width: 49 | # 后补空格 50 | msg += f"{metric_value}" + " " * (value_width - real_width) 51 | else: 52 | # 真实数据长度超过了限定,这时需要近似保留小数 53 | # 保留指定位数,注意,这里由于数据都是0~1之间的数据,所以使用round的时候需要去掉前面的`0.` 54 | msg += f"{round(metric_value, ndigits=value_width - 2)}" 55 | msg += " " 56 | msg += row_end_str 57 | with open(self.txt_path, mode="a", encoding="utf-8") as f: 58 | f.write(msg) 59 | --------------------------------------------------------------------------------