├── .gitignore ├── LICENSE ├── README.md ├── assets ├── benchmark_camo.png ├── benchmark_cod10k.png ├── benchmark_nc4k.png ├── cds2k-benchmark.png ├── cds2k-statistics.png ├── cds2k.png ├── cos_quali_viz.png ├── csu-logo.png ├── dataset_sample_gallery.png ├── reviewed_datasets.png ├── reviewed_image_methods.png ├── reviewed_video_methods.png └── task_definition.png └── cos_eval_toolbox ├── .editorconfig ├── .github └── workflows │ └── publish.yml ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── eval.py ├── examples_COS ├── config_cos_dataset_py_example.json ├── config_cos_dataset_py_example.py ├── config_cos_method_py_example.json └── config_cos_method_py_example.py ├── metrics ├── __init__.py ├── cal_cosod_matrics.py ├── cal_sod_matrics.py ├── draw_curves.py └── extra_metrics.py ├── output_COS ├── cos_curves.npy ├── cos_metrics.npy └── cos_results.txt ├── plot.py ├── pyproject.toml ├── requirements.txt ├── tools ├── .backup │ ├── individual_metrics.py │ └── ranking_per_images.py ├── append_results.py ├── cal_avg_resolution.py ├── check_path.py ├── converter.py ├── generate_cos_config_files.py ├── generate_latex_code.py ├── info_py_to_json.py ├── markdown2html.py ├── readme.md └── rename.py └── utils ├── __init__.py ├── generate_info.py ├── misc.py ├── print_formatter.py └── recorders ├── __init__.py ├── curve_drawer.py ├── excel_recorder.py ├── metric_recorder.py └── txt_recorder.py /.gitignore: -------------------------------------------------------------------------------- 1 | cos_eval_toolbox/benchmark/ 2 | cos_eval_toolbox/.backup/ 3 | cos_eval_toolbox/dataset/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Deng-Ping Fan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Advances in Deep Concealed Scene Understanding 2 | 3 | 4 | 5 | This repository contains a collection of research papers, an evaluation toolbox, and benchmarking results for the task of concealed object segmentation (COS) in images. Besides, to evaluate the generalizability of COS approaches, we re-organize a concealed defect segmentation dataset named CDS2K. 6 | 7 | - Paper link: [arXiv](https://arxiv.org/abs/2304.11234) 8 | - This project is under construction. Contributions are welcome! If you would like to contribute to this repository, please submit a pull request. 9 | 10 | ## Table of Contents 11 | 12 | - [Advances in Deep Concealed Scene Understanding](#advances-in-deep-concealed-scene-understanding) 13 | - [Table of Contents](#table-of-contents) 14 | - [CSU Background](#csu-background) 15 | - [CSU Taxonomy](#csu-taxonomy) 16 | - [CSU Survey](#csu-survey) 17 | - [CSU Benchmark](#csu-benchmark) 18 | - [Defect Segmentation Dataset -- CDS2K](#defect-segmentation-dataset----cds2k) 19 | - [Citation](#citation) 20 | 21 | ## CSU Background 22 | 23 | Concealed scene understanding (CSU) is a hot computer vision topic aiming to perceive objects with camouflaged properties. The current boom in its advanced techniques and novel applications makes it timely to provide an up-to-date survey to enable researchers to understand the global picture of the CSU field, including both current achievements and major challenges. 24 | 25 |

26 |
27 | 28 | Figure 1: Sample gallery of concealed scenarios. (a-d) show natural animals. (e) depicts a concealed human in art. (f) features a synthesized ``lion''. 29 | 30 |

31 | 32 | This paper makes four contributions: 33 | - For the first time, we present **a comprehensive survey** of the deep learning techniques oriented at CSU, including a background with its taxonomy, task-unique challenges, and a review of its developments in the deep learning era via surveying existing datasets and deep techniques. 34 | - For a quantitative comparison of the state-of-the-art, we contribute **the largest and latest benchmark** for Concealed Object Segmentation (COS). 35 | - To evaluate the transferability of deep CSU in practical scenarios, 36 | we re-organize **the largest concealed defect segmentation dataset** termed CDS2K with the hard cases from diversified industrial scenarios, on which we construct a comprehensive benchmark. 37 | - We **discuss open problems and potential research directions** for this community. 38 | 39 | ## CSU Taxonomy 40 | 41 | We introduce a taxonomy of seven popular CSU tasks. Please refer to Section 2.1 of our paper for more details. 42 | - Five of these are image-level tasks: (a) concealed object segmentation (COS), (b) concealed object localization (COL), (c) concealed instance ranking (CIR), (d) concealed instance segmentation (CIS), and (e) concealed object counting (COC). 43 | - The remaining two are video-level tasks: (f) video concealed object segmentation (VCOS) and (g) video concealed object detection (VCOD). 44 | 45 | We illustrate each task with its corresponding annotation visualization. 46 | 47 |

48 |
49 | 50 | Figure 2: Illustration of representative CSU tasks. 51 | 52 |

53 | 54 | ## CSU Survey 55 | 56 | We recap the latest image-based research that includes 50 papers. 57 | 58 |

59 |
60 | 61 | Table 1: Essential characteristics of reviewed video-level CSU methods. 62 | 63 |

64 | 65 | We also review recent nine video-based research 66 | 67 |

68 |
69 | 70 | Table 2: Essential characteristics of reviewed video-level CSU methods. 71 | 72 |

73 | 74 | The following are ten datasets collected for several CSU-related tasks. 75 | 76 |

77 |
78 | 79 | Table 3: Essential characteristics of reviewed video-level CSU methods. 80 | 81 |

82 | 83 | 84 | ## CSU Benchmark 85 | 86 | Our benchmarking is built on COS tasks since this topic is relatively well-established and offers a variety of competing approaches. **WHAT DO WE PROVIDE HERE?** 87 | 88 | - First, we provide a one-key [evaluation toolbox](https://github.com/DengPingFan/CSU/tree/main/cos_eval_toolbox) for CSU. Please the follow instructions and then you will get the results. 89 | - Second, we run COS approaches on three popular benchmarks (CAMO, NC4K, and COD10K) and organize them into the standard format (*png) [Google Drive, 1.16GB](https://drive.google.com/file/d/1v5AZ37YlSjKiBMrfYXhZ9wa9dJLyFuVD/view?usp=sharing). The collection of these prediction masks is public [here (Google Drive, 4.82GB)](https://drive.google.com/file/d/1BPyE6KtQvi8f1gL0IVsILlkv5C00GZYg/view?usp=sharing) for convenient research. 90 | - The benchmark results on nine evaluation metrics are reported in the next three tables. You can find the text file [here](https://github.com/DengPingFan/CSU/tree/main/cos_eval_toolbox/output_COS). 91 | 92 |

93 |
94 | 95 | Table 4: Quantitative comparison of CAMO testing set. 96 | 97 |

98 | 99 |

100 |
101 | 102 | Table 5: Quantitative comparison on NC4K testing set. 103 | 104 |

105 | 106 |

107 |
108 | 109 | Table 6: Quantitative comparison of COD10K testing set. 110 | 111 |

112 | 113 | - Lastly, we provide the attribute-based analyses on the COD10K dataset 114 | 115 |

116 |
117 | 118 | Figure 3: Qualitative results of ten COS approaches. For more descriptions of visual attributes in each column refer to Section 5.6 of the paper. 119 | 120 |

121 | 122 | ## Defect Segmentation Dataset -- CDS2K 123 | 124 | We organize a concealed defect segmentation dataset ([Google Drive, 159MB](https://drive.google.com/file/d/1OGPR34qCNWHVYwyf9OY6IH-7WHzPkC7-/view?usp=sharing)) from the five well-known defect segmentation databases. As shown in Figure 4, we present five sub-databases: (a-l) MVTecAD, (m-o) NEU, (p) CrackForest, (q) KolektorSDD, and (r) MagneticTile. The defective regions are highlighted with red rectangles. (Top-Right) Word cloud visualization of CDS2K. (Bottom) The statistic number of positive/negative samples of each category in our CDS2K. 125 | 126 |

127 |
128 | 129 | Figure 4: Sample gallery of our CDS2K. 130 | 131 |

132 | 133 | The average ratio of defective regions for each category is presented in Table 7, which indicates that most of the defective regions are relatively small 134 | 135 |

136 |
137 | 138 | Table 7: Sample gallery of our CDS2K. 139 | 140 |

141 | 142 | Next, we report the quantitative comparison on the positive samples of CDS2K. Kindly download the result map on [Google Drive (116.6MB)](https://drive.google.com/file/d/1GIP0hdppaBJxV1SSRSOIccn1UgFm0J-l/view?usp=sharing). 143 | 144 |

145 |
146 | 147 | Table 8: SQuantitative comparison on the positive samples of CDS2K. 148 | 149 |

150 | 151 | 152 | 153 | ## Citation 154 | 155 | Please cite our paper if you find the work useful: 156 | 157 | @article{fan2023csu, 158 | title={Advances in Deep Concealed Scene Understanding}, 159 | author={Fan, Deng-Ping and Ji, Ge-Peng and Xu, Peng and Cheng, Ming-Ming and Sakaridis, Christos and Van Gool, Luc}, 160 | journal={Visual Intelligence (VI)}, 161 | year={2023} 162 | } 163 | -------------------------------------------------------------------------------- /assets/benchmark_camo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/benchmark_camo.png -------------------------------------------------------------------------------- /assets/benchmark_cod10k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/benchmark_cod10k.png -------------------------------------------------------------------------------- /assets/benchmark_nc4k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/benchmark_nc4k.png -------------------------------------------------------------------------------- /assets/cds2k-benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/cds2k-benchmark.png -------------------------------------------------------------------------------- /assets/cds2k-statistics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/cds2k-statistics.png -------------------------------------------------------------------------------- /assets/cds2k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/cds2k.png -------------------------------------------------------------------------------- /assets/cos_quali_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/cos_quali_viz.png -------------------------------------------------------------------------------- /assets/csu-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/csu-logo.png -------------------------------------------------------------------------------- /assets/dataset_sample_gallery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/dataset_sample_gallery.png -------------------------------------------------------------------------------- /assets/reviewed_datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/reviewed_datasets.png -------------------------------------------------------------------------------- /assets/reviewed_image_methods.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/reviewed_image_methods.png -------------------------------------------------------------------------------- /assets/reviewed_video_methods.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/reviewed_video_methods.png -------------------------------------------------------------------------------- /assets/task_definition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/task_definition.png -------------------------------------------------------------------------------- /cos_eval_toolbox/.editorconfig: -------------------------------------------------------------------------------- 1 | # https://editorconfig.org/ 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | insert_final_newline = true 9 | trim_trailing_whitespace = true 10 | end_of_line = lf 11 | charset = utf-8 12 | 13 | # Docstrings and comments use max_line_length = 79 14 | [*.py] 15 | max_line_length = 99 16 | 17 | # Use 2 spaces for the HTML files 18 | [*.html] 19 | indent_size = 2 20 | 21 | # The JSON files contain newlines inconsistently 22 | [*.json] 23 | indent_size = 2 24 | insert_final_newline = ignore 25 | 26 | [**/admin/js/vendor/**] 27 | indent_style = ignore 28 | indent_size = ignore 29 | 30 | # Minified JavaScript files shouldn't be changed 31 | [**.min.js] 32 | indent_style = ignore 33 | insert_final_newline = ignore 34 | 35 | # Makefiles always use tabs for indentation 36 | [Makefile] 37 | indent_style = space 38 | 39 | # Batch files use tabs for indentation 40 | [*.bat] 41 | indent_style = space 42 | 43 | [docs/**.txt] 44 | max_line_length = 119 45 | -------------------------------------------------------------------------------- /cos_eval_toolbox/.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: auto-publish 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | auto-update: # job_id 10 | name: Markdown2HTML # 作业显示在 GitHub 上的名称。 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Check out repository 14 | uses: actions/checkout@v2 15 | - name: Setup Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.8 # optional, default is 3.x 19 | - name: Convert markdown to html 20 | run: | 21 | pip install markdown2 22 | cd tools 23 | python markdown2html.py 24 | cd .. 25 | - name: Deploy 26 | uses: peaceiris/actions-gh-pages@v3 27 | with: 28 | github_token: ${{ secrets.GITHUB_TOKEN }} 29 | publish_dir: ./results/htmls 30 | -------------------------------------------------------------------------------- /cos_eval_toolbox/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.2.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-toml 11 | - id: check-added-large-files 12 | - id: fix-encoding-pragma 13 | - id: mixed-line-ending 14 | - repo: https://github.com/pycqa/isort 15 | rev: 5.6.4 16 | hooks: 17 | - id: isort 18 | - repo: https://github.com/psf/black 19 | rev: 20.8b1 20 | # Replace by any tag/version: https://github.com/psf/black/tags 21 | hooks: 22 | - id: black 23 | language_version: python3 24 | # Should be a command that runs python3.6+ 25 | -------------------------------------------------------------------------------- /cos_eval_toolbox/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 MY_ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cos_eval_toolbox/README.md: -------------------------------------------------------------------------------- 1 | # COS Evaluation Toolbox 2 | 3 | A Python-based concealed object segmentation (COS) evaluation toolbox. 4 | 5 | # Features 6 | 7 | This repo provides one-key processing for nine evaluation metrics 8 | 9 | - MAE 10 | - weighted F-measure 11 | - S-measure 12 | - max/average/adaptive F-measure 13 | - max/average/adaptive E-measure 14 | 15 | # One-command evaluation 16 | 17 | To evaluate concealed object segmentation (COS) approaches, you should prepare some libraries with the command: `pip install -r requirements.txt`. Then, download benchmark datasets ([OneDrive, 1.16GB](https://anu365-my.sharepoint.com/:u:/g/personal/u7248002_anu_edu_au/EWQU8s3I1cxLvQuYEt2g6gkBO4uwJ2bZq6Vuf9V1Hum7Lg?e=xMMcAr)) and prediction masks ([OneDrive, 4.82GB](https://anu365-my.sharepoint.com/:u:/g/personal/u7248002_anu_edu_au/Edk5mzHO5JNMv0LHDFBdTq4Bgrg_wmsmYg9hjOzh6-nAjw?e=xdVrT4)) just play with this command: 18 | 19 | ```bash 20 | python eval.py --dataset-json examples_COS/config_cos_dataset_py_example.json \ 21 | --method-json examples_COS/config_cos_method_py_example_all.json \ 22 | --metric-npy output_COS/cos_metrics.npy \ 23 | --curves-npy output_COS/cos_curves.npy \ 24 | --record-txt output_COS/cos_results.txt 25 | ``` 26 | 27 | Your results will store at `./cos_eval_toolbox/output_COS/cos_results.txt` 28 | 29 | 30 | # Custom your evaluation 31 | 32 | 1. Put your prediction masks into a custom file path like `./benchmark/COS-Benchmarking` and prepare your dataset like `./cos_eval_toolbox/dataset/COD10K/`. Then, generate the Python-style configs via 33 | 34 | ```bash 35 | python tools/generate_cos_config_files.py 36 | ``` 37 | 38 | 2. generate the JSON-style files via 39 | 40 | ```bash 41 | python tools/info_py_to_json.py -i ./examples_COS -o ./examples_COS 42 | ``` 43 | 44 | 3. check files via 45 | 46 | ```bash 47 | python tools/check_path.py -m examples_COS/config_cos_method_py_example.json -d examples_COS/config_cos_dataset_py_example.json 48 | ``` 49 | 50 | 4. start to evaluate 51 | 52 | ```bash 53 | python eval.py --dataset-json examples_COS/config_cos_dataset_py_example.json \ 54 | --method-json examples_COS/config_cos_method_py_example.json \ 55 | --metric-npy output_COS/cos_metrics.npy \ 56 | --curves-npy output_COS/cos_curves.npy \ 57 | --record-txt output_COS/cos_results.txt 58 | ``` 59 | 60 | # Citations 61 | 62 | ```text 63 | @inproceedings{Fmeasure, 64 | title={Frequency-tuned salient region detection}, 65 | author={Achanta, Radhakrishna and Hemami, Sheila and Estrada, Francisco and S{\"u}sstrunk, Sabine}, 66 | booktitle=CVPR, 67 | number={CONF}, 68 | pages={1597--1604}, 69 | year={2009} 70 | } 71 | 72 | @inproceedings{MAE, 73 | title={Saliency filters: Contrast based filtering for salient region detection}, 74 | author={Perazzi, Federico and Kr{\"a}henb{\"u}hl, Philipp and Pritch, Yael and Hornung, Alexander}, 75 | booktitle=CVPR, 76 | pages={733--740}, 77 | year={2012} 78 | } 79 | 80 | @inproceedings{Smeasure, 81 | title={Structure-measure: A new way to eval foreground maps}, 82 | author={Fan, Deng-Ping and Cheng, Ming-Ming and Liu, Yun and Li, Tao and Borji, Ali}, 83 | booktitle=ICCV, 84 | pages={4548--4557}, 85 | year={2017} 86 | } 87 | 88 | @inproceedings{Emeasure, 89 | title="Enhanced-alignment Measure for Binary Foreground Map Evaluation", 90 | author="Deng-Ping {Fan} and Cheng {Gong} and Yang {Cao} and Bo {Ren} and Ming-Ming {Cheng} and Ali {Borji}", 91 | booktitle=IJCAI, 92 | pages="698--704", 93 | year={2018} 94 | } 95 | 96 | @inproceedings{wFmeasure, 97 | title={How to eval foreground maps?}, 98 | author={Margolin, Ran and Zelnik-Manor, Lihi and Tal, Ayellet}, 99 | booktitle=CVPR, 100 | pages={248--255}, 101 | year={2014} 102 | } 103 | ``` 104 | 105 | # Acknowledgements 106 | 107 | This repo is built on [PySODEvalToolkit](https://github.com/lartpang/PySODEvalToolkit). We appreciate Dr Pang for his excellent work, please refer [README.md](https://github.com/lartpang/PySODEvalToolkit/blob/master/readme.md) for more interesting plays. -------------------------------------------------------------------------------- /cos_eval_toolbox/eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import os 4 | import textwrap 5 | import warnings 6 | 7 | from metrics import cal_sod_matrics 8 | from utils.generate_info import get_datasets_info, get_methods_info 9 | from utils.misc import make_dir 10 | from utils.recorders import METRIC_MAPPING 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser( 15 | description=textwrap.dedent( 16 | r""" 17 | INCLUDE: 18 | 19 | - F-measure-Threshold Curve 20 | - Precision-Recall Curve 21 | - MAE 22 | - weighted F-measure 23 | - S-measure 24 | - max/average/adaptive F-measure 25 | - max/average/adaptive E-measure 26 | - max/average Precision 27 | - max/average Sensitivity 28 | - max/average Specificity 29 | - max/average F-measure 30 | - max/average Dice 31 | - max/average IoU 32 | 33 | NOTE: 34 | 35 | - Our method automatically calculates the intersection of `pre` and `gt`. 36 | - Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext` 37 | 38 | EXAMPLES: 39 | 40 | python eval_all.py \ 41 | --dataset-json configs/datasets/json/rgbd_sod.json \ 42 | --method-json configs/methods/json/rgbd_other_methods.json configs/methods/json/rgbd_our_method.json --metric-npy output/rgbd_metrics.npy \ 43 | --curves-npy output/rgbd_curves.npy \ 44 | --record-txt output/rgbd_results.txt 45 | """ 46 | ), 47 | formatter_class=argparse.RawTextHelpFormatter, 48 | ) 49 | parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.") 50 | parser.add_argument( 51 | "--method-json", required=True, nargs="+", type=str, help="Json file for methods." 52 | ) 53 | parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.") 54 | parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.") 55 | parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.") 56 | parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.") 57 | parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.") 58 | parser.add_argument( 59 | "--include-methods", 60 | type=str, 61 | nargs="+", 62 | help="Names of only specific methods you want to evaluate.", 63 | ) 64 | parser.add_argument( 65 | "--exclude-methods", 66 | type=str, 67 | nargs="+", 68 | help="Names of some specific methods you do not want to evaluate.", 69 | ) 70 | parser.add_argument( 71 | "--include-datasets", 72 | type=str, 73 | nargs="+", 74 | help="Names of only specific datasets you want to evaluate.", 75 | ) 76 | parser.add_argument( 77 | "--exclude-datasets", 78 | type=str, 79 | nargs="+", 80 | help="Names of some specific datasets you do not want to evaluate.", 81 | ) 82 | parser.add_argument( 83 | "--num-workers", 84 | type=int, 85 | default=8, 86 | help="Number of workers for multi-threading or multi-processing. Default: 4", 87 | ) 88 | parser.add_argument( 89 | "--num-bits", 90 | type=int, 91 | default=3, 92 | help="Number of decimal places for showing results. Default: 3", 93 | ) 94 | parser.add_argument( 95 | "--metric-names", 96 | type=str, 97 | nargs="+", 98 | default=["mae", "fm", "em", "sm", "wfm"], 99 | choices=METRIC_MAPPING.keys(), 100 | help="Names of metrics", 101 | ) 102 | args = parser.parse_args() 103 | 104 | if args.metric_npy is not None: 105 | make_dir(os.path.dirname(args.metric_npy)) 106 | if args.curves_npy is not None: 107 | make_dir(os.path.dirname(args.curves_npy)) 108 | if args.record_txt is not None: 109 | make_dir(os.path.dirname(args.record_txt)) 110 | if args.record_xlsx is not None: 111 | make_dir(os.path.dirname(args.record_xlsx)) 112 | if args.to_overwrite and not args.record_txt: 113 | warnings.warn("--to-overwrite only works with a valid --record-txt") 114 | return args 115 | 116 | 117 | def main(): 118 | args = get_args() 119 | 120 | # 包含所有数据集信息的字典 121 | datasets_info = get_datasets_info( 122 | datastes_info_json=args.dataset_json, 123 | include_datasets=args.include_datasets, 124 | exclude_datasets=args.exclude_datasets, 125 | ) 126 | # 包含所有待比较模型结果的信息的字典 127 | methods_info = get_methods_info( 128 | methods_info_jsons=args.method_json, 129 | for_drawing=True, 130 | include_methods=args.include_methods, 131 | exclude_methods=args.exclude_methods, 132 | ) 133 | 134 | # 确保多进程在windows上也可以正常使用 135 | cal_sod_matrics.cal_sod_matrics( 136 | sheet_name="Results", 137 | to_append=not args.to_overwrite, 138 | txt_path=args.record_txt, 139 | xlsx_path=args.record_xlsx, 140 | methods_info=methods_info, 141 | datasets_info=datasets_info, 142 | curves_npy_path=args.curves_npy, 143 | metrics_npy_path=args.metric_npy, 144 | num_bits=args.num_bits, 145 | num_workers=args.num_workers, 146 | use_mp=False, 147 | metric_names=args.metric_names, 148 | ncols_tqdm=119, 149 | ) 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /cos_eval_toolbox/examples_COS/config_cos_dataset_py_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "CAMO": { 3 | "root": "./dataset/CAMO", 4 | "image": { 5 | "path": "./dataset/CAMO/Imgs", 6 | "suffix": ".jpg" 7 | }, 8 | "mask": { 9 | "path": "./dataset/CAMO/GT", 10 | "suffix": ".png" 11 | } 12 | }, 13 | "COD10K": { 14 | "root": "./dataset/COD10K", 15 | "image": { 16 | "path": "./dataset/COD10K/Imgs", 17 | "suffix": ".jpg" 18 | }, 19 | "mask": { 20 | "path": "./dataset/COD10K/GT", 21 | "suffix": ".png" 22 | } 23 | }, 24 | "NC4K": { 25 | "root": "./dataset/NC4K", 26 | "image": { 27 | "path": "./dataset/NC4K/Imgs", 28 | "suffix": ".jpg" 29 | }, 30 | "mask": { 31 | "path": "./dataset/NC4K/GT", 32 | "suffix": ".png" 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /cos_eval_toolbox/examples_COS/config_cos_dataset_py_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | _COD_DATA_ROOT = "./dataset" 5 | 6 | CAMO = dict( 7 | root=os.path.join(_COD_DATA_ROOT, "CAMO"), 8 | image=dict(path=os.path.join(_COD_DATA_ROOT, "CAMO", "Imgs"), suffix=".jpg"), 9 | mask=dict(path=os.path.join(_COD_DATA_ROOT, "CAMO", "GT"), suffix=".png"), 10 | ) 11 | 12 | COD10K = dict( 13 | root=os.path.join(_COD_DATA_ROOT, "COD10K"), 14 | image=dict(path=os.path.join(_COD_DATA_ROOT, "COD10K", "Imgs"), suffix=".jpg"), 15 | mask=dict(path=os.path.join(_COD_DATA_ROOT, "COD10K", "GT"), suffix=".png"), 16 | ) 17 | 18 | NC4K = dict( 19 | root=os.path.join(_COD_DATA_ROOT, "NC4K"), 20 | image=dict(path=os.path.join(_COD_DATA_ROOT, "NC4K", "Imgs"), suffix=".jpg"), 21 | mask=dict(path=os.path.join(_COD_DATA_ROOT, "NC4K", "GT"), suffix=".png"), 22 | ) -------------------------------------------------------------------------------- /cos_eval_toolbox/examples_COS/config_cos_method_py_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "PFNet": { 3 | "CAMO": { 4 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-PFNet/CAMO", 5 | "suffix": ".png" 6 | }, 7 | "NC4K": { 8 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-PFNet/NC4K", 9 | "suffix": ".png" 10 | }, 11 | "COD10K": { 12 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-PFNet/COD10K", 13 | "suffix": ".png" 14 | } 15 | }, 16 | "C2FNetV2": { 17 | "CAMO": { 18 | "path": "./benchmark/COS-Benchmarking/2022-TCSVT-C2FNetV2/CAMO", 19 | "suffix": ".png" 20 | }, 21 | "NC4K": { 22 | "path": "./benchmark/COS-Benchmarking/2022-TCSVT-C2FNetV2/NC4K", 23 | "suffix": ".png" 24 | }, 25 | "COD10K": { 26 | "path": "./benchmark/COS-Benchmarking/2022-TCSVT-C2FNetV2/COD10K", 27 | "suffix": ".png" 28 | } 29 | }, 30 | "CamoFormerS": { 31 | "CAMO": { 32 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerS/CAMO", 33 | "suffix": ".png" 34 | }, 35 | "NC4K": { 36 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerS/NC4K", 37 | "suffix": ".png" 38 | }, 39 | "COD10K": { 40 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerS/COD10K", 41 | "suffix": ".png" 42 | } 43 | }, 44 | "SINet": { 45 | "CAMO": { 46 | "path": "./benchmark/COS-Benchmarking/2020-CVPR-SINet/CAMO", 47 | "suffix": ".png" 48 | }, 49 | "NC4K": { 50 | "path": "./benchmark/COS-Benchmarking/2020-CVPR-SINet/NC4K", 51 | "suffix": ".png" 52 | }, 53 | "COD10K": { 54 | "path": "./benchmark/COS-Benchmarking/2020-CVPR-SINet/COD10K", 55 | "suffix": ".png" 56 | } 57 | }, 58 | "C2FNet": { 59 | "CAMO": { 60 | "path": "./benchmark/COS-Benchmarking/2021-IJCAI-C2FNet/CAMO", 61 | "suffix": ".png" 62 | }, 63 | "NC4K": { 64 | "path": "./benchmark/COS-Benchmarking/2021-IJCAI-C2FNet/NC4K", 65 | "suffix": ".png" 66 | }, 67 | "COD10K": { 68 | "path": "./benchmark/COS-Benchmarking/2021-IJCAI-C2FNet/COD10K", 69 | "suffix": ".png" 70 | } 71 | }, 72 | "ZoomNet": { 73 | "CAMO": { 74 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-ZoomNet/CAMO", 75 | "suffix": ".png" 76 | }, 77 | "NC4K": { 78 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-ZoomNet/NC4K", 79 | "suffix": ".png" 80 | }, 81 | "COD10K": { 82 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-ZoomNet/COD10K", 83 | "suffix": ".png" 84 | } 85 | }, 86 | "CamoFormerR": { 87 | "CAMO": { 88 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerR/CAMO", 89 | "suffix": ".png" 90 | }, 91 | "NC4K": { 92 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerR/NC4K", 93 | "suffix": ".png" 94 | }, 95 | "COD10K": { 96 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerR/COD10K", 97 | "suffix": ".png" 98 | } 99 | }, 100 | "SMGL": { 101 | "CAMO": { 102 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-SMGL/CAMO", 103 | "suffix": ".png" 104 | }, 105 | "NC4K": { 106 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-SMGL/NC4K", 107 | "suffix": ".png" 108 | }, 109 | "COD10K": { 110 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-SMGL/COD10K", 111 | "suffix": ".png" 112 | } 113 | }, 114 | "OCENet": { 115 | "CAMO": { 116 | "path": "./benchmark/COS-Benchmarking/2022-WACV-OCENet/CAMO", 117 | "suffix": ".png" 118 | }, 119 | "NC4K": { 120 | "path": "./benchmark/COS-Benchmarking/2022-WACV-OCENet/NC4K", 121 | "suffix": ".png" 122 | }, 123 | "COD10K": { 124 | "path": "./benchmark/COS-Benchmarking/2022-WACV-OCENet/COD10K", 125 | "suffix": ".png" 126 | } 127 | }, 128 | "FAPNet": { 129 | "CAMO": { 130 | "path": "./benchmark/COS-Benchmarking/2022-TIP-FAPNet/CAMO", 131 | "suffix": ".png" 132 | }, 133 | "NC4K": { 134 | "path": "./benchmark/COS-Benchmarking/2022-TIP-FAPNet/NC4K", 135 | "suffix": ".png" 136 | }, 137 | "COD10K": { 138 | "path": "./benchmark/COS-Benchmarking/2022-TIP-FAPNet/COD10K", 139 | "suffix": ".png" 140 | } 141 | }, 142 | "DTINet": { 143 | "CAMO": { 144 | "path": "./benchmark/COS-Benchmarking/2022-ICPR-DTINet/CAMO", 145 | "suffix": ".png" 146 | }, 147 | "NC4K": { 148 | "path": "./benchmark/COS-Benchmarking/2022-ICPR-DTINet/NC4K", 149 | "suffix": ".png" 150 | }, 151 | "COD10K": { 152 | "path": "./benchmark/COS-Benchmarking/2022-ICPR-DTINet/COD10K", 153 | "suffix": ".png" 154 | } 155 | }, 156 | "RMGL": { 157 | "CAMO": { 158 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-RMGL/CAMO", 159 | "suffix": ".png" 160 | }, 161 | "NC4K": { 162 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-RMGL/NC4K", 163 | "suffix": ".png" 164 | }, 165 | "COD10K": { 166 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-RMGL/COD10K", 167 | "suffix": ".png" 168 | } 169 | }, 170 | "TPRNet": { 171 | "CAMO": { 172 | "path": "./benchmark/COS-Benchmarking/2022-TVCJ-TPRNet/CAMO", 173 | "suffix": ".png" 174 | }, 175 | "NC4K": { 176 | "path": "./benchmark/COS-Benchmarking/2022-TVCJ-TPRNet/NC4K", 177 | "suffix": ".png" 178 | }, 179 | "COD10K": { 180 | "path": "./benchmark/COS-Benchmarking/2022-TVCJ-TPRNet/COD10K", 181 | "suffix": ".png" 182 | } 183 | }, 184 | "BSANet": { 185 | "CAMO": { 186 | "path": "./benchmark/COS-Benchmarking/2022-AAAI-BSANet/CAMO", 187 | "suffix": ".png" 188 | }, 189 | "NC4K": { 190 | "path": "./benchmark/COS-Benchmarking/2022-AAAI-BSANet/NC4K", 191 | "suffix": ".png" 192 | }, 193 | "COD10K": { 194 | "path": "./benchmark/COS-Benchmarking/2022-AAAI-BSANet/COD10K", 195 | "suffix": ".png" 196 | } 197 | }, 198 | "PopNet": { 199 | "CAMO": { 200 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-PopNet/CAMO", 201 | "suffix": ".png" 202 | }, 203 | "NC4K": { 204 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-PopNet/NC4K", 205 | "suffix": ".png" 206 | }, 207 | "COD10K": { 208 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-PopNet/COD10K", 209 | "suffix": ".png" 210 | } 211 | }, 212 | "LSR": { 213 | "CAMO": { 214 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-LSR/CAMO", 215 | "suffix": ".png" 216 | }, 217 | "NC4K": { 218 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-LSR/NC4K", 219 | "suffix": ".png" 220 | }, 221 | "COD10K": { 222 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-LSR/COD10K", 223 | "suffix": ".png" 224 | } 225 | }, 226 | "CubeNet": { 227 | "CAMO": { 228 | "path": "./benchmark/COS-Benchmarking/2022-PR-CubeNet/CAMO", 229 | "suffix": ".png" 230 | }, 231 | "NC4K": null, 232 | "COD10K": { 233 | "path": "./benchmark/COS-Benchmarking/2022-PR-CubeNet/COD10K", 234 | "suffix": ".png" 235 | } 236 | }, 237 | "JSCOD": { 238 | "CAMO": { 239 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-JSCOD/CAMO", 240 | "suffix": ".png" 241 | }, 242 | "NC4K": { 243 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-JSCOD/NC4K", 244 | "suffix": ".png" 245 | }, 246 | "COD10K": { 247 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-JSCOD/COD10K", 248 | "suffix": ".png" 249 | } 250 | }, 251 | "PFNetPlus": { 252 | "CAMO": { 253 | "path": "./benchmark/COS-Benchmarking/2023-SSCI-PFNetPlus/CAMO", 254 | "suffix": ".png" 255 | }, 256 | "NC4K": null, 257 | "COD10K": { 258 | "path": "./benchmark/COS-Benchmarking/2023-SSCI-PFNetPlus/COD10K", 259 | "suffix": ".png" 260 | } 261 | }, 262 | "BAS": { 263 | "CAMO": { 264 | "path": "./benchmark/COS-Benchmarking/2022-arXiv-BAS/CAMO", 265 | "suffix": ".png" 266 | }, 267 | "NC4K": { 268 | "path": "./benchmark/COS-Benchmarking/2022-arXiv-BAS/NC4K", 269 | "suffix": ".png" 270 | }, 271 | "COD10K": { 272 | "path": "./benchmark/COS-Benchmarking/2022-arXiv-BAS/COD10K", 273 | "suffix": ".png" 274 | } 275 | }, 276 | "D2CNet": { 277 | "CAMO": { 278 | "path": "./benchmark/COS-Benchmarking/2021-TIE-D2CNet/CAMO", 279 | "suffix": ".png" 280 | }, 281 | "NC4K": null, 282 | "COD10K": { 283 | "path": "./benchmark/COS-Benchmarking/2021-TIE-D2CNet/COD10K", 284 | "suffix": ".png" 285 | } 286 | }, 287 | "PreyNet": { 288 | "CAMO": { 289 | "path": "./benchmark/COS-Benchmarking/2022-MM-PreyNet/CAMO", 290 | "suffix": ".png" 291 | }, 292 | "NC4K": { 293 | "path": "./benchmark/COS-Benchmarking/2022-MM-PreyNet/NC4K", 294 | "suffix": ".png" 295 | }, 296 | "COD10K": { 297 | "path": "./benchmark/COS-Benchmarking/2022-MM-PreyNet/COD10K", 298 | "suffix": ".png" 299 | } 300 | }, 301 | "ERRNet": { 302 | "CAMO": { 303 | "path": "./benchmark/COS-Benchmarking/2022-PR-ERRNet/CAMO", 304 | "suffix": ".png" 305 | }, 306 | "NC4K": { 307 | "path": "./benchmark/COS-Benchmarking/2022-PR-ERRNet/NC4K", 308 | "suffix": ".png" 309 | }, 310 | "COD10K": { 311 | "path": "./benchmark/COS-Benchmarking/2022-PR-ERRNet/COD10K", 312 | "suffix": ".png" 313 | } 314 | }, 315 | "SINetV2": { 316 | "CAMO": { 317 | "path": "./benchmark/COS-Benchmarking/2022-TPAMI-SINetV2/CAMO", 318 | "suffix": ".png" 319 | }, 320 | "NC4K": { 321 | "path": "./benchmark/COS-Benchmarking/2022-TPAMI-SINetV2/NC4K", 322 | "suffix": ".png" 323 | }, 324 | "COD10K": { 325 | "path": "./benchmark/COS-Benchmarking/2022-TPAMI-SINetV2/COD10K", 326 | "suffix": ".png" 327 | } 328 | }, 329 | "SegMaR": { 330 | "CAMO": { 331 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-SegMaR/CAMO", 332 | "suffix": ".png" 333 | }, 334 | "NC4K": { 335 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-SegMaR/NC4K", 336 | "suffix": ".png" 337 | }, 338 | "COD10K": { 339 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-SegMaR/COD10K", 340 | "suffix": ".png" 341 | } 342 | }, 343 | "HitNet": { 344 | "CAMO": { 345 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-HitNet/CAMO", 346 | "suffix": ".png" 347 | }, 348 | "NC4K": { 349 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-HitNet/NC4K", 350 | "suffix": ".png" 351 | }, 352 | "COD10K": { 353 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-HitNet/COD10K", 354 | "suffix": ".png" 355 | } 356 | }, 357 | "TINet": { 358 | "CAMO": { 359 | "path": "./benchmark/COS-Benchmarking/2021-AAAI-TINet/CAMO", 360 | "suffix": ".png" 361 | }, 362 | "NC4K": { 363 | "path": "./benchmark/COS-Benchmarking/2021-AAAI-TINet/NC4K", 364 | "suffix": ".png" 365 | }, 366 | "COD10K": { 367 | "path": "./benchmark/COS-Benchmarking/2021-AAAI-TINet/COD10K", 368 | "suffix": ".png" 369 | } 370 | }, 371 | "CRNet": { 372 | "CAMO": { 373 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-CRNet/CAMO", 374 | "suffix": ".png" 375 | }, 376 | "NC4K": null, 377 | "COD10K": { 378 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-CRNet/COD10K", 379 | "suffix": ".png" 380 | } 381 | }, 382 | "BGNet": { 383 | "CAMO": { 384 | "path": "./benchmark/COS-Benchmarking/2022-IJCAI-BGNet/CAMO", 385 | "suffix": ".png" 386 | }, 387 | "NC4K": { 388 | "path": "./benchmark/COS-Benchmarking/2022-IJCAI-BGNet/NC4K", 389 | "suffix": ".png" 390 | }, 391 | "COD10K": { 392 | "path": "./benchmark/COS-Benchmarking/2022-IJCAI-BGNet/COD10K", 393 | "suffix": ".png" 394 | } 395 | }, 396 | "UGTR": { 397 | "CAMO": { 398 | "path": "./benchmark/COS-Benchmarking/2021-ICCV-UGTR/CAMO", 399 | "suffix": ".png" 400 | }, 401 | "NC4K": { 402 | "path": "./benchmark/COS-Benchmarking/2021-ICCV-UGTR/NC4K", 403 | "suffix": ".png" 404 | }, 405 | "COD10K": { 406 | "path": "./benchmark/COS-Benchmarking/2021-ICCV-UGTR/COD10K", 407 | "suffix": ".png" 408 | } 409 | }, 410 | "DGNet": { 411 | "CAMO": { 412 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNet/CAMO", 413 | "suffix": ".png" 414 | }, 415 | "NC4K": { 416 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNet/NC4K", 417 | "suffix": ".png" 418 | }, 419 | "COD10K": { 420 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNet/COD10K", 421 | "suffix": ".png" 422 | } 423 | }, 424 | "CamoFormerP": { 425 | "CAMO": { 426 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerP/CAMO", 427 | "suffix": ".png" 428 | }, 429 | "NC4K": { 430 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerP/NC4K", 431 | "suffix": ".png" 432 | }, 433 | "COD10K": { 434 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerP/COD10K", 435 | "suffix": ".png" 436 | } 437 | }, 438 | "NCHIT": { 439 | "CAMO": { 440 | "path": "./benchmark/COS-Benchmarking/2022-CVIU-NCHIT/CAMO", 441 | "suffix": ".png" 442 | }, 443 | "NC4K": { 444 | "path": "./benchmark/COS-Benchmarking/2022-CVIU-NCHIT/NC4K", 445 | "suffix": ".png" 446 | }, 447 | "COD10K": { 448 | "path": "./benchmark/COS-Benchmarking/2022-CVIU-NCHIT/COD10K", 449 | "suffix": ".png" 450 | } 451 | }, 452 | "FDNet": { 453 | "CAMO": { 454 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-FDNet/CAMO", 455 | "suffix": ".png" 456 | }, 457 | "NC4K": { 458 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-FDNet/NC4K", 459 | "suffix": ".png" 460 | }, 461 | "COD10K": { 462 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-FDNet/COD10K", 463 | "suffix": ".png" 464 | } 465 | }, 466 | "DGNetS": { 467 | "CAMO": { 468 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNetS/CAMO", 469 | "suffix": ".png" 470 | }, 471 | "NC4K": { 472 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNetS/NC4K", 473 | "suffix": ".png" 474 | }, 475 | "COD10K": { 476 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNetS/COD10K", 477 | "suffix": ".png" 478 | } 479 | }, 480 | "CamoFormerC": { 481 | "CAMO": { 482 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerC/CAMO", 483 | "suffix": ".png" 484 | }, 485 | "NC4K": { 486 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerC/NC4K", 487 | "suffix": ".png" 488 | }, 489 | "COD10K": { 490 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerC/COD10K", 491 | "suffix": ".png" 492 | } 493 | } 494 | } -------------------------------------------------------------------------------- /cos_eval_toolbox/examples_COS/config_cos_method_py_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | PFNet_root = "./benchmark/COS-Benchmarking/2021-CVPR-PFNet" 5 | PFNet = { 6 | "CAMO": dict(path=os.path.join(PFNet_root, "CAMO"), suffix=".png"), 7 | "NC4K": dict(path=os.path.join(PFNet_root, "NC4K"), suffix=".png"), 8 | "COD10K": dict(path=os.path.join(PFNet_root, "COD10K"), suffix=".png"), 9 | } 10 | 11 | C2FNetV2_root = "./benchmark/COS-Benchmarking/2022-TCSVT-C2FNetV2" 12 | C2FNetV2 = { 13 | "CAMO": dict(path=os.path.join(C2FNetV2_root, "CAMO"), suffix=".png"), 14 | "NC4K": dict(path=os.path.join(C2FNetV2_root, "NC4K"), suffix=".png"), 15 | "COD10K": dict(path=os.path.join(C2FNetV2_root, "COD10K"), suffix=".png"), 16 | } 17 | 18 | CamoFormerS_root = "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerS" 19 | CamoFormerS = { 20 | "CAMO": dict(path=os.path.join(CamoFormerS_root, "CAMO"), suffix=".png"), 21 | "NC4K": dict(path=os.path.join(CamoFormerS_root, "NC4K"), suffix=".png"), 22 | "COD10K": dict(path=os.path.join(CamoFormerS_root, "COD10K"), suffix=".png"), 23 | } 24 | 25 | SINet_root = "./benchmark/COS-Benchmarking/2020-CVPR-SINet" 26 | SINet = { 27 | "CAMO": dict(path=os.path.join(SINet_root, "CAMO"), suffix=".png"), 28 | "NC4K": dict(path=os.path.join(SINet_root, "NC4K"), suffix=".png"), 29 | "COD10K": dict(path=os.path.join(SINet_root, "COD10K"), suffix=".png"), 30 | } 31 | 32 | C2FNet_root = "./benchmark/COS-Benchmarking/2021-IJCAI-C2FNet" 33 | C2FNet = { 34 | "CAMO": dict(path=os.path.join(C2FNet_root, "CAMO"), suffix=".png"), 35 | "NC4K": dict(path=os.path.join(C2FNet_root, "NC4K"), suffix=".png"), 36 | "COD10K": dict(path=os.path.join(C2FNet_root, "COD10K"), suffix=".png"), 37 | } 38 | 39 | ZoomNet_root = "./benchmark/COS-Benchmarking/2022-CVPR-ZoomNet" 40 | ZoomNet = { 41 | "CAMO": dict(path=os.path.join(ZoomNet_root, "CAMO"), suffix=".png"), 42 | "NC4K": dict(path=os.path.join(ZoomNet_root, "NC4K"), suffix=".png"), 43 | "COD10K": dict(path=os.path.join(ZoomNet_root, "COD10K"), suffix=".png"), 44 | } 45 | 46 | CamoFormerR_root = "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerR" 47 | CamoFormerR = { 48 | "CAMO": dict(path=os.path.join(CamoFormerR_root, "CAMO"), suffix=".png"), 49 | "NC4K": dict(path=os.path.join(CamoFormerR_root, "NC4K"), suffix=".png"), 50 | "COD10K": dict(path=os.path.join(CamoFormerR_root, "COD10K"), suffix=".png"), 51 | } 52 | 53 | SMGL_root = "./benchmark/COS-Benchmarking/2021-CVPR-SMGL" 54 | SMGL = { 55 | "CAMO": dict(path=os.path.join(SMGL_root, "CAMO"), suffix=".png"), 56 | "NC4K": dict(path=os.path.join(SMGL_root, "NC4K"), suffix=".png"), 57 | "COD10K": dict(path=os.path.join(SMGL_root, "COD10K"), suffix=".png"), 58 | } 59 | 60 | OCENet_root = "./benchmark/COS-Benchmarking/2022-WACV-OCENet" 61 | OCENet = { 62 | "CAMO": dict(path=os.path.join(OCENet_root, "CAMO"), suffix=".png"), 63 | "NC4K": dict(path=os.path.join(OCENet_root, "NC4K"), suffix=".png"), 64 | "COD10K": dict(path=os.path.join(OCENet_root, "COD10K"), suffix=".png"), 65 | } 66 | 67 | FAPNet_root = "./benchmark/COS-Benchmarking/2022-TIP-FAPNet" 68 | FAPNet = { 69 | "CAMO": dict(path=os.path.join(FAPNet_root, "CAMO"), suffix=".png"), 70 | "NC4K": dict(path=os.path.join(FAPNet_root, "NC4K"), suffix=".png"), 71 | "COD10K": dict(path=os.path.join(FAPNet_root, "COD10K"), suffix=".png"), 72 | } 73 | 74 | DTINet_root = "./benchmark/COS-Benchmarking/2022-ICPR-DTINet" 75 | DTINet = { 76 | "CAMO": dict(path=os.path.join(DTINet_root, "CAMO"), suffix=".png"), 77 | "NC4K": dict(path=os.path.join(DTINet_root, "NC4K"), suffix=".png"), 78 | "COD10K": dict(path=os.path.join(DTINet_root, "COD10K"), suffix=".png"), 79 | } 80 | 81 | RMGL_root = "./benchmark/COS-Benchmarking/2021-CVPR-RMGL" 82 | RMGL = { 83 | "CAMO": dict(path=os.path.join(RMGL_root, "CAMO"), suffix=".png"), 84 | "NC4K": dict(path=os.path.join(RMGL_root, "NC4K"), suffix=".png"), 85 | "COD10K": dict(path=os.path.join(RMGL_root, "COD10K"), suffix=".png"), 86 | } 87 | 88 | TPRNet_root = "./benchmark/COS-Benchmarking/2022-TVCJ-TPRNet" 89 | TPRNet = { 90 | "CAMO": dict(path=os.path.join(TPRNet_root, "CAMO"), suffix=".png"), 91 | "NC4K": dict(path=os.path.join(TPRNet_root, "NC4K"), suffix=".png"), 92 | "COD10K": dict(path=os.path.join(TPRNet_root, "COD10K"), suffix=".png"), 93 | } 94 | 95 | BSANet_root = "./benchmark/COS-Benchmarking/2022-AAAI-BSANet" 96 | BSANet = { 97 | "CAMO": dict(path=os.path.join(BSANet_root, "CAMO"), suffix=".png"), 98 | "NC4K": dict(path=os.path.join(BSANet_root, "NC4K"), suffix=".png"), 99 | "COD10K": dict(path=os.path.join(BSANet_root, "COD10K"), suffix=".png"), 100 | } 101 | 102 | PopNet_root = "./benchmark/COS-Benchmarking/2023-arXiv-PopNet" 103 | PopNet = { 104 | "CAMO": dict(path=os.path.join(PopNet_root, "CAMO"), suffix=".png"), 105 | "NC4K": dict(path=os.path.join(PopNet_root, "NC4K"), suffix=".png"), 106 | "COD10K": dict(path=os.path.join(PopNet_root, "COD10K"), suffix=".png"), 107 | } 108 | 109 | LSR_root = "./benchmark/COS-Benchmarking/2021-CVPR-LSR" 110 | LSR = { 111 | "CAMO": dict(path=os.path.join(LSR_root, "CAMO"), suffix=".png"), 112 | "NC4K": dict(path=os.path.join(LSR_root, "NC4K"), suffix=".png"), 113 | "COD10K": dict(path=os.path.join(LSR_root, "COD10K"), suffix=".png"), 114 | } 115 | 116 | CubeNet_root = "./benchmark/COS-Benchmarking/2022-PR-CubeNet" 117 | CubeNet = { 118 | "CAMO": dict(path=os.path.join(CubeNet_root, "CAMO"), suffix=".png"), 119 | "NC4K": None, 120 | "COD10K": dict(path=os.path.join(CubeNet_root, "COD10K"), suffix=".png"), 121 | } 122 | 123 | JSCOD_root = "./benchmark/COS-Benchmarking/2021-CVPR-JSCOD" 124 | JSCOD = { 125 | "CAMO": dict(path=os.path.join(JSCOD_root, "CAMO"), suffix=".png"), 126 | "NC4K": dict(path=os.path.join(JSCOD_root, "NC4K"), suffix=".png"), 127 | "COD10K": dict(path=os.path.join(JSCOD_root, "COD10K"), suffix=".png"), 128 | } 129 | 130 | PFNetPlus_root = "./benchmark/COS-Benchmarking/2023-SSCI-PFNetPlus" 131 | PFNetPlus = { 132 | "CAMO": dict(path=os.path.join(PFNetPlus_root, "CAMO"), suffix=".png"), 133 | "NC4K": None, 134 | "COD10K": dict(path=os.path.join(PFNetPlus_root, "COD10K"), suffix=".png"), 135 | } 136 | 137 | BAS_root = "./benchmark/COS-Benchmarking/2022-arXiv-BAS" 138 | BAS = { 139 | "CAMO": dict(path=os.path.join(BAS_root, "CAMO"), suffix=".png"), 140 | "NC4K": dict(path=os.path.join(BAS_root, "NC4K"), suffix=".png"), 141 | "COD10K": dict(path=os.path.join(BAS_root, "COD10K"), suffix=".png"), 142 | } 143 | 144 | D2CNet_root = "./benchmark/COS-Benchmarking/2021-TIE-D2CNet" 145 | D2CNet = { 146 | "CAMO": dict(path=os.path.join(D2CNet_root, "CAMO"), suffix=".png"), 147 | "NC4K": None, 148 | "COD10K": dict(path=os.path.join(D2CNet_root, "COD10K"), suffix=".png"), 149 | } 150 | 151 | PreyNet_root = "./benchmark/COS-Benchmarking/2022-MM-PreyNet" 152 | PreyNet = { 153 | "CAMO": dict(path=os.path.join(PreyNet_root, "CAMO"), suffix=".png"), 154 | "NC4K": dict(path=os.path.join(PreyNet_root, "NC4K"), suffix=".png"), 155 | "COD10K": dict(path=os.path.join(PreyNet_root, "COD10K"), suffix=".png"), 156 | } 157 | 158 | ERRNet_root = "./benchmark/COS-Benchmarking/2022-PR-ERRNet" 159 | ERRNet = { 160 | "CAMO": dict(path=os.path.join(ERRNet_root, "CAMO"), suffix=".png"), 161 | "NC4K": dict(path=os.path.join(ERRNet_root, "NC4K"), suffix=".png"), 162 | "COD10K": dict(path=os.path.join(ERRNet_root, "COD10K"), suffix=".png"), 163 | } 164 | 165 | SINetV2_root = "./benchmark/COS-Benchmarking/2022-TPAMI-SINetV2" 166 | SINetV2 = { 167 | "CAMO": dict(path=os.path.join(SINetV2_root, "CAMO"), suffix=".png"), 168 | "NC4K": dict(path=os.path.join(SINetV2_root, "NC4K"), suffix=".png"), 169 | "COD10K": dict(path=os.path.join(SINetV2_root, "COD10K"), suffix=".png"), 170 | } 171 | 172 | SegMaR_root = "./benchmark/COS-Benchmarking/2022-CVPR-SegMaR" 173 | SegMaR = { 174 | "CAMO": dict(path=os.path.join(SegMaR_root, "CAMO"), suffix=".png"), 175 | "NC4K": dict(path=os.path.join(SegMaR_root, "NC4K"), suffix=".png"), 176 | "COD10K": dict(path=os.path.join(SegMaR_root, "COD10K"), suffix=".png"), 177 | } 178 | 179 | HitNet_root = "./benchmark/COS-Benchmarking/2023-AAAI-HitNet" 180 | HitNet = { 181 | "CAMO": dict(path=os.path.join(HitNet_root, "CAMO"), suffix=".png"), 182 | "NC4K": dict(path=os.path.join(HitNet_root, "NC4K"), suffix=".png"), 183 | "COD10K": dict(path=os.path.join(HitNet_root, "COD10K"), suffix=".png"), 184 | } 185 | 186 | TINet_root = "./benchmark/COS-Benchmarking/2021-AAAI-TINet" 187 | TINet = { 188 | "CAMO": dict(path=os.path.join(TINet_root, "CAMO"), suffix=".png"), 189 | "NC4K": dict(path=os.path.join(TINet_root, "NC4K"), suffix=".png"), 190 | "COD10K": dict(path=os.path.join(TINet_root, "COD10K"), suffix=".png"), 191 | } 192 | 193 | CRNet_root = "./benchmark/COS-Benchmarking/2023-AAAI-CRNet" 194 | CRNet = { 195 | "CAMO": dict(path=os.path.join(CRNet_root, "CAMO"), suffix=".png"), 196 | "NC4K": None, 197 | "COD10K": dict(path=os.path.join(CRNet_root, "COD10K"), suffix=".png"), 198 | } 199 | 200 | BGNet_root = "./benchmark/COS-Benchmarking/2022-IJCAI-BGNet" 201 | BGNet = { 202 | "CAMO": dict(path=os.path.join(BGNet_root, "CAMO"), suffix=".png"), 203 | "NC4K": dict(path=os.path.join(BGNet_root, "NC4K"), suffix=".png"), 204 | "COD10K": dict(path=os.path.join(BGNet_root, "COD10K"), suffix=".png"), 205 | } 206 | 207 | UGTR_root = "./benchmark/COS-Benchmarking/2021-ICCV-UGTR" 208 | UGTR = { 209 | "CAMO": dict(path=os.path.join(UGTR_root, "CAMO"), suffix=".png"), 210 | "NC4K": dict(path=os.path.join(UGTR_root, "NC4K"), suffix=".png"), 211 | "COD10K": dict(path=os.path.join(UGTR_root, "COD10K"), suffix=".png"), 212 | } 213 | 214 | DGNet_root = "./benchmark/COS-Benchmarking/2023-MIR-DGNet" 215 | DGNet = { 216 | "CAMO": dict(path=os.path.join(DGNet_root, "CAMO"), suffix=".png"), 217 | "NC4K": dict(path=os.path.join(DGNet_root, "NC4K"), suffix=".png"), 218 | "COD10K": dict(path=os.path.join(DGNet_root, "COD10K"), suffix=".png"), 219 | } 220 | 221 | CamoFormerP_root = "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerP" 222 | CamoFormerP = { 223 | "CAMO": dict(path=os.path.join(CamoFormerP_root, "CAMO"), suffix=".png"), 224 | "NC4K": dict(path=os.path.join(CamoFormerP_root, "NC4K"), suffix=".png"), 225 | "COD10K": dict(path=os.path.join(CamoFormerP_root, "COD10K"), suffix=".png"), 226 | } 227 | 228 | NCHIT_root = "./benchmark/COS-Benchmarking/2022-CVIU-NCHIT" 229 | NCHIT = { 230 | "CAMO": dict(path=os.path.join(NCHIT_root, "CAMO"), suffix=".png"), 231 | "NC4K": dict(path=os.path.join(NCHIT_root, "NC4K"), suffix=".png"), 232 | "COD10K": dict(path=os.path.join(NCHIT_root, "COD10K"), suffix=".png"), 233 | } 234 | 235 | FDNet_root = "./benchmark/COS-Benchmarking/2022-CVPR-FDNet" 236 | FDNet = { 237 | "CAMO": dict(path=os.path.join(FDNet_root, "CAMO"), suffix=".png"), 238 | "NC4K": dict(path=os.path.join(FDNet_root, "NC4K"), suffix=".png"), 239 | "COD10K": dict(path=os.path.join(FDNet_root, "COD10K"), suffix=".png"), 240 | } 241 | 242 | DGNetS_root = "./benchmark/COS-Benchmarking/2023-MIR-DGNetS" 243 | DGNetS = { 244 | "CAMO": dict(path=os.path.join(DGNetS_root, "CAMO"), suffix=".png"), 245 | "NC4K": dict(path=os.path.join(DGNetS_root, "NC4K"), suffix=".png"), 246 | "COD10K": dict(path=os.path.join(DGNetS_root, "COD10K"), suffix=".png"), 247 | } 248 | 249 | CamoFormerC_root = "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerC" 250 | CamoFormerC = { 251 | "CAMO": dict(path=os.path.join(CamoFormerC_root, "CAMO"), suffix=".png"), 252 | "NC4K": dict(path=os.path.join(CamoFormerC_root, "NC4K"), suffix=".png"), 253 | "COD10K": dict(path=os.path.join(CamoFormerC_root, "COD10K"), suffix=".png"), 254 | } 255 | 256 | -------------------------------------------------------------------------------- /cos_eval_toolbox/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/cos_eval_toolbox/metrics/__init__.py -------------------------------------------------------------------------------- /cos_eval_toolbox/metrics/cal_cosod_matrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from utils.misc import ( 10 | colored_print, 11 | get_gt_pre_with_name, 12 | get_name_with_group_list, 13 | make_dir, 14 | ) 15 | from utils.print_formatter import formatter_for_tabulate 16 | from utils.recorders import GroupedMetricRecorder, MetricExcelRecorder, TxtRecorder 17 | 18 | 19 | def group_names(names: list) -> dict: 20 | grouped_data = defaultdict(list) 21 | for name in names: 22 | group_name, file_name = name.split("/") 23 | grouped_data[group_name].append(file_name) 24 | return grouped_data 25 | 26 | 27 | def cal_cosod_matrics( 28 | data_type: str = "rgb_sod", 29 | txt_path: str = "", 30 | to_append: bool = True, 31 | xlsx_path: str = "", 32 | drawing_info: dict = None, 33 | dataset_info: dict = None, 34 | save_npy: bool = True, 35 | curves_npy_path: str = "./curves.npy", 36 | metrics_npy_path: str = "./metrics.npy", 37 | num_bits: int = 3, 38 | ): 39 | """ 40 | Save the results of all models on different datasets in a `npy` file in the form of a 41 | dictionary. 42 | 43 | :: 44 | 45 | { 46 | dataset1:{ 47 | method1:[fm, em, p, r], 48 | method2:[fm, em, p, r], 49 | ..... 50 | }, 51 | dataset2:{ 52 | method1:[fm, em, p, r], 53 | method2:[fm, em, p, r], 54 | ..... 55 | }, 56 | .... 57 | } 58 | 59 | :param data_type: the type of data 60 | :param txt_path: the path of the txt for saving results 61 | :param to_append: whether to append results to the original record 62 | :param xlsx_path: the path of the xlsx file for saving results 63 | :param drawing_info: the method information for plotting figures 64 | :param dataset_info: the dataset information 65 | :param save_npy: whether to save results into npy files 66 | :param curves_npy_path: the npy file path for saving curve data 67 | :param metrics_npy_path: the npy file path for saving metric values 68 | :param num_bits: the number of bits used to format results 69 | """ 70 | curves = defaultdict(dict) # Two curve metrics 71 | metrics = defaultdict(dict) # Six numerical metrics 72 | 73 | txt_recoder = TxtRecorder( 74 | txt_path=txt_path, 75 | to_append=to_append, 76 | max_method_name_width=max([len(x) for x in drawing_info.keys()]), # 显示完整名字 77 | ) 78 | excel_recorder = MetricExcelRecorder( 79 | xlsx_path=xlsx_path, 80 | sheet_name=data_type, 81 | row_header=["methods"], 82 | dataset_names=sorted(list(dataset_info.keys())), 83 | metric_names=["sm", "wfm", "mae", "adpf", "avgf", "maxf", "adpe", "avge", "maxe"], 84 | ) 85 | 86 | for dataset_name, dataset_path in dataset_info.items(): 87 | txt_recoder.add_row(row_name="Dataset", row_data=dataset_name, row_start_str="\n") 88 | 89 | # 获取真值图片信息 90 | gt_info = dataset_path["mask"] 91 | gt_root = gt_info["path"] 92 | gt_ext = gt_info["suffix"] 93 | # 真值名字列表 94 | gt_index_file = dataset_path.get("index_file") 95 | if gt_index_file: 96 | gt_name_list = get_name_with_group_list(data_path=gt_index_file, file_ext=gt_ext) 97 | else: 98 | gt_name_list = get_name_with_group_list(data_path=gt_root, file_ext=gt_ext) 99 | assert len(gt_name_list) > 0, "there is not ground truth." 100 | 101 | # ==>> test the intersection between pre and gt for each method <<== 102 | for method_name, method_info in drawing_info.items(): 103 | method_root = method_info["path_dict"] 104 | method_dataset_info = method_root.get(dataset_name, None) 105 | if method_dataset_info is None: 106 | colored_print( 107 | msg=f"{method_name} does not have results on {dataset_name}", mode="warning" 108 | ) 109 | continue 110 | 111 | # 预测结果存放路径下的图片文件名字列表和扩展名称 112 | pre_ext = method_dataset_info["suffix"] 113 | pre_root = method_dataset_info["path"] 114 | pre_name_list = get_name_with_group_list(data_path=pre_root, file_ext=pre_ext) 115 | 116 | # get the intersection 117 | eval_name_list = sorted(list(set(gt_name_list).intersection(set(pre_name_list)))) 118 | num_names = len(eval_name_list) 119 | 120 | if num_names == 0: 121 | colored_print( 122 | msg=f"{method_name} does not have results on {dataset_name}", mode="warning" 123 | ) 124 | continue 125 | 126 | grouped_data = group_names(names=eval_name_list) 127 | num_groups = len(grouped_data) 128 | 129 | colored_print( 130 | f"Evaluating {method_name} with {num_names} images and {num_groups} groups" 131 | f" (G:{len(gt_name_list)},P:{len(pre_name_list)}) images on dataset {dataset_name}" 132 | ) 133 | 134 | group_metric_recorder = GroupedMetricRecorder() 135 | inter_group_bar = tqdm( 136 | grouped_data.items(), 137 | total=num_groups, 138 | leave=False, 139 | ncols=79, 140 | desc=f"[{dataset_name}]", 141 | ) 142 | for group_name, names_in_group in inter_group_bar: 143 | intra_group_bar = tqdm( 144 | names_in_group, 145 | total=len(names_in_group), 146 | leave=False, 147 | ncols=79, 148 | desc=f"({group_name})", 149 | ) 150 | for img_name in intra_group_bar: 151 | img_name_with_group = os.path.join(group_name, img_name) 152 | gt, pre = get_gt_pre_with_name( 153 | gt_root=gt_root, 154 | pre_root=pre_root, 155 | img_name=img_name_with_group, 156 | pre_ext=pre_ext, 157 | gt_ext=gt_ext, 158 | to_normalize=False, 159 | ) 160 | group_metric_recorder.update(group_name=group_name, pre=pre, gt=gt) 161 | method_results = group_metric_recorder.show(num_bits=num_bits, return_ndarray=False) 162 | method_curves = method_results["sequential"] 163 | method_metrics = method_results["numerical"] 164 | curves[dataset_name][method_name] = method_curves 165 | metrics[dataset_name][method_name] = method_metrics 166 | 167 | excel_recorder( 168 | row_data=method_metrics, dataset_name=dataset_name, method_name=method_name 169 | ) 170 | txt_recoder(method_results=method_metrics, method_name=method_name) 171 | 172 | if save_npy: 173 | make_dir(os.path.basename(curves_npy_path)) 174 | np.save(curves_npy_path, curves) 175 | np.save(metrics_npy_path, metrics) 176 | colored_print(f"all methods have been saved in {curves_npy_path} and {metrics_npy_path}") 177 | formatted_string = formatter_for_tabulate(metrics) 178 | colored_print(f"all methods have been tested:\n{formatted_string}") 179 | -------------------------------------------------------------------------------- /cos_eval_toolbox/metrics/cal_sod_matrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from collections import defaultdict 5 | from functools import partial 6 | from multiprocessing import RLock, pool 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from utils.misc import get_gt_pre_with_name, get_name_list, make_dir 12 | from utils.print_formatter import formatter_for_tabulate 13 | from utils.recorders import MetricExcelRecorder, MetricRecorder_V2, TxtRecorder 14 | 15 | 16 | class Recorder: 17 | def __init__( 18 | self, 19 | txt_path, 20 | to_append, 21 | max_method_name_width, 22 | xlsx_path, 23 | sheet_name, 24 | dataset_names, 25 | metric_names, 26 | ): 27 | self.curves = defaultdict(dict) # Two curve metrics 28 | self.metrics = defaultdict(dict) # Six numerical metrics 29 | self.dataset_name = None 30 | 31 | self.txt_recorder = None 32 | if txt_path: 33 | self.txt_recorder = TxtRecorder( 34 | txt_path=txt_path, 35 | to_append=to_append, 36 | max_method_name_width=max_method_name_width, 37 | ) 38 | 39 | self.excel_recorder = None 40 | if xlsx_path: 41 | self.excel_recorder = MetricExcelRecorder( 42 | xlsx_path=xlsx_path, 43 | sheet_name=sheet_name, 44 | row_header=["methods"], 45 | dataset_names=dataset_names, 46 | metric_names=metric_names, 47 | ) 48 | 49 | def record_dataset_name(self, dataset_name): 50 | self.dataset_name = dataset_name 51 | if self.txt_recorder: 52 | self.txt_recorder.add_row( 53 | row_name="Dataset", row_data=dataset_name, row_start_str="\n" 54 | ) 55 | 56 | def record(self, method_results, method_name): 57 | method_curves = method_results["sequential"] 58 | method_metrics = method_results["numerical"] 59 | 60 | self.curves[self.dataset_name][method_name] = method_curves 61 | self.metrics[self.dataset_name][method_name] = method_metrics 62 | 63 | if self.excel_recorder: 64 | self.excel_recorder( 65 | row_data=method_metrics, dataset_name=self.dataset_name, method_name=method_name 66 | ) 67 | if self.txt_recorder: 68 | self.txt_recorder(method_results=method_metrics, method_name=method_name) 69 | 70 | 71 | def cal_sod_matrics( 72 | sheet_name: str = "results", 73 | txt_path: str = "", 74 | to_append: bool = True, 75 | xlsx_path: str = "", 76 | methods_info: dict = None, 77 | datasets_info: dict = None, 78 | curves_npy_path: str = "./curves.npy", 79 | metrics_npy_path: str = "./metrics.npy", 80 | num_bits: int = 3, 81 | num_workers: int = 2, 82 | use_mp: bool = False, 83 | metric_names: tuple = ("mae", "fm", "em", "sm", "wfm"), 84 | ncols_tqdm: int = 79, 85 | ): 86 | """ 87 | Save the results of all models on different datasets in a `npy` file in the form of a 88 | dictionary. 89 | 90 | :: 91 | 92 | { 93 | dataset1:{ 94 | method1:[fm, em, p, r], 95 | method2:[fm, em, p, r], 96 | ..... 97 | }, 98 | dataset2:{ 99 | method1:[fm, em, p, r], 100 | method2:[fm, em, p, r], 101 | ..... 102 | }, 103 | .... 104 | } 105 | 106 | :param sheet_name: the type of the sheet in xlsx file 107 | :param txt_path: the path of the txt for saving results 108 | :param to_append: whether to append results to the original record 109 | :param xlsx_path: the path of the xlsx file for saving results 110 | :param methods_info: the method information 111 | :param datasets_info: the dataset information 112 | :param curves_npy_path: the npy file path for saving curve data 113 | :param metrics_npy_path: the npy file path for saving metric values 114 | :param num_bits: the number of bits used to format results 115 | :param num_workers: the number of workers of multiprocessing or multithreading 116 | :param use_mp: using multiprocessing or multithreading 117 | :param metric_names: names of metrics 118 | :param ncols_tqdm: number of columns for tqdm 119 | """ 120 | recorder = Recorder( 121 | txt_path=txt_path, 122 | to_append=to_append, 123 | max_method_name_width=max([len(x) for x in methods_info.keys()]), # 显示完整名字 124 | xlsx_path=xlsx_path, 125 | sheet_name=sheet_name, 126 | dataset_names=sorted(datasets_info.keys()), 127 | metric_names=["sm", "wfm", "mae", "adpf", "avgf", "maxf", "adpe", "avge", "maxe"], 128 | ) 129 | 130 | for dataset_name, dataset_path in datasets_info.items(): 131 | recorder.record_dataset_name(dataset_name) 132 | 133 | # 获取真值图片信息 134 | gt_info = dataset_path["mask"] 135 | gt_root = gt_info["path"] 136 | gt_ext = gt_info["suffix"] 137 | # 真值名字列表 138 | gt_index_file = dataset_path.get("index_file") 139 | if gt_index_file: 140 | gt_name_list = get_name_list(data_path=gt_index_file, name_suffix=gt_ext) 141 | else: 142 | gt_name_list = get_name_list(data_path=gt_root, name_suffix=gt_ext) 143 | assert len(gt_name_list) > 0, "there is not ground truth." 144 | 145 | # ==>> test the intersection between pre and gt for each method <<== 146 | tqdm.set_lock(RLock()) 147 | pool_cls = pool.Pool if use_mp else pool.ThreadPool 148 | procs = pool_cls( 149 | processes=num_workers, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),) 150 | ) 151 | procs_idx = 0 152 | for method_name, method_info in methods_info.items(): 153 | method_root = method_info["path_dict"] 154 | method_dataset_info = method_root.get(dataset_name, None) 155 | if method_dataset_info is None: 156 | tqdm.write(f"{method_name} does not have results on {dataset_name}") 157 | continue 158 | 159 | # 预测结果存放路径下的图片文件名字列表和扩展名称 160 | pre_prefix = method_dataset_info.get("prefix", "") 161 | pre_suffix = method_dataset_info["suffix"] 162 | pre_root = method_dataset_info["path"] 163 | pre_name_list = get_name_list( 164 | data_path=pre_root, 165 | name_prefix=pre_prefix, 166 | name_suffix=pre_suffix, 167 | ) 168 | 169 | # get the intersection 170 | eval_name_list = sorted(list(set(gt_name_list).intersection(pre_name_list))) 171 | if len(eval_name_list) == 0: 172 | tqdm.write(f"{method_name} does not have results on {dataset_name}") 173 | continue 174 | 175 | procs.apply_async( 176 | func=evaluate_data, 177 | kwds=dict( 178 | names=eval_name_list, 179 | num_bits=num_bits, 180 | pre_root=pre_root, 181 | pre_prefix=pre_prefix, 182 | pre_suffix=pre_suffix, 183 | gt_root=gt_root, 184 | gt_ext=gt_ext, 185 | desc=f"Dataset: {dataset_name} ({len(gt_name_list)}) | Method: {method_name} ({len(pre_name_list)})", 186 | proc_idx=procs_idx, 187 | blocking=use_mp, 188 | metric_names=metric_names, 189 | ncols_tqdm=ncols_tqdm, 190 | ), 191 | callback=partial(recorder.record, method_name=method_name), 192 | ) 193 | procs_idx += 1 194 | procs.close() 195 | procs.join() 196 | 197 | if curves_npy_path: 198 | make_dir(os.path.dirname(curves_npy_path)) 199 | np.save(curves_npy_path, recorder.curves) 200 | print(f"All curves has been saved in {curves_npy_path}") 201 | if metrics_npy_path: 202 | make_dir(os.path.dirname(metrics_npy_path)) 203 | np.save(metrics_npy_path, recorder.metrics) 204 | print(f"All metrics has been saved in {metrics_npy_path}") 205 | formatted_string = formatter_for_tabulate(recorder.metrics) 206 | print(f"All methods have been evaluated:\n{formatted_string}") 207 | 208 | 209 | def evaluate_data( 210 | names, 211 | num_bits, 212 | gt_root, 213 | gt_ext, 214 | pre_root, 215 | pre_prefix, 216 | pre_suffix, 217 | desc="", 218 | proc_idx=None, 219 | blocking=True, 220 | metric_names=None, 221 | ncols_tqdm=79, 222 | ): 223 | metric_recoder = MetricRecorder_V2(metric_names=metric_names) 224 | # https://github.com/tqdm/tqdm#parameters 225 | # https://github.com/tqdm/tqdm/blob/master/examples/parallel_bars.py 226 | tqdm_bar = tqdm( 227 | names, 228 | total=len(names), 229 | desc=desc, 230 | position=proc_idx, 231 | ncols=ncols_tqdm, 232 | lock_args=None if blocking else (False,), 233 | ) 234 | for name in tqdm_bar: 235 | gt, pre = get_gt_pre_with_name( 236 | gt_root=gt_root, 237 | pre_root=pre_root, 238 | img_name=name, 239 | pre_prefix=pre_prefix, 240 | pre_suffix=pre_suffix, 241 | gt_ext=gt_ext, 242 | to_normalize=False, 243 | ) 244 | metric_recoder.update(pre=pre, gt=gt) 245 | method_results = metric_recoder.show(num_bits=num_bits, return_ndarray=False) 246 | return method_results 247 | -------------------------------------------------------------------------------- /cos_eval_toolbox/metrics/draw_curves.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | from matplotlib import colors 7 | 8 | from utils.recorders import CurveDrawer 9 | 10 | 11 | def draw_curves( 12 | for_pr: bool = True, 13 | axes_setting: dict = None, 14 | curves_npy_path: list = None, 15 | row_num: int = 1, 16 | our_methods: list = None, 17 | method_aliases: OrderedDict = None, 18 | dataset_aliases: OrderedDict = None, 19 | style_cfg: dict = None, 20 | ncol_of_legend: int = 1, 21 | separated_legend: bool = False, 22 | sharey: bool = False, 23 | line_styles=("-", "--"), 24 | line_width=3, 25 | save_name=None, 26 | ): 27 | """A better curve painter! 28 | 29 | Args: 30 | for_pr (bool, optional): Plot for PR curves or FM curves. Defaults to True. 31 | axes_setting (dict, optional): Setting for axes. Defaults to None. 32 | curves_npy_path (list, optional): Paths of curve npy files. Defaults to None. 33 | row_num (int, optional): Number of rows. Defaults to 1. 34 | our_methods (list, optional): Names of our methods. Defaults to None. 35 | method_aliases (OrderedDict, optional): Aliases of methods. Defaults to None. 36 | dataset_aliases (OrderedDict, optional): Aliases of datasets. Defaults to None. 37 | style_cfg (dict, optional): Config file for the style of matplotlib. Defaults to None. 38 | ncol_of_legend (int, optional): Number of columns for the legend. Defaults to 1. 39 | separated_legend (bool, optional): Use the separated legend. Defaults to False. 40 | sharey (bool, optional): Use a shared y-axis. Defaults to False. 41 | line_styles (tuple, optional): Styles of lines. Defaults to ("-", "--"). 42 | line_width (int, optional): Width of lines. Defaults to 3. 43 | save_name (str, optional): Name or path (without the extension format). Defaults to None. 44 | """ 45 | mode = "pr" if for_pr else "fm" 46 | save_name = save_name or mode 47 | mode_axes_setting = axes_setting[mode] 48 | 49 | x_label, y_label = mode_axes_setting["x_label"], mode_axes_setting["y_label"] 50 | x_ticks, y_ticks = mode_axes_setting["x_ticks"], mode_axes_setting["y_ticks"] 51 | 52 | assert curves_npy_path 53 | if not isinstance(curves_npy_path, (list, tuple)): 54 | curves_npy_path = [curves_npy_path] 55 | 56 | curves = {} 57 | unique_method_names_from_npy = [] 58 | for p in curves_npy_path: 59 | single_curves = np.load(p, allow_pickle=True).item() 60 | for dataset_name, method_infos in single_curves.items(): 61 | curves.setdefault(dataset_name, {}) 62 | for method_name, method_info in method_infos.items(): 63 | curves[dataset_name][method_name] = method_info 64 | if method_name not in unique_method_names_from_npy: 65 | unique_method_names_from_npy.append(method_name) 66 | dataset_names_from_npy = list(curves.keys()) 67 | 68 | if dataset_aliases is None: 69 | dataset_aliases = {k: k for k in dataset_names_from_npy} 70 | else: 71 | for x in dataset_aliases.keys(): 72 | if x not in dataset_names_from_npy: 73 | raise ValueError(f"{x} must be contained in\n{dataset_names_from_npy}") 74 | 75 | if method_aliases is not None: 76 | target_unique_method_names = list(method_aliases.keys()) 77 | for x in target_unique_method_names: 78 | if x not in unique_method_names_from_npy: 79 | raise ValueError( 80 | f"{x} must be contained in\n{sorted(unique_method_names_from_npy)}" 81 | ) 82 | else: 83 | target_unique_method_names = unique_method_names_from_npy 84 | 85 | if our_methods is not None: 86 | for x in our_methods: 87 | if x not in target_unique_method_names: 88 | raise ValueError(f"{x} must be contained in\n{target_unique_method_names}") 89 | assert len(our_methods) <= len(line_styles) 90 | else: 91 | our_methods = [] 92 | 93 | # Give each method a unique color. 94 | color_table = sorted( 95 | [ 96 | color 97 | for name, color in colors.cnames.items() 98 | if name not in ["red", "white"] or not name.startswith("light") or "gray" in name 99 | ] 100 | ) 101 | unique_method_settings = {} 102 | for i, method_name in enumerate(target_unique_method_names): 103 | if method_name in our_methods: 104 | line_color = "red" 105 | line_style = line_styles[our_methods.index(method_name)] 106 | else: 107 | line_color = color_table[i] 108 | line_style = line_styles[i % 2] 109 | line_label = ( 110 | method_name if method_aliases is None else method_aliases.get(method_name, method_name) 111 | ) 112 | 113 | unique_method_settings[method_name] = { 114 | "line_color": line_color, 115 | "line_style": line_style, 116 | "line_label": line_label, 117 | "line_width": line_width, 118 | } 119 | 120 | curve_drawer = CurveDrawer( 121 | row_num=row_num, 122 | num_subplots=len(dataset_aliases), 123 | style_cfg=style_cfg, 124 | ncol_of_legend=ncol_of_legend, 125 | separated_legend=separated_legend, 126 | sharey=sharey, 127 | ) 128 | 129 | for idx, (dataset_name, dataset_alias) in enumerate(dataset_aliases.items()): 130 | dataset_results = curves[dataset_name] 131 | curve_drawer.set_axis_property( 132 | idx=idx, 133 | title=dataset_alias.upper(), 134 | x_label=x_label, 135 | y_label=y_label, 136 | x_ticks=x_ticks, 137 | y_ticks=y_ticks, 138 | ) 139 | 140 | for method_name, method_setting in unique_method_settings.items(): 141 | if method_name not in dataset_results: 142 | raise KeyError(f"{method_name} not in {sorted(dataset_results.keys())}") 143 | method_results = dataset_results[method_name] 144 | if mode == "pr": 145 | assert isinstance(method_results["p"], (list, tuple)) 146 | assert isinstance(method_results["r"], (list, tuple)) 147 | y_data = method_results["p"] 148 | x_data = method_results["r"] 149 | else: 150 | assert isinstance(method_results["fm"], (list, tuple)) 151 | y_data = method_results["fm"] 152 | x_data = np.linspace(0, 1, 256) 153 | 154 | curve_drawer.plot_at_axis( 155 | idx=idx, method_curve_setting=method_setting, x_data=x_data, y_data=y_data 156 | ) 157 | curve_drawer.save(path=save_name) 158 | -------------------------------------------------------------------------------- /cos_eval_toolbox/metrics/extra_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from py_sod_metrics.sod_metrics import _TYPE, _prepare_data 4 | 5 | 6 | class ExtraSegMeasure(object): 7 | def __init__(self): 8 | self.precisions = [] 9 | self.recalls = [] 10 | self.specificities = [] 11 | self.dices = [] 12 | self.fmeasures = [] 13 | self.ious = [] 14 | 15 | def step(self, pred: np.ndarray, gt: np.ndarray): 16 | pred, gt = _prepare_data(pred, gt) 17 | 18 | precisions, recalls, specificities, dices, fmeasures, ious = self.cal_metrics( 19 | pred=pred, gt=gt 20 | ) 21 | 22 | self.precisions.append(precisions) 23 | self.recalls.append(recalls) 24 | self.specificities.append(specificities) 25 | self.dices.append(dices) 26 | self.fmeasures.append(fmeasures) 27 | self.ious.append(ious) 28 | 29 | def cal_metrics(self, pred: np.ndarray, gt: np.ndarray) -> tuple: 30 | """ 31 | Calculate the corresponding precision and recall when the threshold changes from 0 to 255. 32 | 33 | These precisions and recalls can be used to obtain the mean F-measure, maximum F-measure, 34 | precision-recall curve and F-measure-threshold curve. 35 | 36 | For convenience, ``changeable_fms`` is provided here, which can be used directly to obtain 37 | the mean F-measure, maximum F-measure and F-measure-threshold curve. 38 | 39 | IoU = NumAnd / (FN + NumRec) 40 | PreFtem = NumAnd / NumRec 41 | RecallFtem = NumAnd / num_obj 42 | SpecifTem = TN / (TN + FP) 43 | Dice = 2 * NumAnd / (num_obj + num_pred) 44 | FmeasureF = (2.0 * PreFtem * RecallFtem) / (PreFtem + RecallFtem) 45 | 46 | :return: precisions, recalls, specificities, dices, fmeasures, ious 47 | """ 48 | # 1. 获取预测结果在真值前背景区域中的直方图 49 | pred: np.ndarray = (pred * 255).astype(np.uint8) 50 | bins: np.ndarray = np.linspace(0, 256, 257) 51 | tp_hist, _ = np.histogram(pred[gt], bins=bins) # 最后一个bin为[255, 256] 52 | fp_hist, _ = np.histogram(pred[~gt], bins=bins) 53 | # 2. 使用累积直方图(Cumulative Histogram)获得对应真值前背景中大于不同阈值的像素数量 54 | # 这里使用累加(cumsum)就是为了一次性得出 >=不同阈值 的像素数量, 这里仅计算了前景区域 55 | tp_w_thrs = np.cumsum(np.flip(tp_hist), axis=0) 56 | fp_w_thrs = np.cumsum(np.flip(fp_hist), axis=0) 57 | # 3. 使用不同阈值的结果计算对应的precision和recall 58 | # p和r的计算的真值是pred==1>==1,二者仅有分母不同,分母前者是pred==1,后者是gt==1 59 | # 为了同时计算不同阈值的结果,这里使用hsitogram&flip&cumsum 获得了不同各自的前景像素数量 60 | TPs = tp_w_thrs 61 | FPs = fp_w_thrs 62 | T = np.count_nonzero(gt) # T=TPs+FNs 63 | FNs = T - TPs 64 | Ps = TPs + FPs # p_w_thrs 65 | Ns = pred.size - Ps 66 | TNs = Ns - FNs 67 | 68 | ious = np.where(Ps + FNs == 0, 0, TPs / (Ps + FNs)) 69 | specificities = np.where(TNs + FPs == 0, 0, TNs / (TNs + FPs)) 70 | dices = np.where(TPs + FPs == 0, 0, 2 * TPs / (T + Ps)) 71 | precisions = np.where(Ps == 0, 0, TPs / Ps) 72 | recalls = np.where(TPs == 0, 0, TPs / T) 73 | fmeasures = np.where( 74 | precisions + recalls == 0, 0, (2 * precisions * recalls) / (precisions + recalls) 75 | ) 76 | return precisions, recalls, specificities, dices, fmeasures, ious 77 | 78 | def get_results(self) -> dict: 79 | """ 80 | Return the results about F-measure. 81 | 82 | :return: dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), pr=dict(p=precision, r=recall)) 83 | """ 84 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) 85 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) 86 | specificitiy = np.mean(np.array(self.specificities, dtype=_TYPE), axis=0) 87 | fmeasure = np.mean(np.array(self.fmeasures, dtype=_TYPE), axis=0) 88 | dice = np.mean(np.array(self.dices, dtype=_TYPE), axis=0) 89 | iou = np.mean(np.array(self.ious, dtype=_TYPE), axis=0) 90 | return dict(pre=precision, sen=recall, spec=specificitiy, fm=fmeasure, dice=dice, iou=iou) 91 | -------------------------------------------------------------------------------- /cos_eval_toolbox/output_COS/cos_curves.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/cos_eval_toolbox/output_COS/cos_curves.npy -------------------------------------------------------------------------------- /cos_eval_toolbox/output_COS/cos_metrics.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/cos_eval_toolbox/output_COS/cos_metrics.npy -------------------------------------------------------------------------------- /cos_eval_toolbox/output_COS/cos_results.txt: -------------------------------------------------------------------------------- 1 | 2 | ========>> Date: 2023-04-22 08:07:14.980489 <<======== 3 | 4 | ========>> Dataset: CAMO <<======== 5 | [TINet ] mae: 0.087 maxf: 0.745 avgf: 0.728 adpf: 0.729 maxe: 0.848 avge: 0.836 adpe: 0.847 sm: 0.781 wfm: 0.678 6 | [PreyNet ] mae: 0.077 maxf: 0.765 avgf: 0.757 adpf: 0.763 maxe: 0.857 avge: 0.842 adpe: 0.856 sm: 0.79 wfm: 0.708 7 | [TPRNet ] mae: 0.074 maxf: 0.785 avgf: 0.772 adpf: 0.777 maxe: 0.883 avge: 0.861 adpe: 0.88 sm: 0.807 wfm: 0.725 8 | [ERRNet ] mae: 0.085 maxf: 0.742 avgf: 0.729 adpf: 0.731 maxe: 0.858 avge: 0.842 adpe: 0.855 sm: 0.779 wfm: 0.679 9 | [PopNet ] mae: 0.077 maxf: 0.792 avgf: 0.784 adpf: 0.79 maxe: 0.874 avge: 0.859 adpe: 0.871 sm: 0.808 wfm: 0.744 10 | [FAPNet ] mae: 0.076 maxf: 0.792 avgf: 0.776 adpf: 0.776 maxe: 0.88 avge: 0.865 adpe: 0.877 sm: 0.815 wfm: 0.734 11 | [SMGL ] mae: 0.089 maxf: 0.739 avgf: 0.721 adpf: 0.733 maxe: 0.842 avge: 0.807 adpe: 0.85 sm: 0.772 wfm: 0.664 12 | [UGTR ] mae: 0.086 maxf: 0.754 avgf: 0.738 adpf: 0.749 maxe: 0.854 avge: 0.823 adpe: 0.861 sm: 0.785 wfm: 0.686 13 | [SINet ] mae: 0.092 maxf: 0.708 avgf: 0.702 adpf: 0.712 maxe: 0.829 avge: 0.804 adpe: 0.825 sm: 0.745 wfm: 0.644 14 | [RMGL ] mae: 0.088 maxf: 0.74 avgf: 0.726 adpf: 0.738 maxe: 0.842 avge: 0.812 adpe: 0.848 sm: 0.775 wfm: 0.673 15 | [PFNetPlus ] mae: 0.08 maxf: 0.77 avgf: 0.761 adpf: 0.764 maxe: 0.865 avge: 0.85 adpe: 0.862 sm: 0.791 wfm: 0.713 16 | [CRNet ] mae: 0.092 maxf: 0.707 avgf: 0.701 adpf: 0.709 maxe: 0.83 avge: 0.815 adpe: 0.829 sm: 0.735 wfm: 0.641 17 | [C2FNet ] mae: 0.08 maxf: 0.771 avgf: 0.762 adpf: 0.764 maxe: 0.864 avge: 0.854 adpe: 0.865 sm: 0.796 wfm: 0.719 18 | [NCHIT ] mae: 0.088 maxf: 0.739 avgf: 0.707 adpf: 0.723 maxe: 0.84 avge: 0.805 adpe: 0.841 sm: 0.784 wfm: 0.652 19 | [CamoFormerR] mae: 0.076 maxf: 0.813 avgf: 0.745 adpf: 0.735 maxe: 0.916 avge: 0.874 adpe: 0.863 sm: 0.816 wfm: 0.712 20 | [SegMaR ] mae: 0.071 maxf: 0.803 avgf: 0.795 adpf: 0.795 maxe: 0.884 avge: 0.874 adpe: 0.881 sm: 0.815 wfm: 0.753 21 | [DTINet ] mae: 0.05 maxf: 0.843 avgf: 0.823 adpf: 0.821 maxe: 0.927 avge: 0.916 adpe: 0.918 sm: 0.856 wfm: 0.796 22 | [CamoFormerC] mae: 0.05 maxf: 0.855 avgf: 0.842 adpf: 0.842 maxe: 0.92 avge: 0.913 adpe: 0.919 sm: 0.859 wfm: 0.812 23 | [CamoFormerP] mae: 0.046 maxf: 0.868 avgf: 0.854 adpf: 0.853 maxe: 0.938 avge: 0.929 adpe: 0.931 sm: 0.872 wfm: 0.831 24 | [CubeNet ] mae: 0.085 maxf: 0.75 avgf: 0.732 adpf: 0.734 maxe: 0.86 avge: 0.838 adpe: 0.852 sm: 0.788 wfm: 0.682 25 | [DGNetS ] mae: 0.063 maxf: 0.81 avgf: 0.792 adpf: 0.786 maxe: 0.907 avge: 0.893 adpe: 0.896 sm: 0.826 wfm: 0.754 26 | [HitNet ] mae: 0.055 maxf: 0.838 avgf: 0.831 adpf: 0.833 maxe: 0.91 avge: 0.906 adpe: 0.91 sm: 0.849 wfm: 0.809 27 | [LSR ] mae: 0.08 maxf: 0.753 avgf: 0.744 adpf: 0.756 maxe: 0.854 avge: 0.838 adpe: 0.859 sm: 0.787 wfm: 0.696 28 | [DGNet ] mae: 0.057 maxf: 0.822 avgf: 0.806 adpf: 0.804 maxe: 0.915 avge: 0.901 adpe: 0.906 sm: 0.839 wfm: 0.769 29 | [BAS ] mae: 0.096 maxf: 0.703 avgf: 0.692 adpf: 0.696 maxe: 0.808 avge: 0.796 adpe: 0.808 sm: 0.749 wfm: 0.646 30 | [FDNet ] mae: 0.063 maxf: 0.826 avgf: 0.807 adpf: 0.803 maxe: 0.908 avge: 0.895 adpe: 0.901 sm: 0.841 wfm: 0.775 31 | [SINetV2 ] mae: 0.07 maxf: 0.801 avgf: 0.782 adpf: 0.779 maxe: 0.895 avge: 0.882 adpe: 0.884 sm: 0.82 wfm: 0.743 32 | [PFNet ] mae: 0.085 maxf: 0.758 avgf: 0.746 adpf: 0.751 maxe: 0.855 avge: 0.841 adpe: 0.855 sm: 0.782 wfm: 0.695 33 | [JSCOD ] mae: 0.073 maxf: 0.779 avgf: 0.772 adpf: 0.779 maxe: 0.873 avge: 0.859 adpe: 0.872 sm: 0.8 wfm: 0.728 34 | [ZoomNet ] mae: 0.066 maxf: 0.805 avgf: 0.794 adpf: 0.792 maxe: 0.892 avge: 0.877 adpe: 0.883 sm: 0.82 wfm: 0.752 35 | [C2FNetV2 ] mae: 0.077 maxf: 0.779 avgf: 0.77 adpf: 0.777 maxe: 0.869 avge: 0.859 adpe: 0.869 sm: 0.799 wfm: 0.73 36 | [BSANet ] mae: 0.079 maxf: 0.77 avgf: 0.763 adpf: 0.768 maxe: 0.867 avge: 0.851 adpe: 0.866 sm: 0.794 wfm: 0.717 37 | [OCENet ] mae: 0.08 maxf: 0.777 avgf: 0.766 adpf: 0.776 maxe: 0.865 avge: 0.852 adpe: 0.866 sm: 0.802 wfm: 0.723 38 | [BGNet ] mae: 0.073 maxf: 0.799 avgf: 0.789 adpf: 0.786 maxe: 0.882 avge: 0.87 adpe: 0.876 sm: 0.812 wfm: 0.749 39 | [D2CNet ] mae: 0.087 maxf: 0.743 avgf: 0.735 adpf: 0.747 maxe: 0.838 avge: 0.818 adpe: 0.844 sm: 0.774 wfm: 0.683 40 | [CamoFormerS] mae: 0.043 maxf: 0.871 avgf: 0.856 adpf: 0.856 maxe: 0.938 avge: 0.93 adpe: 0.935 sm: 0.876 wfm: 0.832 41 | 42 | ========>> Dataset: COD10K <<======== 43 | [PopNet ] mae: 0.028 maxf: 0.802 avgf: 0.786 adpf: 0.771 maxe: 0.919 avge: 0.91 adpe: 0.91 sm: 0.851 wfm: 0.757 44 | [PreyNet ] mae: 0.034 maxf: 0.747 avgf: 0.736 adpf: 0.731 maxe: 0.891 avge: 0.881 adpe: 0.894 sm: 0.813 wfm: 0.697 45 | [TINet ] mae: 0.042 maxf: 0.712 avgf: 0.679 adpf: 0.652 maxe: 0.878 avge: 0.861 adpe: 0.848 sm: 0.793 wfm: 0.635 46 | [ERRNet ] mae: 0.043 maxf: 0.702 avgf: 0.675 adpf: 0.646 maxe: 0.886 avge: 0.867 adpe: 0.845 sm: 0.786 wfm: 0.63 47 | [UGTR ] mae: 0.035 maxf: 0.742 avgf: 0.712 adpf: 0.671 maxe: 0.891 avge: 0.853 adpe: 0.85 sm: 0.818 wfm: 0.667 48 | [FAPNet ] mae: 0.036 maxf: 0.758 avgf: 0.731 adpf: 0.707 maxe: 0.902 avge: 0.888 adpe: 0.875 sm: 0.822 wfm: 0.694 49 | [SMGL ] mae: 0.037 maxf: 0.733 avgf: 0.702 adpf: 0.667 maxe: 0.889 avge: 0.845 adpe: 0.851 sm: 0.811 wfm: 0.655 50 | [TPRNet ] mae: 0.036 maxf: 0.748 avgf: 0.724 adpf: 0.694 maxe: 0.903 avge: 0.887 adpe: 0.869 sm: 0.817 wfm: 0.683 51 | [CRNet ] mae: 0.049 maxf: 0.636 avgf: 0.633 adpf: 0.637 maxe: 0.845 avge: 0.832 adpe: 0.845 sm: 0.733 wfm: 0.576 52 | [SINet ] mae: 0.043 maxf: 0.691 avgf: 0.679 adpf: 0.667 maxe: 0.874 avge: 0.864 adpe: 0.867 sm: 0.776 wfm: 0.631 53 | [C2FNet ] mae: 0.036 maxf: 0.743 avgf: 0.723 adpf: 0.703 maxe: 0.9 avge: 0.89 adpe: 0.886 sm: 0.813 wfm: 0.686 54 | [RMGL ] mae: 0.035 maxf: 0.738 avgf: 0.711 adpf: 0.681 maxe: 0.89 avge: 0.852 adpe: 0.865 sm: 0.814 wfm: 0.666 55 | [PFNetPlus ] mae: 0.037 maxf: 0.734 avgf: 0.716 adpf: 0.698 maxe: 0.895 avge: 0.884 adpe: 0.88 sm: 0.806 wfm: 0.677 56 | [CamoFormerR] mae: 0.029 maxf: 0.786 avgf: 0.753 adpf: 0.721 maxe: 0.93 avge: 0.916 adpe: 0.9 sm: 0.838 wfm: 0.724 57 | [NCHIT ] mae: 0.046 maxf: 0.698 avgf: 0.649 adpf: 0.596 maxe: 0.879 avge: 0.819 adpe: 0.794 sm: 0.792 wfm: 0.591 58 | [SegMaR ] mae: 0.034 maxf: 0.774 avgf: 0.757 adpf: 0.739 maxe: 0.906 avge: 0.899 adpe: 0.893 sm: 0.833 wfm: 0.724 59 | [CamoFormerP] mae: 0.023 maxf: 0.829 avgf: 0.811 adpf: 0.794 maxe: 0.939 avge: 0.932 adpe: 0.931 sm: 0.869 wfm: 0.786 60 | [CamoFormerC] mae: 0.024 maxf: 0.818 avgf: 0.798 adpf: 0.778 maxe: 0.935 avge: 0.926 adpe: 0.926 sm: 0.86 wfm: 0.77 61 | [HitNet ] mae: 0.023 maxf: 0.838 avgf: 0.823 adpf: 0.818 maxe: 0.938 avge: 0.935 adpe: 0.936 sm: 0.871 wfm: 0.806 62 | [DGNetS ] mae: 0.036 maxf: 0.743 avgf: 0.71 adpf: 0.68 maxe: 0.905 avge: 0.888 adpe: 0.869 sm: 0.81 wfm: 0.672 63 | [LSR ] mae: 0.037 maxf: 0.732 avgf: 0.715 adpf: 0.699 maxe: 0.892 avge: 0.88 adpe: 0.883 sm: 0.804 wfm: 0.673 64 | [DGNet ] mae: 0.033 maxf: 0.759 avgf: 0.728 adpf: 0.698 maxe: 0.911 avge: 0.896 adpe: 0.879 sm: 0.822 wfm: 0.693 65 | [DTINet ] mae: 0.034 maxf: 0.754 avgf: 0.726 adpf: 0.702 maxe: 0.911 avge: 0.896 adpe: 0.881 sm: 0.824 wfm: 0.695 66 | [CubeNet ] mae: 0.041 maxf: 0.715 avgf: 0.692 adpf: 0.669 maxe: 0.883 avge: 0.865 adpe: 0.862 sm: 0.795 wfm: 0.643 67 | [FDNet ] mae: 0.03 maxf: 0.788 avgf: 0.757 adpf: 0.728 maxe: 0.935 avge: 0.919 adpe: 0.906 sm: 0.84 wfm: 0.729 68 | [JSCOD ] mae: 0.035 maxf: 0.738 avgf: 0.721 adpf: 0.705 maxe: 0.891 avge: 0.884 adpe: 0.882 sm: 0.809 wfm: 0.684 69 | [SINetV2 ] mae: 0.037 maxf: 0.752 avgf: 0.718 adpf: 0.682 maxe: 0.906 avge: 0.887 adpe: 0.864 sm: 0.815 wfm: 0.68 70 | [BAS ] mae: 0.038 maxf: 0.729 avgf: 0.715 adpf: 0.707 maxe: 0.87 avge: 0.855 adpe: 0.869 sm: 0.802 wfm: 0.677 71 | [PFNet ] mae: 0.04 maxf: 0.725 avgf: 0.701 adpf: 0.676 maxe: 0.89 avge: 0.877 adpe: 0.868 sm: 0.8 wfm: 0.66 72 | [ZoomNet ] mae: 0.029 maxf: 0.78 avgf: 0.766 adpf: 0.741 maxe: 0.911 avge: 0.888 adpe: 0.893 sm: 0.838 wfm: 0.729 73 | [C2FNetV2 ] mae: 0.036 maxf: 0.742 avgf: 0.725 adpf: 0.718 maxe: 0.896 avge: 0.887 adpe: 0.89 sm: 0.811 wfm: 0.691 74 | [BSANet ] mae: 0.034 maxf: 0.753 avgf: 0.738 adpf: 0.723 maxe: 0.901 avge: 0.891 adpe: 0.894 sm: 0.818 wfm: 0.699 75 | [BGNet ] mae: 0.033 maxf: 0.774 avgf: 0.753 adpf: 0.739 maxe: 0.911 avge: 0.901 adpe: 0.902 sm: 0.831 wfm: 0.722 76 | [OCENet ] mae: 0.033 maxf: 0.764 avgf: 0.741 adpf: 0.718 maxe: 0.905 avge: 0.894 adpe: 0.885 sm: 0.827 wfm: 0.707 77 | [CamoFormerS] mae: 0.024 maxf: 0.818 avgf: 0.799 adpf: 0.78 maxe: 0.941 avge: 0.931 adpe: 0.932 sm: 0.862 wfm: 0.772 78 | [D2CNet ] mae: 0.037 maxf: 0.736 avgf: 0.72 adpf: 0.702 maxe: 0.887 avge: 0.876 adpe: 0.879 sm: 0.807 wfm: 0.68 79 | 80 | ========>> Dataset: NC4K <<======== 81 | [TPRNet ] mae: 0.048 maxf: 0.82 avgf: 0.805 adpf: 0.798 maxe: 0.911 avge: 0.898 adpe: 0.901 sm: 0.846 wfm: 0.768 82 | [SMGL ] mae: 0.055 maxf: 0.797 avgf: 0.777 adpf: 0.771 maxe: 0.893 avge: 0.863 adpe: 0.885 sm: 0.829 wfm: 0.731 83 | [ERRNet ] mae: 0.054 maxf: 0.794 avgf: 0.778 adpf: 0.769 maxe: 0.901 avge: 0.887 adpe: 0.892 sm: 0.827 wfm: 0.737 84 | [PreyNet ] mae: 0.05 maxf: 0.811 avgf: 0.803 adpf: 0.805 maxe: 0.899 avge: 0.887 adpe: 0.899 sm: 0.834 wfm: 0.763 85 | [FAPNet ] mae: 0.047 maxf: 0.826 avgf: 0.81 adpf: 0.804 maxe: 0.91 avge: 0.899 adpe: 0.903 sm: 0.851 wfm: 0.775 86 | [UGTR ] mae: 0.052 maxf: 0.807 avgf: 0.787 adpf: 0.779 maxe: 0.899 avge: 0.874 adpe: 0.889 sm: 0.839 wfm: 0.747 87 | [PopNet ] mae: 0.042 maxf: 0.843 avgf: 0.833 adpf: 0.83 maxe: 0.919 avge: 0.909 adpe: 0.915 sm: 0.861 wfm: 0.802 88 | [TINet ] mae: 0.055 maxf: 0.793 avgf: 0.773 adpf: 0.766 maxe: 0.89 avge: 0.879 adpe: 0.882 sm: 0.829 wfm: 0.734 89 | [SINet ] mae: 0.058 maxf: 0.775 avgf: 0.769 adpf: 0.768 maxe: 0.883 avge: 0.871 adpe: 0.883 sm: 0.808 wfm: 0.723 90 | [SegMaR ] mae: 0.046 maxf: 0.826 avgf: 0.821 adpf: 0.821 maxe: 0.907 avge: 0.896 adpe: 0.905 sm: 0.841 wfm: 0.781 91 | [RMGL ] mae: 0.052 maxf: 0.8 avgf: 0.782 adpf: 0.778 maxe: 0.893 avge: 0.867 adpe: 0.89 sm: 0.833 wfm: 0.74 92 | [CamoFormerR] mae: 0.042 maxf: 0.83 avgf: 0.821 adpf: 0.82 maxe: 0.914 avge: 0.9 adpe: 0.913 sm: 0.855 wfm: 0.788 93 | [C2FNet ] mae: 0.049 maxf: 0.81 avgf: 0.795 adpf: 0.788 maxe: 0.904 avge: 0.897 adpe: 0.901 sm: 0.838 wfm: 0.762 94 | [CamoFormerP] mae: 0.03 maxf: 0.88 avgf: 0.868 adpf: 0.863 maxe: 0.946 avge: 0.939 adpe: 0.941 sm: 0.892 wfm: 0.847 95 | [NCHIT ] mae: 0.058 maxf: 0.792 avgf: 0.758 adpf: 0.751 maxe: 0.894 avge: 0.851 adpe: 0.872 sm: 0.83 wfm: 0.71 96 | [CamoFormerC] mae: 0.032 maxf: 0.87 avgf: 0.857 adpf: 0.851 maxe: 0.94 avge: 0.933 adpe: 0.937 sm: 0.883 wfm: 0.834 97 | [HitNet ] mae: 0.037 maxf: 0.863 avgf: 0.853 adpf: 0.854 maxe: 0.929 avge: 0.926 adpe: 0.928 sm: 0.875 wfm: 0.834 98 | [DTINet ] mae: 0.041 maxf: 0.836 avgf: 0.818 adpf: 0.809 maxe: 0.926 avge: 0.917 adpe: 0.914 sm: 0.863 wfm: 0.792 99 | [DGNet ] mae: 0.042 maxf: 0.833 avgf: 0.814 adpf: 0.803 maxe: 0.922 avge: 0.911 adpe: 0.91 sm: 0.857 wfm: 0.784 100 | [DGNetS ] mae: 0.047 maxf: 0.819 avgf: 0.799 adpf: 0.789 maxe: 0.913 avge: 0.902 adpe: 0.902 sm: 0.845 wfm: 0.764 101 | [LSR ] mae: 0.048 maxf: 0.815 avgf: 0.804 adpf: 0.802 maxe: 0.907 avge: 0.895 adpe: 0.904 sm: 0.84 wfm: 0.766 102 | [JSCOD ] mae: 0.047 maxf: 0.816 avgf: 0.806 adpf: 0.803 maxe: 0.907 avge: 0.898 adpe: 0.906 sm: 0.842 wfm: 0.771 103 | [SINetV2 ] mae: 0.048 maxf: 0.823 avgf: 0.805 adpf: 0.792 maxe: 0.914 avge: 0.903 adpe: 0.901 sm: 0.847 wfm: 0.77 104 | [BAS ] mae: 0.058 maxf: 0.782 avgf: 0.772 adpf: 0.767 maxe: 0.872 avge: 0.859 adpe: 0.868 sm: 0.817 wfm: 0.732 105 | [FDNet ] mae: 0.052 maxf: 0.804 avgf: 0.784 adpf: 0.774 maxe: 0.905 avge: 0.893 adpe: 0.895 sm: 0.834 wfm: 0.75 106 | [PFNet ] mae: 0.053 maxf: 0.799 avgf: 0.784 adpf: 0.779 maxe: 0.898 avge: 0.887 adpe: 0.894 sm: 0.829 wfm: 0.745 107 | [ZoomNet ] mae: 0.043 maxf: 0.828 avgf: 0.818 adpf: 0.814 maxe: 0.912 avge: 0.896 adpe: 0.907 sm: 0.853 wfm: 0.784 108 | [C2FNetV2 ] mae: 0.048 maxf: 0.814 avgf: 0.802 adpf: 0.799 maxe: 0.904 avge: 0.896 adpe: 0.9 sm: 0.84 wfm: 0.77 109 | [BSANet ] mae: 0.048 maxf: 0.817 avgf: 0.808 adpf: 0.805 maxe: 0.907 avge: 0.897 adpe: 0.906 sm: 0.841 wfm: 0.771 110 | [BGNet ] mae: 0.044 maxf: 0.833 avgf: 0.82 adpf: 0.813 maxe: 0.916 avge: 0.907 adpe: 0.911 sm: 0.851 wfm: 0.788 111 | [OCENet ] mae: 0.045 maxf: 0.831 avgf: 0.818 adpf: 0.812 maxe: 0.913 avge: 0.902 adpe: 0.908 sm: 0.853 wfm: 0.785 112 | [CamoFormerS] mae: 0.031 maxf: 0.877 avgf: 0.863 adpf: 0.857 maxe: 0.946 avge: 0.937 adpe: 0.941 sm: 0.888 wfm: 0.84 113 | -------------------------------------------------------------------------------- /cos_eval_toolbox/plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import textwrap 4 | 5 | import numpy as np 6 | import yaml 7 | 8 | from metrics import draw_curves 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser( 13 | description=textwrap.dedent( 14 | r""" 15 | INCLUDE: 16 | 17 | - Fm Curve 18 | - PR Curves 19 | 20 | NOTE: 21 | 22 | - Our method automatically calculates the intersection of `pre` and `gt`. 23 | - Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext` 24 | 25 | EXAMPLES: 26 | 27 | python plot.py \ 28 | --curves-npy output/rgbd-othermethods-curves.npy output/rgbd-ours-curves.npy \ # use the information from these npy files to draw curves 29 | --num-rows 1 \ # set the number of rows of the figure to 2 30 | --style-cfg configs/single_row_style.yml \ # specific the configuration file for the style of matplotlib 31 | --num-col-legend 1 \ # set the number of the columns of the legend in the figure to 1 32 | --mode pr \ # draw `pr` curves 33 | --our-methods Ours \ # specific the names of our own methods, they must be contained in the npy file 34 | --save-name ./output/rgbd-pr-curves # save the figure into `./output/rgbd-pr-curves.`, where `ext_name` will be specificed to the item `savefig.format` in the `--style-cfg` 35 | 36 | python plot.py \ 37 | --curves-npy output/rgbd-othermethods-curves.npy output/rgbd-ours-curves.npy \ 38 | --num-rows 2 \ 39 | --style-cfg configs/two_row_style.yml \ 40 | --num-col-legend 2 \ 41 | --separated-legend \ # use a separated legend 42 | --mode fm \ # draw `fm` curves 43 | --our-methods OursV0 OursV1 \ # specific the names of our own methods, they must be contained in the npy file 44 | --save-name output/rgbd-fm \ 45 | --alias-yaml configs/rgbd_aliases.yaml # aliases corresponding to methods and datasets you want to use 46 | """ 47 | ), 48 | formatter_class=argparse.RawTextHelpFormatter, 49 | ) 50 | parser.add_argument("--alias-yaml", type=str, help="Yaml file for datasets and methods alias.") 51 | parser.add_argument( 52 | "--style-cfg", 53 | type=str, 54 | required=True, 55 | help="Yaml file for plotting curves.", 56 | ) 57 | parser.add_argument( 58 | "--curves-npys", 59 | required=True, 60 | type=str, 61 | nargs="+", 62 | help="Npy file for saving curve results.", 63 | ) 64 | parser.add_argument( 65 | "--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it." 66 | ) 67 | parser.add_argument( 68 | "--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1" 69 | ) 70 | parser.add_argument( 71 | "--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1" 72 | ) 73 | parser.add_argument( 74 | "--mode", 75 | type=str, 76 | choices=["pr", "fm"], 77 | default="pr", 78 | help="Mode for plotting. Default: pr", 79 | ) 80 | parser.add_argument( 81 | "--separated-legend", action="store_true", help="Use the separated legend." 82 | ) 83 | parser.add_argument("--sharey", action="store_true", help="Use the shared y-axis.") 84 | parser.add_argument("--save-name", type=str, help="the exported file path") 85 | args = parser.parse_args() 86 | 87 | return args 88 | 89 | 90 | def main(args): 91 | method_aliases = dataset_aliases = None 92 | if args.alias_yaml: 93 | with open(args.alias_yaml, mode="r", encoding="utf-8") as f: 94 | aliases = yaml.safe_load(f) 95 | method_aliases = aliases.get("method") 96 | dataset_aliases = aliases.get("dataset") 97 | 98 | draw_curves.draw_curves( 99 | for_pr=args.mode == "pr", 100 | # 不同曲线的绘图配置 101 | axes_setting={ 102 | # pr曲线的配置 103 | "pr": { 104 | "x_label": "Recall", 105 | "y_label": "Precision", 106 | "x_ticks": np.linspace(0.5, 1, 6), 107 | "y_ticks": np.linspace(0.7, 1, 6), 108 | }, 109 | # fm曲线的配置 110 | "fm": { 111 | "x_label": "Threshold", 112 | "y_label": r"F$_{\beta}$", 113 | "x_ticks": np.linspace(0, 1, 6), 114 | "y_ticks": np.linspace(0.6, 1, 6), 115 | }, 116 | }, 117 | curves_npy_path=args.curves_npys, 118 | row_num=args.num_rows, 119 | method_aliases=method_aliases, 120 | dataset_aliases=dataset_aliases, 121 | style_cfg=args.style_cfg, 122 | ncol_of_legend=args.num_col_legend, 123 | separated_legend=args.separated_legend, 124 | sharey=args.sharey, 125 | our_methods=args.our_methods, 126 | save_name=args.save_name, 127 | ) 128 | 129 | 130 | if __name__ == "__main__": 131 | args = get_args() 132 | main(args) 133 | -------------------------------------------------------------------------------- /cos_eval_toolbox/pyproject.toml: -------------------------------------------------------------------------------- 1 | # https://github.com/LongTengDao/TOML/ 2 | 3 | [tool.isort] 4 | # https://pycqa.github.io/isort/docs/configuration/options/ 5 | profile = "black" 6 | multi_line_output = 3 7 | filter_files = true 8 | supported_extensions = "py" 9 | 10 | [tool.black] 11 | line-length = 99 12 | include = '\.pyi?$' 13 | exclude = ''' 14 | /( 15 | \.eggs 16 | | \.git 17 | | \.idea 18 | | \.vscode 19 | | \.hg 20 | | \.mypy_cache 21 | | \.tox 22 | | \.venv 23 | | _build 24 | | buck-out 25 | | build 26 | | dist 27 | | output 28 | )/ 29 | ''' 30 | -------------------------------------------------------------------------------- /cos_eval_toolbox/requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | 3 | # PySODEvalToolkit/utils/misc.py: 6 4 | Pillow == 8.1.2 5 | 6 | # PySODEvalToolkit/plot.py: 6 7 | # PySODEvalToolkit/tools/converter.py: 10 8 | PyYAML == 5.4.1 9 | 10 | # PySODEvalToolkit/tools/markdown2html.py: 4 11 | markdown2 == 2.4.1 12 | 13 | # PySODEvalToolkit/metrics/draw_curves.py: 6 14 | # PySODEvalToolkit/utils/generate_info.py: 7 15 | # PySODEvalToolkit/utils/recorders/curve_drawer.py: 8 16 | matplotlib == 3.4.2 17 | 18 | # PySODEvalToolkit/metrics/cal_cosod_matrics.py: 6 19 | # PySODEvalToolkit/metrics/cal_sod_matrics.py: 8 20 | # PySODEvalToolkit/metrics/draw_curves.py: 5 21 | # PySODEvalToolkit/metrics/extra_metrics.py: 2 22 | # PySODEvalToolkit/plot.py: 5 23 | # PySODEvalToolkit/tools/append_results.py: 4 24 | # PySODEvalToolkit/tools/converter.py: 9 25 | # PySODEvalToolkit/untracked/collect_results.py: 8 26 | # PySODEvalToolkit/untracked/find_bad_prediction.py: 4 27 | # PySODEvalToolkit/untracked/modify_npy.py: 1 28 | # PySODEvalToolkit/untracked/plot_results.py: 4 29 | # PySODEvalToolkit/untracked/plot_sota_cmp.py: 4 30 | # PySODEvalToolkit/utils/misc.py: 5 31 | # PySODEvalToolkit/utils/recorders/metric_recorder.py: 6 32 | numpy == 1.19.2 33 | 34 | # PySODEvalToolkit/untracked/collect_results.py: 7 35 | # PySODEvalToolkit/untracked/plot_sota_cmp.py: 3 36 | # PySODEvalToolkit/utils/misc.py: 4 37 | opencv_python_headless == 4.5.1.48 38 | 39 | # PySODEvalToolkit/utils/recorders/excel_recorder.py: 9,10,11 40 | openpyxl == 3.0.7 41 | 42 | # PySODEvalToolkit/metrics/extra_metrics.py: 3 43 | # PySODEvalToolkit/untracked/collect_results.py: 9 44 | # PySODEvalToolkit/utils/recorders/metric_recorder.py: 7 45 | pysodmetrics == 1.3.0 46 | 47 | # PySODEvalToolkit/utils/print_formatter.py: 3 48 | tabulate == 0.8.9 49 | 50 | # PySODEvalToolkit/metrics/cal_cosod_matrics.py: 7 51 | # PySODEvalToolkit/metrics/cal_sod_matrics.py: 9 52 | # PySODEvalToolkit/untracked/collect_results.py: 10 53 | # PySODEvalToolkit/untracked/find_bad_prediction.py: 5 54 | tqdm == 4.59.0 55 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/.backup/individual_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/09/13 3 | # @Author : Johnson-Chou 4 | # @Email : johnson111788@gmail.com 5 | # @FileName : metrics.py 6 | # @Reference: https://github.com/lartpang/PySODMetrics and https://github.com/mczhuge/SOCToolbox 7 | 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | from sklearn import metrics 12 | from torch.autograd import Function 13 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 14 | 15 | 16 | _EPS = np.spacing(1) 17 | _TYPE = np.float64 18 | 19 | 20 | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple: 21 | gt = gt > 128 22 | pred = pred / 255 23 | if pred.max() != pred.min(): 24 | pred = (pred - pred.min()) / (pred.max() - pred.min()) 25 | return pred, gt 26 | 27 | 28 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float: 29 | return min(2 * matrix.mean(), max_value) 30 | 31 | 32 | class Fmeasure(object): 33 | def __init__(self, beta: float = 0.3): 34 | self.beta = beta 35 | self.precisions = [] 36 | self.recalls = [] 37 | self.adaptive_fms = [] 38 | self.changeable_fms = [] 39 | 40 | def step(self, pred: np.ndarray, gt: np.ndarray): 41 | pred, gt = _prepare_data(pred, gt) 42 | 43 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt) 44 | self.adaptive_fms.append(adaptive_fm) 45 | 46 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) 47 | self.precisions.append(precisions) 48 | self.recalls.append(recalls) 49 | self.changeable_fms.append(changeable_fms) 50 | 51 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float: 52 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 53 | binary_predcition = pred >= adaptive_threshold 54 | area_intersection = binary_predcition[gt].sum() 55 | if area_intersection == 0: 56 | adaptive_fm = 0 57 | else: 58 | pre = area_intersection / np.count_nonzero(binary_predcition) 59 | rec = area_intersection / np.count_nonzero(gt) 60 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec) 61 | return adaptive_fm 62 | 63 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: 64 | pred = (pred * 255).astype(np.uint8) 65 | bins = np.linspace(0, 256, 257) 66 | fg_hist, _ = np.histogram(pred[gt], bins=bins) 67 | bg_hist, _ = np.histogram(pred[~gt], bins=bins) 68 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) 69 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) 70 | TPs = fg_w_thrs 71 | Ps = fg_w_thrs + bg_w_thrs 72 | Ps[Ps == 0] = 1 73 | T = max(np.count_nonzero(gt), 1) 74 | precisions = TPs / Ps 75 | recalls = TPs / T 76 | numerator = (1 + self.beta) * precisions * recalls 77 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) 78 | changeable_fms = numerator / denominator 79 | return precisions, recalls, changeable_fms 80 | 81 | def get_results(self) -> dict: 82 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE)) 83 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0) 84 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256 85 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256 86 | return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), 87 | pr=dict(p=precision, r=recall)) 88 | 89 | 90 | class MAE(object): 91 | def __init__(self): 92 | self.maes = [] 93 | 94 | def step(self, pred: np.ndarray, gt: np.ndarray): 95 | pred, gt = _prepare_data(pred, gt) 96 | 97 | mae = self.cal_mae(pred, gt) 98 | self.maes.append(mae) 99 | 100 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> float: 101 | mae = np.mean(np.abs(pred - gt)) 102 | return mae 103 | 104 | def get_results(self) -> dict: 105 | mae = np.mean(np.array(self.maes, _TYPE)) 106 | return dict(mae=mae) 107 | 108 | 109 | class Smeasure(object): 110 | def __init__(self, alpha: float = 0.5): 111 | self.sms = [] 112 | self.alpha = alpha 113 | 114 | def step(self, pred: np.ndarray, gt: np.ndarray): 115 | pred, gt = _prepare_data(pred=pred, gt=gt) 116 | 117 | sm = self.cal_sm(pred, gt) 118 | self.sms.append(sm) 119 | 120 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: 121 | y = np.mean(gt) 122 | if y == 0: 123 | sm = 1 - np.mean(pred) 124 | elif y == 1: 125 | sm = np.mean(pred) 126 | else: 127 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 128 | sm = max(0, sm) 129 | return sm 130 | 131 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float: 132 | fg = pred * gt 133 | bg = (1 - pred) * (1 - gt) 134 | u = np.mean(gt) 135 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) 136 | return object_score 137 | 138 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: 139 | x = np.mean(pred[gt == 1]) 140 | sigma_x = np.std(pred[gt == 1], ddof=1) 141 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) 142 | return score 143 | 144 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float: 145 | x, y = self.centroid(gt) 146 | part_info = self.divide_with_xy(pred, gt, x, y) 147 | w1, w2, w3, w4 = part_info['weight'] 148 | pred1, pred2, pred3, pred4 = part_info['pred'] 149 | gt1, gt2, gt3, gt4 = part_info['gt'] 150 | score1 = self.ssim(pred1, gt1) 151 | score2 = self.ssim(pred2, gt2) 152 | score3 = self.ssim(pred3, gt3) 153 | score4 = self.ssim(pred4, gt4) 154 | 155 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 156 | 157 | def centroid(self, matrix: np.ndarray) -> tuple: 158 | h, w = matrix.shape 159 | if matrix.sum() == 0: 160 | x = np.round(w / 2) 161 | y = np.round(h / 2) 162 | else: 163 | area_object = np.sum(matrix) 164 | row_ids = np.arange(h) 165 | col_ids = np.arange(w) 166 | x = np.round(np.sum(np.sum(matrix, axis=0) * col_ids) / area_object) 167 | y = np.round(np.sum(np.sum(matrix, axis=1) * row_ids) / area_object) 168 | return int(x) + 1, int(y) + 1 169 | 170 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict: 171 | h, w = gt.shape 172 | area = h * w 173 | 174 | gt_LT = gt[0:y, 0:x] 175 | gt_RT = gt[0:y, x:w] 176 | gt_LB = gt[y:h, 0:x] 177 | gt_RB = gt[y:h, x:w] 178 | 179 | pred_LT = pred[0:y, 0:x] 180 | pred_RT = pred[0:y, x:w] 181 | pred_LB = pred[y:h, 0:x] 182 | pred_RB = pred[y:h, x:w] 183 | 184 | w1 = x * y / area 185 | w2 = y * (w - x) / area 186 | w3 = (h - y) * x / area 187 | w4 = 1 - w1 - w2 - w3 188 | 189 | return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB), 190 | pred=(pred_LT, pred_RT, pred_LB, pred_RB), 191 | weight=(w1, w2, w3, w4)) 192 | 193 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: 194 | h, w = pred.shape 195 | N = h * w 196 | 197 | x = np.mean(pred) 198 | y = np.mean(gt) 199 | 200 | sigma_x = np.sum((pred - x) ** 2) / (N - 1) 201 | sigma_y = np.sum((gt - y) ** 2) / (N - 1) 202 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) 203 | 204 | alpha = 4 * x * y * sigma_xy 205 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) 206 | 207 | if alpha != 0: 208 | score = alpha / (beta + _EPS) 209 | elif alpha == 0 and beta == 0: 210 | score = 1 211 | else: 212 | score = 0 213 | return score 214 | 215 | def get_results(self) -> dict: 216 | sm = np.mean(np.array(self.sms, dtype=_TYPE)) 217 | return dict(sm=sm) 218 | 219 | 220 | class Emeasure(object): 221 | def __init__(self): 222 | self.adaptive_ems = [] 223 | self.changeable_ems = [] 224 | 225 | def step(self, pred: np.ndarray, gt: np.ndarray): 226 | pred, gt = _prepare_data(pred=pred, gt=gt) 227 | self.gt_fg_numel = np.count_nonzero(gt) 228 | self.gt_size = gt.shape[0] * gt.shape[1] 229 | 230 | changeable_ems = self.cal_changeable_em(pred, gt) 231 | self.changeable_ems.append(changeable_ems) 232 | adaptive_em = self.cal_adaptive_em(pred, gt) 233 | self.adaptive_ems.append(adaptive_em) 234 | 235 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: 236 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 237 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) 238 | return adaptive_em 239 | 240 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 241 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt) 242 | return changeable_ems 243 | 244 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: 245 | binarized_pred = pred >= threshold 246 | fg_fg_numel = np.count_nonzero(binarized_pred & gt) 247 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) 248 | 249 | fg___numel = fg_fg_numel + fg_bg_numel 250 | bg___numel = self.gt_size - fg___numel 251 | 252 | if self.gt_fg_numel == 0: 253 | enhanced_matrix_sum = bg___numel 254 | elif self.gt_fg_numel == self.gt_size: 255 | enhanced_matrix_sum = fg___numel 256 | else: 257 | parts_numel, combinations = self.generate_parts_numel_combinations( 258 | fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel, 259 | pred_fg_numel=fg___numel, pred_bg_numel=bg___numel, 260 | ) 261 | 262 | results_parts = [] 263 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)): 264 | align_matrix_value = 2 * (combination[0] * combination[1]) / \ 265 | (combination[0] ** 2 + combination[1] ** 2 + _EPS) 266 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 267 | results_parts.append(enhanced_matrix_value * part_numel) 268 | enhanced_matrix_sum = sum(results_parts) 269 | 270 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 271 | return em 272 | 273 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 274 | pred = (pred * 255).astype(np.uint8) 275 | bins = np.linspace(0, 256, 257) 276 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins) 277 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins) 278 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0) 279 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0) 280 | 281 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs 282 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs 283 | 284 | if self.gt_fg_numel == 0: 285 | enhanced_matrix_sum = bg___numel_w_thrs 286 | elif self.gt_fg_numel == self.gt_size: 287 | enhanced_matrix_sum = fg___numel_w_thrs 288 | else: 289 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations( 290 | fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs, 291 | pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs, 292 | ) 293 | 294 | results_parts = np.empty(shape=(4, 256), dtype=np.float64) 295 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)): 296 | align_matrix_value = 2 * (combination[0] * combination[1]) / \ 297 | (combination[0] ** 2 + combination[1] ** 2 + _EPS) 298 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 299 | results_parts[i] = enhanced_matrix_value * part_numel 300 | enhanced_matrix_sum = results_parts.sum(axis=0) 301 | 302 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 303 | return em 304 | 305 | def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel): 306 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel 307 | bg_bg_numel = pred_bg_numel - bg_fg_numel 308 | 309 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] 310 | 311 | mean_pred_value = pred_fg_numel / self.gt_size 312 | mean_gt_value = self.gt_fg_numel / self.gt_size 313 | 314 | demeaned_pred_fg_value = 1 - mean_pred_value 315 | demeaned_pred_bg_value = 0 - mean_pred_value 316 | demeaned_gt_fg_value = 1 - mean_gt_value 317 | demeaned_gt_bg_value = 0 - mean_gt_value 318 | 319 | combinations = [ 320 | (demeaned_pred_fg_value, demeaned_gt_fg_value), 321 | (demeaned_pred_fg_value, demeaned_gt_bg_value), 322 | (demeaned_pred_bg_value, demeaned_gt_fg_value), 323 | (demeaned_pred_bg_value, demeaned_gt_bg_value) 324 | ] 325 | return parts_numel, combinations 326 | 327 | def get_results(self) -> dict: 328 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE)) 329 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0) 330 | return dict(em=dict(adp=adaptive_em, curve=changeable_em)) 331 | 332 | 333 | class WeightedFmeasure(object): 334 | def __init__(self, beta: float = 1): 335 | self.beta = beta 336 | self.weighted_fms = [] 337 | 338 | def step(self, pred: np.ndarray, gt: np.ndarray): 339 | pred, gt = _prepare_data(pred=pred, gt=gt) 340 | 341 | if np.all(~gt): 342 | wfm = 0 343 | else: 344 | wfm = self.cal_wfm(pred, gt) 345 | self.weighted_fms.append(wfm) 346 | 347 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float: 348 | # [Dst,IDXT] = bwdist(dGT); 349 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 350 | 351 | # %Pixel dependency 352 | # E = abs(FG-dGT); 353 | E = np.abs(pred - gt) 354 | Et = np.copy(E) 355 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 356 | 357 | # K = fspecial('gaussian',7,5); 358 | # EA = imfilter(Et,K); 359 | K = self.matlab_style_gauss2D((7, 7), sigma=5) 360 | EA = convolve(Et, weights=K, mode="constant", cval=0) 361 | # MIN_E_EA = E; 362 | # MIN_E_EA(GT & EA np.ndarray: 382 | """ 383 | 2D gaussian mask - should give the same result as MATLAB's 384 | fspecial('gaussian',[shape],[sigma]) 385 | """ 386 | m, n = [(ss - 1) / 2 for ss in shape] 387 | y, x = np.ogrid[-m: m + 1, -n: n + 1] 388 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 389 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 390 | sumh = h.sum() 391 | if sumh != 0: 392 | h /= sumh 393 | return h 394 | 395 | def get_results(self) -> dict: 396 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE)) 397 | return dict(wfm=weighted_fm) 398 | 399 | 400 | class DICE(object): 401 | def __init__(self): 402 | self.dice = [] 403 | 404 | def step(self, pred: np.ndarray, gt: np.ndarray): 405 | # pred, gt = _prepare_data(pred=pred, gt=gt) 406 | dice = self.cal_dice(pred, gt) 407 | self.dice.append(dice) 408 | return dice 409 | 410 | def cal_dice(self, pred: np.ndarray, gt: np.ndarray): 411 | # N = gt.size(0) 412 | smooth = 1 413 | 414 | pred_flat = pred.reshape(-1) 415 | gt_flat = gt.reshape(-1) 416 | 417 | intersection = pred_flat * gt_flat 418 | 419 | dice = 2 * (intersection.sum() + smooth) / (pred_flat.sum() + gt_flat.sum() + smooth) 420 | dice = 1 - dice.sum() 421 | 422 | return dice 423 | 424 | def get_results(self): 425 | dice = np.mean(np.array(self.dice, dtype=_TYPE)) 426 | return dice 427 | 428 | 429 | class BinarizedF(Function): 430 | ''' 431 | @ Reference: https://blog.csdn.net/weixin_42696356/article/details/100899711 432 | ''' 433 | @staticmethod 434 | def forward(ctx, input): 435 | ctx.save_for_backward(input) 436 | a = torch.ones_like(input) 437 | b = torch.zeros_like(input) 438 | output = torch.where(input>=0.5,a,b) 439 | return output 440 | 441 | @staticmethod 442 | def backward(ctx, output_grad): 443 | input, = ctx.saved_tensors 444 | input_abs = torch.abs(input) 445 | ones = torch.ones_like(input) 446 | zeros = torch.zeros_like(input) 447 | input_grad = torch.where(input_abs<=1,ones, zeros) 448 | return input_grad 449 | 450 | 451 | class BinarizedModule(nn.Module): 452 | ''' 453 | @ Reference: https://www.flyai.com/article/art7714fcddbf30a9ff5a35633f?type=e 454 | ''' 455 | def __init__(self): 456 | super(BinarizedModule, self).__init__() 457 | self.BF = BinarizedF() 458 | 459 | def forward(self, input): 460 | output =self.BF.apply(torch.Tensor(input)) 461 | return output 462 | 463 | 464 | class IoU(object): 465 | def __init__(self): 466 | self.iou = [] 467 | self.n_classes = 2 468 | self.bin = BinarizedModule() 469 | 470 | def step(self, pred: np.ndarray, gt: np.ndarray): 471 | iou = self.cal_iou(pred, gt) 472 | self.iou.append(iou) 473 | return iou 474 | 475 | def _cal_iou(self, pred: np.ndarray, gt: np.ndarray): 476 | def cal_cm(y_true, y_pred): 477 | y_true = y_true.reshape(1, -1).squeeze() 478 | y_pred = y_pred.reshape(1, -1).squeeze() 479 | cm = metrics.confusion_matrix(y_true, y_pred) 480 | return cm 481 | pred = self.bin(pred) 482 | confusion_matrix = cal_cm(pred, gt) 483 | intersection = np.diag(confusion_matrix) # 交集 484 | union = np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) - np.diag(confusion_matrix) # 并集 485 | IoU = intersection / union # 交并比,即IoU 486 | return IoU 487 | 488 | def cal_iou(self, pred: np.ndarray, target: np.ndarray): 489 | Iand1 = np.sum(target * pred) 490 | Ior1 = np.sum(target) + np.sum(pred) - Iand1 491 | IoU1 = Iand1 / Ior1 492 | return IoU1 493 | 494 | def get_results(self): 495 | iou = np.mean(np.array(self.iou, dtype=_TYPE)) 496 | return iou -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/.backup/ranking_per_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import individual_metrics as Measure 4 | 5 | model_list = ['2023-arXiv-CamoFormerS', '2023-arXiv-CamoFormerP', '2023-AAAI-HitNet', '2023-arXiv-PopNet', '2022-CVPR-FDNet'] 6 | 7 | pred_root = '/home/users/u7248002/project/PySODEvalToolkit/cos_benchmark/COS-Benchmarking' 8 | gt_root = '/home/users/u7248002/project/PySODEvalToolkit/cos_benchmark/TestDataset/COD10K/GT' 9 | 10 | SM = Measure.Smeasure() 11 | EM = Measure.Emeasure() 12 | MAE = Measure.MAE() 13 | 14 | for file_name in os.listdir(gt_root): 15 | s_list, e_list, r_mae_list, total_list = [], [], [], [] 16 | t_counter = 0 17 | for model_name in model_list: 18 | pred_pth = os.path.join(pred_root, model_name, 'COD10K', file_name) 19 | gt_pth = os.path.join(gt_root, file_name) 20 | 21 | assert os.path.isfile(gt_pth) and os.path.isfile(pred_pth) 22 | pred_ary = cv2.imread(pred_pth, cv2.IMREAD_GRAYSCALE) 23 | gt_ary = cv2.imread(gt_pth, cv2.IMREAD_GRAYSCALE) 24 | 25 | assert len(pred_ary.shape) == 2 and len(gt_ary.shape) == 2 26 | if pred_ary.shape != gt_ary.shape: 27 | pred_ary = cv2.resize(pred_ary, (gt_ary.shape[1], gt_ary.shape[0]), cv2.INTER_NEAREST) 28 | 29 | SM.step(pred=pred_ary, gt=gt_ary) 30 | EM.step(pred=pred_ary, gt=gt_ary) 31 | MAE.step(pred=pred_ary, gt=gt_ary) 32 | 33 | sm = SM.get_results()['sm'] 34 | em = EM.get_results()['em']['curve'].max() 35 | r_mae = 1 - MAE.get_results()['mae'] 36 | 37 | totol_score = sm + em + r_mae 38 | 39 | s_list.append(sm) 40 | e_list.append(em) 41 | r_mae_list.append(r_mae) 42 | total_list.append(totol_score) 43 | t_counter += totol_score 44 | 45 | print(f'{file_name} | {total_list[0]} | {total_list[1]} | {total_list[2]} | {total_list[3]} | {total_list[4]} | {t_counter}') -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/append_results.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | 4 | import numpy as np 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser( 9 | description="""A simple tool for merging two npy file. 10 | Patch the method items corresponding to the `--method-names` and `--dataset-names` of `--new-npy` into `--old-npy`, 11 | and output the whole container to `--out-npy`. 12 | """ 13 | ) 14 | parser.add_argument("--old-npy", type=str, required=True) 15 | parser.add_argument("--new-npy", type=str, required=True) 16 | parser.add_argument("--method-names", type=str, nargs="+") 17 | parser.add_argument("--dataset-names", type=str, nargs="+") 18 | parser.add_argument("--out-npy", type=str, required=True) 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | args = get_args() 25 | new_npy: dict = np.load(args.new_npy, allow_pickle=True).item() 26 | old_npy: dict = np.load(args.old_npy, allow_pickle=True).item() 27 | 28 | for dataset_name, methods_info in new_npy.items(): 29 | if args.dataset_names and dataset_name not in args.dataset_names: 30 | continue 31 | 32 | print(f"[PROCESSING INFORMATION ABOUT DATASET {dataset_name}...]") 33 | old_methods_info = old_npy.get(dataset_name) 34 | if not old_methods_info: 35 | raise KeyError(f"{old_npy} doesn't contain the information about {dataset_name}.") 36 | 37 | print(f"OLD_NPY: {list(old_methods_info.keys())}") 38 | print(f"NEW_NPY: {list(methods_info.keys())}") 39 | 40 | for method_name, method_info in methods_info.items(): 41 | if args.method_names and method_name not in args.method_names: 42 | continue 43 | 44 | if method_name not in old_npy[dataset_name]: 45 | old_methods_info[method_name] = method_info 46 | print(f"MERGED_NPY: {list(old_methods_info.keys())}") 47 | 48 | np.save(file=args.out_npy, arr=old_npy) 49 | print(f"THE MERGED_NPY IS SAVED INTO {args.out_npy}") 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/cal_avg_resolution.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/22 3 | # @Author : Daniel Ji 4 | 5 | import os 6 | from PIL import Image 7 | 8 | 9 | def cal_avg_res(): 10 | """ 11 | This function calculates the average resolution (H & W) of whole dataset 12 | """ 13 | root = './cos_benchmark/TestDataset' 14 | 15 | for data_name in os.listdir(root): 16 | data_root = os.path.join(root, data_name, 'GT') 17 | H_avg, W_avg, count = 0, 0, 0 18 | 19 | for file_name in os.listdir(data_root): 20 | img_path = os.path.join(data_root, file_name) 21 | img = Image.open(img_path) 22 | size = img.size 23 | H_avg += size[0] 24 | W_avg += size[1] 25 | count += 1 26 | 27 | print(f'{data_name}, H-average is {H_avg/count} and W-average is {W_avg/count}') 28 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/check_path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import json 5 | import os 6 | from collections import OrderedDict 7 | 8 | parser = argparse.ArgumentParser(description="A simple tool for checking your json config file.") 9 | parser.add_argument( 10 | "-m", "--method-jsons", nargs="+", required=True, help="The json file about all methods." 11 | ) 12 | parser.add_argument( 13 | "-d", "--dataset-jsons", nargs="+", required=True, help="The json file about all datasets." 14 | ) 15 | args = parser.parse_args() 16 | 17 | for method_json, dataset_json in zip(args.method_jsons, args.dataset_jsons): 18 | with open(method_json, encoding="utf-8", mode="r") as f: 19 | methods_info = json.load(f, object_hook=OrderedDict) # 有序载入 20 | with open(dataset_json, encoding="utf-8", mode="r") as f: 21 | datasets_info = json.load(f, object_hook=OrderedDict) # 有序载入 22 | 23 | total_msgs = [] 24 | for method_name, method_info in methods_info.items(): 25 | print(f"Checking for {method_name} ...") 26 | for dataset_name, results_info in method_info.items(): 27 | if results_info is None: 28 | continue 29 | 30 | dataset_mask_info = datasets_info[dataset_name]["mask"] 31 | mask_path = dataset_mask_info["path"] 32 | mask_suffix = dataset_mask_info["suffix"] 33 | 34 | dir_path = results_info["path"] 35 | file_prefix = results_info.get("prefix", "") 36 | file_suffix = results_info["suffix"] 37 | 38 | if not os.path.exists(dir_path): 39 | total_msgs.append(f"{dir_path} 不存在") 40 | continue 41 | elif not os.path.isdir(dir_path): 42 | total_msgs.append(f"{dir_path} 不是正常的文件夹路径") 43 | continue 44 | else: 45 | pred_names = [ 46 | name[len(file_prefix) : -len(file_suffix)] 47 | for name in os.listdir(dir_path) 48 | if name.startswith(file_prefix) and name.endswith(file_suffix) 49 | ] 50 | if len(pred_names) == 0: 51 | total_msgs.append(f"{dir_path} 中不包含前缀为{file_prefix}且后缀为{file_suffix}的文件") 52 | continue 53 | 54 | mask_names = [ 55 | name[: -len(mask_suffix)] 56 | for name in os.listdir(mask_path) 57 | if name.endswith(mask_suffix) 58 | ] 59 | intersection_names = set(mask_names).intersection(set(pred_names)) 60 | if len(intersection_names) == 0: 61 | total_msgs.append(f"{dir_path} 中数据名字与真值 {mask_path} 不匹配") 62 | elif len(intersection_names) != len(mask_names): 63 | difference_names = set(mask_names).difference(pred_names) 64 | total_msgs.append( 65 | f"{dir_path} 中数据({len(list(pred_names))})与真值({len(list(mask_names))})不一致" 66 | ) 67 | 68 | if total_msgs: 69 | print(*total_msgs, sep="\n") 70 | else: 71 | print(f"{method_json} & {dataset_json} 基本正常") 72 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/converter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/8/25 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | import argparse 7 | from itertools import chain 8 | 9 | import numpy as np 10 | import yaml 11 | 12 | parser = argparse.ArgumentParser( 13 | description="A useful and convenient tool to convert your .npy results into the table code in latex." 14 | ) 15 | parser.add_argument( 16 | "-i", 17 | "--result-file", 18 | required=True, 19 | nargs="+", 20 | action="extend", 21 | help="The path of the *_metrics.npy file.", 22 | ) 23 | parser.add_argument( 24 | "-o", "--tex-file", required=True, type=str, help="The path of the exported tex file." 25 | ) 26 | parser.add_argument( 27 | "-c", "--config-file", type=str, help="The path of the customized config yaml file." 28 | ) 29 | parser.add_argument( 30 | "--contain-table-env", 31 | action="store_true", 32 | help="Whether to containe the table env in the exported code.", 33 | ) 34 | parser.add_argument( 35 | "--transpose", 36 | action="store_true", 37 | help="Whether to transpose the table.", 38 | ) 39 | args = parser.parse_args() 40 | 41 | 42 | def update_dict(parent_dict, sub_dict): 43 | for sub_k, sub_v in sub_dict.items(): 44 | if sub_k in parent_dict: 45 | if sub_v is not None and isinstance(sub_v, dict): 46 | update_dict(parent_dict=parent_dict[sub_k], sub_dict=sub_v) 47 | continue 48 | parent_dict.update(sub_dict) 49 | 50 | 51 | results = {} 52 | for result_file in args.result_file: 53 | result = np.load(file=result_file, allow_pickle=True).item() 54 | for dataset_name, method_infos in result.items(): 55 | results.setdefault(dataset_name, {}) 56 | for method_name, method_info in method_infos.items(): 57 | results[dataset_name][method_name] = method_info 58 | 59 | IMPOSSIBLE_UP_BOUND = 1 60 | IMPOSSIBLE_DOWN_BOUND = 0 61 | 62 | # 读取数据 63 | dataset_names = sorted(list(results.keys())) 64 | metric_names = ["SM", "wFm", "MAE", "adpE", "avgE", "maxE", "adpF", "avgF", "maxF"] 65 | method_names = sorted(list(set(chain(*[list(results[n].keys()) for n in dataset_names])))) 66 | 67 | if args.config_file is not None: 68 | assert args.config_file.endswith(".yaml") or args.config_file.endswith("yml") 69 | with open(args.config_file, mode="r", encoding="utf-8") as f: 70 | cfg = yaml.safe_load(f) 71 | 72 | if "dataset_names" not in cfg: 73 | print( 74 | "`dataset_names` doesnot be contained in your config file, so we use the default config." 75 | ) 76 | else: 77 | dataset_names = cfg["dataset_names"] 78 | if "metric_names" not in cfg: 79 | print( 80 | "`metric_names` doesnot be contained in your config file, so we use the default config." 81 | ) 82 | else: 83 | metric_names = cfg["metric_names"] 84 | if "method_names" not in cfg: 85 | print( 86 | "`method_names` doesnot be contained in your config file, so we use the default config." 87 | ) 88 | else: 89 | method_names = cfg["method_names"] 90 | 91 | print( 92 | f"CONFIG INFORMATION:" 93 | f"\n- DATASETS ({len(dataset_names)}): {dataset_names}]" 94 | f"\n- METRICS ({len(metric_names)}): {metric_names}" 95 | f"\n- METHODS ({len(method_names)}): {method_names}" 96 | ) 97 | 98 | if isinstance(metric_names, (list, tuple)): 99 | ori_metric_names = metric_names 100 | elif isinstance(metric_names, dict): 101 | ori_metric_names, metric_names = list(zip(*list(metric_names.items()))) 102 | else: 103 | raise NotImplementedError 104 | 105 | if isinstance(method_names, (list, tuple)): 106 | ori_method_names = method_names 107 | elif isinstance(method_names, dict): 108 | ori_method_names, method_names = list(zip(*list(method_names.items()))) 109 | else: 110 | raise NotImplementedError 111 | 112 | # 整理表格 113 | ori_columns = [] 114 | column_for_index = [] 115 | for dataset_idx, dataset_name in enumerate(dataset_names): 116 | for metric_idx, ori_metric_name in enumerate(ori_metric_names): 117 | fiiled_value = ( 118 | IMPOSSIBLE_UP_BOUND if ori_metric_name.lower() == "mae" else IMPOSSIBLE_DOWN_BOUND 119 | ) 120 | fiiled_dict = {k: fiiled_value for k in ori_metric_names} 121 | ori_column = [] 122 | for method_name in ori_method_names: 123 | method_result = results[dataset_name].get(method_name, fiiled_dict) 124 | if ori_metric_name not in method_result: 125 | raise KeyError( 126 | f"{ori_metric_name} must be contained in {list(method_result.keys())}" 127 | ) 128 | ori_column.append(method_result[ori_metric_name]) 129 | 130 | column_for_index.append([x * round(1 - fiiled_value * 2) for x in ori_column]) 131 | ori_columns.append(ori_column) 132 | 133 | style_templates = dict( 134 | method_row_body="& {method_name}", 135 | method_column_body=" {method_name}", 136 | dataset_row_body="& \multicolumn{{{num_metrics}}}{{c}}{{\\textbf{{{dataset_name}}}}}", 137 | dataset_column_body="\multirow{{-{num_metrics}}}{{*}}{{\\rotatebox{{90}}{{\\textbf{{{dataset_name}}}}}}}", 138 | dataset_head=" ", 139 | metric_body="& {metric_name}", 140 | metric_row_head=" ", 141 | metric_column_head="& ", 142 | body=[ 143 | "& {{\color{{reda}} \\textbf{{{txt:.03f}}}}}", # top1 144 | "& {{\color{{mygreen}} \\textbf{{{txt:.03f}}}}}", # top2 145 | "& {{\color{{myblue}} \\textbf{{{txt:.03f}}}}}", # top3 146 | "& {txt:.03f}", # other 147 | ], 148 | ) 149 | 150 | 151 | # 排序并添加样式 152 | def replace_cell(ori_value, k): 153 | if ori_value == IMPOSSIBLE_UP_BOUND or ori_value == IMPOSSIBLE_DOWN_BOUND: 154 | new_value = "& " 155 | else: 156 | new_value = style_templates["body"][k].format(txt=ori_value) 157 | return new_value 158 | 159 | 160 | for col, ori_col in zip(column_for_index, ori_columns): 161 | col_array = np.array(col).reshape(-1) 162 | sorted_col_array = np.sort(np.unique(col_array), axis=-1)[-3:][::-1] 163 | # [top1_idxes, top2_idxes, top3_idxes] 164 | top_k_idxes = [np.argwhere(col_array == x).tolist() for x in sorted_col_array] 165 | for k, idxes in enumerate(top_k_idxes): 166 | for row_idx in idxes: 167 | ori_col[row_idx[0]] = replace_cell(ori_col[row_idx[0]], k) 168 | 169 | for idx, x in enumerate(ori_col): 170 | if not isinstance(x, str): 171 | ori_col[idx] = replace_cell(x, -1) 172 | 173 | # 构建表头 174 | num_datasets = len(dataset_names) 175 | num_metrics = len(metric_names) 176 | num_methods = len(method_names) 177 | 178 | # 先构开头的列,再整体构造开头的行 179 | latex_table_head = [] 180 | latex_table_tail = [] 181 | 182 | if not args.transpose: 183 | dataset_row = ( 184 | [style_templates["dataset_head"]] 185 | + [ 186 | style_templates["dataset_row_body"].format(num_metrics=num_metrics, dataset_name=x) 187 | for x in dataset_names 188 | ] 189 | + [r"\\"] 190 | ) 191 | metric_row = ( 192 | [style_templates["metric_row_head"]] 193 | + [style_templates["metric_body"].format(metric_name=x) for x in metric_names] 194 | * num_datasets 195 | + [r"\\"] 196 | ) 197 | additional_rows = [dataset_row, metric_row] 198 | 199 | # 构建第一列 200 | method_column = [ 201 | style_templates["method_column_body"].format(method_name=x) for x in method_names 202 | ] 203 | additional_columns = [method_column] 204 | 205 | columns = additional_columns + ori_columns 206 | rows = [list(row) + [r"\\"] for row in zip(*columns)] 207 | rows = additional_rows + rows 208 | 209 | if args.contain_table_env: 210 | column_style = "|".join([f"*{num_metrics}{{c}}"] * len(dataset_names)) 211 | latex_table_head = [ 212 | f"\\begin{{tabular}}{{l|{column_style}}}\n", 213 | "\\toprule[2pt]", 214 | ] 215 | else: 216 | dataset_column = [] 217 | for x in dataset_names: 218 | blank_cells = [" "] * (num_metrics - 1) 219 | dataset_cell = [ 220 | style_templates["dataset_column_body"].format(num_metrics=num_metrics, dataset_name=x) 221 | ] 222 | dataset_column.extend(blank_cells + dataset_cell) 223 | metric_column = [ 224 | style_templates["metric_body"].format(metric_name=x) for x in metric_names 225 | ] * num_datasets 226 | additional_columns = [dataset_column, metric_column] 227 | 228 | method_row = ( 229 | [style_templates["dataset_head"], style_templates["metric_column_head"]] 230 | + [style_templates["method_row_body"].format(method_name=x) for x in method_names] 231 | + [r"\\"] 232 | ) 233 | additional_rows = [method_row] 234 | 235 | additional_columns = [list(x) for x in zip(*additional_columns)] 236 | rows = [cells + row + [r"\\"] for cells, row in zip(additional_columns, ori_columns)] 237 | rows = additional_rows + rows 238 | 239 | if args.contain_table_env: 240 | column_style = "".join([f"*{{{num_methods}}}{{c}}"]) 241 | latex_table_head = [ 242 | f"\\begin{{tabular}}{{cc|{column_style}}}\n", 243 | "\\toprule[2pt]", 244 | ] 245 | 246 | if args.contain_table_env: 247 | latex_table_tail = [ 248 | "\\bottomrule[2pt]\n", 249 | "\\end{tabular}", 250 | ] 251 | 252 | rows = [latex_table_head] + rows + [latex_table_tail] 253 | 254 | with open(args.tex_file, mode="w", encoding="utf-8") as f: 255 | for row in rows: 256 | f.write("".join(row) + "\n") 257 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/generate_cos_config_files.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/22 3 | # @Author : Daniel Ji 4 | 5 | import os 6 | 7 | def check_integraty(): 8 | """ 9 | Check the matching status of prediction and ground-truth masks 10 | """ 11 | pred_path = './benchmark/COS-Benchmarking' 12 | gt_path = './dataset' 13 | 14 | def returnNotMatches(a, b): 15 | return [[x for x in a if x not in b], [x for x in b if x not in a]] 16 | 17 | for model_name in os.listdir(pred_path): 18 | for data_name in ['CAMO', 'COD10K', 'NC4K']: 19 | model_data_path = os.path.join(pred_path, model_name, data_name) 20 | gt_data_path = os.path.join(gt_path, data_name, 'GT') 21 | 22 | if os.path.exists(model_data_path): 23 | if not os.listdir(model_data_path) == os.listdir(gt_data_path): 24 | info = returnNotMatches(os.listdir(model_data_path), os.listdir(gt_data_path)) 25 | print('not match', model_name, data_name) 26 | print(info) 27 | else: 28 | print('not exist', model_name, data_name) 29 | 30 | 31 | def generate_model_config_py(): 32 | root = './benchmark/COS-Benchmarking' 33 | write_txt_path = './examples_COS' 34 | os.makedirs(write_txt_path, exist_ok=True) 35 | 36 | # Open the file in write mode 37 | with open(f'{write_txt_path}/config_cos_method_py_example_all.py', 'w') as f: 38 | # Write some data to the file 39 | f.write('# -*- coding: utf-8 -*-\n') 40 | f.write('import os\n\n') 41 | 42 | for model_name in os.listdir(root): 43 | model_path = os.path.join(root, model_name) 44 | print('{}_root = \"{}\"'.format(model_name.split('-')[-1], model_path)) 45 | f.write('{}_root = \"{}\" \n'.format(model_name.split('-')[-1], model_path)) 46 | print('{} = {{'.format(model_name.split('-')[-1])) 47 | f.write('{} = {{ \n'.format(model_name.split('-')[-1])) 48 | 49 | for data_name in ['CAMO', 'NC4K', 'COD10K']: 50 | model_data_path = os.path.join(model_path, data_name) 51 | if os.path.exists(model_data_path): 52 | print('\t\"{}\": dict(path=os.path.join({}_root, \"{}\"), suffix=".png"),'.format(data_name, model_name.split('-')[-1], data_name)) 53 | f.write('\t\"{}\": dict(path=os.path.join({}_root, \"{}\"), suffix=".png"), \n'.format(data_name, model_name.split('-')[-1], data_name)) 54 | else: 55 | print('\t\"{}\": None,'.format(data_name)) 56 | f.write('\t\"{}\": None, \n'.format(data_name)) 57 | print('}\n') 58 | f.write('}\n\n') 59 | 60 | generate_model_config_py() -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/generate_latex_code.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/22 3 | # @Author : Daniel Ji 4 | 5 | eval_txt_path = './output_CDS2K/cos_results.txt' 6 | 7 | with open(eval_txt_path) as f: 8 | 9 | model_list, mae_list, maxf_list, avgf_list, adpf_list, maxe_list, avge_list, adpe_list, sm_list, wfm_list, sum_list = list(), list(), list(), list(), list(), list(), list(), list(), list(), list(), list() 10 | 11 | for line in f.readlines(): 12 | if line.startswith('['): 13 | model = line.split('[')[1].split(']')[0] 14 | mae = float(line.split('mae: ')[1].split(' ')[0]) 15 | maxf = float(line.split('maxf: ')[1].split(' ')[0]) 16 | avgf = float(line.split('avgf: ')[1].split(' ')[0]) 17 | adpf = float(line.split('adpf: ')[1].split(' ')[0]) 18 | maxe = float(line.split('maxe: ')[1].split(' ')[0]) 19 | avge = float(line.split('avge: ')[1].split(' ')[0]) 20 | adpe = float(line.split('adpe: ')[1].split(' ')[0]) 21 | sm = float(line.split('sm: ')[1].split(' ')[0]) 22 | wfm = float(line.split('wfm: ')[1].split(' ')[0]) 23 | sum = (1-mae) + maxf + avgf + adpf + maxe + avge + adpe + sm + wfm # optional 24 | 25 | print(f'{model}\n&{sm:.3f} &{wfm:.3f} &{mae:.3f} &{adpe:.3f} &{avge:.3f} &{maxe:.3f} &{adpf:.3f} &{avgf:.3f} &{maxf:.3f}') 26 | 27 | model_list.append(model) 28 | mae_list.append(mae) 29 | maxf_list.append(maxf) 30 | avgf_list.append(avgf) 31 | adpf_list.append(adpf) 32 | maxe_list.append(maxe) 33 | avge_list.append(avge) 34 | adpe_list.append(adpe) 35 | sm_list.append(sm) 36 | wfm_list.append(wfm) 37 | sum_list.append(sum) 38 | 39 | # print('\n', max(sm_list), max(wfm_list), min(mae_list), max(adpe_list), max(avge_list), max(maxe_list), max(adpf_list), max(avgf_list), max(maxf_list), max(sum_list)) 40 | 41 | # sm_list.sort() 42 | # wfm_list.sort() 43 | # mae_list.sort(reverse=True) 44 | # adpe_list.sort() 45 | # avge_list.sort() 46 | # maxe_list.sort() 47 | # adpf_list.sort() 48 | # avgf_list.sort() 49 | # maxf_list.sort() 50 | # sum_list.sort() 51 | 52 | # print('\n', sm_list[-1], wfm_list[-1], mae_list[-1], adpe_list[-1], avge_list[-1], maxe_list[-1], adpf_list[-1], avgf_list[-1], maxf_list[-1]) 53 | # print('\n', sm_list[-2], wfm_list[-2], mae_list[-2], adpe_list[-2], avge_list[-2], maxe_list[-2], adpf_list[-2], avgf_list[-2], maxf_list[-2]) 54 | # print('\n', sm_list[-3], wfm_list[-3], mae_list[-3], adpe_list[-3], avge_list[-3], maxe_list[-3], adpf_list[-3], avgf_list[-3], maxf_list[-3]) 55 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/info_py_to_json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/3/14 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | import argparse 6 | import ast 7 | import json 8 | import os 9 | import sys 10 | from importlib import import_module 11 | 12 | 13 | def validate_py_syntax(filename): 14 | with open(filename, "r") as f: 15 | content = f.read() 16 | try: 17 | ast.parse(content) 18 | except SyntaxError as e: 19 | raise SyntaxError("There are syntax errors in config " f"file {filename}: {e}") 20 | 21 | 22 | def convert_py_to_json(source_config_root, target_config_root): 23 | if not os.path.isdir(source_config_root): 24 | raise NotADirectoryError(source_config_root) 25 | if not os.path.exists(target_config_root): 26 | os.makedirs(target_config_root) 27 | else: 28 | if not os.path.isdir(target_config_root): 29 | raise NotADirectoryError(target_config_root) 30 | 31 | sys.path.insert(0, source_config_root) 32 | source_config_files = os.listdir(source_config_root) 33 | for source_config_file in source_config_files: 34 | source_config_path = os.path.join(source_config_root, source_config_file) 35 | if not (os.path.isfile(source_config_path) and source_config_path.endswith(".py")): 36 | continue 37 | validate_py_syntax(source_config_path) 38 | print(source_config_path) 39 | 40 | temp_module_name = os.path.splitext(source_config_file)[0] 41 | mod = import_module(temp_module_name) 42 | 43 | total_dict = {} 44 | for name, value in mod.__dict__.items(): 45 | if not name.startswith("_") and isinstance(value, dict): 46 | total_dict[name] = value 47 | 48 | # delete imported module 49 | del sys.modules[temp_module_name] 50 | 51 | with open( 52 | os.path.join(target_config_root, os.path.basename(temp_module_name) + ".json"), 53 | encoding="utf-8", 54 | mode="w", 55 | ) as f: 56 | json.dump(total_dict, f, indent=2) 57 | 58 | 59 | def get_args(): 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("-i", "--source-py-root", required=True, type=str) 62 | parser.add_argument("-o", "--target-json-root", required=True, type=str) 63 | args = parser.parse_args() 64 | return args 65 | 66 | 67 | if __name__ == "__main__": 68 | args = get_args() 69 | convert_py_to_json( 70 | source_config_root=args.source_py_root, target_config_root=args.target_json_root 71 | ) 72 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/markdown2html.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import markdown2 5 | 6 | index_template = """ 7 | 8 | 9 | 10 | 11 | {html_title} 12 | 13 | 14 | 15 | 16 |

{html_urls}

17 | {html_body} 18 | 19 | 20 | 21 | """ 22 | 23 | html_template = """ 24 | 25 | 26 | 27 | 28 | {html_title} 29 | 30 | 31 | 32 | 33 | 34 |

{html_urls}

35 | {html_body} 36 | 37 | 38 | 39 | """ 40 | list_item_template = "
  • {item_name}
  • " 41 | 42 | 43 | def save_html(html_text, html_path): 44 | with open(html_path, encoding="utf-8", mode="w") as f: 45 | f.write(html_text) 46 | 47 | 48 | md_root = "../results" 49 | html_root = "../results/htmls" 50 | 51 | md_names = [ 52 | name[:-3] 53 | for name in os.listdir(md_root) 54 | if os.path.isfile(os.path.join(md_root, name)) and name.endswith(".md") 55 | ] 56 | 57 | url_list = "\n".join( 58 | [""] 61 | ) 62 | 63 | index_html = index_template.format(html_title="Index", html_urls=url_list, html_body="Updating...") 64 | save_html(html_text=index_html, html_path=os.path.join(html_root, "index.html")) 65 | 66 | for md_name in md_names: 67 | html_body = markdown2.markdown_path( 68 | os.path.join(md_root, md_name + ".md"), 69 | extras={"tables": True, "html-classes": {"table": "sortable"}}, 70 | ) 71 | html = html_template.format(html_title=md_name, html_urls=url_list, html_body=html_body) 72 | save_html(html_text=html, html_path=os.path.join(html_root, md_name + ".html")) 73 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/readme.md: -------------------------------------------------------------------------------- 1 | # Useful tools 2 | 3 | ## `append_results.py` 4 | 5 | 将新生成的npy文件与旧的npy文件合并到一个新npy文件中。 6 | 7 | ```shell 8 | $ python append_results.py --help 9 | usage: append_results.py [-h] --old-npy OLD_NPY --new-npy NEW_NPY [--method-names METHOD_NAMES [METHOD_NAMES ...]] 10 | [--dataset-names DATASET_NAMES [DATASET_NAMES ...]] --out-npy OUT_NPY 11 | 12 | A simple tool for merging two npy file. Patch the method items corresponding to the `--method-names` and `--dataset-names` of `--new-npy` 13 | into `--old-npy`, and output the whole container to `--out-npy`. 14 | 15 | optional arguments: 16 | -h, --help show this help message and exit 17 | --old-npy OLD_NPY 18 | --new-npy NEW_NPY 19 | --method-names METHOD_NAMES [METHOD_NAMES ...] 20 | --dataset-names DATASET_NAMES [DATASET_NAMES ...] 21 | --out-npy OUT_NPY 22 | ``` 23 | 24 | 使用情形: 25 | 26 | 对于rgb sod数据,我已经生成了包含一批方法的结果的npy文件:`old_rgb_sod_curves.npy`。 27 | 对于某个方法重新评估后又获得了一个新的npy文件(文件不重名,所以不会覆盖):`new_rgb_sod_curves.npy`。 28 | 现在我想要将这两个结果整合到一个文件中:`finalnew_rgb_sod_curves.npy`。 29 | 可以通过如下指令实现。 30 | 31 | ```shell 32 | python tools/append_results.py --old-npy output/old_rgb_sod_curves.npy \ 33 | --new-npy output/new_rgb_sod_curves.npy \ 34 | --out-npy output/finalnew_rgb_sod_curves.npy 35 | ``` 36 | 37 | 38 | ## `converter.py` 39 | 40 | 将生成的 `*_metrics.npy` 文件中的信息导出成latex表格的形式. 41 | 42 | 可以按照例子文件夹中的 `examples/converter_config.py` 进行手动配置, 从而针对性的生成latex表格代码. 43 | 44 | ```shell 45 | $ python converter.py --help 46 | usage: converter.py [-h] -i RESULT_FILE [RESULT_FILE ...] -o TEX_FILE [-c CONFIG_FILE] [--contain-table-env] [--transpose] 47 | 48 | A useful and convenient tool to convert your .npy results into the table code in latex. 49 | 50 | optional arguments: 51 | -h, --help show this help message and exit 52 | -i RESULT_FILE [RESULT_FILE ...], --result-file RESULT_FILE [RESULT_FILE ...] 53 | The path of the *_metrics.npy file. 54 | -o TEX_FILE, --tex-file TEX_FILE 55 | The path of the exported tex file. 56 | -c CONFIG_FILE, --config-file CONFIG_FILE 57 | The path of the customized config yaml file. 58 | --contain-table-env Whether to containe the table env in the exported code. 59 | --transpose Whether to transpose the table. 60 | ``` 61 | 62 | 使用案例如下. 63 | 64 | ```shell 65 | $ python tools/converter.py -i output/your_metrics_1.npy output/your_metrics_2.npy -o output/your_metrics.tex -c ./examples/converter_config.yaml --transpose --contain-table-env 66 | ``` 67 | 68 | 该指令从多个npy文件中(如果你仅有一个, 可以紧跟一个npy文件)读取数据, 处理后导出到指定的tex文件中. 并且使用指定的config文件设置了相关的数据集、指标以及模型方法. 69 | 70 | 另外, 对于输出的表格代码, 使用转置后的竖表, 并且包含table的 `tabular` 环境代码. 71 | 72 | ## `check_path.py` 73 | 74 | 通过将json中的信息与实际系统中的路径进行匹配, 检验是否存在异常. 75 | 76 | ```shell 77 | $ python check_path.py --help 78 | usage: check_path.py [-h] -m METHOD_JSONS [METHOD_JSONS ...] -d DATASET_JSONS [DATASET_JSONS ...] 79 | 80 | A simple tool for checking your json config file. 81 | 82 | optional arguments: 83 | -h, --help show this help message and exit 84 | -m METHOD_JSONS [METHOD_JSONS ...], --method-jsons METHOD_JSONS [METHOD_JSONS ...] 85 | The json file about all methods. 86 | -d DATASET_JSONS [DATASET_JSONS ...], --dataset-jsons DATASET_JSONS [DATASET_JSONS ...] 87 | ``` 88 | 89 | 使用案例如下, 请务必确保, `-m` 和 `-d` 后的文件路径是一一对应的, 在代码中会使用 `zip` 来同步迭代两个列表来确认文件名称的匹配关系. 90 | 91 | ```shell 92 | $ python check_path.py -m ../configs/methods/json/rgb_sod_methods.json ../configs/methods/json/rgbd_sod_methods.json \ 93 | -d ../configs/datasets/json/rgb_sod.json ../configs/datasets/json/rgbd_sod.json 94 | ``` 95 | 96 | ## `info_py_to_json.py` 97 | 98 | 将基于python格式的配置文件转化为更便携的json文件. 99 | 100 | ```shell 101 | $ python info_py_to_json.py --help 102 | usage: info_py_to_json.py [-h] -i SOURCE_PY_ROOT -o TARGET_JSON_ROOT 103 | 104 | optional arguments: 105 | -h, --help show this help message and exit 106 | -i SOURCE_PY_ROOT, --source-py-root SOURCE_PY_ROOT 107 | -o TARGET_JSON_ROOT, --target-json-root TARGET_JSON_ROOT 108 | ``` 109 | 110 | 即提供了两个必需提供的配置项, 存放python配置文件的输入目录 `-i` 和将要存放生成的json文件输出目录 `-o` . 111 | 112 | 通过载入输入目录中各个python文件, 我们默认从中获取内部包含的不使用 `_` 开头的字典对象对应于各个数据集或者方法的配置信息. 113 | 114 | 最后将这些信息汇总到一个完整的字典中, 直接导出到json文件中, 保存到输出目录下. 115 | 116 | ## `markdown2html.py` 117 | 118 | 该文件用于将results中存放结果信息的markdown文件转化为html文件, 便于在基于github page的静态站点上进行展示. 119 | 120 | ## `rename.py` 121 | 122 | 批量重命名. 123 | 124 | 使用前建议通读代码, 请小心使用, 防止文件覆盖造成不必要的损失. 125 | -------------------------------------------------------------------------------- /cos_eval_toolbox/tools/rename.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/4/24 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | import glob 7 | import os 8 | import re 9 | import shutil 10 | 11 | 12 | def path_join(base_path, sub_path): 13 | if sub_path.startswith(os.sep): 14 | sub_path = sub_path[len(os.sep) :] 15 | return os.path.join(base_path, sub_path) 16 | 17 | 18 | def rename_all_files(src_pattern, dst_pattern, src_name, src_dir, dst_dir=None): 19 | """ 20 | :param src_pattern: 匹配原始数据名字的正则表达式 21 | :param dst_pattern: 对应的修改后的字符式 22 | :param src_dir: 存放原始数据的文件夹路径,可以组合src_name来构造路径模式,使用glob进行数据搜索 23 | :param src_name: glob类型的模式 24 | :param dst_dir: 存放修改后数据的文件夹路径,默认为None,表示直接修改原始数据 25 | """ 26 | assert os.path.isdir(src_dir) 27 | 28 | if dst_dir is None: 29 | dst_dir = src_dir 30 | rename_func = os.replace 31 | else: 32 | assert os.path.isdir(dst_dir) 33 | if dst_dir == src_dir: 34 | rename_func = os.replace 35 | else: 36 | rename_func = shutil.copy 37 | print(f"将会使用 {rename_func.__name__} 来修改数据") 38 | 39 | src_dir = os.path.abspath(src_dir) 40 | dst_dir = os.path.abspath(dst_dir) 41 | src_data_paths = glob.glob(path_join(src_dir, src_name)) 42 | 43 | print(f"开始替换 {src_dir} 中的数据") 44 | for idx, src_data_path in enumerate(src_data_paths, start=1): 45 | src_name_w_dir_name = src_data_path[len(src_dir) + 1 :] 46 | dst_name_w_dir_name = re.sub(src_pattern, repl=dst_pattern, string=src_name_w_dir_name) 47 | if dst_name_w_dir_name == src_name_w_dir_name: 48 | continue 49 | dst_data_path = path_join(dst_dir, dst_name_w_dir_name) 50 | 51 | dst_data_dir = os.path.dirname(dst_data_path) 52 | if not os.path.exists(dst_data_dir): 53 | print(f"{idx}: {dst_data_dir} 不存在,新建一下") 54 | os.makedirs(dst_data_dir) 55 | rename_func(src=src_data_path, dst=dst_data_path) 56 | print(f"{src_data_path} -> {dst_data_path}") 57 | 58 | print("OK...") 59 | 60 | 61 | if __name__ == "__main__": 62 | rename_all_files( 63 | src_pattern=r"", 64 | dst_pattern="", 65 | src_name="*/*.png", 66 | src_dir="", 67 | dst_dir="", 68 | ) 69 | -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/cos_eval_toolbox/utils/__init__.py -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/generate_info.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import os 5 | from collections import OrderedDict 6 | 7 | from matplotlib import colors 8 | 9 | # max = 148 10 | _COLOR_Genarator = iter( 11 | sorted( 12 | [ 13 | color 14 | for name, color in colors.cnames.items() 15 | if name not in ["red", "white"] or not name.startswith("light") or "gray" in name 16 | ] 17 | ) 18 | ) 19 | 20 | 21 | def curve_info_generator(): 22 | # TODO 当前尽是对方法依次赋予`-`和`--`,但是对于某一方法仅包含特定数据集的结果时, 23 | # 会导致确实结果的数据集中的该方法相邻的两个方法style一样 24 | line_style_flag = True 25 | 26 | def _template_generator( 27 | method_info: dict, method_name: str, line_color: str = None, line_width: int = 2 28 | ) -> dict: 29 | nonlocal line_style_flag 30 | template_info = dict( 31 | path_dict=method_info, 32 | curve_setting=dict( 33 | line_style="-" if line_style_flag else "--", 34 | line_label=method_name, 35 | line_width=line_width, 36 | ), 37 | ) 38 | if line_color is not None: 39 | template_info["curve_setting"]["line_color"] = line_color 40 | else: 41 | template_info["curve_setting"]["line_color"] = next(_COLOR_Genarator) 42 | 43 | line_style_flag = not line_style_flag 44 | return template_info 45 | 46 | return _template_generator 47 | 48 | 49 | def simple_info_generator(): 50 | def _template_generator(method_info: dict, method_name: str) -> dict: 51 | template_info = dict(path_dict=method_info, label=method_name) 52 | return template_info 53 | 54 | return _template_generator 55 | 56 | 57 | def get_valid_elements( 58 | source: list, 59 | include_elements: list = None, 60 | exclude_elements: list = None, 61 | ): 62 | if include_elements is None: 63 | include_elements = [] 64 | if exclude_elements is None: 65 | exclude_elements = [] 66 | assert not set(include_elements).intersection( 67 | exclude_elements 68 | ), "`include_elements` and `exclude_elements` must have no intersection." 69 | 70 | targeted = set(source).difference(exclude_elements) 71 | assert targeted, "`exclude_elements can not include all datasets." 72 | 73 | if include_elements: 74 | # include_elements: [] or [dataset1_name, dataset2_name, ...] 75 | # only latter will be used to select datasets from `targeted` 76 | targeted = targeted.intersection(include_elements) 77 | 78 | return list(targeted) 79 | 80 | 81 | def get_methods_info( 82 | methods_info_jsons: list, 83 | for_drawing: bool = False, 84 | our_name: str = None, 85 | include_methods: list = None, 86 | exclude_methods: list = None, 87 | ) -> OrderedDict: 88 | """ 89 | 在json文件中存储的对应方法的字典的键值会被直接用于绘图 90 | 91 | :param methods_info_jsons: 保存方法信息的json文件,支持多个文件组合使用,按照输入的顺序依此读取 92 | :param for_drawing: 是否用于绘制曲线图,True会补充一些绘图信息 93 | :param our_name: 在绘图时,可以通过指定our_name来使用红色加粗实线强调特定方法的曲线 94 | :param include_methods: 仅返回列表中指定的方法的信息,为None时,返回所有 95 | :param exclude_methods: 仅返回列表中指定的方法的信息,为None时,返回所有,与include_datasets必须仅有一个非None 96 | :return: methods_full_info 97 | """ 98 | if not isinstance(methods_info_jsons, (list, tuple)): 99 | methods_info_jsons = [methods_info_jsons] 100 | 101 | methods_info = {} 102 | for f in methods_info_jsons: 103 | if not os.path.isfile(f): 104 | raise FileNotFoundError(f"{f} is not be found!!!") 105 | 106 | with open(f, encoding="utf-8", mode="r") as f: 107 | methods_info.update(json.load(f, object_hook=OrderedDict)) # 有序载入 108 | 109 | if our_name: 110 | assert our_name in methods_info, f"{our_name} is not in json file." 111 | 112 | targeted_methods = get_valid_elements( 113 | source=list(methods_info.keys()), 114 | include_elements=include_methods, 115 | exclude_elements=exclude_methods, 116 | ) 117 | if our_name and our_name in targeted_methods: 118 | targeted_methods.pop(targeted_methods.index(our_name)) 119 | targeted_methods.sort() 120 | targeted_methods.insert(0, our_name) 121 | 122 | if for_drawing: 123 | info_generator = curve_info_generator() 124 | else: 125 | info_generator = simple_info_generator() 126 | 127 | methods_full_info = [] 128 | for method_name in targeted_methods: 129 | method_path = methods_info[method_name] 130 | 131 | if for_drawing and our_name and our_name == method_name: 132 | method_info = info_generator(method_path, method_name, line_color="red", line_width=3) 133 | else: 134 | method_info = info_generator(method_path, method_name) 135 | methods_full_info.append((method_name, method_info)) 136 | return OrderedDict(methods_full_info) 137 | 138 | 139 | def get_datasets_info( 140 | datastes_info_json: str, 141 | include_datasets: list = None, 142 | exclude_datasets: list = None, 143 | ) -> OrderedDict: 144 | """ 145 | 在json文件中存储的所有数据集的信息会被直接导出到一个字典中 146 | 147 | :param datastes_info_json: 保存方法信息的json文件 148 | :param include_datasets: 指定读取信息的数据集名字,为None时,读取所有 149 | :param exclude_datasets: 排除读取信息的数据集名字,为None时,读取所有,与include_datasets必须仅有一个非None 150 | :return: datastes_full_info 151 | """ 152 | 153 | assert os.path.isfile(datastes_info_json), datastes_info_json 154 | with open(datastes_info_json, encoding="utf-8", mode="r") as f: 155 | datasets_info = json.load(f, object_hook=OrderedDict) # 有序载入 156 | 157 | targeted_datasets = get_valid_elements( 158 | source=list(datasets_info.keys()), 159 | include_elements=include_datasets, 160 | exclude_elements=exclude_datasets, 161 | ) 162 | targeted_datasets.sort() 163 | 164 | datasets_full_info = [] 165 | for dataset_name in targeted_datasets: 166 | data_path = datasets_info[dataset_name] 167 | 168 | datasets_full_info.append((dataset_name, data_path)) 169 | return OrderedDict(datasets_full_info) 170 | -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def get_ext(path_list): 10 | ext_list = list(set([os.path.splitext(p)[1] for p in path_list])) 11 | if len(ext_list) != 1: 12 | if ".png" in ext_list: 13 | ext = ".png" 14 | elif ".jpg" in ext_list: 15 | ext = ".jpg" 16 | elif ".bmp" in ext_list: 17 | ext = ".bmp" 18 | else: 19 | raise NotImplementedError 20 | print(f"预测文件夹中包含多种扩展名,这里仅使用{ext}") 21 | else: 22 | ext = ext_list[0] 23 | return ext 24 | 25 | 26 | def get_name_list_and_suffix(data_path: str) -> tuple: 27 | name_list = [] 28 | if os.path.isfile(data_path): 29 | print(f" ++>> {data_path} is a file. <<++ ") 30 | with open(data_path, mode="r", encoding="utf-8") as file: 31 | line = file.readline() 32 | while line: 33 | img_name = os.path.basename(line.split()[0]) 34 | file_ext = os.path.splitext(img_name)[1] 35 | name_list.append(os.path.splitext(img_name)[0]) 36 | line = file.readline() 37 | if file_ext == "": 38 | # 默认为png 39 | file_ext = ".png" 40 | else: 41 | print(f" ++>> {data_path} is a folder. <<++ ") 42 | data_list = os.listdir(data_path) 43 | file_ext = get_ext(data_list) 44 | name_list = [os.path.splitext(f)[0] for f in data_list if f.endswith(file_ext)] 45 | name_list = list(set(name_list)) 46 | return name_list, file_ext 47 | 48 | 49 | def get_name_list(data_path: str, name_prefix: str = "", name_suffix: str = "") -> list: 50 | if os.path.isfile(data_path): 51 | assert data_path.endswith((".txt", ".lst")) 52 | data_list = [] 53 | with open(data_path, encoding="utf-8", mode="r") as f: 54 | line = f.readline().strip() 55 | while line: 56 | data_list.append(line) 57 | line = f.readline().strip() 58 | else: 59 | data_list = os.listdir(data_path) 60 | 61 | name_list = data_list 62 | if not name_prefix and not name_suffix: 63 | name_list = [os.path.splitext(f)[0] for f in name_list] 64 | else: 65 | name_list = [ 66 | f[len(name_prefix) : -len(name_suffix)] 67 | for f in name_list 68 | if f.startswith(name_prefix) and f.endswith(name_suffix) 69 | ] 70 | 71 | name_list = list(set(name_list)) 72 | return name_list 73 | 74 | 75 | def get_name_with_group_list(data_path: str, file_ext: str = None) -> list: 76 | name_list = [] 77 | if os.path.isfile(data_path): 78 | print(f" ++>> {data_path} is a file. <<++ ") 79 | with open(data_path, mode="r", encoding="utf-8") as file: 80 | line = file.readline() 81 | while line: 82 | img_name_with_group = line.split() 83 | name_list.append(os.path.splitext(img_name_with_group)[0]) 84 | line = file.readline() 85 | else: 86 | print(f" ++>> {data_path} is a folder. <<++ ") 87 | group_names = sorted(os.listdir(data_path)) 88 | for group_name in group_names: 89 | image_names = [ 90 | "/".join([group_name, x]) 91 | for x in sorted(os.listdir(os.path.join(data_path, group_name))) 92 | ] 93 | if file_ext is not None: 94 | name_list += [os.path.splitext(f)[0] for f in image_names if f.endswith(file_ext)] 95 | else: 96 | name_list += [os.path.splitext(f)[0] for f in image_names] 97 | # group_name/file_name.ext 98 | name_list = list(set(name_list)) # 去重 99 | return name_list 100 | 101 | 102 | def get_list_with_postfix(dataset_path: str, postfix: str): 103 | name_list = [] 104 | if os.path.isfile(dataset_path): 105 | print(f" ++>> {dataset_path} is a file. <<++ ") 106 | with open(dataset_path, mode="r", encoding="utf-8") as file: 107 | line = file.readline() 108 | while line: 109 | img_name = os.path.basename(line.split()[0]) 110 | name_list.append(os.path.splitext(img_name)[0]) 111 | line = file.readline() 112 | else: 113 | print(f" ++>> {dataset_path} is a folder. <<++ ") 114 | name_list = [ 115 | os.path.splitext(f)[0] for f in os.listdir(dataset_path) if f.endswith(postfix) 116 | ] 117 | name_list = list(set(name_list)) 118 | return name_list 119 | 120 | 121 | def rgb_loader(path): 122 | with open(path, "rb") as f: 123 | img = Image.open(f) 124 | return img.convert("L") 125 | 126 | 127 | def binary_loader(path): 128 | assert os.path.exists(path), f"`{path}` does not exist." 129 | with open(path, "rb") as f: 130 | img = Image.open(f) 131 | return img.convert("L") 132 | 133 | 134 | def load_data(pre_root, gt_root, name, postfixs): 135 | pre = binary_loader(os.path.join(pre_root, name + postfixs[0])) 136 | gt = binary_loader(os.path.join(gt_root, name + postfixs[1])) 137 | return pre, gt 138 | 139 | 140 | def normalize_pil(pre, gt): 141 | gt = np.asarray(gt) 142 | pre = np.asarray(pre) 143 | gt = gt / (gt.max() + 1e-8) 144 | gt = np.where(gt > 0.5, 1, 0) 145 | max_pre = pre.max() 146 | min_pre = pre.min() 147 | if max_pre == min_pre: 148 | pre = pre / 255 149 | else: 150 | pre = (pre - min_pre) / (max_pre - min_pre) 151 | return pre, gt 152 | 153 | 154 | def make_dir(path): 155 | if not os.path.exists(path): 156 | print(f"`{path}` does not exist,we will create it.") 157 | os.makedirs(path) 158 | else: 159 | assert os.path.isdir(path), f"`{path}` should be a folder" 160 | print(f"`{path}`已存在") 161 | 162 | 163 | def imread_wich_checking(path, for_color: bool = True, with_cv2: bool = True) -> np.ndarray: 164 | assert os.path.exists(path=path) and os.path.isfile(path=path), path 165 | if with_cv2: 166 | if for_color: 167 | data = cv2.imread(path, flags=cv2.IMREAD_COLOR) 168 | data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) 169 | else: 170 | data = cv2.imread(path, flags=cv2.IMREAD_GRAYSCALE) 171 | else: 172 | data = np.array(Image.open(path).convert("RGB" if for_color else "L")) 173 | return data 174 | 175 | 176 | def get_gt_pre_with_name( 177 | gt_root: str, 178 | pre_root: str, 179 | img_name: str, 180 | pre_prefix: str, 181 | pre_suffix: str, 182 | gt_ext: str = ".png", 183 | to_normalize: bool = False, 184 | ): 185 | img_path = os.path.join(pre_root, pre_prefix + img_name + pre_suffix) 186 | gt_path = os.path.join(gt_root, img_name + gt_ext) 187 | 188 | pre = imread_wich_checking(img_path, for_color=False) 189 | gt = imread_wich_checking(gt_path, for_color=False) 190 | 191 | if pre.shape != gt.shape: 192 | pre = cv2.resize(pre, dsize=gt.shape[::-1], interpolation=cv2.INTER_LINEAR).astype( 193 | np.uint8 194 | ) 195 | 196 | if to_normalize: 197 | gt = normalize_array(gt, to_binary=True, max_eq_255=True) 198 | pre = normalize_array(pre, to_binary=False, max_eq_255=True) 199 | return gt, pre 200 | 201 | 202 | def normalize_array( 203 | data: np.ndarray, to_binary: bool = False, max_eq_255: bool = True 204 | ) -> np.ndarray: 205 | if max_eq_255: 206 | data = data / 255 207 | # else: data is in [0, 1] 208 | if to_binary: 209 | data = (data > 0.5).astype(np.uint8) 210 | else: 211 | if data.max() != data.min(): 212 | data = (data - data.min()) / (data.max() - data.min()) 213 | data = data.astype(np.float32) 214 | return data 215 | 216 | 217 | def get_valid_key_name(data_dict: dict, key_name: str) -> str: 218 | if data_dict.get(key_name.lower(), "keyerror") == "keyerror": 219 | key_name = key_name.upper() 220 | else: 221 | key_name = key_name.lower() 222 | return key_name 223 | 224 | 225 | def get_target_key(target_dict: dict, key: str) -> str: 226 | """ 227 | from the keys of the target_dict, get the valid key name corresponding to the `key` 228 | if there is not a valid name, return None 229 | """ 230 | target_keys = {k.lower(): k for k in target_dict.keys()} 231 | return target_keys.get(key.lower(), None) 232 | 233 | 234 | def colored_print(msg: str, mode: str = "general"): 235 | """ 236 | 为不同类型的字符串消息的打印提供一些显示格式的定制 237 | 238 | :param msg: 要输出的字符串消息 239 | :param mode: 对应的字符串打印模式,目前支持 general/warning/error 240 | :return: 241 | """ 242 | if mode == "general": 243 | msg = msg 244 | elif mode == "warning": 245 | msg = f"\033[5;31m{msg}\033[0m" 246 | elif mode == "error": 247 | msg = f"\033[1;31m{msg}\033[0m" 248 | else: 249 | raise ValueError(f"{mode} is invalid mode.") 250 | print(msg) 251 | 252 | 253 | class ColoredPrinter: 254 | """ 255 | 为不同类型的字符串消息的打印提供一些显示格式的定制 256 | """ 257 | 258 | @staticmethod 259 | def info(msg): 260 | print(msg) 261 | 262 | @staticmethod 263 | def warn(msg): 264 | msg = f"\033[5;31m{msg}\033[0m" 265 | print(msg) 266 | 267 | @staticmethod 268 | def error(msg): 269 | msg = f"\033[1;31m{msg}\033[0m" 270 | print(msg) 271 | 272 | 273 | def update_info(source_info: dict, new_info: dict): 274 | for name, info in source_info.items(): 275 | if name in new_info: 276 | if isinstance(info, dict): 277 | update_info(source_info=info, new_info=new_info[name]) 278 | else: # int, float, list, tuple 279 | info = new_info[name] 280 | source_info[name] = info 281 | return source_info 282 | -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/print_formatter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from tabulate import tabulate 4 | 5 | 6 | def print_formatter( 7 | results: dict, method_name_length=10, metric_name_length=5, metric_value_length=5 8 | ): 9 | dataset_regions = [] 10 | for dataset_name, dataset_metrics in results.items(): 11 | dataset_head_row = f" Dataset: {dataset_name} " 12 | dataset_region = [dataset_head_row] 13 | for method_name, metric_info in dataset_metrics.items(): 14 | showed_method_name = clip_string( 15 | method_name, max_length=method_name_length, mode="left" 16 | ) 17 | method_row_head = f"{showed_method_name} " 18 | method_row_body = [] 19 | for metric_name, metric_value in metric_info.items(): 20 | showed_metric_name = clip_string( 21 | metric_name, max_length=metric_name_length, mode="right" 22 | ) 23 | showed_value_string = clip_string( 24 | str(metric_value), max_length=metric_value_length, mode="left" 25 | ) 26 | method_row_body.append(f"{showed_metric_name}: {showed_value_string}") 27 | method_row = method_row_head + ", ".join(method_row_body) 28 | dataset_region.append(method_row) 29 | dataset_region_string = "\n".join(dataset_region) 30 | dataset_regions.append(dataset_region_string) 31 | dividing_line = "\n" + "-" * len(dataset_region[-1]) + "\n" # 直接使用最后一个数据集区域的最后一行数据的长度作为分割线的长度 32 | formatted_string = dividing_line.join(dataset_regions) 33 | return formatted_string 34 | 35 | 36 | def clip_string(string: str, max_length: int, padding_char: str = " ", mode: str = "left"): 37 | assert isinstance(max_length, int), f"{max_length} must be `int` type" 38 | 39 | real_length = len(string) 40 | if real_length <= max_length: 41 | padding_length = max_length - real_length 42 | if mode == "left": 43 | clipped_string = string + f"{padding_char}" * padding_length 44 | elif mode == "center": 45 | left_padding_str = f"{padding_char}" * (padding_length // 2) 46 | right_padding_str = f"{padding_char}" * (padding_length - padding_length // 2) 47 | clipped_string = left_padding_str + string + right_padding_str 48 | elif mode == "right": 49 | clipped_string = f"{padding_char}" * padding_length + string 50 | else: 51 | raise NotImplementedError 52 | else: 53 | clipped_string = string[:max_length] 54 | 55 | return clipped_string 56 | 57 | 58 | def formatter_for_tabulate( 59 | results: dict, 60 | dataset_titlefmt: str = "Dataset: {}", 61 | method_name_length=None, 62 | metric_value_length=None, 63 | tablefmt="github", 64 | ): 65 | """ 66 | tabulate format: 67 | 68 | :: 69 | 70 | table = [["spam",42],["eggs",451],["bacon",0]] 71 | headers = ["item", "qty"] 72 | print(tabulate(table, headers, tablefmt="github")) 73 | 74 | | item | qty | 75 | |--------|-------| 76 | | spam | 42 | 77 | | eggs | 451 | 78 | | bacon | 0 | 79 | 80 | 本函数的作用: 81 | 针对不同的数据集各自构造符合tabulate格式的列表并使用换行符间隔串联起来返回 82 | """ 83 | all_tables = [] 84 | for dataset_name, dataset_metrics in results.items(): 85 | all_tables.append(dataset_titlefmt.format(dataset_name)) 86 | 87 | table = [] 88 | headers = ["methods"] 89 | for method_name, metric_info in dataset_metrics.items(): 90 | if method_name_length: 91 | method_name = clip_string(method_name, max_length=method_name_length, mode="left") 92 | method_row = [method_name] 93 | # 保障顺序的一致性,虽然python3中已经实现了字典的有序性,但是为了确保万无一失,(毕竟可能设计到导出和导入)这里直接重新排序 94 | for metric_name, metric_value in sorted(metric_info.items(), key=lambda item: item[0]): 95 | if metric_value_length: 96 | metric_value = clip_string( 97 | str(metric_value), max_length=metric_value_length, mode="center" 98 | ) 99 | if metric_name not in headers: 100 | headers.append(metric_name) 101 | method_row.append(metric_value) 102 | table.append(method_row) 103 | all_tables.append(tabulate(table, headers, tablefmt=tablefmt)) 104 | 105 | formatted_string = "\n".join(all_tables) 106 | return formatted_string 107 | -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/recorders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .curve_drawer import CurveDrawer 3 | from .excel_recorder import MetricExcelRecorder 4 | from .metric_recorder import ( 5 | METRIC_MAPPING, 6 | GroupedMetricRecorder, 7 | MetricRecorder, 8 | MetricRecorder_V2, 9 | ) 10 | from .txt_recorder import TxtRecorder 11 | -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/recorders/curve_drawer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/1/4 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | import math 6 | import os 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class CurveDrawer(object): 12 | def __init__( 13 | self, 14 | row_num, 15 | num_subplots, 16 | style_cfg=None, 17 | ncol_of_legend=1, 18 | separated_legend=False, 19 | sharey=False, 20 | ): 21 | """A better wrapper of matplotlib for me. 22 | 23 | Args: 24 | row_num (int): Number of rows. 25 | num_subplots (int): Number of subplots. 26 | style_cfg (str, optional): Style yaml file path for matplotlib. Defaults to None. 27 | ncol_of_legend (int, optional): Number of columns of the legend. Defaults to 1. 28 | separated_legend (bool, optional): Use the separated legend. Defaults to False. 29 | sharey (bool, optional): Use a shared y-axis. Defaults to False. 30 | """ 31 | if style_cfg is not None: 32 | assert os.path.isfile(style_cfg) 33 | plt.style.use(style_cfg) 34 | 35 | self.ncol_of_legend = ncol_of_legend 36 | self.separated_legend = separated_legend 37 | if self.separated_legend: 38 | num_subplots += 1 39 | self.num_subplots = num_subplots 40 | self.sharey = sharey 41 | 42 | fig, axes = plt.subplots( 43 | nrows=row_num, ncols=math.ceil(self.num_subplots / row_num), sharey=self.sharey 44 | ) 45 | self.fig = fig 46 | self.axes = axes.flatten() 47 | 48 | self.init_subplots() 49 | self.dummy_data = {} 50 | 51 | def init_subplots(self): 52 | for ax in self.axes: 53 | ax.set_axis_off() 54 | 55 | def plot_at_axis(self, idx, method_curve_setting, x_data, y_data): 56 | """ 57 | :param method_curve_setting: { 58 | "line_color": "color"(str), 59 | "line_style": "style"(str), 60 | "line_label": "label"(str), 61 | "line_width": width(int), 62 | } 63 | """ 64 | assert isinstance(idx, int) and 0 <= idx < self.num_subplots 65 | self.axes[idx].plot( 66 | x_data, 67 | y_data, 68 | linewidth=method_curve_setting["line_width"], 69 | label=method_curve_setting["line_label"], 70 | color=method_curve_setting["line_color"], 71 | linestyle=method_curve_setting["line_style"], 72 | ) 73 | 74 | if self.separated_legend: 75 | self.dummy_data[method_curve_setting["line_label"]] = method_curve_setting 76 | 77 | def set_axis_property( 78 | self, idx, title=None, x_label=None, y_label=None, x_ticks=None, y_ticks=None 79 | ): 80 | ax = self.axes[idx] 81 | 82 | ax.set_axis_on() 83 | 84 | # give plot a title 85 | ax.set_title(title) 86 | 87 | # make axis labels 88 | ax.set_xlabel(x_label) 89 | ax.set_ylabel(y_label) 90 | 91 | # 对坐标刻度的设置 92 | x_ticks = [] if x_ticks is None else x_ticks 93 | y_ticks = [] if y_ticks is None else y_ticks 94 | ax.set_xlim((min(x_ticks), max(x_ticks))) 95 | ax.set_ylim((min(y_ticks), max(x_ticks))) 96 | ax.set_xticks(x_ticks) 97 | ax.set_yticks(y_ticks) 98 | ax.set_xticklabels(labels=[f"{x:.2f}" for x in x_ticks]) 99 | ax.set_yticklabels(labels=[f"{y:.2f}" for y in y_ticks]) 100 | 101 | def _plot(self): 102 | if self.sharey: 103 | for ax in self.axes[1:]: 104 | ax.set_ylabel(None) 105 | ax.tick_params(bottom=True, top=False, left=False, right=False) 106 | 107 | if self.separated_legend: 108 | # settings for the legend axis 109 | for method_label, method_info in self.dummy_data.items(): 110 | self.plot_at_axis( 111 | idx=self.num_subplots - 1, method_curve_setting=method_info, x_data=0, y_data=0 112 | ) 113 | ax = self.axes[self.num_subplots - 1] 114 | ax.set_axis_off() 115 | ax.legend(loc=10, ncol=self.ncol_of_legend, facecolor="white", edgecolor="white") 116 | else: 117 | # settings for the legneds of all common subplots. 118 | for ax in self.axes: 119 | # loc=0,自动将位置放在最合适的 120 | ax.legend(loc=3, ncol=self.ncol_of_legend, facecolor="white", edgecolor="white") 121 | 122 | def show(self): 123 | self._plot() 124 | plt.tight_layout() 125 | plt.show() 126 | 127 | def save(self, path): 128 | self._plot() 129 | plt.tight_layout() 130 | plt.savefig(path) 131 | -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/recorders/excel_recorder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/1/3 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | import os 7 | import re 8 | 9 | from openpyxl import Workbook, load_workbook 10 | from openpyxl.utils import get_column_letter 11 | from openpyxl.worksheet.worksheet import Worksheet 12 | 13 | 14 | # Thanks: 15 | # - Python_Openpyxl: https://www.cnblogs.com/programmer-tlh/p/10461353.html 16 | # - Python之re模块: https://www.cnblogs.com/shenjianping/p/11647473.html 17 | class _BaseExcelRecorder(object): 18 | def __init__(self, xlsx_path: str): 19 | """ 20 | 提供写xlsx文档功能的基础类。主要基于openpyxl实现了一层更方便的封装。 21 | 22 | :param xlsx_path: xlsx文档的路径。 23 | """ 24 | self.xlsx_path = xlsx_path 25 | if not os.path.exists(self.xlsx_path): 26 | print("We have created a new excel file!!!") 27 | self._initial_xlsx() 28 | else: 29 | print("Excel file has existed!") 30 | 31 | def _initial_xlsx(self): 32 | Workbook().save(self.xlsx_path) 33 | 34 | def load_sheet(self, sheet_name: str): 35 | wb = load_workbook(self.xlsx_path) 36 | if sheet_name not in wb.sheetnames: 37 | wb.create_sheet(title=sheet_name, index=0) 38 | sheet = wb[sheet_name] 39 | 40 | return wb, sheet 41 | 42 | def append_row(self, sheet: Worksheet, row_data): 43 | assert isinstance(row_data, (tuple, list)) 44 | sheet.append(row_data) 45 | 46 | def insert_row(self, sheet: Worksheet, row_data, row_id, min_col=1, interval=0): 47 | assert isinstance(row_id, int) and isinstance(min_col, int) and row_id > 0 and min_col > 0 48 | assert isinstance(row_data, (tuple, list)), row_data 49 | 50 | num_elements = len(row_data) 51 | row_data = iter(row_data) 52 | for row in sheet.iter_rows( 53 | min_row=row_id, 54 | max_row=row_id, 55 | min_col=min_col, 56 | max_col=min_col + (interval + 1) * (num_elements - 1), 57 | ): 58 | for i, cell in enumerate(row): 59 | if i % (interval + 1) == 0: 60 | sheet.cell(row=row_id, column=cell.column, value=next(row_data)) 61 | 62 | @staticmethod 63 | def merge_region(sheet: Worksheet, min_row, max_row, min_col, max_col): 64 | assert max_row >= min_row > 0 and max_col >= min_col > 0 65 | 66 | merged_region = ( 67 | f"{get_column_letter(min_col)}{min_row}:{get_column_letter(max_col)}{max_row}" 68 | ) 69 | sheet.merge_cells(merged_region) 70 | 71 | @staticmethod 72 | def get_col_id_with_row_id(sheet: Worksheet, col_name: str, row_id): 73 | """ 74 | 从指定行中寻找特定的列名,并返回对应的列序号 75 | """ 76 | assert isinstance(row_id, int) and row_id > 0 77 | 78 | for cell in sheet[row_id]: 79 | if cell.value == col_name: 80 | return cell.column 81 | raise ValueError(f"In row {row_id}, there is not the column {col_name}!") 82 | 83 | def get_row_id_with_col_name(self, sheet: Worksheet, row_name: str, col_name: str): 84 | """ 85 | 从指定列名字的一列中寻找指定行,返回对应的row_id, col_id, is_new_row 86 | """ 87 | is_new_row = True 88 | col_id = self.get_col_id_with_row_id(sheet=sheet, col_name=col_name, row_id=1) 89 | 90 | row_id = 0 91 | for cell in sheet[get_column_letter(col_id)]: 92 | row_id = cell.row 93 | if cell.value == row_name: 94 | return (row_id, col_id), not is_new_row 95 | return (row_id + 1, col_id), is_new_row 96 | 97 | @staticmethod 98 | def get_row_id_with_col_id(sheet: Worksheet, row_name: str, col_id: int): 99 | """ 100 | 从指定序号的一列中寻找指定行 101 | """ 102 | assert isinstance(col_id, int) and col_id > 0 103 | 104 | is_new_row = True 105 | row_id = 0 106 | for cell in sheet[get_column_letter(col_id)]: 107 | row_id = cell.row 108 | if cell.value == row_name: 109 | return row_id, not is_new_row 110 | return row_id + 1, is_new_row 111 | 112 | @staticmethod 113 | def format_string_with_config(string: str, repalce_config: dict = None): 114 | assert repalce_config is not None 115 | if repalce_config.get("lower"): 116 | string = string.lower() 117 | elif repalce_config.get("upper"): 118 | string = string.upper() 119 | elif repalce_config.get("title"): 120 | string = string.title() 121 | 122 | sub_rule = repalce_config.get("replace") 123 | if sub_rule: 124 | string = re.sub(pattern=sub_rule[0], repl=sub_rule[1], string=string) 125 | return string 126 | 127 | 128 | class MetricExcelRecorder(_BaseExcelRecorder): 129 | def __init__( 130 | self, 131 | xlsx_path: str, 132 | sheet_name: str = None, 133 | repalce_config=None, 134 | row_header=None, 135 | dataset_names=None, 136 | metric_names=None, 137 | ): 138 | """ 139 | 向xlsx文档写数据的类 140 | 141 | :param xlsx_path: 对应的xlsx文档路径 142 | :param sheet_name: 要写入数据对应的sheet名字 143 | 默认为 `results` 144 | :param repalce_config: 用于替换对应数据字典的键的模式,会被用于re.sub来进行替换 145 | 默认为 dict(lower=True, replace=(r"[_-]", "")) 146 | :param row_header: 用于指定表格工作表左上角的内容,这里默认为 `["methods", "num_data"]` 147 | :param dataset_names: 对应的数据集名称列表 148 | 默认为rgb sod的数据集合 ["pascals", "ecssd", "hkuis", "dutste", "dutomron"] 149 | :param metric_names: 对应指标名称列表 150 | 默认为 ["smeasure","wfmeasure","mae","adpfm","meanfm","maxfm","adpem","meanem","maxem"] 151 | """ 152 | super().__init__(xlsx_path=xlsx_path) 153 | if sheet_name is None: 154 | sheet_name = "results" 155 | 156 | if repalce_config is None: 157 | self.repalce_config = dict(lower=True, replace=(r"[_-]", "")) 158 | else: 159 | self.repalce_config = repalce_config 160 | 161 | if row_header is None: 162 | row_header = ["methods", "num_data"] 163 | self.row_header = [ 164 | self.format_string_with_config(s, self.repalce_config) for s in row_header 165 | ] 166 | if dataset_names is None: 167 | dataset_names = ["pascals", "ecssd", "hkuis", "dutste", "dutomron"] 168 | self.dataset_names = [ 169 | self.format_string_with_config(s, self.repalce_config) for s in dataset_names 170 | ] 171 | if metric_names is None: 172 | metric_names = [ 173 | "smeasure", 174 | "wfmeasure", 175 | "mae", 176 | "adpfm", 177 | "meanfm", 178 | "maxfm", 179 | "adpem", 180 | "meanem", 181 | "maxem", 182 | ] 183 | self.metric_names = [ 184 | self.format_string_with_config(s, self.repalce_config) for s in metric_names 185 | ] 186 | 187 | self.sheet_name = sheet_name 188 | self._initial_table() 189 | 190 | def _initial_table(self): 191 | wb, sheet = self.load_sheet(sheet_name=self.sheet_name) 192 | 193 | # 插入row_header 194 | self.insert_row(sheet=sheet, row_data=self.row_header, row_id=1, min_col=1) 195 | # 合并row_header的单元格 196 | for col_id in range(len(self.row_header)): 197 | self.merge_region( 198 | sheet=sheet, min_row=1, max_row=2, min_col=col_id + 1, max_col=col_id + 1 199 | ) 200 | 201 | # 插入数据集信息 202 | self.insert_row( 203 | sheet=sheet, 204 | row_data=self.dataset_names, 205 | row_id=1, 206 | min_col=len(self.row_header) + 1, 207 | interval=len(self.metric_names) - 1, 208 | ) 209 | 210 | # 插入指标信息 211 | start_col = len(self.row_header) + 1 212 | for i in range(len(self.dataset_names)): 213 | self.insert_row( 214 | sheet=sheet, 215 | row_data=self.metric_names, 216 | row_id=2, 217 | min_col=start_col + i * len(self.metric_names), 218 | ) 219 | wb.save(self.xlsx_path) 220 | 221 | def _format_row_data(self, row_data: dict) -> list: 222 | row_data = { 223 | self.format_string_with_config(k, self.repalce_config): v for k, v in row_data.items() 224 | } 225 | return [row_data[n] for n in self.metric_names] 226 | 227 | def __call__(self, row_data: dict, dataset_name: str, method_name: str): 228 | dataset_name = self.format_string_with_config(dataset_name, self.repalce_config) 229 | assert ( 230 | dataset_name in self.dataset_names 231 | ), f"{dataset_name} is not contained in {self.dataset_names}" 232 | 233 | # 1 载入数据表 234 | wb, sheet = self.load_sheet(sheet_name=self.sheet_name) 235 | # 2 搜索method_name是否存在,如果存在则直接寻找对应的行列坐标,如果不存在则直接使用新行 236 | dataset_col_start_id = self.get_col_id_with_row_id( 237 | sheet=sheet, col_name=dataset_name, row_id=1 238 | ) 239 | (method_row_id, method_col_id), is_new_row = self.get_row_id_with_col_name( 240 | sheet=sheet, row_name=method_name, col_name="methods" 241 | ) 242 | # 3 插入方法名字到对应的位置 243 | if is_new_row: 244 | sheet.cell(row=method_row_id, column=method_col_id, value=method_name) 245 | # 4 格式化指标数据部分为合理的格式,并插入表中 246 | row_data = self._format_row_data(row_data=row_data) 247 | self.insert_row( 248 | sheet=sheet, row_data=row_data, row_id=method_row_id, min_col=dataset_col_start_id 249 | ) 250 | # 4 写入新表 251 | wb.save(self.xlsx_path) 252 | -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/recorders/metric_recorder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/1/4 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | import numpy as np 7 | from py_sod_metrics.sod_metrics import ( 8 | MAE, 9 | Emeasure, 10 | Fmeasure, 11 | Smeasure, 12 | WeightedFmeasure, 13 | ) 14 | 15 | from metrics.extra_metrics import ExtraSegMeasure 16 | 17 | 18 | def ndarray_to_basetype(data): 19 | """ 20 | 将单独的ndarray,或者tuple,list或者dict中的ndarray转化为基本数据类型, 21 | 即列表(.tolist())和python标量 22 | """ 23 | 24 | def _to_list_or_scalar(item): 25 | listed_item = item.tolist() 26 | if isinstance(listed_item, list) and len(listed_item) == 1: 27 | listed_item = listed_item[0] 28 | return listed_item 29 | 30 | if isinstance(data, (tuple, list)): 31 | results = [_to_list_or_scalar(item) for item in data] 32 | elif isinstance(data, dict): 33 | results = {k: _to_list_or_scalar(item) for k, item in data.items()} 34 | else: 35 | assert isinstance(data, np.ndarray) 36 | results = _to_list_or_scalar(data) 37 | return results 38 | 39 | 40 | METRIC_MAPPING = { 41 | "mae": MAE, 42 | "fm": Fmeasure, 43 | "em": Emeasure, 44 | "sm": Smeasure, 45 | "wfm": WeightedFmeasure, 46 | "extra": ExtraSegMeasure, 47 | } 48 | 49 | 50 | class MetricRecorder_V2(object): 51 | def __init__(self, metric_names=None): 52 | """ 53 | 用于统计各种指标的类 54 | """ 55 | if metric_names is None: 56 | metric_names = ("mae", "fm", "em", "sm", "wfm") 57 | self.metric_objs = {} 58 | for metric_name in metric_names: 59 | self.metric_objs[metric_name] = METRIC_MAPPING[metric_name]() 60 | 61 | def update(self, pre: np.ndarray, gt: np.ndarray): 62 | assert pre.shape == gt.shape 63 | assert pre.dtype == np.uint8 64 | assert gt.dtype == np.uint8 65 | for m_name, m_obj in self.metric_objs.items(): 66 | m_obj.step(pre, gt) 67 | 68 | def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: 69 | """ 70 | 返回指标计算结果: 71 | 72 | - 曲线数据(sequential) 73 | - 数值指标(numerical) 74 | """ 75 | sequential_results = {} 76 | numerical_results = {} 77 | for m_name, m_obj in self.metric_objs.items(): 78 | info = m_obj.get_results() 79 | if m_name == "fm": 80 | fm = info["fm"] 81 | pr = info["pr"] 82 | sequential_results.update( 83 | { 84 | "fm": np.flip(fm["curve"]), 85 | "p": np.flip(pr["p"]), 86 | "r": np.flip(pr["r"]), 87 | } 88 | ) 89 | numerical_results.update( 90 | {"maxf": fm["curve"].max(), "avgf": fm["curve"].mean(), "adpf": fm["adp"]} 91 | ) 92 | elif m_name == "wfm": 93 | wfm = info["wfm"] 94 | numerical_results["wfm"] = wfm 95 | elif m_name == "sm": 96 | sm = info["sm"] 97 | numerical_results["sm"] = sm 98 | elif m_name == "em": 99 | em = info["em"] 100 | sequential_results["em"] = np.flip(em["curve"]) 101 | numerical_results.update( 102 | {"maxe": em["curve"].max(), "avge": em["curve"].mean(), "adpe": em["adp"]} 103 | ) 104 | elif m_name == "mae": 105 | mae = info["mae"] 106 | numerical_results["mae"] = mae 107 | elif m_name == "extra": 108 | pre = info["pre"] 109 | sen = info["sen"] 110 | spec = info["spec"] 111 | fm_std = info["fm"] 112 | dice = info["dice"] 113 | iou = info["iou"] 114 | numerical_results.update( 115 | { 116 | "maxpre": pre.max(), 117 | "avgpre": pre.mean(), 118 | "maxsen": sen.max(), 119 | "avgsen": sen.mean(), 120 | "maxspec": spec.max(), 121 | "avgspec": spec.mean(), 122 | "maxfm_std": fm_std.max(), 123 | "avgfm_std": fm_std.mean(), 124 | "maxdice": dice.max(), 125 | "avgdice": dice.mean(), 126 | "maxiou": iou.max(), 127 | "avgiou": iou.mean(), 128 | } 129 | ) 130 | else: 131 | raise NotImplementedError 132 | 133 | if num_bits is not None and isinstance(num_bits, int): 134 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()} 135 | if not return_ndarray: 136 | sequential_results = ndarray_to_basetype(sequential_results) 137 | numerical_results = ndarray_to_basetype(numerical_results) 138 | return {"sequential": sequential_results, "numerical": numerical_results} 139 | 140 | 141 | class MetricRecorder(object): 142 | def __init__(self): 143 | """ 144 | 用于统计各种指标的类 145 | """ 146 | self.mae = MAE() 147 | self.fm = Fmeasure() 148 | self.sm = Smeasure() 149 | self.em = Emeasure() 150 | self.wfm = WeightedFmeasure() 151 | 152 | def update(self, pre: np.ndarray, gt: np.ndarray): 153 | assert pre.shape == gt.shape 154 | assert pre.dtype == np.uint8 155 | assert gt.dtype == np.uint8 156 | 157 | self.mae.step(pre, gt) 158 | self.sm.step(pre, gt) 159 | self.fm.step(pre, gt) 160 | self.em.step(pre, gt) 161 | self.wfm.step(pre, gt) 162 | 163 | def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: 164 | """ 165 | 返回指标计算结果: 166 | 167 | - 曲线数据(sequential): fm/em/p/r 168 | - 数值指标(numerical): SM/MAE/maxE/avgE/adpE/maxF/avgF/adpF/wFm 169 | """ 170 | fm_info = self.fm.get_results() 171 | fm = fm_info["fm"] 172 | pr = fm_info["pr"] 173 | wfm = self.wfm.get_results()["wfm"] 174 | sm = self.sm.get_results()["sm"] 175 | em = self.em.get_results()["em"] 176 | mae = self.mae.get_results()["mae"] 177 | 178 | sequential_results = { 179 | "fm": np.flip(fm["curve"]), 180 | "em": np.flip(em["curve"]), 181 | "p": np.flip(pr["p"]), 182 | "r": np.flip(pr["r"]), 183 | } 184 | numerical_results = { 185 | "SM": sm, 186 | "MAE": mae, 187 | "maxE": em["curve"].max(), 188 | "avgE": em["curve"].mean(), 189 | "adpE": em["adp"], 190 | "maxF": fm["curve"].max(), 191 | "avgF": fm["curve"].mean(), 192 | "adpF": fm["adp"], 193 | "wFm": wfm, 194 | } 195 | if num_bits is not None and isinstance(num_bits, int): 196 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()} 197 | if not return_ndarray: 198 | sequential_results = ndarray_to_basetype(sequential_results) 199 | numerical_results = ndarray_to_basetype(numerical_results) 200 | return {"sequential": sequential_results, "numerical": numerical_results} 201 | 202 | 203 | class GroupedMetricRecorder(object): 204 | def __init__(self): 205 | self.metric_recorders = {} 206 | # 这些指标会根据最终所有分组进行平均得到的曲线计算 207 | self.re_cal_metrics = ["maxE", "avgE", "maxF", "avgF"] 208 | 209 | def update(self, group_name: str, pre: np.ndarray, gt: np.ndarray): 210 | if group_name not in self.metric_recorders: 211 | self.metric_recorders[group_name] = MetricRecorder() 212 | self.metric_recorders[group_name].update(pre, gt) 213 | 214 | def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: 215 | """ 216 | 返回指标计算结果: 217 | 218 | - 曲线数据(sequential): fm/em/p/r 219 | - 数值指标(numerical): SM/MAE/maxE/avgE/adpE/maxF/avgF/adpF/wFm 220 | """ 221 | group_metrics = {} 222 | for k, v in self.metric_recorders.items(): 223 | group_metric = v.show(num_bits=None, return_ndarray=True) 224 | group_metrics[k] = { 225 | **group_metric["sequential"], 226 | **{ 227 | metric_name: metric_value 228 | for metric_name, metric_value in group_metric["numerical"].items() 229 | if metric_name not in self.re_cal_metrics 230 | }, 231 | } 232 | avg_results = self.average_group_metrics(group_metrics=group_metrics) 233 | 234 | sequential_results = { 235 | "fm": avg_results["fm"], 236 | "em": avg_results["em"], 237 | "p": avg_results["p"], 238 | "r": avg_results["r"], 239 | } 240 | numerical_results = { 241 | "SM": avg_results["SM"], 242 | "MAE": avg_results["MAE"], 243 | "maxE": avg_results["em"].max(), 244 | "avgE": avg_results["em"].mean(), 245 | "adpE": avg_results["adpE"], 246 | "maxF": avg_results["fm"].max(), 247 | "avgF": avg_results["fm"].mean(), 248 | "adpF": avg_results["adpF"], 249 | "wFm": avg_results["wFm"], 250 | } 251 | if num_bits is not None and isinstance(num_bits, int): 252 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()} 253 | if not return_ndarray: 254 | sequential_results = ndarray_to_basetype(sequential_results) 255 | numerical_results = ndarray_to_basetype(numerical_results) 256 | return {"sequential": sequential_results, "numerical": numerical_results} 257 | 258 | @staticmethod 259 | def average_group_metrics(group_metrics: dict) -> dict: 260 | recorder = defaultdict(list) 261 | for group_name, metrics in group_metrics.items(): 262 | for metric_name, metric_array in metrics.items(): 263 | recorder[metric_name].append(metric_array) 264 | results = {k: np.mean(np.vstack(v), axis=0) for k, v in recorder.items()} 265 | return results 266 | -------------------------------------------------------------------------------- /cos_eval_toolbox/utils/recorders/txt_recorder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/1/4 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | from datetime import datetime 7 | 8 | 9 | class TxtRecorder: 10 | def __init__(self, txt_path, to_append=True, max_method_name_width=10): 11 | """ 12 | 用于向txt文档写数据的类。 13 | 14 | :param txt_path: txt文档路径 15 | :param to_append: 是否要继续使用之前的文档,如果没有就重新创建 16 | :param max_method_name_width: 方法字符串的最大长度 17 | """ 18 | self.txt_path = txt_path 19 | self.max_method_name_width = max_method_name_width 20 | mode = "a" if to_append else "w" 21 | with open(txt_path, mode=mode, encoding="utf-8") as f: 22 | f.write(f"\n ========>> Date: {datetime.now()} <<======== \n") 23 | 24 | def add_row(self, row_name, row_data, row_start_str="", row_end_str="\n"): 25 | with open(self.txt_path, mode="a", encoding="utf-8") as f: 26 | f.write(f"{row_start_str} ========>> {row_name}: {row_data} <<======== {row_end_str}") 27 | 28 | def __call__( 29 | self, 30 | method_results: dict, 31 | method_name: str = "", 32 | row_start_str="", 33 | row_end_str="\n", 34 | value_width=6, 35 | ): 36 | msg = row_start_str 37 | if len(method_name) > self.max_method_name_width: 38 | method_name = method_name[: self.max_method_name_width - 3] + "..." 39 | else: 40 | method_name += " " * (self.max_method_name_width - len(method_name)) 41 | msg += f"[{method_name}] " 42 | for metric_name, metric_value in method_results.items(): 43 | assert isinstance(metric_value, float) 44 | msg += f"{metric_name}: " 45 | real_width = len(str(metric_value)) 46 | if value_width > real_width: 47 | # 后补空格 48 | msg += f"{metric_value}" + " " * (value_width - real_width) 49 | else: 50 | # 真实数据长度超过了限定,这时需要近似保留小数 51 | # 保留指定位数,注意,这里由于数据都是0~1之间的数据,所以使用round的时候需要去掉前面的`0.` 52 | msg += f"{round(metric_value, ndigits=value_width - 2)}" 53 | msg += " " 54 | msg += row_end_str 55 | with open(self.txt_path, mode="a", encoding="utf-8") as f: 56 | f.write(msg) 57 | --------------------------------------------------------------------------------