├── .gitignore
├── LICENSE
├── README.md
├── assets
├── benchmark_camo.png
├── benchmark_cod10k.png
├── benchmark_nc4k.png
├── cds2k-benchmark.png
├── cds2k-statistics.png
├── cds2k.png
├── cos_quali_viz.png
├── csu-logo.png
├── dataset_sample_gallery.png
├── reviewed_datasets.png
├── reviewed_image_methods.png
├── reviewed_video_methods.png
└── task_definition.png
└── cos_eval_toolbox
├── .editorconfig
├── .github
└── workflows
│ └── publish.yml
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── eval.py
├── examples_COS
├── config_cos_dataset_py_example.json
├── config_cos_dataset_py_example.py
├── config_cos_method_py_example.json
└── config_cos_method_py_example.py
├── metrics
├── __init__.py
├── cal_cosod_matrics.py
├── cal_sod_matrics.py
├── draw_curves.py
└── extra_metrics.py
├── output_COS
├── cos_curves.npy
├── cos_metrics.npy
└── cos_results.txt
├── plot.py
├── pyproject.toml
├── requirements.txt
├── tools
├── .backup
│ ├── individual_metrics.py
│ └── ranking_per_images.py
├── append_results.py
├── cal_avg_resolution.py
├── check_path.py
├── converter.py
├── generate_cos_config_files.py
├── generate_latex_code.py
├── info_py_to_json.py
├── markdown2html.py
├── readme.md
└── rename.py
└── utils
├── __init__.py
├── generate_info.py
├── misc.py
├── print_formatter.py
└── recorders
├── __init__.py
├── curve_drawer.py
├── excel_recorder.py
├── metric_recorder.py
└── txt_recorder.py
/.gitignore:
--------------------------------------------------------------------------------
1 | cos_eval_toolbox/benchmark/
2 | cos_eval_toolbox/.backup/
3 | cos_eval_toolbox/dataset/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Deng-Ping Fan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Advances in Deep Concealed Scene Understanding
2 |
3 |
4 |
5 | This repository contains a collection of research papers, an evaluation toolbox, and benchmarking results for the task of concealed object segmentation (COS) in images. Besides, to evaluate the generalizability of COS approaches, we re-organize a concealed defect segmentation dataset named CDS2K.
6 |
7 | - Paper link: [arXiv](https://arxiv.org/abs/2304.11234)
8 | - This project is under construction. Contributions are welcome! If you would like to contribute to this repository, please submit a pull request.
9 |
10 | ## Table of Contents
11 |
12 | - [Advances in Deep Concealed Scene Understanding](#advances-in-deep-concealed-scene-understanding)
13 | - [Table of Contents](#table-of-contents)
14 | - [CSU Background](#csu-background)
15 | - [CSU Taxonomy](#csu-taxonomy)
16 | - [CSU Survey](#csu-survey)
17 | - [CSU Benchmark](#csu-benchmark)
18 | - [Defect Segmentation Dataset -- CDS2K](#defect-segmentation-dataset----cds2k)
19 | - [Citation](#citation)
20 |
21 | ## CSU Background
22 |
23 | Concealed scene understanding (CSU) is a hot computer vision topic aiming to perceive objects with camouflaged properties. The current boom in its advanced techniques and novel applications makes it timely to provide an up-to-date survey to enable researchers to understand the global picture of the CSU field, including both current achievements and major challenges.
24 |
25 |
26 |
27 |
28 | Figure 1: Sample gallery of concealed scenarios. (a-d) show natural animals. (e) depicts a concealed human in art. (f) features a synthesized ``lion''.
29 |
30 |
31 |
32 | This paper makes four contributions:
33 | - For the first time, we present **a comprehensive survey** of the deep learning techniques oriented at CSU, including a background with its taxonomy, task-unique challenges, and a review of its developments in the deep learning era via surveying existing datasets and deep techniques.
34 | - For a quantitative comparison of the state-of-the-art, we contribute **the largest and latest benchmark** for Concealed Object Segmentation (COS).
35 | - To evaluate the transferability of deep CSU in practical scenarios,
36 | we re-organize **the largest concealed defect segmentation dataset** termed CDS2K with the hard cases from diversified industrial scenarios, on which we construct a comprehensive benchmark.
37 | - We **discuss open problems and potential research directions** for this community.
38 |
39 | ## CSU Taxonomy
40 |
41 | We introduce a taxonomy of seven popular CSU tasks. Please refer to Section 2.1 of our paper for more details.
42 | - Five of these are image-level tasks: (a) concealed object segmentation (COS), (b) concealed object localization (COL), (c) concealed instance ranking (CIR), (d) concealed instance segmentation (CIS), and (e) concealed object counting (COC).
43 | - The remaining two are video-level tasks: (f) video concealed object segmentation (VCOS) and (g) video concealed object detection (VCOD).
44 |
45 | We illustrate each task with its corresponding annotation visualization.
46 |
47 |
48 |
49 |
50 | Figure 2: Illustration of representative CSU tasks.
51 |
52 |
53 |
54 | ## CSU Survey
55 |
56 | We recap the latest image-based research that includes 50 papers.
57 |
58 |
59 |
60 |
61 | Table 1: Essential characteristics of reviewed video-level CSU methods.
62 |
63 |
64 |
65 | We also review recent nine video-based research
66 |
67 |
68 |
69 |
70 | Table 2: Essential characteristics of reviewed video-level CSU methods.
71 |
72 |
73 |
74 | The following are ten datasets collected for several CSU-related tasks.
75 |
76 |
77 |
78 |
79 | Table 3: Essential characteristics of reviewed video-level CSU methods.
80 |
81 |
82 |
83 |
84 | ## CSU Benchmark
85 |
86 | Our benchmarking is built on COS tasks since this topic is relatively well-established and offers a variety of competing approaches. **WHAT DO WE PROVIDE HERE?**
87 |
88 | - First, we provide a one-key [evaluation toolbox](https://github.com/DengPingFan/CSU/tree/main/cos_eval_toolbox) for CSU. Please the follow instructions and then you will get the results.
89 | - Second, we run COS approaches on three popular benchmarks (CAMO, NC4K, and COD10K) and organize them into the standard format (*png) [Google Drive, 1.16GB](https://drive.google.com/file/d/1v5AZ37YlSjKiBMrfYXhZ9wa9dJLyFuVD/view?usp=sharing). The collection of these prediction masks is public [here (Google Drive, 4.82GB)](https://drive.google.com/file/d/1BPyE6KtQvi8f1gL0IVsILlkv5C00GZYg/view?usp=sharing) for convenient research.
90 | - The benchmark results on nine evaluation metrics are reported in the next three tables. You can find the text file [here](https://github.com/DengPingFan/CSU/tree/main/cos_eval_toolbox/output_COS).
91 |
92 |
93 |
94 |
95 | Table 4: Quantitative comparison of CAMO testing set.
96 |
97 |
98 |
99 |
100 |
101 |
102 | Table 5: Quantitative comparison on NC4K testing set.
103 |
104 |
105 |
106 |
107 |
108 |
109 | Table 6: Quantitative comparison of COD10K testing set.
110 |
111 |
112 |
113 | - Lastly, we provide the attribute-based analyses on the COD10K dataset
114 |
115 |
116 |
117 |
118 | Figure 3: Qualitative results of ten COS approaches. For more descriptions of visual attributes in each column refer to Section 5.6 of the paper.
119 |
120 |
121 |
122 | ## Defect Segmentation Dataset -- CDS2K
123 |
124 | We organize a concealed defect segmentation dataset ([Google Drive, 159MB](https://drive.google.com/file/d/1OGPR34qCNWHVYwyf9OY6IH-7WHzPkC7-/view?usp=sharing)) from the five well-known defect segmentation databases. As shown in Figure 4, we present five sub-databases: (a-l) MVTecAD, (m-o) NEU, (p) CrackForest, (q) KolektorSDD, and (r) MagneticTile. The defective regions are highlighted with red rectangles. (Top-Right) Word cloud visualization of CDS2K. (Bottom) The statistic number of positive/negative samples of each category in our CDS2K.
125 |
126 |
127 |
128 |
129 | Figure 4: Sample gallery of our CDS2K.
130 |
131 |
132 |
133 | The average ratio of defective regions for each category is presented in Table 7, which indicates that most of the defective regions are relatively small
134 |
135 |
136 |
137 |
138 | Table 7: Sample gallery of our CDS2K.
139 |
140 |
141 |
142 | Next, we report the quantitative comparison on the positive samples of CDS2K. Kindly download the result map on [Google Drive (116.6MB)](https://drive.google.com/file/d/1GIP0hdppaBJxV1SSRSOIccn1UgFm0J-l/view?usp=sharing).
143 |
144 |
145 |
146 |
147 | Table 8: SQuantitative comparison on the positive samples of CDS2K.
148 |
149 |
150 |
151 |
152 |
153 | ## Citation
154 |
155 | Please cite our paper if you find the work useful:
156 |
157 | @article{fan2023csu,
158 | title={Advances in Deep Concealed Scene Understanding},
159 | author={Fan, Deng-Ping and Ji, Ge-Peng and Xu, Peng and Cheng, Ming-Ming and Sakaridis, Christos and Van Gool, Luc},
160 | journal={Visual Intelligence (VI)},
161 | year={2023}
162 | }
163 |
--------------------------------------------------------------------------------
/assets/benchmark_camo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/benchmark_camo.png
--------------------------------------------------------------------------------
/assets/benchmark_cod10k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/benchmark_cod10k.png
--------------------------------------------------------------------------------
/assets/benchmark_nc4k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/benchmark_nc4k.png
--------------------------------------------------------------------------------
/assets/cds2k-benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/cds2k-benchmark.png
--------------------------------------------------------------------------------
/assets/cds2k-statistics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/cds2k-statistics.png
--------------------------------------------------------------------------------
/assets/cds2k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/cds2k.png
--------------------------------------------------------------------------------
/assets/cos_quali_viz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/cos_quali_viz.png
--------------------------------------------------------------------------------
/assets/csu-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/csu-logo.png
--------------------------------------------------------------------------------
/assets/dataset_sample_gallery.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/dataset_sample_gallery.png
--------------------------------------------------------------------------------
/assets/reviewed_datasets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/reviewed_datasets.png
--------------------------------------------------------------------------------
/assets/reviewed_image_methods.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/reviewed_image_methods.png
--------------------------------------------------------------------------------
/assets/reviewed_video_methods.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/reviewed_video_methods.png
--------------------------------------------------------------------------------
/assets/task_definition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/assets/task_definition.png
--------------------------------------------------------------------------------
/cos_eval_toolbox/.editorconfig:
--------------------------------------------------------------------------------
1 | # https://editorconfig.org/
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | insert_final_newline = true
9 | trim_trailing_whitespace = true
10 | end_of_line = lf
11 | charset = utf-8
12 |
13 | # Docstrings and comments use max_line_length = 79
14 | [*.py]
15 | max_line_length = 99
16 |
17 | # Use 2 spaces for the HTML files
18 | [*.html]
19 | indent_size = 2
20 |
21 | # The JSON files contain newlines inconsistently
22 | [*.json]
23 | indent_size = 2
24 | insert_final_newline = ignore
25 |
26 | [**/admin/js/vendor/**]
27 | indent_style = ignore
28 | indent_size = ignore
29 |
30 | # Minified JavaScript files shouldn't be changed
31 | [**.min.js]
32 | indent_style = ignore
33 | insert_final_newline = ignore
34 |
35 | # Makefiles always use tabs for indentation
36 | [Makefile]
37 | indent_style = space
38 |
39 | # Batch files use tabs for indentation
40 | [*.bat]
41 | indent_style = space
42 |
43 | [docs/**.txt]
44 | max_line_length = 119
45 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: auto-publish
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 |
8 | jobs:
9 | auto-update: # job_id
10 | name: Markdown2HTML # 作业显示在 GitHub 上的名称。
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Check out repository
14 | uses: actions/checkout@v2
15 | - name: Setup Python
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: 3.8 # optional, default is 3.x
19 | - name: Convert markdown to html
20 | run: |
21 | pip install markdown2
22 | cd tools
23 | python markdown2html.py
24 | cd ..
25 | - name: Deploy
26 | uses: peaceiris/actions-gh-pages@v3
27 | with:
28 | github_token: ${{ secrets.GITHUB_TOKEN }}
29 | publish_dir: ./results/htmls
30 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v3.2.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-yaml
10 | - id: check-toml
11 | - id: check-added-large-files
12 | - id: fix-encoding-pragma
13 | - id: mixed-line-ending
14 | - repo: https://github.com/pycqa/isort
15 | rev: 5.6.4
16 | hooks:
17 | - id: isort
18 | - repo: https://github.com/psf/black
19 | rev: 20.8b1
20 | # Replace by any tag/version: https://github.com/psf/black/tags
21 | hooks:
22 | - id: black
23 | language_version: python3
24 | # Should be a command that runs python3.6+
25 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 MY_
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/README.md:
--------------------------------------------------------------------------------
1 | # COS Evaluation Toolbox
2 |
3 | A Python-based concealed object segmentation (COS) evaluation toolbox.
4 |
5 | # Features
6 |
7 | This repo provides one-key processing for nine evaluation metrics
8 |
9 | - MAE
10 | - weighted F-measure
11 | - S-measure
12 | - max/average/adaptive F-measure
13 | - max/average/adaptive E-measure
14 |
15 | # One-command evaluation
16 |
17 | To evaluate concealed object segmentation (COS) approaches, you should prepare some libraries with the command: `pip install -r requirements.txt`. Then, download benchmark datasets ([OneDrive, 1.16GB](https://anu365-my.sharepoint.com/:u:/g/personal/u7248002_anu_edu_au/EWQU8s3I1cxLvQuYEt2g6gkBO4uwJ2bZq6Vuf9V1Hum7Lg?e=xMMcAr)) and prediction masks ([OneDrive, 4.82GB](https://anu365-my.sharepoint.com/:u:/g/personal/u7248002_anu_edu_au/Edk5mzHO5JNMv0LHDFBdTq4Bgrg_wmsmYg9hjOzh6-nAjw?e=xdVrT4)) just play with this command:
18 |
19 | ```bash
20 | python eval.py --dataset-json examples_COS/config_cos_dataset_py_example.json \
21 | --method-json examples_COS/config_cos_method_py_example_all.json \
22 | --metric-npy output_COS/cos_metrics.npy \
23 | --curves-npy output_COS/cos_curves.npy \
24 | --record-txt output_COS/cos_results.txt
25 | ```
26 |
27 | Your results will store at `./cos_eval_toolbox/output_COS/cos_results.txt`
28 |
29 |
30 | # Custom your evaluation
31 |
32 | 1. Put your prediction masks into a custom file path like `./benchmark/COS-Benchmarking` and prepare your dataset like `./cos_eval_toolbox/dataset/COD10K/`. Then, generate the Python-style configs via
33 |
34 | ```bash
35 | python tools/generate_cos_config_files.py
36 | ```
37 |
38 | 2. generate the JSON-style files via
39 |
40 | ```bash
41 | python tools/info_py_to_json.py -i ./examples_COS -o ./examples_COS
42 | ```
43 |
44 | 3. check files via
45 |
46 | ```bash
47 | python tools/check_path.py -m examples_COS/config_cos_method_py_example.json -d examples_COS/config_cos_dataset_py_example.json
48 | ```
49 |
50 | 4. start to evaluate
51 |
52 | ```bash
53 | python eval.py --dataset-json examples_COS/config_cos_dataset_py_example.json \
54 | --method-json examples_COS/config_cos_method_py_example.json \
55 | --metric-npy output_COS/cos_metrics.npy \
56 | --curves-npy output_COS/cos_curves.npy \
57 | --record-txt output_COS/cos_results.txt
58 | ```
59 |
60 | # Citations
61 |
62 | ```text
63 | @inproceedings{Fmeasure,
64 | title={Frequency-tuned salient region detection},
65 | author={Achanta, Radhakrishna and Hemami, Sheila and Estrada, Francisco and S{\"u}sstrunk, Sabine},
66 | booktitle=CVPR,
67 | number={CONF},
68 | pages={1597--1604},
69 | year={2009}
70 | }
71 |
72 | @inproceedings{MAE,
73 | title={Saliency filters: Contrast based filtering for salient region detection},
74 | author={Perazzi, Federico and Kr{\"a}henb{\"u}hl, Philipp and Pritch, Yael and Hornung, Alexander},
75 | booktitle=CVPR,
76 | pages={733--740},
77 | year={2012}
78 | }
79 |
80 | @inproceedings{Smeasure,
81 | title={Structure-measure: A new way to eval foreground maps},
82 | author={Fan, Deng-Ping and Cheng, Ming-Ming and Liu, Yun and Li, Tao and Borji, Ali},
83 | booktitle=ICCV,
84 | pages={4548--4557},
85 | year={2017}
86 | }
87 |
88 | @inproceedings{Emeasure,
89 | title="Enhanced-alignment Measure for Binary Foreground Map Evaluation",
90 | author="Deng-Ping {Fan} and Cheng {Gong} and Yang {Cao} and Bo {Ren} and Ming-Ming {Cheng} and Ali {Borji}",
91 | booktitle=IJCAI,
92 | pages="698--704",
93 | year={2018}
94 | }
95 |
96 | @inproceedings{wFmeasure,
97 | title={How to eval foreground maps?},
98 | author={Margolin, Ran and Zelnik-Manor, Lihi and Tal, Ayellet},
99 | booktitle=CVPR,
100 | pages={248--255},
101 | year={2014}
102 | }
103 | ```
104 |
105 | # Acknowledgements
106 |
107 | This repo is built on [PySODEvalToolkit](https://github.com/lartpang/PySODEvalToolkit). We appreciate Dr Pang for his excellent work, please refer [README.md](https://github.com/lartpang/PySODEvalToolkit/blob/master/readme.md) for more interesting plays.
--------------------------------------------------------------------------------
/cos_eval_toolbox/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 | import os
4 | import textwrap
5 | import warnings
6 |
7 | from metrics import cal_sod_matrics
8 | from utils.generate_info import get_datasets_info, get_methods_info
9 | from utils.misc import make_dir
10 | from utils.recorders import METRIC_MAPPING
11 |
12 |
13 | def get_args():
14 | parser = argparse.ArgumentParser(
15 | description=textwrap.dedent(
16 | r"""
17 | INCLUDE:
18 |
19 | - F-measure-Threshold Curve
20 | - Precision-Recall Curve
21 | - MAE
22 | - weighted F-measure
23 | - S-measure
24 | - max/average/adaptive F-measure
25 | - max/average/adaptive E-measure
26 | - max/average Precision
27 | - max/average Sensitivity
28 | - max/average Specificity
29 | - max/average F-measure
30 | - max/average Dice
31 | - max/average IoU
32 |
33 | NOTE:
34 |
35 | - Our method automatically calculates the intersection of `pre` and `gt`.
36 | - Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext`
37 |
38 | EXAMPLES:
39 |
40 | python eval_all.py \
41 | --dataset-json configs/datasets/json/rgbd_sod.json \
42 | --method-json configs/methods/json/rgbd_other_methods.json configs/methods/json/rgbd_our_method.json --metric-npy output/rgbd_metrics.npy \
43 | --curves-npy output/rgbd_curves.npy \
44 | --record-txt output/rgbd_results.txt
45 | """
46 | ),
47 | formatter_class=argparse.RawTextHelpFormatter,
48 | )
49 | parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.")
50 | parser.add_argument(
51 | "--method-json", required=True, nargs="+", type=str, help="Json file for methods."
52 | )
53 | parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.")
54 | parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.")
55 | parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.")
56 | parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.")
57 | parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.")
58 | parser.add_argument(
59 | "--include-methods",
60 | type=str,
61 | nargs="+",
62 | help="Names of only specific methods you want to evaluate.",
63 | )
64 | parser.add_argument(
65 | "--exclude-methods",
66 | type=str,
67 | nargs="+",
68 | help="Names of some specific methods you do not want to evaluate.",
69 | )
70 | parser.add_argument(
71 | "--include-datasets",
72 | type=str,
73 | nargs="+",
74 | help="Names of only specific datasets you want to evaluate.",
75 | )
76 | parser.add_argument(
77 | "--exclude-datasets",
78 | type=str,
79 | nargs="+",
80 | help="Names of some specific datasets you do not want to evaluate.",
81 | )
82 | parser.add_argument(
83 | "--num-workers",
84 | type=int,
85 | default=8,
86 | help="Number of workers for multi-threading or multi-processing. Default: 4",
87 | )
88 | parser.add_argument(
89 | "--num-bits",
90 | type=int,
91 | default=3,
92 | help="Number of decimal places for showing results. Default: 3",
93 | )
94 | parser.add_argument(
95 | "--metric-names",
96 | type=str,
97 | nargs="+",
98 | default=["mae", "fm", "em", "sm", "wfm"],
99 | choices=METRIC_MAPPING.keys(),
100 | help="Names of metrics",
101 | )
102 | args = parser.parse_args()
103 |
104 | if args.metric_npy is not None:
105 | make_dir(os.path.dirname(args.metric_npy))
106 | if args.curves_npy is not None:
107 | make_dir(os.path.dirname(args.curves_npy))
108 | if args.record_txt is not None:
109 | make_dir(os.path.dirname(args.record_txt))
110 | if args.record_xlsx is not None:
111 | make_dir(os.path.dirname(args.record_xlsx))
112 | if args.to_overwrite and not args.record_txt:
113 | warnings.warn("--to-overwrite only works with a valid --record-txt")
114 | return args
115 |
116 |
117 | def main():
118 | args = get_args()
119 |
120 | # 包含所有数据集信息的字典
121 | datasets_info = get_datasets_info(
122 | datastes_info_json=args.dataset_json,
123 | include_datasets=args.include_datasets,
124 | exclude_datasets=args.exclude_datasets,
125 | )
126 | # 包含所有待比较模型结果的信息的字典
127 | methods_info = get_methods_info(
128 | methods_info_jsons=args.method_json,
129 | for_drawing=True,
130 | include_methods=args.include_methods,
131 | exclude_methods=args.exclude_methods,
132 | )
133 |
134 | # 确保多进程在windows上也可以正常使用
135 | cal_sod_matrics.cal_sod_matrics(
136 | sheet_name="Results",
137 | to_append=not args.to_overwrite,
138 | txt_path=args.record_txt,
139 | xlsx_path=args.record_xlsx,
140 | methods_info=methods_info,
141 | datasets_info=datasets_info,
142 | curves_npy_path=args.curves_npy,
143 | metrics_npy_path=args.metric_npy,
144 | num_bits=args.num_bits,
145 | num_workers=args.num_workers,
146 | use_mp=False,
147 | metric_names=args.metric_names,
148 | ncols_tqdm=119,
149 | )
150 |
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/examples_COS/config_cos_dataset_py_example.json:
--------------------------------------------------------------------------------
1 | {
2 | "CAMO": {
3 | "root": "./dataset/CAMO",
4 | "image": {
5 | "path": "./dataset/CAMO/Imgs",
6 | "suffix": ".jpg"
7 | },
8 | "mask": {
9 | "path": "./dataset/CAMO/GT",
10 | "suffix": ".png"
11 | }
12 | },
13 | "COD10K": {
14 | "root": "./dataset/COD10K",
15 | "image": {
16 | "path": "./dataset/COD10K/Imgs",
17 | "suffix": ".jpg"
18 | },
19 | "mask": {
20 | "path": "./dataset/COD10K/GT",
21 | "suffix": ".png"
22 | }
23 | },
24 | "NC4K": {
25 | "root": "./dataset/NC4K",
26 | "image": {
27 | "path": "./dataset/NC4K/Imgs",
28 | "suffix": ".jpg"
29 | },
30 | "mask": {
31 | "path": "./dataset/NC4K/GT",
32 | "suffix": ".png"
33 | }
34 | }
35 | }
--------------------------------------------------------------------------------
/cos_eval_toolbox/examples_COS/config_cos_dataset_py_example.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | _COD_DATA_ROOT = "./dataset"
5 |
6 | CAMO = dict(
7 | root=os.path.join(_COD_DATA_ROOT, "CAMO"),
8 | image=dict(path=os.path.join(_COD_DATA_ROOT, "CAMO", "Imgs"), suffix=".jpg"),
9 | mask=dict(path=os.path.join(_COD_DATA_ROOT, "CAMO", "GT"), suffix=".png"),
10 | )
11 |
12 | COD10K = dict(
13 | root=os.path.join(_COD_DATA_ROOT, "COD10K"),
14 | image=dict(path=os.path.join(_COD_DATA_ROOT, "COD10K", "Imgs"), suffix=".jpg"),
15 | mask=dict(path=os.path.join(_COD_DATA_ROOT, "COD10K", "GT"), suffix=".png"),
16 | )
17 |
18 | NC4K = dict(
19 | root=os.path.join(_COD_DATA_ROOT, "NC4K"),
20 | image=dict(path=os.path.join(_COD_DATA_ROOT, "NC4K", "Imgs"), suffix=".jpg"),
21 | mask=dict(path=os.path.join(_COD_DATA_ROOT, "NC4K", "GT"), suffix=".png"),
22 | )
--------------------------------------------------------------------------------
/cos_eval_toolbox/examples_COS/config_cos_method_py_example.json:
--------------------------------------------------------------------------------
1 | {
2 | "PFNet": {
3 | "CAMO": {
4 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-PFNet/CAMO",
5 | "suffix": ".png"
6 | },
7 | "NC4K": {
8 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-PFNet/NC4K",
9 | "suffix": ".png"
10 | },
11 | "COD10K": {
12 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-PFNet/COD10K",
13 | "suffix": ".png"
14 | }
15 | },
16 | "C2FNetV2": {
17 | "CAMO": {
18 | "path": "./benchmark/COS-Benchmarking/2022-TCSVT-C2FNetV2/CAMO",
19 | "suffix": ".png"
20 | },
21 | "NC4K": {
22 | "path": "./benchmark/COS-Benchmarking/2022-TCSVT-C2FNetV2/NC4K",
23 | "suffix": ".png"
24 | },
25 | "COD10K": {
26 | "path": "./benchmark/COS-Benchmarking/2022-TCSVT-C2FNetV2/COD10K",
27 | "suffix": ".png"
28 | }
29 | },
30 | "CamoFormerS": {
31 | "CAMO": {
32 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerS/CAMO",
33 | "suffix": ".png"
34 | },
35 | "NC4K": {
36 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerS/NC4K",
37 | "suffix": ".png"
38 | },
39 | "COD10K": {
40 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerS/COD10K",
41 | "suffix": ".png"
42 | }
43 | },
44 | "SINet": {
45 | "CAMO": {
46 | "path": "./benchmark/COS-Benchmarking/2020-CVPR-SINet/CAMO",
47 | "suffix": ".png"
48 | },
49 | "NC4K": {
50 | "path": "./benchmark/COS-Benchmarking/2020-CVPR-SINet/NC4K",
51 | "suffix": ".png"
52 | },
53 | "COD10K": {
54 | "path": "./benchmark/COS-Benchmarking/2020-CVPR-SINet/COD10K",
55 | "suffix": ".png"
56 | }
57 | },
58 | "C2FNet": {
59 | "CAMO": {
60 | "path": "./benchmark/COS-Benchmarking/2021-IJCAI-C2FNet/CAMO",
61 | "suffix": ".png"
62 | },
63 | "NC4K": {
64 | "path": "./benchmark/COS-Benchmarking/2021-IJCAI-C2FNet/NC4K",
65 | "suffix": ".png"
66 | },
67 | "COD10K": {
68 | "path": "./benchmark/COS-Benchmarking/2021-IJCAI-C2FNet/COD10K",
69 | "suffix": ".png"
70 | }
71 | },
72 | "ZoomNet": {
73 | "CAMO": {
74 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-ZoomNet/CAMO",
75 | "suffix": ".png"
76 | },
77 | "NC4K": {
78 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-ZoomNet/NC4K",
79 | "suffix": ".png"
80 | },
81 | "COD10K": {
82 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-ZoomNet/COD10K",
83 | "suffix": ".png"
84 | }
85 | },
86 | "CamoFormerR": {
87 | "CAMO": {
88 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerR/CAMO",
89 | "suffix": ".png"
90 | },
91 | "NC4K": {
92 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerR/NC4K",
93 | "suffix": ".png"
94 | },
95 | "COD10K": {
96 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerR/COD10K",
97 | "suffix": ".png"
98 | }
99 | },
100 | "SMGL": {
101 | "CAMO": {
102 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-SMGL/CAMO",
103 | "suffix": ".png"
104 | },
105 | "NC4K": {
106 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-SMGL/NC4K",
107 | "suffix": ".png"
108 | },
109 | "COD10K": {
110 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-SMGL/COD10K",
111 | "suffix": ".png"
112 | }
113 | },
114 | "OCENet": {
115 | "CAMO": {
116 | "path": "./benchmark/COS-Benchmarking/2022-WACV-OCENet/CAMO",
117 | "suffix": ".png"
118 | },
119 | "NC4K": {
120 | "path": "./benchmark/COS-Benchmarking/2022-WACV-OCENet/NC4K",
121 | "suffix": ".png"
122 | },
123 | "COD10K": {
124 | "path": "./benchmark/COS-Benchmarking/2022-WACV-OCENet/COD10K",
125 | "suffix": ".png"
126 | }
127 | },
128 | "FAPNet": {
129 | "CAMO": {
130 | "path": "./benchmark/COS-Benchmarking/2022-TIP-FAPNet/CAMO",
131 | "suffix": ".png"
132 | },
133 | "NC4K": {
134 | "path": "./benchmark/COS-Benchmarking/2022-TIP-FAPNet/NC4K",
135 | "suffix": ".png"
136 | },
137 | "COD10K": {
138 | "path": "./benchmark/COS-Benchmarking/2022-TIP-FAPNet/COD10K",
139 | "suffix": ".png"
140 | }
141 | },
142 | "DTINet": {
143 | "CAMO": {
144 | "path": "./benchmark/COS-Benchmarking/2022-ICPR-DTINet/CAMO",
145 | "suffix": ".png"
146 | },
147 | "NC4K": {
148 | "path": "./benchmark/COS-Benchmarking/2022-ICPR-DTINet/NC4K",
149 | "suffix": ".png"
150 | },
151 | "COD10K": {
152 | "path": "./benchmark/COS-Benchmarking/2022-ICPR-DTINet/COD10K",
153 | "suffix": ".png"
154 | }
155 | },
156 | "RMGL": {
157 | "CAMO": {
158 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-RMGL/CAMO",
159 | "suffix": ".png"
160 | },
161 | "NC4K": {
162 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-RMGL/NC4K",
163 | "suffix": ".png"
164 | },
165 | "COD10K": {
166 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-RMGL/COD10K",
167 | "suffix": ".png"
168 | }
169 | },
170 | "TPRNet": {
171 | "CAMO": {
172 | "path": "./benchmark/COS-Benchmarking/2022-TVCJ-TPRNet/CAMO",
173 | "suffix": ".png"
174 | },
175 | "NC4K": {
176 | "path": "./benchmark/COS-Benchmarking/2022-TVCJ-TPRNet/NC4K",
177 | "suffix": ".png"
178 | },
179 | "COD10K": {
180 | "path": "./benchmark/COS-Benchmarking/2022-TVCJ-TPRNet/COD10K",
181 | "suffix": ".png"
182 | }
183 | },
184 | "BSANet": {
185 | "CAMO": {
186 | "path": "./benchmark/COS-Benchmarking/2022-AAAI-BSANet/CAMO",
187 | "suffix": ".png"
188 | },
189 | "NC4K": {
190 | "path": "./benchmark/COS-Benchmarking/2022-AAAI-BSANet/NC4K",
191 | "suffix": ".png"
192 | },
193 | "COD10K": {
194 | "path": "./benchmark/COS-Benchmarking/2022-AAAI-BSANet/COD10K",
195 | "suffix": ".png"
196 | }
197 | },
198 | "PopNet": {
199 | "CAMO": {
200 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-PopNet/CAMO",
201 | "suffix": ".png"
202 | },
203 | "NC4K": {
204 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-PopNet/NC4K",
205 | "suffix": ".png"
206 | },
207 | "COD10K": {
208 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-PopNet/COD10K",
209 | "suffix": ".png"
210 | }
211 | },
212 | "LSR": {
213 | "CAMO": {
214 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-LSR/CAMO",
215 | "suffix": ".png"
216 | },
217 | "NC4K": {
218 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-LSR/NC4K",
219 | "suffix": ".png"
220 | },
221 | "COD10K": {
222 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-LSR/COD10K",
223 | "suffix": ".png"
224 | }
225 | },
226 | "CubeNet": {
227 | "CAMO": {
228 | "path": "./benchmark/COS-Benchmarking/2022-PR-CubeNet/CAMO",
229 | "suffix": ".png"
230 | },
231 | "NC4K": null,
232 | "COD10K": {
233 | "path": "./benchmark/COS-Benchmarking/2022-PR-CubeNet/COD10K",
234 | "suffix": ".png"
235 | }
236 | },
237 | "JSCOD": {
238 | "CAMO": {
239 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-JSCOD/CAMO",
240 | "suffix": ".png"
241 | },
242 | "NC4K": {
243 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-JSCOD/NC4K",
244 | "suffix": ".png"
245 | },
246 | "COD10K": {
247 | "path": "./benchmark/COS-Benchmarking/2021-CVPR-JSCOD/COD10K",
248 | "suffix": ".png"
249 | }
250 | },
251 | "PFNetPlus": {
252 | "CAMO": {
253 | "path": "./benchmark/COS-Benchmarking/2023-SSCI-PFNetPlus/CAMO",
254 | "suffix": ".png"
255 | },
256 | "NC4K": null,
257 | "COD10K": {
258 | "path": "./benchmark/COS-Benchmarking/2023-SSCI-PFNetPlus/COD10K",
259 | "suffix": ".png"
260 | }
261 | },
262 | "BAS": {
263 | "CAMO": {
264 | "path": "./benchmark/COS-Benchmarking/2022-arXiv-BAS/CAMO",
265 | "suffix": ".png"
266 | },
267 | "NC4K": {
268 | "path": "./benchmark/COS-Benchmarking/2022-arXiv-BAS/NC4K",
269 | "suffix": ".png"
270 | },
271 | "COD10K": {
272 | "path": "./benchmark/COS-Benchmarking/2022-arXiv-BAS/COD10K",
273 | "suffix": ".png"
274 | }
275 | },
276 | "D2CNet": {
277 | "CAMO": {
278 | "path": "./benchmark/COS-Benchmarking/2021-TIE-D2CNet/CAMO",
279 | "suffix": ".png"
280 | },
281 | "NC4K": null,
282 | "COD10K": {
283 | "path": "./benchmark/COS-Benchmarking/2021-TIE-D2CNet/COD10K",
284 | "suffix": ".png"
285 | }
286 | },
287 | "PreyNet": {
288 | "CAMO": {
289 | "path": "./benchmark/COS-Benchmarking/2022-MM-PreyNet/CAMO",
290 | "suffix": ".png"
291 | },
292 | "NC4K": {
293 | "path": "./benchmark/COS-Benchmarking/2022-MM-PreyNet/NC4K",
294 | "suffix": ".png"
295 | },
296 | "COD10K": {
297 | "path": "./benchmark/COS-Benchmarking/2022-MM-PreyNet/COD10K",
298 | "suffix": ".png"
299 | }
300 | },
301 | "ERRNet": {
302 | "CAMO": {
303 | "path": "./benchmark/COS-Benchmarking/2022-PR-ERRNet/CAMO",
304 | "suffix": ".png"
305 | },
306 | "NC4K": {
307 | "path": "./benchmark/COS-Benchmarking/2022-PR-ERRNet/NC4K",
308 | "suffix": ".png"
309 | },
310 | "COD10K": {
311 | "path": "./benchmark/COS-Benchmarking/2022-PR-ERRNet/COD10K",
312 | "suffix": ".png"
313 | }
314 | },
315 | "SINetV2": {
316 | "CAMO": {
317 | "path": "./benchmark/COS-Benchmarking/2022-TPAMI-SINetV2/CAMO",
318 | "suffix": ".png"
319 | },
320 | "NC4K": {
321 | "path": "./benchmark/COS-Benchmarking/2022-TPAMI-SINetV2/NC4K",
322 | "suffix": ".png"
323 | },
324 | "COD10K": {
325 | "path": "./benchmark/COS-Benchmarking/2022-TPAMI-SINetV2/COD10K",
326 | "suffix": ".png"
327 | }
328 | },
329 | "SegMaR": {
330 | "CAMO": {
331 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-SegMaR/CAMO",
332 | "suffix": ".png"
333 | },
334 | "NC4K": {
335 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-SegMaR/NC4K",
336 | "suffix": ".png"
337 | },
338 | "COD10K": {
339 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-SegMaR/COD10K",
340 | "suffix": ".png"
341 | }
342 | },
343 | "HitNet": {
344 | "CAMO": {
345 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-HitNet/CAMO",
346 | "suffix": ".png"
347 | },
348 | "NC4K": {
349 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-HitNet/NC4K",
350 | "suffix": ".png"
351 | },
352 | "COD10K": {
353 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-HitNet/COD10K",
354 | "suffix": ".png"
355 | }
356 | },
357 | "TINet": {
358 | "CAMO": {
359 | "path": "./benchmark/COS-Benchmarking/2021-AAAI-TINet/CAMO",
360 | "suffix": ".png"
361 | },
362 | "NC4K": {
363 | "path": "./benchmark/COS-Benchmarking/2021-AAAI-TINet/NC4K",
364 | "suffix": ".png"
365 | },
366 | "COD10K": {
367 | "path": "./benchmark/COS-Benchmarking/2021-AAAI-TINet/COD10K",
368 | "suffix": ".png"
369 | }
370 | },
371 | "CRNet": {
372 | "CAMO": {
373 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-CRNet/CAMO",
374 | "suffix": ".png"
375 | },
376 | "NC4K": null,
377 | "COD10K": {
378 | "path": "./benchmark/COS-Benchmarking/2023-AAAI-CRNet/COD10K",
379 | "suffix": ".png"
380 | }
381 | },
382 | "BGNet": {
383 | "CAMO": {
384 | "path": "./benchmark/COS-Benchmarking/2022-IJCAI-BGNet/CAMO",
385 | "suffix": ".png"
386 | },
387 | "NC4K": {
388 | "path": "./benchmark/COS-Benchmarking/2022-IJCAI-BGNet/NC4K",
389 | "suffix": ".png"
390 | },
391 | "COD10K": {
392 | "path": "./benchmark/COS-Benchmarking/2022-IJCAI-BGNet/COD10K",
393 | "suffix": ".png"
394 | }
395 | },
396 | "UGTR": {
397 | "CAMO": {
398 | "path": "./benchmark/COS-Benchmarking/2021-ICCV-UGTR/CAMO",
399 | "suffix": ".png"
400 | },
401 | "NC4K": {
402 | "path": "./benchmark/COS-Benchmarking/2021-ICCV-UGTR/NC4K",
403 | "suffix": ".png"
404 | },
405 | "COD10K": {
406 | "path": "./benchmark/COS-Benchmarking/2021-ICCV-UGTR/COD10K",
407 | "suffix": ".png"
408 | }
409 | },
410 | "DGNet": {
411 | "CAMO": {
412 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNet/CAMO",
413 | "suffix": ".png"
414 | },
415 | "NC4K": {
416 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNet/NC4K",
417 | "suffix": ".png"
418 | },
419 | "COD10K": {
420 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNet/COD10K",
421 | "suffix": ".png"
422 | }
423 | },
424 | "CamoFormerP": {
425 | "CAMO": {
426 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerP/CAMO",
427 | "suffix": ".png"
428 | },
429 | "NC4K": {
430 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerP/NC4K",
431 | "suffix": ".png"
432 | },
433 | "COD10K": {
434 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerP/COD10K",
435 | "suffix": ".png"
436 | }
437 | },
438 | "NCHIT": {
439 | "CAMO": {
440 | "path": "./benchmark/COS-Benchmarking/2022-CVIU-NCHIT/CAMO",
441 | "suffix": ".png"
442 | },
443 | "NC4K": {
444 | "path": "./benchmark/COS-Benchmarking/2022-CVIU-NCHIT/NC4K",
445 | "suffix": ".png"
446 | },
447 | "COD10K": {
448 | "path": "./benchmark/COS-Benchmarking/2022-CVIU-NCHIT/COD10K",
449 | "suffix": ".png"
450 | }
451 | },
452 | "FDNet": {
453 | "CAMO": {
454 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-FDNet/CAMO",
455 | "suffix": ".png"
456 | },
457 | "NC4K": {
458 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-FDNet/NC4K",
459 | "suffix": ".png"
460 | },
461 | "COD10K": {
462 | "path": "./benchmark/COS-Benchmarking/2022-CVPR-FDNet/COD10K",
463 | "suffix": ".png"
464 | }
465 | },
466 | "DGNetS": {
467 | "CAMO": {
468 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNetS/CAMO",
469 | "suffix": ".png"
470 | },
471 | "NC4K": {
472 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNetS/NC4K",
473 | "suffix": ".png"
474 | },
475 | "COD10K": {
476 | "path": "./benchmark/COS-Benchmarking/2023-MIR-DGNetS/COD10K",
477 | "suffix": ".png"
478 | }
479 | },
480 | "CamoFormerC": {
481 | "CAMO": {
482 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerC/CAMO",
483 | "suffix": ".png"
484 | },
485 | "NC4K": {
486 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerC/NC4K",
487 | "suffix": ".png"
488 | },
489 | "COD10K": {
490 | "path": "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerC/COD10K",
491 | "suffix": ".png"
492 | }
493 | }
494 | }
--------------------------------------------------------------------------------
/cos_eval_toolbox/examples_COS/config_cos_method_py_example.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | PFNet_root = "./benchmark/COS-Benchmarking/2021-CVPR-PFNet"
5 | PFNet = {
6 | "CAMO": dict(path=os.path.join(PFNet_root, "CAMO"), suffix=".png"),
7 | "NC4K": dict(path=os.path.join(PFNet_root, "NC4K"), suffix=".png"),
8 | "COD10K": dict(path=os.path.join(PFNet_root, "COD10K"), suffix=".png"),
9 | }
10 |
11 | C2FNetV2_root = "./benchmark/COS-Benchmarking/2022-TCSVT-C2FNetV2"
12 | C2FNetV2 = {
13 | "CAMO": dict(path=os.path.join(C2FNetV2_root, "CAMO"), suffix=".png"),
14 | "NC4K": dict(path=os.path.join(C2FNetV2_root, "NC4K"), suffix=".png"),
15 | "COD10K": dict(path=os.path.join(C2FNetV2_root, "COD10K"), suffix=".png"),
16 | }
17 |
18 | CamoFormerS_root = "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerS"
19 | CamoFormerS = {
20 | "CAMO": dict(path=os.path.join(CamoFormerS_root, "CAMO"), suffix=".png"),
21 | "NC4K": dict(path=os.path.join(CamoFormerS_root, "NC4K"), suffix=".png"),
22 | "COD10K": dict(path=os.path.join(CamoFormerS_root, "COD10K"), suffix=".png"),
23 | }
24 |
25 | SINet_root = "./benchmark/COS-Benchmarking/2020-CVPR-SINet"
26 | SINet = {
27 | "CAMO": dict(path=os.path.join(SINet_root, "CAMO"), suffix=".png"),
28 | "NC4K": dict(path=os.path.join(SINet_root, "NC4K"), suffix=".png"),
29 | "COD10K": dict(path=os.path.join(SINet_root, "COD10K"), suffix=".png"),
30 | }
31 |
32 | C2FNet_root = "./benchmark/COS-Benchmarking/2021-IJCAI-C2FNet"
33 | C2FNet = {
34 | "CAMO": dict(path=os.path.join(C2FNet_root, "CAMO"), suffix=".png"),
35 | "NC4K": dict(path=os.path.join(C2FNet_root, "NC4K"), suffix=".png"),
36 | "COD10K": dict(path=os.path.join(C2FNet_root, "COD10K"), suffix=".png"),
37 | }
38 |
39 | ZoomNet_root = "./benchmark/COS-Benchmarking/2022-CVPR-ZoomNet"
40 | ZoomNet = {
41 | "CAMO": dict(path=os.path.join(ZoomNet_root, "CAMO"), suffix=".png"),
42 | "NC4K": dict(path=os.path.join(ZoomNet_root, "NC4K"), suffix=".png"),
43 | "COD10K": dict(path=os.path.join(ZoomNet_root, "COD10K"), suffix=".png"),
44 | }
45 |
46 | CamoFormerR_root = "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerR"
47 | CamoFormerR = {
48 | "CAMO": dict(path=os.path.join(CamoFormerR_root, "CAMO"), suffix=".png"),
49 | "NC4K": dict(path=os.path.join(CamoFormerR_root, "NC4K"), suffix=".png"),
50 | "COD10K": dict(path=os.path.join(CamoFormerR_root, "COD10K"), suffix=".png"),
51 | }
52 |
53 | SMGL_root = "./benchmark/COS-Benchmarking/2021-CVPR-SMGL"
54 | SMGL = {
55 | "CAMO": dict(path=os.path.join(SMGL_root, "CAMO"), suffix=".png"),
56 | "NC4K": dict(path=os.path.join(SMGL_root, "NC4K"), suffix=".png"),
57 | "COD10K": dict(path=os.path.join(SMGL_root, "COD10K"), suffix=".png"),
58 | }
59 |
60 | OCENet_root = "./benchmark/COS-Benchmarking/2022-WACV-OCENet"
61 | OCENet = {
62 | "CAMO": dict(path=os.path.join(OCENet_root, "CAMO"), suffix=".png"),
63 | "NC4K": dict(path=os.path.join(OCENet_root, "NC4K"), suffix=".png"),
64 | "COD10K": dict(path=os.path.join(OCENet_root, "COD10K"), suffix=".png"),
65 | }
66 |
67 | FAPNet_root = "./benchmark/COS-Benchmarking/2022-TIP-FAPNet"
68 | FAPNet = {
69 | "CAMO": dict(path=os.path.join(FAPNet_root, "CAMO"), suffix=".png"),
70 | "NC4K": dict(path=os.path.join(FAPNet_root, "NC4K"), suffix=".png"),
71 | "COD10K": dict(path=os.path.join(FAPNet_root, "COD10K"), suffix=".png"),
72 | }
73 |
74 | DTINet_root = "./benchmark/COS-Benchmarking/2022-ICPR-DTINet"
75 | DTINet = {
76 | "CAMO": dict(path=os.path.join(DTINet_root, "CAMO"), suffix=".png"),
77 | "NC4K": dict(path=os.path.join(DTINet_root, "NC4K"), suffix=".png"),
78 | "COD10K": dict(path=os.path.join(DTINet_root, "COD10K"), suffix=".png"),
79 | }
80 |
81 | RMGL_root = "./benchmark/COS-Benchmarking/2021-CVPR-RMGL"
82 | RMGL = {
83 | "CAMO": dict(path=os.path.join(RMGL_root, "CAMO"), suffix=".png"),
84 | "NC4K": dict(path=os.path.join(RMGL_root, "NC4K"), suffix=".png"),
85 | "COD10K": dict(path=os.path.join(RMGL_root, "COD10K"), suffix=".png"),
86 | }
87 |
88 | TPRNet_root = "./benchmark/COS-Benchmarking/2022-TVCJ-TPRNet"
89 | TPRNet = {
90 | "CAMO": dict(path=os.path.join(TPRNet_root, "CAMO"), suffix=".png"),
91 | "NC4K": dict(path=os.path.join(TPRNet_root, "NC4K"), suffix=".png"),
92 | "COD10K": dict(path=os.path.join(TPRNet_root, "COD10K"), suffix=".png"),
93 | }
94 |
95 | BSANet_root = "./benchmark/COS-Benchmarking/2022-AAAI-BSANet"
96 | BSANet = {
97 | "CAMO": dict(path=os.path.join(BSANet_root, "CAMO"), suffix=".png"),
98 | "NC4K": dict(path=os.path.join(BSANet_root, "NC4K"), suffix=".png"),
99 | "COD10K": dict(path=os.path.join(BSANet_root, "COD10K"), suffix=".png"),
100 | }
101 |
102 | PopNet_root = "./benchmark/COS-Benchmarking/2023-arXiv-PopNet"
103 | PopNet = {
104 | "CAMO": dict(path=os.path.join(PopNet_root, "CAMO"), suffix=".png"),
105 | "NC4K": dict(path=os.path.join(PopNet_root, "NC4K"), suffix=".png"),
106 | "COD10K": dict(path=os.path.join(PopNet_root, "COD10K"), suffix=".png"),
107 | }
108 |
109 | LSR_root = "./benchmark/COS-Benchmarking/2021-CVPR-LSR"
110 | LSR = {
111 | "CAMO": dict(path=os.path.join(LSR_root, "CAMO"), suffix=".png"),
112 | "NC4K": dict(path=os.path.join(LSR_root, "NC4K"), suffix=".png"),
113 | "COD10K": dict(path=os.path.join(LSR_root, "COD10K"), suffix=".png"),
114 | }
115 |
116 | CubeNet_root = "./benchmark/COS-Benchmarking/2022-PR-CubeNet"
117 | CubeNet = {
118 | "CAMO": dict(path=os.path.join(CubeNet_root, "CAMO"), suffix=".png"),
119 | "NC4K": None,
120 | "COD10K": dict(path=os.path.join(CubeNet_root, "COD10K"), suffix=".png"),
121 | }
122 |
123 | JSCOD_root = "./benchmark/COS-Benchmarking/2021-CVPR-JSCOD"
124 | JSCOD = {
125 | "CAMO": dict(path=os.path.join(JSCOD_root, "CAMO"), suffix=".png"),
126 | "NC4K": dict(path=os.path.join(JSCOD_root, "NC4K"), suffix=".png"),
127 | "COD10K": dict(path=os.path.join(JSCOD_root, "COD10K"), suffix=".png"),
128 | }
129 |
130 | PFNetPlus_root = "./benchmark/COS-Benchmarking/2023-SSCI-PFNetPlus"
131 | PFNetPlus = {
132 | "CAMO": dict(path=os.path.join(PFNetPlus_root, "CAMO"), suffix=".png"),
133 | "NC4K": None,
134 | "COD10K": dict(path=os.path.join(PFNetPlus_root, "COD10K"), suffix=".png"),
135 | }
136 |
137 | BAS_root = "./benchmark/COS-Benchmarking/2022-arXiv-BAS"
138 | BAS = {
139 | "CAMO": dict(path=os.path.join(BAS_root, "CAMO"), suffix=".png"),
140 | "NC4K": dict(path=os.path.join(BAS_root, "NC4K"), suffix=".png"),
141 | "COD10K": dict(path=os.path.join(BAS_root, "COD10K"), suffix=".png"),
142 | }
143 |
144 | D2CNet_root = "./benchmark/COS-Benchmarking/2021-TIE-D2CNet"
145 | D2CNet = {
146 | "CAMO": dict(path=os.path.join(D2CNet_root, "CAMO"), suffix=".png"),
147 | "NC4K": None,
148 | "COD10K": dict(path=os.path.join(D2CNet_root, "COD10K"), suffix=".png"),
149 | }
150 |
151 | PreyNet_root = "./benchmark/COS-Benchmarking/2022-MM-PreyNet"
152 | PreyNet = {
153 | "CAMO": dict(path=os.path.join(PreyNet_root, "CAMO"), suffix=".png"),
154 | "NC4K": dict(path=os.path.join(PreyNet_root, "NC4K"), suffix=".png"),
155 | "COD10K": dict(path=os.path.join(PreyNet_root, "COD10K"), suffix=".png"),
156 | }
157 |
158 | ERRNet_root = "./benchmark/COS-Benchmarking/2022-PR-ERRNet"
159 | ERRNet = {
160 | "CAMO": dict(path=os.path.join(ERRNet_root, "CAMO"), suffix=".png"),
161 | "NC4K": dict(path=os.path.join(ERRNet_root, "NC4K"), suffix=".png"),
162 | "COD10K": dict(path=os.path.join(ERRNet_root, "COD10K"), suffix=".png"),
163 | }
164 |
165 | SINetV2_root = "./benchmark/COS-Benchmarking/2022-TPAMI-SINetV2"
166 | SINetV2 = {
167 | "CAMO": dict(path=os.path.join(SINetV2_root, "CAMO"), suffix=".png"),
168 | "NC4K": dict(path=os.path.join(SINetV2_root, "NC4K"), suffix=".png"),
169 | "COD10K": dict(path=os.path.join(SINetV2_root, "COD10K"), suffix=".png"),
170 | }
171 |
172 | SegMaR_root = "./benchmark/COS-Benchmarking/2022-CVPR-SegMaR"
173 | SegMaR = {
174 | "CAMO": dict(path=os.path.join(SegMaR_root, "CAMO"), suffix=".png"),
175 | "NC4K": dict(path=os.path.join(SegMaR_root, "NC4K"), suffix=".png"),
176 | "COD10K": dict(path=os.path.join(SegMaR_root, "COD10K"), suffix=".png"),
177 | }
178 |
179 | HitNet_root = "./benchmark/COS-Benchmarking/2023-AAAI-HitNet"
180 | HitNet = {
181 | "CAMO": dict(path=os.path.join(HitNet_root, "CAMO"), suffix=".png"),
182 | "NC4K": dict(path=os.path.join(HitNet_root, "NC4K"), suffix=".png"),
183 | "COD10K": dict(path=os.path.join(HitNet_root, "COD10K"), suffix=".png"),
184 | }
185 |
186 | TINet_root = "./benchmark/COS-Benchmarking/2021-AAAI-TINet"
187 | TINet = {
188 | "CAMO": dict(path=os.path.join(TINet_root, "CAMO"), suffix=".png"),
189 | "NC4K": dict(path=os.path.join(TINet_root, "NC4K"), suffix=".png"),
190 | "COD10K": dict(path=os.path.join(TINet_root, "COD10K"), suffix=".png"),
191 | }
192 |
193 | CRNet_root = "./benchmark/COS-Benchmarking/2023-AAAI-CRNet"
194 | CRNet = {
195 | "CAMO": dict(path=os.path.join(CRNet_root, "CAMO"), suffix=".png"),
196 | "NC4K": None,
197 | "COD10K": dict(path=os.path.join(CRNet_root, "COD10K"), suffix=".png"),
198 | }
199 |
200 | BGNet_root = "./benchmark/COS-Benchmarking/2022-IJCAI-BGNet"
201 | BGNet = {
202 | "CAMO": dict(path=os.path.join(BGNet_root, "CAMO"), suffix=".png"),
203 | "NC4K": dict(path=os.path.join(BGNet_root, "NC4K"), suffix=".png"),
204 | "COD10K": dict(path=os.path.join(BGNet_root, "COD10K"), suffix=".png"),
205 | }
206 |
207 | UGTR_root = "./benchmark/COS-Benchmarking/2021-ICCV-UGTR"
208 | UGTR = {
209 | "CAMO": dict(path=os.path.join(UGTR_root, "CAMO"), suffix=".png"),
210 | "NC4K": dict(path=os.path.join(UGTR_root, "NC4K"), suffix=".png"),
211 | "COD10K": dict(path=os.path.join(UGTR_root, "COD10K"), suffix=".png"),
212 | }
213 |
214 | DGNet_root = "./benchmark/COS-Benchmarking/2023-MIR-DGNet"
215 | DGNet = {
216 | "CAMO": dict(path=os.path.join(DGNet_root, "CAMO"), suffix=".png"),
217 | "NC4K": dict(path=os.path.join(DGNet_root, "NC4K"), suffix=".png"),
218 | "COD10K": dict(path=os.path.join(DGNet_root, "COD10K"), suffix=".png"),
219 | }
220 |
221 | CamoFormerP_root = "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerP"
222 | CamoFormerP = {
223 | "CAMO": dict(path=os.path.join(CamoFormerP_root, "CAMO"), suffix=".png"),
224 | "NC4K": dict(path=os.path.join(CamoFormerP_root, "NC4K"), suffix=".png"),
225 | "COD10K": dict(path=os.path.join(CamoFormerP_root, "COD10K"), suffix=".png"),
226 | }
227 |
228 | NCHIT_root = "./benchmark/COS-Benchmarking/2022-CVIU-NCHIT"
229 | NCHIT = {
230 | "CAMO": dict(path=os.path.join(NCHIT_root, "CAMO"), suffix=".png"),
231 | "NC4K": dict(path=os.path.join(NCHIT_root, "NC4K"), suffix=".png"),
232 | "COD10K": dict(path=os.path.join(NCHIT_root, "COD10K"), suffix=".png"),
233 | }
234 |
235 | FDNet_root = "./benchmark/COS-Benchmarking/2022-CVPR-FDNet"
236 | FDNet = {
237 | "CAMO": dict(path=os.path.join(FDNet_root, "CAMO"), suffix=".png"),
238 | "NC4K": dict(path=os.path.join(FDNet_root, "NC4K"), suffix=".png"),
239 | "COD10K": dict(path=os.path.join(FDNet_root, "COD10K"), suffix=".png"),
240 | }
241 |
242 | DGNetS_root = "./benchmark/COS-Benchmarking/2023-MIR-DGNetS"
243 | DGNetS = {
244 | "CAMO": dict(path=os.path.join(DGNetS_root, "CAMO"), suffix=".png"),
245 | "NC4K": dict(path=os.path.join(DGNetS_root, "NC4K"), suffix=".png"),
246 | "COD10K": dict(path=os.path.join(DGNetS_root, "COD10K"), suffix=".png"),
247 | }
248 |
249 | CamoFormerC_root = "./benchmark/COS-Benchmarking/2023-arXiv-CamoFormerC"
250 | CamoFormerC = {
251 | "CAMO": dict(path=os.path.join(CamoFormerC_root, "CAMO"), suffix=".png"),
252 | "NC4K": dict(path=os.path.join(CamoFormerC_root, "NC4K"), suffix=".png"),
253 | "COD10K": dict(path=os.path.join(CamoFormerC_root, "COD10K"), suffix=".png"),
254 | }
255 |
256 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/cos_eval_toolbox/metrics/__init__.py
--------------------------------------------------------------------------------
/cos_eval_toolbox/metrics/cal_cosod_matrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from utils.misc import (
10 | colored_print,
11 | get_gt_pre_with_name,
12 | get_name_with_group_list,
13 | make_dir,
14 | )
15 | from utils.print_formatter import formatter_for_tabulate
16 | from utils.recorders import GroupedMetricRecorder, MetricExcelRecorder, TxtRecorder
17 |
18 |
19 | def group_names(names: list) -> dict:
20 | grouped_data = defaultdict(list)
21 | for name in names:
22 | group_name, file_name = name.split("/")
23 | grouped_data[group_name].append(file_name)
24 | return grouped_data
25 |
26 |
27 | def cal_cosod_matrics(
28 | data_type: str = "rgb_sod",
29 | txt_path: str = "",
30 | to_append: bool = True,
31 | xlsx_path: str = "",
32 | drawing_info: dict = None,
33 | dataset_info: dict = None,
34 | save_npy: bool = True,
35 | curves_npy_path: str = "./curves.npy",
36 | metrics_npy_path: str = "./metrics.npy",
37 | num_bits: int = 3,
38 | ):
39 | """
40 | Save the results of all models on different datasets in a `npy` file in the form of a
41 | dictionary.
42 |
43 | ::
44 |
45 | {
46 | dataset1:{
47 | method1:[fm, em, p, r],
48 | method2:[fm, em, p, r],
49 | .....
50 | },
51 | dataset2:{
52 | method1:[fm, em, p, r],
53 | method2:[fm, em, p, r],
54 | .....
55 | },
56 | ....
57 | }
58 |
59 | :param data_type: the type of data
60 | :param txt_path: the path of the txt for saving results
61 | :param to_append: whether to append results to the original record
62 | :param xlsx_path: the path of the xlsx file for saving results
63 | :param drawing_info: the method information for plotting figures
64 | :param dataset_info: the dataset information
65 | :param save_npy: whether to save results into npy files
66 | :param curves_npy_path: the npy file path for saving curve data
67 | :param metrics_npy_path: the npy file path for saving metric values
68 | :param num_bits: the number of bits used to format results
69 | """
70 | curves = defaultdict(dict) # Two curve metrics
71 | metrics = defaultdict(dict) # Six numerical metrics
72 |
73 | txt_recoder = TxtRecorder(
74 | txt_path=txt_path,
75 | to_append=to_append,
76 | max_method_name_width=max([len(x) for x in drawing_info.keys()]), # 显示完整名字
77 | )
78 | excel_recorder = MetricExcelRecorder(
79 | xlsx_path=xlsx_path,
80 | sheet_name=data_type,
81 | row_header=["methods"],
82 | dataset_names=sorted(list(dataset_info.keys())),
83 | metric_names=["sm", "wfm", "mae", "adpf", "avgf", "maxf", "adpe", "avge", "maxe"],
84 | )
85 |
86 | for dataset_name, dataset_path in dataset_info.items():
87 | txt_recoder.add_row(row_name="Dataset", row_data=dataset_name, row_start_str="\n")
88 |
89 | # 获取真值图片信息
90 | gt_info = dataset_path["mask"]
91 | gt_root = gt_info["path"]
92 | gt_ext = gt_info["suffix"]
93 | # 真值名字列表
94 | gt_index_file = dataset_path.get("index_file")
95 | if gt_index_file:
96 | gt_name_list = get_name_with_group_list(data_path=gt_index_file, file_ext=gt_ext)
97 | else:
98 | gt_name_list = get_name_with_group_list(data_path=gt_root, file_ext=gt_ext)
99 | assert len(gt_name_list) > 0, "there is not ground truth."
100 |
101 | # ==>> test the intersection between pre and gt for each method <<==
102 | for method_name, method_info in drawing_info.items():
103 | method_root = method_info["path_dict"]
104 | method_dataset_info = method_root.get(dataset_name, None)
105 | if method_dataset_info is None:
106 | colored_print(
107 | msg=f"{method_name} does not have results on {dataset_name}", mode="warning"
108 | )
109 | continue
110 |
111 | # 预测结果存放路径下的图片文件名字列表和扩展名称
112 | pre_ext = method_dataset_info["suffix"]
113 | pre_root = method_dataset_info["path"]
114 | pre_name_list = get_name_with_group_list(data_path=pre_root, file_ext=pre_ext)
115 |
116 | # get the intersection
117 | eval_name_list = sorted(list(set(gt_name_list).intersection(set(pre_name_list))))
118 | num_names = len(eval_name_list)
119 |
120 | if num_names == 0:
121 | colored_print(
122 | msg=f"{method_name} does not have results on {dataset_name}", mode="warning"
123 | )
124 | continue
125 |
126 | grouped_data = group_names(names=eval_name_list)
127 | num_groups = len(grouped_data)
128 |
129 | colored_print(
130 | f"Evaluating {method_name} with {num_names} images and {num_groups} groups"
131 | f" (G:{len(gt_name_list)},P:{len(pre_name_list)}) images on dataset {dataset_name}"
132 | )
133 |
134 | group_metric_recorder = GroupedMetricRecorder()
135 | inter_group_bar = tqdm(
136 | grouped_data.items(),
137 | total=num_groups,
138 | leave=False,
139 | ncols=79,
140 | desc=f"[{dataset_name}]",
141 | )
142 | for group_name, names_in_group in inter_group_bar:
143 | intra_group_bar = tqdm(
144 | names_in_group,
145 | total=len(names_in_group),
146 | leave=False,
147 | ncols=79,
148 | desc=f"({group_name})",
149 | )
150 | for img_name in intra_group_bar:
151 | img_name_with_group = os.path.join(group_name, img_name)
152 | gt, pre = get_gt_pre_with_name(
153 | gt_root=gt_root,
154 | pre_root=pre_root,
155 | img_name=img_name_with_group,
156 | pre_ext=pre_ext,
157 | gt_ext=gt_ext,
158 | to_normalize=False,
159 | )
160 | group_metric_recorder.update(group_name=group_name, pre=pre, gt=gt)
161 | method_results = group_metric_recorder.show(num_bits=num_bits, return_ndarray=False)
162 | method_curves = method_results["sequential"]
163 | method_metrics = method_results["numerical"]
164 | curves[dataset_name][method_name] = method_curves
165 | metrics[dataset_name][method_name] = method_metrics
166 |
167 | excel_recorder(
168 | row_data=method_metrics, dataset_name=dataset_name, method_name=method_name
169 | )
170 | txt_recoder(method_results=method_metrics, method_name=method_name)
171 |
172 | if save_npy:
173 | make_dir(os.path.basename(curves_npy_path))
174 | np.save(curves_npy_path, curves)
175 | np.save(metrics_npy_path, metrics)
176 | colored_print(f"all methods have been saved in {curves_npy_path} and {metrics_npy_path}")
177 | formatted_string = formatter_for_tabulate(metrics)
178 | colored_print(f"all methods have been tested:\n{formatted_string}")
179 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/metrics/cal_sod_matrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | from collections import defaultdict
5 | from functools import partial
6 | from multiprocessing import RLock, pool
7 |
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from utils.misc import get_gt_pre_with_name, get_name_list, make_dir
12 | from utils.print_formatter import formatter_for_tabulate
13 | from utils.recorders import MetricExcelRecorder, MetricRecorder_V2, TxtRecorder
14 |
15 |
16 | class Recorder:
17 | def __init__(
18 | self,
19 | txt_path,
20 | to_append,
21 | max_method_name_width,
22 | xlsx_path,
23 | sheet_name,
24 | dataset_names,
25 | metric_names,
26 | ):
27 | self.curves = defaultdict(dict) # Two curve metrics
28 | self.metrics = defaultdict(dict) # Six numerical metrics
29 | self.dataset_name = None
30 |
31 | self.txt_recorder = None
32 | if txt_path:
33 | self.txt_recorder = TxtRecorder(
34 | txt_path=txt_path,
35 | to_append=to_append,
36 | max_method_name_width=max_method_name_width,
37 | )
38 |
39 | self.excel_recorder = None
40 | if xlsx_path:
41 | self.excel_recorder = MetricExcelRecorder(
42 | xlsx_path=xlsx_path,
43 | sheet_name=sheet_name,
44 | row_header=["methods"],
45 | dataset_names=dataset_names,
46 | metric_names=metric_names,
47 | )
48 |
49 | def record_dataset_name(self, dataset_name):
50 | self.dataset_name = dataset_name
51 | if self.txt_recorder:
52 | self.txt_recorder.add_row(
53 | row_name="Dataset", row_data=dataset_name, row_start_str="\n"
54 | )
55 |
56 | def record(self, method_results, method_name):
57 | method_curves = method_results["sequential"]
58 | method_metrics = method_results["numerical"]
59 |
60 | self.curves[self.dataset_name][method_name] = method_curves
61 | self.metrics[self.dataset_name][method_name] = method_metrics
62 |
63 | if self.excel_recorder:
64 | self.excel_recorder(
65 | row_data=method_metrics, dataset_name=self.dataset_name, method_name=method_name
66 | )
67 | if self.txt_recorder:
68 | self.txt_recorder(method_results=method_metrics, method_name=method_name)
69 |
70 |
71 | def cal_sod_matrics(
72 | sheet_name: str = "results",
73 | txt_path: str = "",
74 | to_append: bool = True,
75 | xlsx_path: str = "",
76 | methods_info: dict = None,
77 | datasets_info: dict = None,
78 | curves_npy_path: str = "./curves.npy",
79 | metrics_npy_path: str = "./metrics.npy",
80 | num_bits: int = 3,
81 | num_workers: int = 2,
82 | use_mp: bool = False,
83 | metric_names: tuple = ("mae", "fm", "em", "sm", "wfm"),
84 | ncols_tqdm: int = 79,
85 | ):
86 | """
87 | Save the results of all models on different datasets in a `npy` file in the form of a
88 | dictionary.
89 |
90 | ::
91 |
92 | {
93 | dataset1:{
94 | method1:[fm, em, p, r],
95 | method2:[fm, em, p, r],
96 | .....
97 | },
98 | dataset2:{
99 | method1:[fm, em, p, r],
100 | method2:[fm, em, p, r],
101 | .....
102 | },
103 | ....
104 | }
105 |
106 | :param sheet_name: the type of the sheet in xlsx file
107 | :param txt_path: the path of the txt for saving results
108 | :param to_append: whether to append results to the original record
109 | :param xlsx_path: the path of the xlsx file for saving results
110 | :param methods_info: the method information
111 | :param datasets_info: the dataset information
112 | :param curves_npy_path: the npy file path for saving curve data
113 | :param metrics_npy_path: the npy file path for saving metric values
114 | :param num_bits: the number of bits used to format results
115 | :param num_workers: the number of workers of multiprocessing or multithreading
116 | :param use_mp: using multiprocessing or multithreading
117 | :param metric_names: names of metrics
118 | :param ncols_tqdm: number of columns for tqdm
119 | """
120 | recorder = Recorder(
121 | txt_path=txt_path,
122 | to_append=to_append,
123 | max_method_name_width=max([len(x) for x in methods_info.keys()]), # 显示完整名字
124 | xlsx_path=xlsx_path,
125 | sheet_name=sheet_name,
126 | dataset_names=sorted(datasets_info.keys()),
127 | metric_names=["sm", "wfm", "mae", "adpf", "avgf", "maxf", "adpe", "avge", "maxe"],
128 | )
129 |
130 | for dataset_name, dataset_path in datasets_info.items():
131 | recorder.record_dataset_name(dataset_name)
132 |
133 | # 获取真值图片信息
134 | gt_info = dataset_path["mask"]
135 | gt_root = gt_info["path"]
136 | gt_ext = gt_info["suffix"]
137 | # 真值名字列表
138 | gt_index_file = dataset_path.get("index_file")
139 | if gt_index_file:
140 | gt_name_list = get_name_list(data_path=gt_index_file, name_suffix=gt_ext)
141 | else:
142 | gt_name_list = get_name_list(data_path=gt_root, name_suffix=gt_ext)
143 | assert len(gt_name_list) > 0, "there is not ground truth."
144 |
145 | # ==>> test the intersection between pre and gt for each method <<==
146 | tqdm.set_lock(RLock())
147 | pool_cls = pool.Pool if use_mp else pool.ThreadPool
148 | procs = pool_cls(
149 | processes=num_workers, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),)
150 | )
151 | procs_idx = 0
152 | for method_name, method_info in methods_info.items():
153 | method_root = method_info["path_dict"]
154 | method_dataset_info = method_root.get(dataset_name, None)
155 | if method_dataset_info is None:
156 | tqdm.write(f"{method_name} does not have results on {dataset_name}")
157 | continue
158 |
159 | # 预测结果存放路径下的图片文件名字列表和扩展名称
160 | pre_prefix = method_dataset_info.get("prefix", "")
161 | pre_suffix = method_dataset_info["suffix"]
162 | pre_root = method_dataset_info["path"]
163 | pre_name_list = get_name_list(
164 | data_path=pre_root,
165 | name_prefix=pre_prefix,
166 | name_suffix=pre_suffix,
167 | )
168 |
169 | # get the intersection
170 | eval_name_list = sorted(list(set(gt_name_list).intersection(pre_name_list)))
171 | if len(eval_name_list) == 0:
172 | tqdm.write(f"{method_name} does not have results on {dataset_name}")
173 | continue
174 |
175 | procs.apply_async(
176 | func=evaluate_data,
177 | kwds=dict(
178 | names=eval_name_list,
179 | num_bits=num_bits,
180 | pre_root=pre_root,
181 | pre_prefix=pre_prefix,
182 | pre_suffix=pre_suffix,
183 | gt_root=gt_root,
184 | gt_ext=gt_ext,
185 | desc=f"Dataset: {dataset_name} ({len(gt_name_list)}) | Method: {method_name} ({len(pre_name_list)})",
186 | proc_idx=procs_idx,
187 | blocking=use_mp,
188 | metric_names=metric_names,
189 | ncols_tqdm=ncols_tqdm,
190 | ),
191 | callback=partial(recorder.record, method_name=method_name),
192 | )
193 | procs_idx += 1
194 | procs.close()
195 | procs.join()
196 |
197 | if curves_npy_path:
198 | make_dir(os.path.dirname(curves_npy_path))
199 | np.save(curves_npy_path, recorder.curves)
200 | print(f"All curves has been saved in {curves_npy_path}")
201 | if metrics_npy_path:
202 | make_dir(os.path.dirname(metrics_npy_path))
203 | np.save(metrics_npy_path, recorder.metrics)
204 | print(f"All metrics has been saved in {metrics_npy_path}")
205 | formatted_string = formatter_for_tabulate(recorder.metrics)
206 | print(f"All methods have been evaluated:\n{formatted_string}")
207 |
208 |
209 | def evaluate_data(
210 | names,
211 | num_bits,
212 | gt_root,
213 | gt_ext,
214 | pre_root,
215 | pre_prefix,
216 | pre_suffix,
217 | desc="",
218 | proc_idx=None,
219 | blocking=True,
220 | metric_names=None,
221 | ncols_tqdm=79,
222 | ):
223 | metric_recoder = MetricRecorder_V2(metric_names=metric_names)
224 | # https://github.com/tqdm/tqdm#parameters
225 | # https://github.com/tqdm/tqdm/blob/master/examples/parallel_bars.py
226 | tqdm_bar = tqdm(
227 | names,
228 | total=len(names),
229 | desc=desc,
230 | position=proc_idx,
231 | ncols=ncols_tqdm,
232 | lock_args=None if blocking else (False,),
233 | )
234 | for name in tqdm_bar:
235 | gt, pre = get_gt_pre_with_name(
236 | gt_root=gt_root,
237 | pre_root=pre_root,
238 | img_name=name,
239 | pre_prefix=pre_prefix,
240 | pre_suffix=pre_suffix,
241 | gt_ext=gt_ext,
242 | to_normalize=False,
243 | )
244 | metric_recoder.update(pre=pre, gt=gt)
245 | method_results = metric_recoder.show(num_bits=num_bits, return_ndarray=False)
246 | return method_results
247 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/metrics/draw_curves.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from collections import OrderedDict
4 |
5 | import numpy as np
6 | from matplotlib import colors
7 |
8 | from utils.recorders import CurveDrawer
9 |
10 |
11 | def draw_curves(
12 | for_pr: bool = True,
13 | axes_setting: dict = None,
14 | curves_npy_path: list = None,
15 | row_num: int = 1,
16 | our_methods: list = None,
17 | method_aliases: OrderedDict = None,
18 | dataset_aliases: OrderedDict = None,
19 | style_cfg: dict = None,
20 | ncol_of_legend: int = 1,
21 | separated_legend: bool = False,
22 | sharey: bool = False,
23 | line_styles=("-", "--"),
24 | line_width=3,
25 | save_name=None,
26 | ):
27 | """A better curve painter!
28 |
29 | Args:
30 | for_pr (bool, optional): Plot for PR curves or FM curves. Defaults to True.
31 | axes_setting (dict, optional): Setting for axes. Defaults to None.
32 | curves_npy_path (list, optional): Paths of curve npy files. Defaults to None.
33 | row_num (int, optional): Number of rows. Defaults to 1.
34 | our_methods (list, optional): Names of our methods. Defaults to None.
35 | method_aliases (OrderedDict, optional): Aliases of methods. Defaults to None.
36 | dataset_aliases (OrderedDict, optional): Aliases of datasets. Defaults to None.
37 | style_cfg (dict, optional): Config file for the style of matplotlib. Defaults to None.
38 | ncol_of_legend (int, optional): Number of columns for the legend. Defaults to 1.
39 | separated_legend (bool, optional): Use the separated legend. Defaults to False.
40 | sharey (bool, optional): Use a shared y-axis. Defaults to False.
41 | line_styles (tuple, optional): Styles of lines. Defaults to ("-", "--").
42 | line_width (int, optional): Width of lines. Defaults to 3.
43 | save_name (str, optional): Name or path (without the extension format). Defaults to None.
44 | """
45 | mode = "pr" if for_pr else "fm"
46 | save_name = save_name or mode
47 | mode_axes_setting = axes_setting[mode]
48 |
49 | x_label, y_label = mode_axes_setting["x_label"], mode_axes_setting["y_label"]
50 | x_ticks, y_ticks = mode_axes_setting["x_ticks"], mode_axes_setting["y_ticks"]
51 |
52 | assert curves_npy_path
53 | if not isinstance(curves_npy_path, (list, tuple)):
54 | curves_npy_path = [curves_npy_path]
55 |
56 | curves = {}
57 | unique_method_names_from_npy = []
58 | for p in curves_npy_path:
59 | single_curves = np.load(p, allow_pickle=True).item()
60 | for dataset_name, method_infos in single_curves.items():
61 | curves.setdefault(dataset_name, {})
62 | for method_name, method_info in method_infos.items():
63 | curves[dataset_name][method_name] = method_info
64 | if method_name not in unique_method_names_from_npy:
65 | unique_method_names_from_npy.append(method_name)
66 | dataset_names_from_npy = list(curves.keys())
67 |
68 | if dataset_aliases is None:
69 | dataset_aliases = {k: k for k in dataset_names_from_npy}
70 | else:
71 | for x in dataset_aliases.keys():
72 | if x not in dataset_names_from_npy:
73 | raise ValueError(f"{x} must be contained in\n{dataset_names_from_npy}")
74 |
75 | if method_aliases is not None:
76 | target_unique_method_names = list(method_aliases.keys())
77 | for x in target_unique_method_names:
78 | if x not in unique_method_names_from_npy:
79 | raise ValueError(
80 | f"{x} must be contained in\n{sorted(unique_method_names_from_npy)}"
81 | )
82 | else:
83 | target_unique_method_names = unique_method_names_from_npy
84 |
85 | if our_methods is not None:
86 | for x in our_methods:
87 | if x not in target_unique_method_names:
88 | raise ValueError(f"{x} must be contained in\n{target_unique_method_names}")
89 | assert len(our_methods) <= len(line_styles)
90 | else:
91 | our_methods = []
92 |
93 | # Give each method a unique color.
94 | color_table = sorted(
95 | [
96 | color
97 | for name, color in colors.cnames.items()
98 | if name not in ["red", "white"] or not name.startswith("light") or "gray" in name
99 | ]
100 | )
101 | unique_method_settings = {}
102 | for i, method_name in enumerate(target_unique_method_names):
103 | if method_name in our_methods:
104 | line_color = "red"
105 | line_style = line_styles[our_methods.index(method_name)]
106 | else:
107 | line_color = color_table[i]
108 | line_style = line_styles[i % 2]
109 | line_label = (
110 | method_name if method_aliases is None else method_aliases.get(method_name, method_name)
111 | )
112 |
113 | unique_method_settings[method_name] = {
114 | "line_color": line_color,
115 | "line_style": line_style,
116 | "line_label": line_label,
117 | "line_width": line_width,
118 | }
119 |
120 | curve_drawer = CurveDrawer(
121 | row_num=row_num,
122 | num_subplots=len(dataset_aliases),
123 | style_cfg=style_cfg,
124 | ncol_of_legend=ncol_of_legend,
125 | separated_legend=separated_legend,
126 | sharey=sharey,
127 | )
128 |
129 | for idx, (dataset_name, dataset_alias) in enumerate(dataset_aliases.items()):
130 | dataset_results = curves[dataset_name]
131 | curve_drawer.set_axis_property(
132 | idx=idx,
133 | title=dataset_alias.upper(),
134 | x_label=x_label,
135 | y_label=y_label,
136 | x_ticks=x_ticks,
137 | y_ticks=y_ticks,
138 | )
139 |
140 | for method_name, method_setting in unique_method_settings.items():
141 | if method_name not in dataset_results:
142 | raise KeyError(f"{method_name} not in {sorted(dataset_results.keys())}")
143 | method_results = dataset_results[method_name]
144 | if mode == "pr":
145 | assert isinstance(method_results["p"], (list, tuple))
146 | assert isinstance(method_results["r"], (list, tuple))
147 | y_data = method_results["p"]
148 | x_data = method_results["r"]
149 | else:
150 | assert isinstance(method_results["fm"], (list, tuple))
151 | y_data = method_results["fm"]
152 | x_data = np.linspace(0, 1, 256)
153 |
154 | curve_drawer.plot_at_axis(
155 | idx=idx, method_curve_setting=method_setting, x_data=x_data, y_data=y_data
156 | )
157 | curve_drawer.save(path=save_name)
158 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/metrics/extra_metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | from py_sod_metrics.sod_metrics import _TYPE, _prepare_data
4 |
5 |
6 | class ExtraSegMeasure(object):
7 | def __init__(self):
8 | self.precisions = []
9 | self.recalls = []
10 | self.specificities = []
11 | self.dices = []
12 | self.fmeasures = []
13 | self.ious = []
14 |
15 | def step(self, pred: np.ndarray, gt: np.ndarray):
16 | pred, gt = _prepare_data(pred, gt)
17 |
18 | precisions, recalls, specificities, dices, fmeasures, ious = self.cal_metrics(
19 | pred=pred, gt=gt
20 | )
21 |
22 | self.precisions.append(precisions)
23 | self.recalls.append(recalls)
24 | self.specificities.append(specificities)
25 | self.dices.append(dices)
26 | self.fmeasures.append(fmeasures)
27 | self.ious.append(ious)
28 |
29 | def cal_metrics(self, pred: np.ndarray, gt: np.ndarray) -> tuple:
30 | """
31 | Calculate the corresponding precision and recall when the threshold changes from 0 to 255.
32 |
33 | These precisions and recalls can be used to obtain the mean F-measure, maximum F-measure,
34 | precision-recall curve and F-measure-threshold curve.
35 |
36 | For convenience, ``changeable_fms`` is provided here, which can be used directly to obtain
37 | the mean F-measure, maximum F-measure and F-measure-threshold curve.
38 |
39 | IoU = NumAnd / (FN + NumRec)
40 | PreFtem = NumAnd / NumRec
41 | RecallFtem = NumAnd / num_obj
42 | SpecifTem = TN / (TN + FP)
43 | Dice = 2 * NumAnd / (num_obj + num_pred)
44 | FmeasureF = (2.0 * PreFtem * RecallFtem) / (PreFtem + RecallFtem)
45 |
46 | :return: precisions, recalls, specificities, dices, fmeasures, ious
47 | """
48 | # 1. 获取预测结果在真值前背景区域中的直方图
49 | pred: np.ndarray = (pred * 255).astype(np.uint8)
50 | bins: np.ndarray = np.linspace(0, 256, 257)
51 | tp_hist, _ = np.histogram(pred[gt], bins=bins) # 最后一个bin为[255, 256]
52 | fp_hist, _ = np.histogram(pred[~gt], bins=bins)
53 | # 2. 使用累积直方图(Cumulative Histogram)获得对应真值前背景中大于不同阈值的像素数量
54 | # 这里使用累加(cumsum)就是为了一次性得出 >=不同阈值 的像素数量, 这里仅计算了前景区域
55 | tp_w_thrs = np.cumsum(np.flip(tp_hist), axis=0)
56 | fp_w_thrs = np.cumsum(np.flip(fp_hist), axis=0)
57 | # 3. 使用不同阈值的结果计算对应的precision和recall
58 | # p和r的计算的真值是pred==1>==1,二者仅有分母不同,分母前者是pred==1,后者是gt==1
59 | # 为了同时计算不同阈值的结果,这里使用hsitogram&flip&cumsum 获得了不同各自的前景像素数量
60 | TPs = tp_w_thrs
61 | FPs = fp_w_thrs
62 | T = np.count_nonzero(gt) # T=TPs+FNs
63 | FNs = T - TPs
64 | Ps = TPs + FPs # p_w_thrs
65 | Ns = pred.size - Ps
66 | TNs = Ns - FNs
67 |
68 | ious = np.where(Ps + FNs == 0, 0, TPs / (Ps + FNs))
69 | specificities = np.where(TNs + FPs == 0, 0, TNs / (TNs + FPs))
70 | dices = np.where(TPs + FPs == 0, 0, 2 * TPs / (T + Ps))
71 | precisions = np.where(Ps == 0, 0, TPs / Ps)
72 | recalls = np.where(TPs == 0, 0, TPs / T)
73 | fmeasures = np.where(
74 | precisions + recalls == 0, 0, (2 * precisions * recalls) / (precisions + recalls)
75 | )
76 | return precisions, recalls, specificities, dices, fmeasures, ious
77 |
78 | def get_results(self) -> dict:
79 | """
80 | Return the results about F-measure.
81 |
82 | :return: dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), pr=dict(p=precision, r=recall))
83 | """
84 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0)
85 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0)
86 | specificitiy = np.mean(np.array(self.specificities, dtype=_TYPE), axis=0)
87 | fmeasure = np.mean(np.array(self.fmeasures, dtype=_TYPE), axis=0)
88 | dice = np.mean(np.array(self.dices, dtype=_TYPE), axis=0)
89 | iou = np.mean(np.array(self.ious, dtype=_TYPE), axis=0)
90 | return dict(pre=precision, sen=recall, spec=specificitiy, fm=fmeasure, dice=dice, iou=iou)
91 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/output_COS/cos_curves.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/cos_eval_toolbox/output_COS/cos_curves.npy
--------------------------------------------------------------------------------
/cos_eval_toolbox/output_COS/cos_metrics.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/cos_eval_toolbox/output_COS/cos_metrics.npy
--------------------------------------------------------------------------------
/cos_eval_toolbox/output_COS/cos_results.txt:
--------------------------------------------------------------------------------
1 |
2 | ========>> Date: 2023-04-22 08:07:14.980489 <<========
3 |
4 | ========>> Dataset: CAMO <<========
5 | [TINet ] mae: 0.087 maxf: 0.745 avgf: 0.728 adpf: 0.729 maxe: 0.848 avge: 0.836 adpe: 0.847 sm: 0.781 wfm: 0.678
6 | [PreyNet ] mae: 0.077 maxf: 0.765 avgf: 0.757 adpf: 0.763 maxe: 0.857 avge: 0.842 adpe: 0.856 sm: 0.79 wfm: 0.708
7 | [TPRNet ] mae: 0.074 maxf: 0.785 avgf: 0.772 adpf: 0.777 maxe: 0.883 avge: 0.861 adpe: 0.88 sm: 0.807 wfm: 0.725
8 | [ERRNet ] mae: 0.085 maxf: 0.742 avgf: 0.729 adpf: 0.731 maxe: 0.858 avge: 0.842 adpe: 0.855 sm: 0.779 wfm: 0.679
9 | [PopNet ] mae: 0.077 maxf: 0.792 avgf: 0.784 adpf: 0.79 maxe: 0.874 avge: 0.859 adpe: 0.871 sm: 0.808 wfm: 0.744
10 | [FAPNet ] mae: 0.076 maxf: 0.792 avgf: 0.776 adpf: 0.776 maxe: 0.88 avge: 0.865 adpe: 0.877 sm: 0.815 wfm: 0.734
11 | [SMGL ] mae: 0.089 maxf: 0.739 avgf: 0.721 adpf: 0.733 maxe: 0.842 avge: 0.807 adpe: 0.85 sm: 0.772 wfm: 0.664
12 | [UGTR ] mae: 0.086 maxf: 0.754 avgf: 0.738 adpf: 0.749 maxe: 0.854 avge: 0.823 adpe: 0.861 sm: 0.785 wfm: 0.686
13 | [SINet ] mae: 0.092 maxf: 0.708 avgf: 0.702 adpf: 0.712 maxe: 0.829 avge: 0.804 adpe: 0.825 sm: 0.745 wfm: 0.644
14 | [RMGL ] mae: 0.088 maxf: 0.74 avgf: 0.726 adpf: 0.738 maxe: 0.842 avge: 0.812 adpe: 0.848 sm: 0.775 wfm: 0.673
15 | [PFNetPlus ] mae: 0.08 maxf: 0.77 avgf: 0.761 adpf: 0.764 maxe: 0.865 avge: 0.85 adpe: 0.862 sm: 0.791 wfm: 0.713
16 | [CRNet ] mae: 0.092 maxf: 0.707 avgf: 0.701 adpf: 0.709 maxe: 0.83 avge: 0.815 adpe: 0.829 sm: 0.735 wfm: 0.641
17 | [C2FNet ] mae: 0.08 maxf: 0.771 avgf: 0.762 adpf: 0.764 maxe: 0.864 avge: 0.854 adpe: 0.865 sm: 0.796 wfm: 0.719
18 | [NCHIT ] mae: 0.088 maxf: 0.739 avgf: 0.707 adpf: 0.723 maxe: 0.84 avge: 0.805 adpe: 0.841 sm: 0.784 wfm: 0.652
19 | [CamoFormerR] mae: 0.076 maxf: 0.813 avgf: 0.745 adpf: 0.735 maxe: 0.916 avge: 0.874 adpe: 0.863 sm: 0.816 wfm: 0.712
20 | [SegMaR ] mae: 0.071 maxf: 0.803 avgf: 0.795 adpf: 0.795 maxe: 0.884 avge: 0.874 adpe: 0.881 sm: 0.815 wfm: 0.753
21 | [DTINet ] mae: 0.05 maxf: 0.843 avgf: 0.823 adpf: 0.821 maxe: 0.927 avge: 0.916 adpe: 0.918 sm: 0.856 wfm: 0.796
22 | [CamoFormerC] mae: 0.05 maxf: 0.855 avgf: 0.842 adpf: 0.842 maxe: 0.92 avge: 0.913 adpe: 0.919 sm: 0.859 wfm: 0.812
23 | [CamoFormerP] mae: 0.046 maxf: 0.868 avgf: 0.854 adpf: 0.853 maxe: 0.938 avge: 0.929 adpe: 0.931 sm: 0.872 wfm: 0.831
24 | [CubeNet ] mae: 0.085 maxf: 0.75 avgf: 0.732 adpf: 0.734 maxe: 0.86 avge: 0.838 adpe: 0.852 sm: 0.788 wfm: 0.682
25 | [DGNetS ] mae: 0.063 maxf: 0.81 avgf: 0.792 adpf: 0.786 maxe: 0.907 avge: 0.893 adpe: 0.896 sm: 0.826 wfm: 0.754
26 | [HitNet ] mae: 0.055 maxf: 0.838 avgf: 0.831 adpf: 0.833 maxe: 0.91 avge: 0.906 adpe: 0.91 sm: 0.849 wfm: 0.809
27 | [LSR ] mae: 0.08 maxf: 0.753 avgf: 0.744 adpf: 0.756 maxe: 0.854 avge: 0.838 adpe: 0.859 sm: 0.787 wfm: 0.696
28 | [DGNet ] mae: 0.057 maxf: 0.822 avgf: 0.806 adpf: 0.804 maxe: 0.915 avge: 0.901 adpe: 0.906 sm: 0.839 wfm: 0.769
29 | [BAS ] mae: 0.096 maxf: 0.703 avgf: 0.692 adpf: 0.696 maxe: 0.808 avge: 0.796 adpe: 0.808 sm: 0.749 wfm: 0.646
30 | [FDNet ] mae: 0.063 maxf: 0.826 avgf: 0.807 adpf: 0.803 maxe: 0.908 avge: 0.895 adpe: 0.901 sm: 0.841 wfm: 0.775
31 | [SINetV2 ] mae: 0.07 maxf: 0.801 avgf: 0.782 adpf: 0.779 maxe: 0.895 avge: 0.882 adpe: 0.884 sm: 0.82 wfm: 0.743
32 | [PFNet ] mae: 0.085 maxf: 0.758 avgf: 0.746 adpf: 0.751 maxe: 0.855 avge: 0.841 adpe: 0.855 sm: 0.782 wfm: 0.695
33 | [JSCOD ] mae: 0.073 maxf: 0.779 avgf: 0.772 adpf: 0.779 maxe: 0.873 avge: 0.859 adpe: 0.872 sm: 0.8 wfm: 0.728
34 | [ZoomNet ] mae: 0.066 maxf: 0.805 avgf: 0.794 adpf: 0.792 maxe: 0.892 avge: 0.877 adpe: 0.883 sm: 0.82 wfm: 0.752
35 | [C2FNetV2 ] mae: 0.077 maxf: 0.779 avgf: 0.77 adpf: 0.777 maxe: 0.869 avge: 0.859 adpe: 0.869 sm: 0.799 wfm: 0.73
36 | [BSANet ] mae: 0.079 maxf: 0.77 avgf: 0.763 adpf: 0.768 maxe: 0.867 avge: 0.851 adpe: 0.866 sm: 0.794 wfm: 0.717
37 | [OCENet ] mae: 0.08 maxf: 0.777 avgf: 0.766 adpf: 0.776 maxe: 0.865 avge: 0.852 adpe: 0.866 sm: 0.802 wfm: 0.723
38 | [BGNet ] mae: 0.073 maxf: 0.799 avgf: 0.789 adpf: 0.786 maxe: 0.882 avge: 0.87 adpe: 0.876 sm: 0.812 wfm: 0.749
39 | [D2CNet ] mae: 0.087 maxf: 0.743 avgf: 0.735 adpf: 0.747 maxe: 0.838 avge: 0.818 adpe: 0.844 sm: 0.774 wfm: 0.683
40 | [CamoFormerS] mae: 0.043 maxf: 0.871 avgf: 0.856 adpf: 0.856 maxe: 0.938 avge: 0.93 adpe: 0.935 sm: 0.876 wfm: 0.832
41 |
42 | ========>> Dataset: COD10K <<========
43 | [PopNet ] mae: 0.028 maxf: 0.802 avgf: 0.786 adpf: 0.771 maxe: 0.919 avge: 0.91 adpe: 0.91 sm: 0.851 wfm: 0.757
44 | [PreyNet ] mae: 0.034 maxf: 0.747 avgf: 0.736 adpf: 0.731 maxe: 0.891 avge: 0.881 adpe: 0.894 sm: 0.813 wfm: 0.697
45 | [TINet ] mae: 0.042 maxf: 0.712 avgf: 0.679 adpf: 0.652 maxe: 0.878 avge: 0.861 adpe: 0.848 sm: 0.793 wfm: 0.635
46 | [ERRNet ] mae: 0.043 maxf: 0.702 avgf: 0.675 adpf: 0.646 maxe: 0.886 avge: 0.867 adpe: 0.845 sm: 0.786 wfm: 0.63
47 | [UGTR ] mae: 0.035 maxf: 0.742 avgf: 0.712 adpf: 0.671 maxe: 0.891 avge: 0.853 adpe: 0.85 sm: 0.818 wfm: 0.667
48 | [FAPNet ] mae: 0.036 maxf: 0.758 avgf: 0.731 adpf: 0.707 maxe: 0.902 avge: 0.888 adpe: 0.875 sm: 0.822 wfm: 0.694
49 | [SMGL ] mae: 0.037 maxf: 0.733 avgf: 0.702 adpf: 0.667 maxe: 0.889 avge: 0.845 adpe: 0.851 sm: 0.811 wfm: 0.655
50 | [TPRNet ] mae: 0.036 maxf: 0.748 avgf: 0.724 adpf: 0.694 maxe: 0.903 avge: 0.887 adpe: 0.869 sm: 0.817 wfm: 0.683
51 | [CRNet ] mae: 0.049 maxf: 0.636 avgf: 0.633 adpf: 0.637 maxe: 0.845 avge: 0.832 adpe: 0.845 sm: 0.733 wfm: 0.576
52 | [SINet ] mae: 0.043 maxf: 0.691 avgf: 0.679 adpf: 0.667 maxe: 0.874 avge: 0.864 adpe: 0.867 sm: 0.776 wfm: 0.631
53 | [C2FNet ] mae: 0.036 maxf: 0.743 avgf: 0.723 adpf: 0.703 maxe: 0.9 avge: 0.89 adpe: 0.886 sm: 0.813 wfm: 0.686
54 | [RMGL ] mae: 0.035 maxf: 0.738 avgf: 0.711 adpf: 0.681 maxe: 0.89 avge: 0.852 adpe: 0.865 sm: 0.814 wfm: 0.666
55 | [PFNetPlus ] mae: 0.037 maxf: 0.734 avgf: 0.716 adpf: 0.698 maxe: 0.895 avge: 0.884 adpe: 0.88 sm: 0.806 wfm: 0.677
56 | [CamoFormerR] mae: 0.029 maxf: 0.786 avgf: 0.753 adpf: 0.721 maxe: 0.93 avge: 0.916 adpe: 0.9 sm: 0.838 wfm: 0.724
57 | [NCHIT ] mae: 0.046 maxf: 0.698 avgf: 0.649 adpf: 0.596 maxe: 0.879 avge: 0.819 adpe: 0.794 sm: 0.792 wfm: 0.591
58 | [SegMaR ] mae: 0.034 maxf: 0.774 avgf: 0.757 adpf: 0.739 maxe: 0.906 avge: 0.899 adpe: 0.893 sm: 0.833 wfm: 0.724
59 | [CamoFormerP] mae: 0.023 maxf: 0.829 avgf: 0.811 adpf: 0.794 maxe: 0.939 avge: 0.932 adpe: 0.931 sm: 0.869 wfm: 0.786
60 | [CamoFormerC] mae: 0.024 maxf: 0.818 avgf: 0.798 adpf: 0.778 maxe: 0.935 avge: 0.926 adpe: 0.926 sm: 0.86 wfm: 0.77
61 | [HitNet ] mae: 0.023 maxf: 0.838 avgf: 0.823 adpf: 0.818 maxe: 0.938 avge: 0.935 adpe: 0.936 sm: 0.871 wfm: 0.806
62 | [DGNetS ] mae: 0.036 maxf: 0.743 avgf: 0.71 adpf: 0.68 maxe: 0.905 avge: 0.888 adpe: 0.869 sm: 0.81 wfm: 0.672
63 | [LSR ] mae: 0.037 maxf: 0.732 avgf: 0.715 adpf: 0.699 maxe: 0.892 avge: 0.88 adpe: 0.883 sm: 0.804 wfm: 0.673
64 | [DGNet ] mae: 0.033 maxf: 0.759 avgf: 0.728 adpf: 0.698 maxe: 0.911 avge: 0.896 adpe: 0.879 sm: 0.822 wfm: 0.693
65 | [DTINet ] mae: 0.034 maxf: 0.754 avgf: 0.726 adpf: 0.702 maxe: 0.911 avge: 0.896 adpe: 0.881 sm: 0.824 wfm: 0.695
66 | [CubeNet ] mae: 0.041 maxf: 0.715 avgf: 0.692 adpf: 0.669 maxe: 0.883 avge: 0.865 adpe: 0.862 sm: 0.795 wfm: 0.643
67 | [FDNet ] mae: 0.03 maxf: 0.788 avgf: 0.757 adpf: 0.728 maxe: 0.935 avge: 0.919 adpe: 0.906 sm: 0.84 wfm: 0.729
68 | [JSCOD ] mae: 0.035 maxf: 0.738 avgf: 0.721 adpf: 0.705 maxe: 0.891 avge: 0.884 adpe: 0.882 sm: 0.809 wfm: 0.684
69 | [SINetV2 ] mae: 0.037 maxf: 0.752 avgf: 0.718 adpf: 0.682 maxe: 0.906 avge: 0.887 adpe: 0.864 sm: 0.815 wfm: 0.68
70 | [BAS ] mae: 0.038 maxf: 0.729 avgf: 0.715 adpf: 0.707 maxe: 0.87 avge: 0.855 adpe: 0.869 sm: 0.802 wfm: 0.677
71 | [PFNet ] mae: 0.04 maxf: 0.725 avgf: 0.701 adpf: 0.676 maxe: 0.89 avge: 0.877 adpe: 0.868 sm: 0.8 wfm: 0.66
72 | [ZoomNet ] mae: 0.029 maxf: 0.78 avgf: 0.766 adpf: 0.741 maxe: 0.911 avge: 0.888 adpe: 0.893 sm: 0.838 wfm: 0.729
73 | [C2FNetV2 ] mae: 0.036 maxf: 0.742 avgf: 0.725 adpf: 0.718 maxe: 0.896 avge: 0.887 adpe: 0.89 sm: 0.811 wfm: 0.691
74 | [BSANet ] mae: 0.034 maxf: 0.753 avgf: 0.738 adpf: 0.723 maxe: 0.901 avge: 0.891 adpe: 0.894 sm: 0.818 wfm: 0.699
75 | [BGNet ] mae: 0.033 maxf: 0.774 avgf: 0.753 adpf: 0.739 maxe: 0.911 avge: 0.901 adpe: 0.902 sm: 0.831 wfm: 0.722
76 | [OCENet ] mae: 0.033 maxf: 0.764 avgf: 0.741 adpf: 0.718 maxe: 0.905 avge: 0.894 adpe: 0.885 sm: 0.827 wfm: 0.707
77 | [CamoFormerS] mae: 0.024 maxf: 0.818 avgf: 0.799 adpf: 0.78 maxe: 0.941 avge: 0.931 adpe: 0.932 sm: 0.862 wfm: 0.772
78 | [D2CNet ] mae: 0.037 maxf: 0.736 avgf: 0.72 adpf: 0.702 maxe: 0.887 avge: 0.876 adpe: 0.879 sm: 0.807 wfm: 0.68
79 |
80 | ========>> Dataset: NC4K <<========
81 | [TPRNet ] mae: 0.048 maxf: 0.82 avgf: 0.805 adpf: 0.798 maxe: 0.911 avge: 0.898 adpe: 0.901 sm: 0.846 wfm: 0.768
82 | [SMGL ] mae: 0.055 maxf: 0.797 avgf: 0.777 adpf: 0.771 maxe: 0.893 avge: 0.863 adpe: 0.885 sm: 0.829 wfm: 0.731
83 | [ERRNet ] mae: 0.054 maxf: 0.794 avgf: 0.778 adpf: 0.769 maxe: 0.901 avge: 0.887 adpe: 0.892 sm: 0.827 wfm: 0.737
84 | [PreyNet ] mae: 0.05 maxf: 0.811 avgf: 0.803 adpf: 0.805 maxe: 0.899 avge: 0.887 adpe: 0.899 sm: 0.834 wfm: 0.763
85 | [FAPNet ] mae: 0.047 maxf: 0.826 avgf: 0.81 adpf: 0.804 maxe: 0.91 avge: 0.899 adpe: 0.903 sm: 0.851 wfm: 0.775
86 | [UGTR ] mae: 0.052 maxf: 0.807 avgf: 0.787 adpf: 0.779 maxe: 0.899 avge: 0.874 adpe: 0.889 sm: 0.839 wfm: 0.747
87 | [PopNet ] mae: 0.042 maxf: 0.843 avgf: 0.833 adpf: 0.83 maxe: 0.919 avge: 0.909 adpe: 0.915 sm: 0.861 wfm: 0.802
88 | [TINet ] mae: 0.055 maxf: 0.793 avgf: 0.773 adpf: 0.766 maxe: 0.89 avge: 0.879 adpe: 0.882 sm: 0.829 wfm: 0.734
89 | [SINet ] mae: 0.058 maxf: 0.775 avgf: 0.769 adpf: 0.768 maxe: 0.883 avge: 0.871 adpe: 0.883 sm: 0.808 wfm: 0.723
90 | [SegMaR ] mae: 0.046 maxf: 0.826 avgf: 0.821 adpf: 0.821 maxe: 0.907 avge: 0.896 adpe: 0.905 sm: 0.841 wfm: 0.781
91 | [RMGL ] mae: 0.052 maxf: 0.8 avgf: 0.782 adpf: 0.778 maxe: 0.893 avge: 0.867 adpe: 0.89 sm: 0.833 wfm: 0.74
92 | [CamoFormerR] mae: 0.042 maxf: 0.83 avgf: 0.821 adpf: 0.82 maxe: 0.914 avge: 0.9 adpe: 0.913 sm: 0.855 wfm: 0.788
93 | [C2FNet ] mae: 0.049 maxf: 0.81 avgf: 0.795 adpf: 0.788 maxe: 0.904 avge: 0.897 adpe: 0.901 sm: 0.838 wfm: 0.762
94 | [CamoFormerP] mae: 0.03 maxf: 0.88 avgf: 0.868 adpf: 0.863 maxe: 0.946 avge: 0.939 adpe: 0.941 sm: 0.892 wfm: 0.847
95 | [NCHIT ] mae: 0.058 maxf: 0.792 avgf: 0.758 adpf: 0.751 maxe: 0.894 avge: 0.851 adpe: 0.872 sm: 0.83 wfm: 0.71
96 | [CamoFormerC] mae: 0.032 maxf: 0.87 avgf: 0.857 adpf: 0.851 maxe: 0.94 avge: 0.933 adpe: 0.937 sm: 0.883 wfm: 0.834
97 | [HitNet ] mae: 0.037 maxf: 0.863 avgf: 0.853 adpf: 0.854 maxe: 0.929 avge: 0.926 adpe: 0.928 sm: 0.875 wfm: 0.834
98 | [DTINet ] mae: 0.041 maxf: 0.836 avgf: 0.818 adpf: 0.809 maxe: 0.926 avge: 0.917 adpe: 0.914 sm: 0.863 wfm: 0.792
99 | [DGNet ] mae: 0.042 maxf: 0.833 avgf: 0.814 adpf: 0.803 maxe: 0.922 avge: 0.911 adpe: 0.91 sm: 0.857 wfm: 0.784
100 | [DGNetS ] mae: 0.047 maxf: 0.819 avgf: 0.799 adpf: 0.789 maxe: 0.913 avge: 0.902 adpe: 0.902 sm: 0.845 wfm: 0.764
101 | [LSR ] mae: 0.048 maxf: 0.815 avgf: 0.804 adpf: 0.802 maxe: 0.907 avge: 0.895 adpe: 0.904 sm: 0.84 wfm: 0.766
102 | [JSCOD ] mae: 0.047 maxf: 0.816 avgf: 0.806 adpf: 0.803 maxe: 0.907 avge: 0.898 adpe: 0.906 sm: 0.842 wfm: 0.771
103 | [SINetV2 ] mae: 0.048 maxf: 0.823 avgf: 0.805 adpf: 0.792 maxe: 0.914 avge: 0.903 adpe: 0.901 sm: 0.847 wfm: 0.77
104 | [BAS ] mae: 0.058 maxf: 0.782 avgf: 0.772 adpf: 0.767 maxe: 0.872 avge: 0.859 adpe: 0.868 sm: 0.817 wfm: 0.732
105 | [FDNet ] mae: 0.052 maxf: 0.804 avgf: 0.784 adpf: 0.774 maxe: 0.905 avge: 0.893 adpe: 0.895 sm: 0.834 wfm: 0.75
106 | [PFNet ] mae: 0.053 maxf: 0.799 avgf: 0.784 adpf: 0.779 maxe: 0.898 avge: 0.887 adpe: 0.894 sm: 0.829 wfm: 0.745
107 | [ZoomNet ] mae: 0.043 maxf: 0.828 avgf: 0.818 adpf: 0.814 maxe: 0.912 avge: 0.896 adpe: 0.907 sm: 0.853 wfm: 0.784
108 | [C2FNetV2 ] mae: 0.048 maxf: 0.814 avgf: 0.802 adpf: 0.799 maxe: 0.904 avge: 0.896 adpe: 0.9 sm: 0.84 wfm: 0.77
109 | [BSANet ] mae: 0.048 maxf: 0.817 avgf: 0.808 adpf: 0.805 maxe: 0.907 avge: 0.897 adpe: 0.906 sm: 0.841 wfm: 0.771
110 | [BGNet ] mae: 0.044 maxf: 0.833 avgf: 0.82 adpf: 0.813 maxe: 0.916 avge: 0.907 adpe: 0.911 sm: 0.851 wfm: 0.788
111 | [OCENet ] mae: 0.045 maxf: 0.831 avgf: 0.818 adpf: 0.812 maxe: 0.913 avge: 0.902 adpe: 0.908 sm: 0.853 wfm: 0.785
112 | [CamoFormerS] mae: 0.031 maxf: 0.877 avgf: 0.863 adpf: 0.857 maxe: 0.946 avge: 0.937 adpe: 0.941 sm: 0.888 wfm: 0.84
113 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/plot.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 | import textwrap
4 |
5 | import numpy as np
6 | import yaml
7 |
8 | from metrics import draw_curves
9 |
10 |
11 | def get_args():
12 | parser = argparse.ArgumentParser(
13 | description=textwrap.dedent(
14 | r"""
15 | INCLUDE:
16 |
17 | - Fm Curve
18 | - PR Curves
19 |
20 | NOTE:
21 |
22 | - Our method automatically calculates the intersection of `pre` and `gt`.
23 | - Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext`
24 |
25 | EXAMPLES:
26 |
27 | python plot.py \
28 | --curves-npy output/rgbd-othermethods-curves.npy output/rgbd-ours-curves.npy \ # use the information from these npy files to draw curves
29 | --num-rows 1 \ # set the number of rows of the figure to 2
30 | --style-cfg configs/single_row_style.yml \ # specific the configuration file for the style of matplotlib
31 | --num-col-legend 1 \ # set the number of the columns of the legend in the figure to 1
32 | --mode pr \ # draw `pr` curves
33 | --our-methods Ours \ # specific the names of our own methods, they must be contained in the npy file
34 | --save-name ./output/rgbd-pr-curves # save the figure into `./output/rgbd-pr-curves.`, where `ext_name` will be specificed to the item `savefig.format` in the `--style-cfg`
35 |
36 | python plot.py \
37 | --curves-npy output/rgbd-othermethods-curves.npy output/rgbd-ours-curves.npy \
38 | --num-rows 2 \
39 | --style-cfg configs/two_row_style.yml \
40 | --num-col-legend 2 \
41 | --separated-legend \ # use a separated legend
42 | --mode fm \ # draw `fm` curves
43 | --our-methods OursV0 OursV1 \ # specific the names of our own methods, they must be contained in the npy file
44 | --save-name output/rgbd-fm \
45 | --alias-yaml configs/rgbd_aliases.yaml # aliases corresponding to methods and datasets you want to use
46 | """
47 | ),
48 | formatter_class=argparse.RawTextHelpFormatter,
49 | )
50 | parser.add_argument("--alias-yaml", type=str, help="Yaml file for datasets and methods alias.")
51 | parser.add_argument(
52 | "--style-cfg",
53 | type=str,
54 | required=True,
55 | help="Yaml file for plotting curves.",
56 | )
57 | parser.add_argument(
58 | "--curves-npys",
59 | required=True,
60 | type=str,
61 | nargs="+",
62 | help="Npy file for saving curve results.",
63 | )
64 | parser.add_argument(
65 | "--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it."
66 | )
67 | parser.add_argument(
68 | "--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1"
69 | )
70 | parser.add_argument(
71 | "--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1"
72 | )
73 | parser.add_argument(
74 | "--mode",
75 | type=str,
76 | choices=["pr", "fm"],
77 | default="pr",
78 | help="Mode for plotting. Default: pr",
79 | )
80 | parser.add_argument(
81 | "--separated-legend", action="store_true", help="Use the separated legend."
82 | )
83 | parser.add_argument("--sharey", action="store_true", help="Use the shared y-axis.")
84 | parser.add_argument("--save-name", type=str, help="the exported file path")
85 | args = parser.parse_args()
86 |
87 | return args
88 |
89 |
90 | def main(args):
91 | method_aliases = dataset_aliases = None
92 | if args.alias_yaml:
93 | with open(args.alias_yaml, mode="r", encoding="utf-8") as f:
94 | aliases = yaml.safe_load(f)
95 | method_aliases = aliases.get("method")
96 | dataset_aliases = aliases.get("dataset")
97 |
98 | draw_curves.draw_curves(
99 | for_pr=args.mode == "pr",
100 | # 不同曲线的绘图配置
101 | axes_setting={
102 | # pr曲线的配置
103 | "pr": {
104 | "x_label": "Recall",
105 | "y_label": "Precision",
106 | "x_ticks": np.linspace(0.5, 1, 6),
107 | "y_ticks": np.linspace(0.7, 1, 6),
108 | },
109 | # fm曲线的配置
110 | "fm": {
111 | "x_label": "Threshold",
112 | "y_label": r"F$_{\beta}$",
113 | "x_ticks": np.linspace(0, 1, 6),
114 | "y_ticks": np.linspace(0.6, 1, 6),
115 | },
116 | },
117 | curves_npy_path=args.curves_npys,
118 | row_num=args.num_rows,
119 | method_aliases=method_aliases,
120 | dataset_aliases=dataset_aliases,
121 | style_cfg=args.style_cfg,
122 | ncol_of_legend=args.num_col_legend,
123 | separated_legend=args.separated_legend,
124 | sharey=args.sharey,
125 | our_methods=args.our_methods,
126 | save_name=args.save_name,
127 | )
128 |
129 |
130 | if __name__ == "__main__":
131 | args = get_args()
132 | main(args)
133 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/pyproject.toml:
--------------------------------------------------------------------------------
1 | # https://github.com/LongTengDao/TOML/
2 |
3 | [tool.isort]
4 | # https://pycqa.github.io/isort/docs/configuration/options/
5 | profile = "black"
6 | multi_line_output = 3
7 | filter_files = true
8 | supported_extensions = "py"
9 |
10 | [tool.black]
11 | line-length = 99
12 | include = '\.pyi?$'
13 | exclude = '''
14 | /(
15 | \.eggs
16 | | \.git
17 | | \.idea
18 | | \.vscode
19 | | \.hg
20 | | \.mypy_cache
21 | | \.tox
22 | | \.venv
23 | | _build
24 | | buck-out
25 | | build
26 | | dist
27 | | output
28 | )/
29 | '''
30 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/requirements.txt:
--------------------------------------------------------------------------------
1 | # Automatically generated by https://github.com/damnever/pigar.
2 |
3 | # PySODEvalToolkit/utils/misc.py: 6
4 | Pillow == 8.1.2
5 |
6 | # PySODEvalToolkit/plot.py: 6
7 | # PySODEvalToolkit/tools/converter.py: 10
8 | PyYAML == 5.4.1
9 |
10 | # PySODEvalToolkit/tools/markdown2html.py: 4
11 | markdown2 == 2.4.1
12 |
13 | # PySODEvalToolkit/metrics/draw_curves.py: 6
14 | # PySODEvalToolkit/utils/generate_info.py: 7
15 | # PySODEvalToolkit/utils/recorders/curve_drawer.py: 8
16 | matplotlib == 3.4.2
17 |
18 | # PySODEvalToolkit/metrics/cal_cosod_matrics.py: 6
19 | # PySODEvalToolkit/metrics/cal_sod_matrics.py: 8
20 | # PySODEvalToolkit/metrics/draw_curves.py: 5
21 | # PySODEvalToolkit/metrics/extra_metrics.py: 2
22 | # PySODEvalToolkit/plot.py: 5
23 | # PySODEvalToolkit/tools/append_results.py: 4
24 | # PySODEvalToolkit/tools/converter.py: 9
25 | # PySODEvalToolkit/untracked/collect_results.py: 8
26 | # PySODEvalToolkit/untracked/find_bad_prediction.py: 4
27 | # PySODEvalToolkit/untracked/modify_npy.py: 1
28 | # PySODEvalToolkit/untracked/plot_results.py: 4
29 | # PySODEvalToolkit/untracked/plot_sota_cmp.py: 4
30 | # PySODEvalToolkit/utils/misc.py: 5
31 | # PySODEvalToolkit/utils/recorders/metric_recorder.py: 6
32 | numpy == 1.19.2
33 |
34 | # PySODEvalToolkit/untracked/collect_results.py: 7
35 | # PySODEvalToolkit/untracked/plot_sota_cmp.py: 3
36 | # PySODEvalToolkit/utils/misc.py: 4
37 | opencv_python_headless == 4.5.1.48
38 |
39 | # PySODEvalToolkit/utils/recorders/excel_recorder.py: 9,10,11
40 | openpyxl == 3.0.7
41 |
42 | # PySODEvalToolkit/metrics/extra_metrics.py: 3
43 | # PySODEvalToolkit/untracked/collect_results.py: 9
44 | # PySODEvalToolkit/utils/recorders/metric_recorder.py: 7
45 | pysodmetrics == 1.3.0
46 |
47 | # PySODEvalToolkit/utils/print_formatter.py: 3
48 | tabulate == 0.8.9
49 |
50 | # PySODEvalToolkit/metrics/cal_cosod_matrics.py: 7
51 | # PySODEvalToolkit/metrics/cal_sod_matrics.py: 9
52 | # PySODEvalToolkit/untracked/collect_results.py: 10
53 | # PySODEvalToolkit/untracked/find_bad_prediction.py: 5
54 | tqdm == 4.59.0
55 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/.backup/individual_metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/09/13
3 | # @Author : Johnson-Chou
4 | # @Email : johnson111788@gmail.com
5 | # @FileName : metrics.py
6 | # @Reference: https://github.com/lartpang/PySODMetrics and https://github.com/mczhuge/SOCToolbox
7 |
8 | import torch
9 | import numpy as np
10 | import torch.nn as nn
11 | from sklearn import metrics
12 | from torch.autograd import Function
13 | from scipy.ndimage import convolve, distance_transform_edt as bwdist
14 |
15 |
16 | _EPS = np.spacing(1)
17 | _TYPE = np.float64
18 |
19 |
20 | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple:
21 | gt = gt > 128
22 | pred = pred / 255
23 | if pred.max() != pred.min():
24 | pred = (pred - pred.min()) / (pred.max() - pred.min())
25 | return pred, gt
26 |
27 |
28 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float:
29 | return min(2 * matrix.mean(), max_value)
30 |
31 |
32 | class Fmeasure(object):
33 | def __init__(self, beta: float = 0.3):
34 | self.beta = beta
35 | self.precisions = []
36 | self.recalls = []
37 | self.adaptive_fms = []
38 | self.changeable_fms = []
39 |
40 | def step(self, pred: np.ndarray, gt: np.ndarray):
41 | pred, gt = _prepare_data(pred, gt)
42 |
43 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt)
44 | self.adaptive_fms.append(adaptive_fm)
45 |
46 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt)
47 | self.precisions.append(precisions)
48 | self.recalls.append(recalls)
49 | self.changeable_fms.append(changeable_fms)
50 |
51 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float:
52 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1)
53 | binary_predcition = pred >= adaptive_threshold
54 | area_intersection = binary_predcition[gt].sum()
55 | if area_intersection == 0:
56 | adaptive_fm = 0
57 | else:
58 | pre = area_intersection / np.count_nonzero(binary_predcition)
59 | rec = area_intersection / np.count_nonzero(gt)
60 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec)
61 | return adaptive_fm
62 |
63 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple:
64 | pred = (pred * 255).astype(np.uint8)
65 | bins = np.linspace(0, 256, 257)
66 | fg_hist, _ = np.histogram(pred[gt], bins=bins)
67 | bg_hist, _ = np.histogram(pred[~gt], bins=bins)
68 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0)
69 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0)
70 | TPs = fg_w_thrs
71 | Ps = fg_w_thrs + bg_w_thrs
72 | Ps[Ps == 0] = 1
73 | T = max(np.count_nonzero(gt), 1)
74 | precisions = TPs / Ps
75 | recalls = TPs / T
76 | numerator = (1 + self.beta) * precisions * recalls
77 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls)
78 | changeable_fms = numerator / denominator
79 | return precisions, recalls, changeable_fms
80 |
81 | def get_results(self) -> dict:
82 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE))
83 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0)
84 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256
85 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256
86 | return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm),
87 | pr=dict(p=precision, r=recall))
88 |
89 |
90 | class MAE(object):
91 | def __init__(self):
92 | self.maes = []
93 |
94 | def step(self, pred: np.ndarray, gt: np.ndarray):
95 | pred, gt = _prepare_data(pred, gt)
96 |
97 | mae = self.cal_mae(pred, gt)
98 | self.maes.append(mae)
99 |
100 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> float:
101 | mae = np.mean(np.abs(pred - gt))
102 | return mae
103 |
104 | def get_results(self) -> dict:
105 | mae = np.mean(np.array(self.maes, _TYPE))
106 | return dict(mae=mae)
107 |
108 |
109 | class Smeasure(object):
110 | def __init__(self, alpha: float = 0.5):
111 | self.sms = []
112 | self.alpha = alpha
113 |
114 | def step(self, pred: np.ndarray, gt: np.ndarray):
115 | pred, gt = _prepare_data(pred=pred, gt=gt)
116 |
117 | sm = self.cal_sm(pred, gt)
118 | self.sms.append(sm)
119 |
120 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float:
121 | y = np.mean(gt)
122 | if y == 0:
123 | sm = 1 - np.mean(pred)
124 | elif y == 1:
125 | sm = np.mean(pred)
126 | else:
127 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
128 | sm = max(0, sm)
129 | return sm
130 |
131 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float:
132 | fg = pred * gt
133 | bg = (1 - pred) * (1 - gt)
134 | u = np.mean(gt)
135 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt)
136 | return object_score
137 |
138 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float:
139 | x = np.mean(pred[gt == 1])
140 | sigma_x = np.std(pred[gt == 1], ddof=1)
141 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS)
142 | return score
143 |
144 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float:
145 | x, y = self.centroid(gt)
146 | part_info = self.divide_with_xy(pred, gt, x, y)
147 | w1, w2, w3, w4 = part_info['weight']
148 | pred1, pred2, pred3, pred4 = part_info['pred']
149 | gt1, gt2, gt3, gt4 = part_info['gt']
150 | score1 = self.ssim(pred1, gt1)
151 | score2 = self.ssim(pred2, gt2)
152 | score3 = self.ssim(pred3, gt3)
153 | score4 = self.ssim(pred4, gt4)
154 |
155 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4
156 |
157 | def centroid(self, matrix: np.ndarray) -> tuple:
158 | h, w = matrix.shape
159 | if matrix.sum() == 0:
160 | x = np.round(w / 2)
161 | y = np.round(h / 2)
162 | else:
163 | area_object = np.sum(matrix)
164 | row_ids = np.arange(h)
165 | col_ids = np.arange(w)
166 | x = np.round(np.sum(np.sum(matrix, axis=0) * col_ids) / area_object)
167 | y = np.round(np.sum(np.sum(matrix, axis=1) * row_ids) / area_object)
168 | return int(x) + 1, int(y) + 1
169 |
170 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict:
171 | h, w = gt.shape
172 | area = h * w
173 |
174 | gt_LT = gt[0:y, 0:x]
175 | gt_RT = gt[0:y, x:w]
176 | gt_LB = gt[y:h, 0:x]
177 | gt_RB = gt[y:h, x:w]
178 |
179 | pred_LT = pred[0:y, 0:x]
180 | pred_RT = pred[0:y, x:w]
181 | pred_LB = pred[y:h, 0:x]
182 | pred_RB = pred[y:h, x:w]
183 |
184 | w1 = x * y / area
185 | w2 = y * (w - x) / area
186 | w3 = (h - y) * x / area
187 | w4 = 1 - w1 - w2 - w3
188 |
189 | return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB),
190 | pred=(pred_LT, pred_RT, pred_LB, pred_RB),
191 | weight=(w1, w2, w3, w4))
192 |
193 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float:
194 | h, w = pred.shape
195 | N = h * w
196 |
197 | x = np.mean(pred)
198 | y = np.mean(gt)
199 |
200 | sigma_x = np.sum((pred - x) ** 2) / (N - 1)
201 | sigma_y = np.sum((gt - y) ** 2) / (N - 1)
202 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1)
203 |
204 | alpha = 4 * x * y * sigma_xy
205 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y)
206 |
207 | if alpha != 0:
208 | score = alpha / (beta + _EPS)
209 | elif alpha == 0 and beta == 0:
210 | score = 1
211 | else:
212 | score = 0
213 | return score
214 |
215 | def get_results(self) -> dict:
216 | sm = np.mean(np.array(self.sms, dtype=_TYPE))
217 | return dict(sm=sm)
218 |
219 |
220 | class Emeasure(object):
221 | def __init__(self):
222 | self.adaptive_ems = []
223 | self.changeable_ems = []
224 |
225 | def step(self, pred: np.ndarray, gt: np.ndarray):
226 | pred, gt = _prepare_data(pred=pred, gt=gt)
227 | self.gt_fg_numel = np.count_nonzero(gt)
228 | self.gt_size = gt.shape[0] * gt.shape[1]
229 |
230 | changeable_ems = self.cal_changeable_em(pred, gt)
231 | self.changeable_ems.append(changeable_ems)
232 | adaptive_em = self.cal_adaptive_em(pred, gt)
233 | self.adaptive_ems.append(adaptive_em)
234 |
235 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float:
236 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1)
237 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold)
238 | return adaptive_em
239 |
240 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
241 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt)
242 | return changeable_ems
243 |
244 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
245 | binarized_pred = pred >= threshold
246 | fg_fg_numel = np.count_nonzero(binarized_pred & gt)
247 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
248 |
249 | fg___numel = fg_fg_numel + fg_bg_numel
250 | bg___numel = self.gt_size - fg___numel
251 |
252 | if self.gt_fg_numel == 0:
253 | enhanced_matrix_sum = bg___numel
254 | elif self.gt_fg_numel == self.gt_size:
255 | enhanced_matrix_sum = fg___numel
256 | else:
257 | parts_numel, combinations = self.generate_parts_numel_combinations(
258 | fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel,
259 | pred_fg_numel=fg___numel, pred_bg_numel=bg___numel,
260 | )
261 |
262 | results_parts = []
263 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
264 | align_matrix_value = 2 * (combination[0] * combination[1]) / \
265 | (combination[0] ** 2 + combination[1] ** 2 + _EPS)
266 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
267 | results_parts.append(enhanced_matrix_value * part_numel)
268 | enhanced_matrix_sum = sum(results_parts)
269 |
270 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
271 | return em
272 |
273 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
274 | pred = (pred * 255).astype(np.uint8)
275 | bins = np.linspace(0, 256, 257)
276 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
277 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
278 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
279 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
280 |
281 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
282 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
283 |
284 | if self.gt_fg_numel == 0:
285 | enhanced_matrix_sum = bg___numel_w_thrs
286 | elif self.gt_fg_numel == self.gt_size:
287 | enhanced_matrix_sum = fg___numel_w_thrs
288 | else:
289 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
290 | fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
291 | pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
292 | )
293 |
294 | results_parts = np.empty(shape=(4, 256), dtype=np.float64)
295 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
296 | align_matrix_value = 2 * (combination[0] * combination[1]) / \
297 | (combination[0] ** 2 + combination[1] ** 2 + _EPS)
298 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
299 | results_parts[i] = enhanced_matrix_value * part_numel
300 | enhanced_matrix_sum = results_parts.sum(axis=0)
301 |
302 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
303 | return em
304 |
305 | def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel):
306 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel
307 | bg_bg_numel = pred_bg_numel - bg_fg_numel
308 |
309 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel]
310 |
311 | mean_pred_value = pred_fg_numel / self.gt_size
312 | mean_gt_value = self.gt_fg_numel / self.gt_size
313 |
314 | demeaned_pred_fg_value = 1 - mean_pred_value
315 | demeaned_pred_bg_value = 0 - mean_pred_value
316 | demeaned_gt_fg_value = 1 - mean_gt_value
317 | demeaned_gt_bg_value = 0 - mean_gt_value
318 |
319 | combinations = [
320 | (demeaned_pred_fg_value, demeaned_gt_fg_value),
321 | (demeaned_pred_fg_value, demeaned_gt_bg_value),
322 | (demeaned_pred_bg_value, demeaned_gt_fg_value),
323 | (demeaned_pred_bg_value, demeaned_gt_bg_value)
324 | ]
325 | return parts_numel, combinations
326 |
327 | def get_results(self) -> dict:
328 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE))
329 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0)
330 | return dict(em=dict(adp=adaptive_em, curve=changeable_em))
331 |
332 |
333 | class WeightedFmeasure(object):
334 | def __init__(self, beta: float = 1):
335 | self.beta = beta
336 | self.weighted_fms = []
337 |
338 | def step(self, pred: np.ndarray, gt: np.ndarray):
339 | pred, gt = _prepare_data(pred=pred, gt=gt)
340 |
341 | if np.all(~gt):
342 | wfm = 0
343 | else:
344 | wfm = self.cal_wfm(pred, gt)
345 | self.weighted_fms.append(wfm)
346 |
347 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float:
348 | # [Dst,IDXT] = bwdist(dGT);
349 | Dst, Idxt = bwdist(gt == 0, return_indices=True)
350 |
351 | # %Pixel dependency
352 | # E = abs(FG-dGT);
353 | E = np.abs(pred - gt)
354 | Et = np.copy(E)
355 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]]
356 |
357 | # K = fspecial('gaussian',7,5);
358 | # EA = imfilter(Et,K);
359 | K = self.matlab_style_gauss2D((7, 7), sigma=5)
360 | EA = convolve(Et, weights=K, mode="constant", cval=0)
361 | # MIN_E_EA = E;
362 | # MIN_E_EA(GT & EA np.ndarray:
382 | """
383 | 2D gaussian mask - should give the same result as MATLAB's
384 | fspecial('gaussian',[shape],[sigma])
385 | """
386 | m, n = [(ss - 1) / 2 for ss in shape]
387 | y, x = np.ogrid[-m: m + 1, -n: n + 1]
388 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
389 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
390 | sumh = h.sum()
391 | if sumh != 0:
392 | h /= sumh
393 | return h
394 |
395 | def get_results(self) -> dict:
396 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE))
397 | return dict(wfm=weighted_fm)
398 |
399 |
400 | class DICE(object):
401 | def __init__(self):
402 | self.dice = []
403 |
404 | def step(self, pred: np.ndarray, gt: np.ndarray):
405 | # pred, gt = _prepare_data(pred=pred, gt=gt)
406 | dice = self.cal_dice(pred, gt)
407 | self.dice.append(dice)
408 | return dice
409 |
410 | def cal_dice(self, pred: np.ndarray, gt: np.ndarray):
411 | # N = gt.size(0)
412 | smooth = 1
413 |
414 | pred_flat = pred.reshape(-1)
415 | gt_flat = gt.reshape(-1)
416 |
417 | intersection = pred_flat * gt_flat
418 |
419 | dice = 2 * (intersection.sum() + smooth) / (pred_flat.sum() + gt_flat.sum() + smooth)
420 | dice = 1 - dice.sum()
421 |
422 | return dice
423 |
424 | def get_results(self):
425 | dice = np.mean(np.array(self.dice, dtype=_TYPE))
426 | return dice
427 |
428 |
429 | class BinarizedF(Function):
430 | '''
431 | @ Reference: https://blog.csdn.net/weixin_42696356/article/details/100899711
432 | '''
433 | @staticmethod
434 | def forward(ctx, input):
435 | ctx.save_for_backward(input)
436 | a = torch.ones_like(input)
437 | b = torch.zeros_like(input)
438 | output = torch.where(input>=0.5,a,b)
439 | return output
440 |
441 | @staticmethod
442 | def backward(ctx, output_grad):
443 | input, = ctx.saved_tensors
444 | input_abs = torch.abs(input)
445 | ones = torch.ones_like(input)
446 | zeros = torch.zeros_like(input)
447 | input_grad = torch.where(input_abs<=1,ones, zeros)
448 | return input_grad
449 |
450 |
451 | class BinarizedModule(nn.Module):
452 | '''
453 | @ Reference: https://www.flyai.com/article/art7714fcddbf30a9ff5a35633f?type=e
454 | '''
455 | def __init__(self):
456 | super(BinarizedModule, self).__init__()
457 | self.BF = BinarizedF()
458 |
459 | def forward(self, input):
460 | output =self.BF.apply(torch.Tensor(input))
461 | return output
462 |
463 |
464 | class IoU(object):
465 | def __init__(self):
466 | self.iou = []
467 | self.n_classes = 2
468 | self.bin = BinarizedModule()
469 |
470 | def step(self, pred: np.ndarray, gt: np.ndarray):
471 | iou = self.cal_iou(pred, gt)
472 | self.iou.append(iou)
473 | return iou
474 |
475 | def _cal_iou(self, pred: np.ndarray, gt: np.ndarray):
476 | def cal_cm(y_true, y_pred):
477 | y_true = y_true.reshape(1, -1).squeeze()
478 | y_pred = y_pred.reshape(1, -1).squeeze()
479 | cm = metrics.confusion_matrix(y_true, y_pred)
480 | return cm
481 | pred = self.bin(pred)
482 | confusion_matrix = cal_cm(pred, gt)
483 | intersection = np.diag(confusion_matrix) # 交集
484 | union = np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) - np.diag(confusion_matrix) # 并集
485 | IoU = intersection / union # 交并比,即IoU
486 | return IoU
487 |
488 | def cal_iou(self, pred: np.ndarray, target: np.ndarray):
489 | Iand1 = np.sum(target * pred)
490 | Ior1 = np.sum(target) + np.sum(pred) - Iand1
491 | IoU1 = Iand1 / Ior1
492 | return IoU1
493 |
494 | def get_results(self):
495 | iou = np.mean(np.array(self.iou, dtype=_TYPE))
496 | return iou
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/.backup/ranking_per_images.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import individual_metrics as Measure
4 |
5 | model_list = ['2023-arXiv-CamoFormerS', '2023-arXiv-CamoFormerP', '2023-AAAI-HitNet', '2023-arXiv-PopNet', '2022-CVPR-FDNet']
6 |
7 | pred_root = '/home/users/u7248002/project/PySODEvalToolkit/cos_benchmark/COS-Benchmarking'
8 | gt_root = '/home/users/u7248002/project/PySODEvalToolkit/cos_benchmark/TestDataset/COD10K/GT'
9 |
10 | SM = Measure.Smeasure()
11 | EM = Measure.Emeasure()
12 | MAE = Measure.MAE()
13 |
14 | for file_name in os.listdir(gt_root):
15 | s_list, e_list, r_mae_list, total_list = [], [], [], []
16 | t_counter = 0
17 | for model_name in model_list:
18 | pred_pth = os.path.join(pred_root, model_name, 'COD10K', file_name)
19 | gt_pth = os.path.join(gt_root, file_name)
20 |
21 | assert os.path.isfile(gt_pth) and os.path.isfile(pred_pth)
22 | pred_ary = cv2.imread(pred_pth, cv2.IMREAD_GRAYSCALE)
23 | gt_ary = cv2.imread(gt_pth, cv2.IMREAD_GRAYSCALE)
24 |
25 | assert len(pred_ary.shape) == 2 and len(gt_ary.shape) == 2
26 | if pred_ary.shape != gt_ary.shape:
27 | pred_ary = cv2.resize(pred_ary, (gt_ary.shape[1], gt_ary.shape[0]), cv2.INTER_NEAREST)
28 |
29 | SM.step(pred=pred_ary, gt=gt_ary)
30 | EM.step(pred=pred_ary, gt=gt_ary)
31 | MAE.step(pred=pred_ary, gt=gt_ary)
32 |
33 | sm = SM.get_results()['sm']
34 | em = EM.get_results()['em']['curve'].max()
35 | r_mae = 1 - MAE.get_results()['mae']
36 |
37 | totol_score = sm + em + r_mae
38 |
39 | s_list.append(sm)
40 | e_list.append(em)
41 | r_mae_list.append(r_mae)
42 | total_list.append(totol_score)
43 | t_counter += totol_score
44 |
45 | print(f'{file_name} | {total_list[0]} | {total_list[1]} | {total_list[2]} | {total_list[3]} | {total_list[4]} | {t_counter}')
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/append_results.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 |
4 | import numpy as np
5 |
6 |
7 | def get_args():
8 | parser = argparse.ArgumentParser(
9 | description="""A simple tool for merging two npy file.
10 | Patch the method items corresponding to the `--method-names` and `--dataset-names` of `--new-npy` into `--old-npy`,
11 | and output the whole container to `--out-npy`.
12 | """
13 | )
14 | parser.add_argument("--old-npy", type=str, required=True)
15 | parser.add_argument("--new-npy", type=str, required=True)
16 | parser.add_argument("--method-names", type=str, nargs="+")
17 | parser.add_argument("--dataset-names", type=str, nargs="+")
18 | parser.add_argument("--out-npy", type=str, required=True)
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | def main():
24 | args = get_args()
25 | new_npy: dict = np.load(args.new_npy, allow_pickle=True).item()
26 | old_npy: dict = np.load(args.old_npy, allow_pickle=True).item()
27 |
28 | for dataset_name, methods_info in new_npy.items():
29 | if args.dataset_names and dataset_name not in args.dataset_names:
30 | continue
31 |
32 | print(f"[PROCESSING INFORMATION ABOUT DATASET {dataset_name}...]")
33 | old_methods_info = old_npy.get(dataset_name)
34 | if not old_methods_info:
35 | raise KeyError(f"{old_npy} doesn't contain the information about {dataset_name}.")
36 |
37 | print(f"OLD_NPY: {list(old_methods_info.keys())}")
38 | print(f"NEW_NPY: {list(methods_info.keys())}")
39 |
40 | for method_name, method_info in methods_info.items():
41 | if args.method_names and method_name not in args.method_names:
42 | continue
43 |
44 | if method_name not in old_npy[dataset_name]:
45 | old_methods_info[method_name] = method_info
46 | print(f"MERGED_NPY: {list(old_methods_info.keys())}")
47 |
48 | np.save(file=args.out_npy, arr=old_npy)
49 | print(f"THE MERGED_NPY IS SAVED INTO {args.out_npy}")
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/cal_avg_resolution.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/4/22
3 | # @Author : Daniel Ji
4 |
5 | import os
6 | from PIL import Image
7 |
8 |
9 | def cal_avg_res():
10 | """
11 | This function calculates the average resolution (H & W) of whole dataset
12 | """
13 | root = './cos_benchmark/TestDataset'
14 |
15 | for data_name in os.listdir(root):
16 | data_root = os.path.join(root, data_name, 'GT')
17 | H_avg, W_avg, count = 0, 0, 0
18 |
19 | for file_name in os.listdir(data_root):
20 | img_path = os.path.join(data_root, file_name)
21 | img = Image.open(img_path)
22 | size = img.size
23 | H_avg += size[0]
24 | W_avg += size[1]
25 | count += 1
26 |
27 | print(f'{data_name}, H-average is {H_avg/count} and W-average is {W_avg/count}')
28 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/check_path.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import argparse
4 | import json
5 | import os
6 | from collections import OrderedDict
7 |
8 | parser = argparse.ArgumentParser(description="A simple tool for checking your json config file.")
9 | parser.add_argument(
10 | "-m", "--method-jsons", nargs="+", required=True, help="The json file about all methods."
11 | )
12 | parser.add_argument(
13 | "-d", "--dataset-jsons", nargs="+", required=True, help="The json file about all datasets."
14 | )
15 | args = parser.parse_args()
16 |
17 | for method_json, dataset_json in zip(args.method_jsons, args.dataset_jsons):
18 | with open(method_json, encoding="utf-8", mode="r") as f:
19 | methods_info = json.load(f, object_hook=OrderedDict) # 有序载入
20 | with open(dataset_json, encoding="utf-8", mode="r") as f:
21 | datasets_info = json.load(f, object_hook=OrderedDict) # 有序载入
22 |
23 | total_msgs = []
24 | for method_name, method_info in methods_info.items():
25 | print(f"Checking for {method_name} ...")
26 | for dataset_name, results_info in method_info.items():
27 | if results_info is None:
28 | continue
29 |
30 | dataset_mask_info = datasets_info[dataset_name]["mask"]
31 | mask_path = dataset_mask_info["path"]
32 | mask_suffix = dataset_mask_info["suffix"]
33 |
34 | dir_path = results_info["path"]
35 | file_prefix = results_info.get("prefix", "")
36 | file_suffix = results_info["suffix"]
37 |
38 | if not os.path.exists(dir_path):
39 | total_msgs.append(f"{dir_path} 不存在")
40 | continue
41 | elif not os.path.isdir(dir_path):
42 | total_msgs.append(f"{dir_path} 不是正常的文件夹路径")
43 | continue
44 | else:
45 | pred_names = [
46 | name[len(file_prefix) : -len(file_suffix)]
47 | for name in os.listdir(dir_path)
48 | if name.startswith(file_prefix) and name.endswith(file_suffix)
49 | ]
50 | if len(pred_names) == 0:
51 | total_msgs.append(f"{dir_path} 中不包含前缀为{file_prefix}且后缀为{file_suffix}的文件")
52 | continue
53 |
54 | mask_names = [
55 | name[: -len(mask_suffix)]
56 | for name in os.listdir(mask_path)
57 | if name.endswith(mask_suffix)
58 | ]
59 | intersection_names = set(mask_names).intersection(set(pred_names))
60 | if len(intersection_names) == 0:
61 | total_msgs.append(f"{dir_path} 中数据名字与真值 {mask_path} 不匹配")
62 | elif len(intersection_names) != len(mask_names):
63 | difference_names = set(mask_names).difference(pred_names)
64 | total_msgs.append(
65 | f"{dir_path} 中数据({len(list(pred_names))})与真值({len(list(mask_names))})不一致"
66 | )
67 |
68 | if total_msgs:
69 | print(*total_msgs, sep="\n")
70 | else:
71 | print(f"{method_json} & {dataset_json} 基本正常")
72 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/converter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/8/25
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 |
6 | import argparse
7 | from itertools import chain
8 |
9 | import numpy as np
10 | import yaml
11 |
12 | parser = argparse.ArgumentParser(
13 | description="A useful and convenient tool to convert your .npy results into the table code in latex."
14 | )
15 | parser.add_argument(
16 | "-i",
17 | "--result-file",
18 | required=True,
19 | nargs="+",
20 | action="extend",
21 | help="The path of the *_metrics.npy file.",
22 | )
23 | parser.add_argument(
24 | "-o", "--tex-file", required=True, type=str, help="The path of the exported tex file."
25 | )
26 | parser.add_argument(
27 | "-c", "--config-file", type=str, help="The path of the customized config yaml file."
28 | )
29 | parser.add_argument(
30 | "--contain-table-env",
31 | action="store_true",
32 | help="Whether to containe the table env in the exported code.",
33 | )
34 | parser.add_argument(
35 | "--transpose",
36 | action="store_true",
37 | help="Whether to transpose the table.",
38 | )
39 | args = parser.parse_args()
40 |
41 |
42 | def update_dict(parent_dict, sub_dict):
43 | for sub_k, sub_v in sub_dict.items():
44 | if sub_k in parent_dict:
45 | if sub_v is not None and isinstance(sub_v, dict):
46 | update_dict(parent_dict=parent_dict[sub_k], sub_dict=sub_v)
47 | continue
48 | parent_dict.update(sub_dict)
49 |
50 |
51 | results = {}
52 | for result_file in args.result_file:
53 | result = np.load(file=result_file, allow_pickle=True).item()
54 | for dataset_name, method_infos in result.items():
55 | results.setdefault(dataset_name, {})
56 | for method_name, method_info in method_infos.items():
57 | results[dataset_name][method_name] = method_info
58 |
59 | IMPOSSIBLE_UP_BOUND = 1
60 | IMPOSSIBLE_DOWN_BOUND = 0
61 |
62 | # 读取数据
63 | dataset_names = sorted(list(results.keys()))
64 | metric_names = ["SM", "wFm", "MAE", "adpE", "avgE", "maxE", "adpF", "avgF", "maxF"]
65 | method_names = sorted(list(set(chain(*[list(results[n].keys()) for n in dataset_names]))))
66 |
67 | if args.config_file is not None:
68 | assert args.config_file.endswith(".yaml") or args.config_file.endswith("yml")
69 | with open(args.config_file, mode="r", encoding="utf-8") as f:
70 | cfg = yaml.safe_load(f)
71 |
72 | if "dataset_names" not in cfg:
73 | print(
74 | "`dataset_names` doesnot be contained in your config file, so we use the default config."
75 | )
76 | else:
77 | dataset_names = cfg["dataset_names"]
78 | if "metric_names" not in cfg:
79 | print(
80 | "`metric_names` doesnot be contained in your config file, so we use the default config."
81 | )
82 | else:
83 | metric_names = cfg["metric_names"]
84 | if "method_names" not in cfg:
85 | print(
86 | "`method_names` doesnot be contained in your config file, so we use the default config."
87 | )
88 | else:
89 | method_names = cfg["method_names"]
90 |
91 | print(
92 | f"CONFIG INFORMATION:"
93 | f"\n- DATASETS ({len(dataset_names)}): {dataset_names}]"
94 | f"\n- METRICS ({len(metric_names)}): {metric_names}"
95 | f"\n- METHODS ({len(method_names)}): {method_names}"
96 | )
97 |
98 | if isinstance(metric_names, (list, tuple)):
99 | ori_metric_names = metric_names
100 | elif isinstance(metric_names, dict):
101 | ori_metric_names, metric_names = list(zip(*list(metric_names.items())))
102 | else:
103 | raise NotImplementedError
104 |
105 | if isinstance(method_names, (list, tuple)):
106 | ori_method_names = method_names
107 | elif isinstance(method_names, dict):
108 | ori_method_names, method_names = list(zip(*list(method_names.items())))
109 | else:
110 | raise NotImplementedError
111 |
112 | # 整理表格
113 | ori_columns = []
114 | column_for_index = []
115 | for dataset_idx, dataset_name in enumerate(dataset_names):
116 | for metric_idx, ori_metric_name in enumerate(ori_metric_names):
117 | fiiled_value = (
118 | IMPOSSIBLE_UP_BOUND if ori_metric_name.lower() == "mae" else IMPOSSIBLE_DOWN_BOUND
119 | )
120 | fiiled_dict = {k: fiiled_value for k in ori_metric_names}
121 | ori_column = []
122 | for method_name in ori_method_names:
123 | method_result = results[dataset_name].get(method_name, fiiled_dict)
124 | if ori_metric_name not in method_result:
125 | raise KeyError(
126 | f"{ori_metric_name} must be contained in {list(method_result.keys())}"
127 | )
128 | ori_column.append(method_result[ori_metric_name])
129 |
130 | column_for_index.append([x * round(1 - fiiled_value * 2) for x in ori_column])
131 | ori_columns.append(ori_column)
132 |
133 | style_templates = dict(
134 | method_row_body="& {method_name}",
135 | method_column_body=" {method_name}",
136 | dataset_row_body="& \multicolumn{{{num_metrics}}}{{c}}{{\\textbf{{{dataset_name}}}}}",
137 | dataset_column_body="\multirow{{-{num_metrics}}}{{*}}{{\\rotatebox{{90}}{{\\textbf{{{dataset_name}}}}}}}",
138 | dataset_head=" ",
139 | metric_body="& {metric_name}",
140 | metric_row_head=" ",
141 | metric_column_head="& ",
142 | body=[
143 | "& {{\color{{reda}} \\textbf{{{txt:.03f}}}}}", # top1
144 | "& {{\color{{mygreen}} \\textbf{{{txt:.03f}}}}}", # top2
145 | "& {{\color{{myblue}} \\textbf{{{txt:.03f}}}}}", # top3
146 | "& {txt:.03f}", # other
147 | ],
148 | )
149 |
150 |
151 | # 排序并添加样式
152 | def replace_cell(ori_value, k):
153 | if ori_value == IMPOSSIBLE_UP_BOUND or ori_value == IMPOSSIBLE_DOWN_BOUND:
154 | new_value = "& "
155 | else:
156 | new_value = style_templates["body"][k].format(txt=ori_value)
157 | return new_value
158 |
159 |
160 | for col, ori_col in zip(column_for_index, ori_columns):
161 | col_array = np.array(col).reshape(-1)
162 | sorted_col_array = np.sort(np.unique(col_array), axis=-1)[-3:][::-1]
163 | # [top1_idxes, top2_idxes, top3_idxes]
164 | top_k_idxes = [np.argwhere(col_array == x).tolist() for x in sorted_col_array]
165 | for k, idxes in enumerate(top_k_idxes):
166 | for row_idx in idxes:
167 | ori_col[row_idx[0]] = replace_cell(ori_col[row_idx[0]], k)
168 |
169 | for idx, x in enumerate(ori_col):
170 | if not isinstance(x, str):
171 | ori_col[idx] = replace_cell(x, -1)
172 |
173 | # 构建表头
174 | num_datasets = len(dataset_names)
175 | num_metrics = len(metric_names)
176 | num_methods = len(method_names)
177 |
178 | # 先构开头的列,再整体构造开头的行
179 | latex_table_head = []
180 | latex_table_tail = []
181 |
182 | if not args.transpose:
183 | dataset_row = (
184 | [style_templates["dataset_head"]]
185 | + [
186 | style_templates["dataset_row_body"].format(num_metrics=num_metrics, dataset_name=x)
187 | for x in dataset_names
188 | ]
189 | + [r"\\"]
190 | )
191 | metric_row = (
192 | [style_templates["metric_row_head"]]
193 | + [style_templates["metric_body"].format(metric_name=x) for x in metric_names]
194 | * num_datasets
195 | + [r"\\"]
196 | )
197 | additional_rows = [dataset_row, metric_row]
198 |
199 | # 构建第一列
200 | method_column = [
201 | style_templates["method_column_body"].format(method_name=x) for x in method_names
202 | ]
203 | additional_columns = [method_column]
204 |
205 | columns = additional_columns + ori_columns
206 | rows = [list(row) + [r"\\"] for row in zip(*columns)]
207 | rows = additional_rows + rows
208 |
209 | if args.contain_table_env:
210 | column_style = "|".join([f"*{num_metrics}{{c}}"] * len(dataset_names))
211 | latex_table_head = [
212 | f"\\begin{{tabular}}{{l|{column_style}}}\n",
213 | "\\toprule[2pt]",
214 | ]
215 | else:
216 | dataset_column = []
217 | for x in dataset_names:
218 | blank_cells = [" "] * (num_metrics - 1)
219 | dataset_cell = [
220 | style_templates["dataset_column_body"].format(num_metrics=num_metrics, dataset_name=x)
221 | ]
222 | dataset_column.extend(blank_cells + dataset_cell)
223 | metric_column = [
224 | style_templates["metric_body"].format(metric_name=x) for x in metric_names
225 | ] * num_datasets
226 | additional_columns = [dataset_column, metric_column]
227 |
228 | method_row = (
229 | [style_templates["dataset_head"], style_templates["metric_column_head"]]
230 | + [style_templates["method_row_body"].format(method_name=x) for x in method_names]
231 | + [r"\\"]
232 | )
233 | additional_rows = [method_row]
234 |
235 | additional_columns = [list(x) for x in zip(*additional_columns)]
236 | rows = [cells + row + [r"\\"] for cells, row in zip(additional_columns, ori_columns)]
237 | rows = additional_rows + rows
238 |
239 | if args.contain_table_env:
240 | column_style = "".join([f"*{{{num_methods}}}{{c}}"])
241 | latex_table_head = [
242 | f"\\begin{{tabular}}{{cc|{column_style}}}\n",
243 | "\\toprule[2pt]",
244 | ]
245 |
246 | if args.contain_table_env:
247 | latex_table_tail = [
248 | "\\bottomrule[2pt]\n",
249 | "\\end{tabular}",
250 | ]
251 |
252 | rows = [latex_table_head] + rows + [latex_table_tail]
253 |
254 | with open(args.tex_file, mode="w", encoding="utf-8") as f:
255 | for row in rows:
256 | f.write("".join(row) + "\n")
257 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/generate_cos_config_files.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/4/22
3 | # @Author : Daniel Ji
4 |
5 | import os
6 |
7 | def check_integraty():
8 | """
9 | Check the matching status of prediction and ground-truth masks
10 | """
11 | pred_path = './benchmark/COS-Benchmarking'
12 | gt_path = './dataset'
13 |
14 | def returnNotMatches(a, b):
15 | return [[x for x in a if x not in b], [x for x in b if x not in a]]
16 |
17 | for model_name in os.listdir(pred_path):
18 | for data_name in ['CAMO', 'COD10K', 'NC4K']:
19 | model_data_path = os.path.join(pred_path, model_name, data_name)
20 | gt_data_path = os.path.join(gt_path, data_name, 'GT')
21 |
22 | if os.path.exists(model_data_path):
23 | if not os.listdir(model_data_path) == os.listdir(gt_data_path):
24 | info = returnNotMatches(os.listdir(model_data_path), os.listdir(gt_data_path))
25 | print('not match', model_name, data_name)
26 | print(info)
27 | else:
28 | print('not exist', model_name, data_name)
29 |
30 |
31 | def generate_model_config_py():
32 | root = './benchmark/COS-Benchmarking'
33 | write_txt_path = './examples_COS'
34 | os.makedirs(write_txt_path, exist_ok=True)
35 |
36 | # Open the file in write mode
37 | with open(f'{write_txt_path}/config_cos_method_py_example_all.py', 'w') as f:
38 | # Write some data to the file
39 | f.write('# -*- coding: utf-8 -*-\n')
40 | f.write('import os\n\n')
41 |
42 | for model_name in os.listdir(root):
43 | model_path = os.path.join(root, model_name)
44 | print('{}_root = \"{}\"'.format(model_name.split('-')[-1], model_path))
45 | f.write('{}_root = \"{}\" \n'.format(model_name.split('-')[-1], model_path))
46 | print('{} = {{'.format(model_name.split('-')[-1]))
47 | f.write('{} = {{ \n'.format(model_name.split('-')[-1]))
48 |
49 | for data_name in ['CAMO', 'NC4K', 'COD10K']:
50 | model_data_path = os.path.join(model_path, data_name)
51 | if os.path.exists(model_data_path):
52 | print('\t\"{}\": dict(path=os.path.join({}_root, \"{}\"), suffix=".png"),'.format(data_name, model_name.split('-')[-1], data_name))
53 | f.write('\t\"{}\": dict(path=os.path.join({}_root, \"{}\"), suffix=".png"), \n'.format(data_name, model_name.split('-')[-1], data_name))
54 | else:
55 | print('\t\"{}\": None,'.format(data_name))
56 | f.write('\t\"{}\": None, \n'.format(data_name))
57 | print('}\n')
58 | f.write('}\n\n')
59 |
60 | generate_model_config_py()
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/generate_latex_code.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/4/22
3 | # @Author : Daniel Ji
4 |
5 | eval_txt_path = './output_CDS2K/cos_results.txt'
6 |
7 | with open(eval_txt_path) as f:
8 |
9 | model_list, mae_list, maxf_list, avgf_list, adpf_list, maxe_list, avge_list, adpe_list, sm_list, wfm_list, sum_list = list(), list(), list(), list(), list(), list(), list(), list(), list(), list(), list()
10 |
11 | for line in f.readlines():
12 | if line.startswith('['):
13 | model = line.split('[')[1].split(']')[0]
14 | mae = float(line.split('mae: ')[1].split(' ')[0])
15 | maxf = float(line.split('maxf: ')[1].split(' ')[0])
16 | avgf = float(line.split('avgf: ')[1].split(' ')[0])
17 | adpf = float(line.split('adpf: ')[1].split(' ')[0])
18 | maxe = float(line.split('maxe: ')[1].split(' ')[0])
19 | avge = float(line.split('avge: ')[1].split(' ')[0])
20 | adpe = float(line.split('adpe: ')[1].split(' ')[0])
21 | sm = float(line.split('sm: ')[1].split(' ')[0])
22 | wfm = float(line.split('wfm: ')[1].split(' ')[0])
23 | sum = (1-mae) + maxf + avgf + adpf + maxe + avge + adpe + sm + wfm # optional
24 |
25 | print(f'{model}\n&{sm:.3f} &{wfm:.3f} &{mae:.3f} &{adpe:.3f} &{avge:.3f} &{maxe:.3f} &{adpf:.3f} &{avgf:.3f} &{maxf:.3f}')
26 |
27 | model_list.append(model)
28 | mae_list.append(mae)
29 | maxf_list.append(maxf)
30 | avgf_list.append(avgf)
31 | adpf_list.append(adpf)
32 | maxe_list.append(maxe)
33 | avge_list.append(avge)
34 | adpe_list.append(adpe)
35 | sm_list.append(sm)
36 | wfm_list.append(wfm)
37 | sum_list.append(sum)
38 |
39 | # print('\n', max(sm_list), max(wfm_list), min(mae_list), max(adpe_list), max(avge_list), max(maxe_list), max(adpf_list), max(avgf_list), max(maxf_list), max(sum_list))
40 |
41 | # sm_list.sort()
42 | # wfm_list.sort()
43 | # mae_list.sort(reverse=True)
44 | # adpe_list.sort()
45 | # avge_list.sort()
46 | # maxe_list.sort()
47 | # adpf_list.sort()
48 | # avgf_list.sort()
49 | # maxf_list.sort()
50 | # sum_list.sort()
51 |
52 | # print('\n', sm_list[-1], wfm_list[-1], mae_list[-1], adpe_list[-1], avge_list[-1], maxe_list[-1], adpf_list[-1], avgf_list[-1], maxf_list[-1])
53 | # print('\n', sm_list[-2], wfm_list[-2], mae_list[-2], adpe_list[-2], avge_list[-2], maxe_list[-2], adpf_list[-2], avgf_list[-2], maxf_list[-2])
54 | # print('\n', sm_list[-3], wfm_list[-3], mae_list[-3], adpe_list[-3], avge_list[-3], maxe_list[-3], adpf_list[-3], avgf_list[-3], maxf_list[-3])
55 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/info_py_to_json.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/3/14
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 | import argparse
6 | import ast
7 | import json
8 | import os
9 | import sys
10 | from importlib import import_module
11 |
12 |
13 | def validate_py_syntax(filename):
14 | with open(filename, "r") as f:
15 | content = f.read()
16 | try:
17 | ast.parse(content)
18 | except SyntaxError as e:
19 | raise SyntaxError("There are syntax errors in config " f"file {filename}: {e}")
20 |
21 |
22 | def convert_py_to_json(source_config_root, target_config_root):
23 | if not os.path.isdir(source_config_root):
24 | raise NotADirectoryError(source_config_root)
25 | if not os.path.exists(target_config_root):
26 | os.makedirs(target_config_root)
27 | else:
28 | if not os.path.isdir(target_config_root):
29 | raise NotADirectoryError(target_config_root)
30 |
31 | sys.path.insert(0, source_config_root)
32 | source_config_files = os.listdir(source_config_root)
33 | for source_config_file in source_config_files:
34 | source_config_path = os.path.join(source_config_root, source_config_file)
35 | if not (os.path.isfile(source_config_path) and source_config_path.endswith(".py")):
36 | continue
37 | validate_py_syntax(source_config_path)
38 | print(source_config_path)
39 |
40 | temp_module_name = os.path.splitext(source_config_file)[0]
41 | mod = import_module(temp_module_name)
42 |
43 | total_dict = {}
44 | for name, value in mod.__dict__.items():
45 | if not name.startswith("_") and isinstance(value, dict):
46 | total_dict[name] = value
47 |
48 | # delete imported module
49 | del sys.modules[temp_module_name]
50 |
51 | with open(
52 | os.path.join(target_config_root, os.path.basename(temp_module_name) + ".json"),
53 | encoding="utf-8",
54 | mode="w",
55 | ) as f:
56 | json.dump(total_dict, f, indent=2)
57 |
58 |
59 | def get_args():
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("-i", "--source-py-root", required=True, type=str)
62 | parser.add_argument("-o", "--target-json-root", required=True, type=str)
63 | args = parser.parse_args()
64 | return args
65 |
66 |
67 | if __name__ == "__main__":
68 | args = get_args()
69 | convert_py_to_json(
70 | source_config_root=args.source_py_root, target_config_root=args.target_json_root
71 | )
72 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/markdown2html.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | import markdown2
5 |
6 | index_template = """
7 |
8 |
9 |
10 |
11 | {html_title}
12 |
13 |
14 |
15 |
16 | {html_urls}
17 | {html_body}
18 |
19 |
20 |
21 | """
22 |
23 | html_template = """
24 |
25 |
26 |
27 |
28 | {html_title}
29 |
30 |
31 |
32 |
33 |
34 | {html_urls}
35 | {html_body}
36 |
37 |
38 |
39 | """
40 | list_item_template = "{item_name} "
41 |
42 |
43 | def save_html(html_text, html_path):
44 | with open(html_path, encoding="utf-8", mode="w") as f:
45 | f.write(html_text)
46 |
47 |
48 | md_root = "../results"
49 | html_root = "../results/htmls"
50 |
51 | md_names = [
52 | name[:-3]
53 | for name in os.listdir(md_root)
54 | if os.path.isfile(os.path.join(md_root, name)) and name.endswith(".md")
55 | ]
56 |
57 | url_list = "\n".join(
58 | [""]
59 | + [list_item_template.format(html_name=name, item_name=name.upper()) for name in md_names]
60 | + [" "]
61 | )
62 |
63 | index_html = index_template.format(html_title="Index", html_urls=url_list, html_body="Updating...")
64 | save_html(html_text=index_html, html_path=os.path.join(html_root, "index.html"))
65 |
66 | for md_name in md_names:
67 | html_body = markdown2.markdown_path(
68 | os.path.join(md_root, md_name + ".md"),
69 | extras={"tables": True, "html-classes": {"table": "sortable"}},
70 | )
71 | html = html_template.format(html_title=md_name, html_urls=url_list, html_body=html_body)
72 | save_html(html_text=html, html_path=os.path.join(html_root, md_name + ".html"))
73 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/readme.md:
--------------------------------------------------------------------------------
1 | # Useful tools
2 |
3 | ## `append_results.py`
4 |
5 | 将新生成的npy文件与旧的npy文件合并到一个新npy文件中。
6 |
7 | ```shell
8 | $ python append_results.py --help
9 | usage: append_results.py [-h] --old-npy OLD_NPY --new-npy NEW_NPY [--method-names METHOD_NAMES [METHOD_NAMES ...]]
10 | [--dataset-names DATASET_NAMES [DATASET_NAMES ...]] --out-npy OUT_NPY
11 |
12 | A simple tool for merging two npy file. Patch the method items corresponding to the `--method-names` and `--dataset-names` of `--new-npy`
13 | into `--old-npy`, and output the whole container to `--out-npy`.
14 |
15 | optional arguments:
16 | -h, --help show this help message and exit
17 | --old-npy OLD_NPY
18 | --new-npy NEW_NPY
19 | --method-names METHOD_NAMES [METHOD_NAMES ...]
20 | --dataset-names DATASET_NAMES [DATASET_NAMES ...]
21 | --out-npy OUT_NPY
22 | ```
23 |
24 | 使用情形:
25 |
26 | 对于rgb sod数据,我已经生成了包含一批方法的结果的npy文件:`old_rgb_sod_curves.npy`。
27 | 对于某个方法重新评估后又获得了一个新的npy文件(文件不重名,所以不会覆盖):`new_rgb_sod_curves.npy`。
28 | 现在我想要将这两个结果整合到一个文件中:`finalnew_rgb_sod_curves.npy`。
29 | 可以通过如下指令实现。
30 |
31 | ```shell
32 | python tools/append_results.py --old-npy output/old_rgb_sod_curves.npy \
33 | --new-npy output/new_rgb_sod_curves.npy \
34 | --out-npy output/finalnew_rgb_sod_curves.npy
35 | ```
36 |
37 |
38 | ## `converter.py`
39 |
40 | 将生成的 `*_metrics.npy` 文件中的信息导出成latex表格的形式.
41 |
42 | 可以按照例子文件夹中的 `examples/converter_config.py` 进行手动配置, 从而针对性的生成latex表格代码.
43 |
44 | ```shell
45 | $ python converter.py --help
46 | usage: converter.py [-h] -i RESULT_FILE [RESULT_FILE ...] -o TEX_FILE [-c CONFIG_FILE] [--contain-table-env] [--transpose]
47 |
48 | A useful and convenient tool to convert your .npy results into the table code in latex.
49 |
50 | optional arguments:
51 | -h, --help show this help message and exit
52 | -i RESULT_FILE [RESULT_FILE ...], --result-file RESULT_FILE [RESULT_FILE ...]
53 | The path of the *_metrics.npy file.
54 | -o TEX_FILE, --tex-file TEX_FILE
55 | The path of the exported tex file.
56 | -c CONFIG_FILE, --config-file CONFIG_FILE
57 | The path of the customized config yaml file.
58 | --contain-table-env Whether to containe the table env in the exported code.
59 | --transpose Whether to transpose the table.
60 | ```
61 |
62 | 使用案例如下.
63 |
64 | ```shell
65 | $ python tools/converter.py -i output/your_metrics_1.npy output/your_metrics_2.npy -o output/your_metrics.tex -c ./examples/converter_config.yaml --transpose --contain-table-env
66 | ```
67 |
68 | 该指令从多个npy文件中(如果你仅有一个, 可以紧跟一个npy文件)读取数据, 处理后导出到指定的tex文件中. 并且使用指定的config文件设置了相关的数据集、指标以及模型方法.
69 |
70 | 另外, 对于输出的表格代码, 使用转置后的竖表, 并且包含table的 `tabular` 环境代码.
71 |
72 | ## `check_path.py`
73 |
74 | 通过将json中的信息与实际系统中的路径进行匹配, 检验是否存在异常.
75 |
76 | ```shell
77 | $ python check_path.py --help
78 | usage: check_path.py [-h] -m METHOD_JSONS [METHOD_JSONS ...] -d DATASET_JSONS [DATASET_JSONS ...]
79 |
80 | A simple tool for checking your json config file.
81 |
82 | optional arguments:
83 | -h, --help show this help message and exit
84 | -m METHOD_JSONS [METHOD_JSONS ...], --method-jsons METHOD_JSONS [METHOD_JSONS ...]
85 | The json file about all methods.
86 | -d DATASET_JSONS [DATASET_JSONS ...], --dataset-jsons DATASET_JSONS [DATASET_JSONS ...]
87 | ```
88 |
89 | 使用案例如下, 请务必确保, `-m` 和 `-d` 后的文件路径是一一对应的, 在代码中会使用 `zip` 来同步迭代两个列表来确认文件名称的匹配关系.
90 |
91 | ```shell
92 | $ python check_path.py -m ../configs/methods/json/rgb_sod_methods.json ../configs/methods/json/rgbd_sod_methods.json \
93 | -d ../configs/datasets/json/rgb_sod.json ../configs/datasets/json/rgbd_sod.json
94 | ```
95 |
96 | ## `info_py_to_json.py`
97 |
98 | 将基于python格式的配置文件转化为更便携的json文件.
99 |
100 | ```shell
101 | $ python info_py_to_json.py --help
102 | usage: info_py_to_json.py [-h] -i SOURCE_PY_ROOT -o TARGET_JSON_ROOT
103 |
104 | optional arguments:
105 | -h, --help show this help message and exit
106 | -i SOURCE_PY_ROOT, --source-py-root SOURCE_PY_ROOT
107 | -o TARGET_JSON_ROOT, --target-json-root TARGET_JSON_ROOT
108 | ```
109 |
110 | 即提供了两个必需提供的配置项, 存放python配置文件的输入目录 `-i` 和将要存放生成的json文件输出目录 `-o` .
111 |
112 | 通过载入输入目录中各个python文件, 我们默认从中获取内部包含的不使用 `_` 开头的字典对象对应于各个数据集或者方法的配置信息.
113 |
114 | 最后将这些信息汇总到一个完整的字典中, 直接导出到json文件中, 保存到输出目录下.
115 |
116 | ## `markdown2html.py`
117 |
118 | 该文件用于将results中存放结果信息的markdown文件转化为html文件, 便于在基于github page的静态站点上进行展示.
119 |
120 | ## `rename.py`
121 |
122 | 批量重命名.
123 |
124 | 使用前建议通读代码, 请小心使用, 防止文件覆盖造成不必要的损失.
125 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/tools/rename.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/4/24
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 |
6 | import glob
7 | import os
8 | import re
9 | import shutil
10 |
11 |
12 | def path_join(base_path, sub_path):
13 | if sub_path.startswith(os.sep):
14 | sub_path = sub_path[len(os.sep) :]
15 | return os.path.join(base_path, sub_path)
16 |
17 |
18 | def rename_all_files(src_pattern, dst_pattern, src_name, src_dir, dst_dir=None):
19 | """
20 | :param src_pattern: 匹配原始数据名字的正则表达式
21 | :param dst_pattern: 对应的修改后的字符式
22 | :param src_dir: 存放原始数据的文件夹路径,可以组合src_name来构造路径模式,使用glob进行数据搜索
23 | :param src_name: glob类型的模式
24 | :param dst_dir: 存放修改后数据的文件夹路径,默认为None,表示直接修改原始数据
25 | """
26 | assert os.path.isdir(src_dir)
27 |
28 | if dst_dir is None:
29 | dst_dir = src_dir
30 | rename_func = os.replace
31 | else:
32 | assert os.path.isdir(dst_dir)
33 | if dst_dir == src_dir:
34 | rename_func = os.replace
35 | else:
36 | rename_func = shutil.copy
37 | print(f"将会使用 {rename_func.__name__} 来修改数据")
38 |
39 | src_dir = os.path.abspath(src_dir)
40 | dst_dir = os.path.abspath(dst_dir)
41 | src_data_paths = glob.glob(path_join(src_dir, src_name))
42 |
43 | print(f"开始替换 {src_dir} 中的数据")
44 | for idx, src_data_path in enumerate(src_data_paths, start=1):
45 | src_name_w_dir_name = src_data_path[len(src_dir) + 1 :]
46 | dst_name_w_dir_name = re.sub(src_pattern, repl=dst_pattern, string=src_name_w_dir_name)
47 | if dst_name_w_dir_name == src_name_w_dir_name:
48 | continue
49 | dst_data_path = path_join(dst_dir, dst_name_w_dir_name)
50 |
51 | dst_data_dir = os.path.dirname(dst_data_path)
52 | if not os.path.exists(dst_data_dir):
53 | print(f"{idx}: {dst_data_dir} 不存在,新建一下")
54 | os.makedirs(dst_data_dir)
55 | rename_func(src=src_data_path, dst=dst_data_path)
56 | print(f"{src_data_path} -> {dst_data_path}")
57 |
58 | print("OK...")
59 |
60 |
61 | if __name__ == "__main__":
62 | rename_all_files(
63 | src_pattern=r"",
64 | dst_pattern="",
65 | src_name="*/*.png",
66 | src_dir="",
67 | dst_dir="",
68 | )
69 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/CSU/892e7bf716e75dd1506a97be80b1d04b03b21965/cos_eval_toolbox/utils/__init__.py
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/generate_info.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import json
4 | import os
5 | from collections import OrderedDict
6 |
7 | from matplotlib import colors
8 |
9 | # max = 148
10 | _COLOR_Genarator = iter(
11 | sorted(
12 | [
13 | color
14 | for name, color in colors.cnames.items()
15 | if name not in ["red", "white"] or not name.startswith("light") or "gray" in name
16 | ]
17 | )
18 | )
19 |
20 |
21 | def curve_info_generator():
22 | # TODO 当前尽是对方法依次赋予`-`和`--`,但是对于某一方法仅包含特定数据集的结果时,
23 | # 会导致确实结果的数据集中的该方法相邻的两个方法style一样
24 | line_style_flag = True
25 |
26 | def _template_generator(
27 | method_info: dict, method_name: str, line_color: str = None, line_width: int = 2
28 | ) -> dict:
29 | nonlocal line_style_flag
30 | template_info = dict(
31 | path_dict=method_info,
32 | curve_setting=dict(
33 | line_style="-" if line_style_flag else "--",
34 | line_label=method_name,
35 | line_width=line_width,
36 | ),
37 | )
38 | if line_color is not None:
39 | template_info["curve_setting"]["line_color"] = line_color
40 | else:
41 | template_info["curve_setting"]["line_color"] = next(_COLOR_Genarator)
42 |
43 | line_style_flag = not line_style_flag
44 | return template_info
45 |
46 | return _template_generator
47 |
48 |
49 | def simple_info_generator():
50 | def _template_generator(method_info: dict, method_name: str) -> dict:
51 | template_info = dict(path_dict=method_info, label=method_name)
52 | return template_info
53 |
54 | return _template_generator
55 |
56 |
57 | def get_valid_elements(
58 | source: list,
59 | include_elements: list = None,
60 | exclude_elements: list = None,
61 | ):
62 | if include_elements is None:
63 | include_elements = []
64 | if exclude_elements is None:
65 | exclude_elements = []
66 | assert not set(include_elements).intersection(
67 | exclude_elements
68 | ), "`include_elements` and `exclude_elements` must have no intersection."
69 |
70 | targeted = set(source).difference(exclude_elements)
71 | assert targeted, "`exclude_elements can not include all datasets."
72 |
73 | if include_elements:
74 | # include_elements: [] or [dataset1_name, dataset2_name, ...]
75 | # only latter will be used to select datasets from `targeted`
76 | targeted = targeted.intersection(include_elements)
77 |
78 | return list(targeted)
79 |
80 |
81 | def get_methods_info(
82 | methods_info_jsons: list,
83 | for_drawing: bool = False,
84 | our_name: str = None,
85 | include_methods: list = None,
86 | exclude_methods: list = None,
87 | ) -> OrderedDict:
88 | """
89 | 在json文件中存储的对应方法的字典的键值会被直接用于绘图
90 |
91 | :param methods_info_jsons: 保存方法信息的json文件,支持多个文件组合使用,按照输入的顺序依此读取
92 | :param for_drawing: 是否用于绘制曲线图,True会补充一些绘图信息
93 | :param our_name: 在绘图时,可以通过指定our_name来使用红色加粗实线强调特定方法的曲线
94 | :param include_methods: 仅返回列表中指定的方法的信息,为None时,返回所有
95 | :param exclude_methods: 仅返回列表中指定的方法的信息,为None时,返回所有,与include_datasets必须仅有一个非None
96 | :return: methods_full_info
97 | """
98 | if not isinstance(methods_info_jsons, (list, tuple)):
99 | methods_info_jsons = [methods_info_jsons]
100 |
101 | methods_info = {}
102 | for f in methods_info_jsons:
103 | if not os.path.isfile(f):
104 | raise FileNotFoundError(f"{f} is not be found!!!")
105 |
106 | with open(f, encoding="utf-8", mode="r") as f:
107 | methods_info.update(json.load(f, object_hook=OrderedDict)) # 有序载入
108 |
109 | if our_name:
110 | assert our_name in methods_info, f"{our_name} is not in json file."
111 |
112 | targeted_methods = get_valid_elements(
113 | source=list(methods_info.keys()),
114 | include_elements=include_methods,
115 | exclude_elements=exclude_methods,
116 | )
117 | if our_name and our_name in targeted_methods:
118 | targeted_methods.pop(targeted_methods.index(our_name))
119 | targeted_methods.sort()
120 | targeted_methods.insert(0, our_name)
121 |
122 | if for_drawing:
123 | info_generator = curve_info_generator()
124 | else:
125 | info_generator = simple_info_generator()
126 |
127 | methods_full_info = []
128 | for method_name in targeted_methods:
129 | method_path = methods_info[method_name]
130 |
131 | if for_drawing and our_name and our_name == method_name:
132 | method_info = info_generator(method_path, method_name, line_color="red", line_width=3)
133 | else:
134 | method_info = info_generator(method_path, method_name)
135 | methods_full_info.append((method_name, method_info))
136 | return OrderedDict(methods_full_info)
137 |
138 |
139 | def get_datasets_info(
140 | datastes_info_json: str,
141 | include_datasets: list = None,
142 | exclude_datasets: list = None,
143 | ) -> OrderedDict:
144 | """
145 | 在json文件中存储的所有数据集的信息会被直接导出到一个字典中
146 |
147 | :param datastes_info_json: 保存方法信息的json文件
148 | :param include_datasets: 指定读取信息的数据集名字,为None时,读取所有
149 | :param exclude_datasets: 排除读取信息的数据集名字,为None时,读取所有,与include_datasets必须仅有一个非None
150 | :return: datastes_full_info
151 | """
152 |
153 | assert os.path.isfile(datastes_info_json), datastes_info_json
154 | with open(datastes_info_json, encoding="utf-8", mode="r") as f:
155 | datasets_info = json.load(f, object_hook=OrderedDict) # 有序载入
156 |
157 | targeted_datasets = get_valid_elements(
158 | source=list(datasets_info.keys()),
159 | include_elements=include_datasets,
160 | exclude_elements=exclude_datasets,
161 | )
162 | targeted_datasets.sort()
163 |
164 | datasets_full_info = []
165 | for dataset_name in targeted_datasets:
166 | data_path = datasets_info[dataset_name]
167 |
168 | datasets_full_info.append((dataset_name, data_path))
169 | return OrderedDict(datasets_full_info)
170 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/misc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 |
8 |
9 | def get_ext(path_list):
10 | ext_list = list(set([os.path.splitext(p)[1] for p in path_list]))
11 | if len(ext_list) != 1:
12 | if ".png" in ext_list:
13 | ext = ".png"
14 | elif ".jpg" in ext_list:
15 | ext = ".jpg"
16 | elif ".bmp" in ext_list:
17 | ext = ".bmp"
18 | else:
19 | raise NotImplementedError
20 | print(f"预测文件夹中包含多种扩展名,这里仅使用{ext}")
21 | else:
22 | ext = ext_list[0]
23 | return ext
24 |
25 |
26 | def get_name_list_and_suffix(data_path: str) -> tuple:
27 | name_list = []
28 | if os.path.isfile(data_path):
29 | print(f" ++>> {data_path} is a file. <<++ ")
30 | with open(data_path, mode="r", encoding="utf-8") as file:
31 | line = file.readline()
32 | while line:
33 | img_name = os.path.basename(line.split()[0])
34 | file_ext = os.path.splitext(img_name)[1]
35 | name_list.append(os.path.splitext(img_name)[0])
36 | line = file.readline()
37 | if file_ext == "":
38 | # 默认为png
39 | file_ext = ".png"
40 | else:
41 | print(f" ++>> {data_path} is a folder. <<++ ")
42 | data_list = os.listdir(data_path)
43 | file_ext = get_ext(data_list)
44 | name_list = [os.path.splitext(f)[0] for f in data_list if f.endswith(file_ext)]
45 | name_list = list(set(name_list))
46 | return name_list, file_ext
47 |
48 |
49 | def get_name_list(data_path: str, name_prefix: str = "", name_suffix: str = "") -> list:
50 | if os.path.isfile(data_path):
51 | assert data_path.endswith((".txt", ".lst"))
52 | data_list = []
53 | with open(data_path, encoding="utf-8", mode="r") as f:
54 | line = f.readline().strip()
55 | while line:
56 | data_list.append(line)
57 | line = f.readline().strip()
58 | else:
59 | data_list = os.listdir(data_path)
60 |
61 | name_list = data_list
62 | if not name_prefix and not name_suffix:
63 | name_list = [os.path.splitext(f)[0] for f in name_list]
64 | else:
65 | name_list = [
66 | f[len(name_prefix) : -len(name_suffix)]
67 | for f in name_list
68 | if f.startswith(name_prefix) and f.endswith(name_suffix)
69 | ]
70 |
71 | name_list = list(set(name_list))
72 | return name_list
73 |
74 |
75 | def get_name_with_group_list(data_path: str, file_ext: str = None) -> list:
76 | name_list = []
77 | if os.path.isfile(data_path):
78 | print(f" ++>> {data_path} is a file. <<++ ")
79 | with open(data_path, mode="r", encoding="utf-8") as file:
80 | line = file.readline()
81 | while line:
82 | img_name_with_group = line.split()
83 | name_list.append(os.path.splitext(img_name_with_group)[0])
84 | line = file.readline()
85 | else:
86 | print(f" ++>> {data_path} is a folder. <<++ ")
87 | group_names = sorted(os.listdir(data_path))
88 | for group_name in group_names:
89 | image_names = [
90 | "/".join([group_name, x])
91 | for x in sorted(os.listdir(os.path.join(data_path, group_name)))
92 | ]
93 | if file_ext is not None:
94 | name_list += [os.path.splitext(f)[0] for f in image_names if f.endswith(file_ext)]
95 | else:
96 | name_list += [os.path.splitext(f)[0] for f in image_names]
97 | # group_name/file_name.ext
98 | name_list = list(set(name_list)) # 去重
99 | return name_list
100 |
101 |
102 | def get_list_with_postfix(dataset_path: str, postfix: str):
103 | name_list = []
104 | if os.path.isfile(dataset_path):
105 | print(f" ++>> {dataset_path} is a file. <<++ ")
106 | with open(dataset_path, mode="r", encoding="utf-8") as file:
107 | line = file.readline()
108 | while line:
109 | img_name = os.path.basename(line.split()[0])
110 | name_list.append(os.path.splitext(img_name)[0])
111 | line = file.readline()
112 | else:
113 | print(f" ++>> {dataset_path} is a folder. <<++ ")
114 | name_list = [
115 | os.path.splitext(f)[0] for f in os.listdir(dataset_path) if f.endswith(postfix)
116 | ]
117 | name_list = list(set(name_list))
118 | return name_list
119 |
120 |
121 | def rgb_loader(path):
122 | with open(path, "rb") as f:
123 | img = Image.open(f)
124 | return img.convert("L")
125 |
126 |
127 | def binary_loader(path):
128 | assert os.path.exists(path), f"`{path}` does not exist."
129 | with open(path, "rb") as f:
130 | img = Image.open(f)
131 | return img.convert("L")
132 |
133 |
134 | def load_data(pre_root, gt_root, name, postfixs):
135 | pre = binary_loader(os.path.join(pre_root, name + postfixs[0]))
136 | gt = binary_loader(os.path.join(gt_root, name + postfixs[1]))
137 | return pre, gt
138 |
139 |
140 | def normalize_pil(pre, gt):
141 | gt = np.asarray(gt)
142 | pre = np.asarray(pre)
143 | gt = gt / (gt.max() + 1e-8)
144 | gt = np.where(gt > 0.5, 1, 0)
145 | max_pre = pre.max()
146 | min_pre = pre.min()
147 | if max_pre == min_pre:
148 | pre = pre / 255
149 | else:
150 | pre = (pre - min_pre) / (max_pre - min_pre)
151 | return pre, gt
152 |
153 |
154 | def make_dir(path):
155 | if not os.path.exists(path):
156 | print(f"`{path}` does not exist,we will create it.")
157 | os.makedirs(path)
158 | else:
159 | assert os.path.isdir(path), f"`{path}` should be a folder"
160 | print(f"`{path}`已存在")
161 |
162 |
163 | def imread_wich_checking(path, for_color: bool = True, with_cv2: bool = True) -> np.ndarray:
164 | assert os.path.exists(path=path) and os.path.isfile(path=path), path
165 | if with_cv2:
166 | if for_color:
167 | data = cv2.imread(path, flags=cv2.IMREAD_COLOR)
168 | data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
169 | else:
170 | data = cv2.imread(path, flags=cv2.IMREAD_GRAYSCALE)
171 | else:
172 | data = np.array(Image.open(path).convert("RGB" if for_color else "L"))
173 | return data
174 |
175 |
176 | def get_gt_pre_with_name(
177 | gt_root: str,
178 | pre_root: str,
179 | img_name: str,
180 | pre_prefix: str,
181 | pre_suffix: str,
182 | gt_ext: str = ".png",
183 | to_normalize: bool = False,
184 | ):
185 | img_path = os.path.join(pre_root, pre_prefix + img_name + pre_suffix)
186 | gt_path = os.path.join(gt_root, img_name + gt_ext)
187 |
188 | pre = imread_wich_checking(img_path, for_color=False)
189 | gt = imread_wich_checking(gt_path, for_color=False)
190 |
191 | if pre.shape != gt.shape:
192 | pre = cv2.resize(pre, dsize=gt.shape[::-1], interpolation=cv2.INTER_LINEAR).astype(
193 | np.uint8
194 | )
195 |
196 | if to_normalize:
197 | gt = normalize_array(gt, to_binary=True, max_eq_255=True)
198 | pre = normalize_array(pre, to_binary=False, max_eq_255=True)
199 | return gt, pre
200 |
201 |
202 | def normalize_array(
203 | data: np.ndarray, to_binary: bool = False, max_eq_255: bool = True
204 | ) -> np.ndarray:
205 | if max_eq_255:
206 | data = data / 255
207 | # else: data is in [0, 1]
208 | if to_binary:
209 | data = (data > 0.5).astype(np.uint8)
210 | else:
211 | if data.max() != data.min():
212 | data = (data - data.min()) / (data.max() - data.min())
213 | data = data.astype(np.float32)
214 | return data
215 |
216 |
217 | def get_valid_key_name(data_dict: dict, key_name: str) -> str:
218 | if data_dict.get(key_name.lower(), "keyerror") == "keyerror":
219 | key_name = key_name.upper()
220 | else:
221 | key_name = key_name.lower()
222 | return key_name
223 |
224 |
225 | def get_target_key(target_dict: dict, key: str) -> str:
226 | """
227 | from the keys of the target_dict, get the valid key name corresponding to the `key`
228 | if there is not a valid name, return None
229 | """
230 | target_keys = {k.lower(): k for k in target_dict.keys()}
231 | return target_keys.get(key.lower(), None)
232 |
233 |
234 | def colored_print(msg: str, mode: str = "general"):
235 | """
236 | 为不同类型的字符串消息的打印提供一些显示格式的定制
237 |
238 | :param msg: 要输出的字符串消息
239 | :param mode: 对应的字符串打印模式,目前支持 general/warning/error
240 | :return:
241 | """
242 | if mode == "general":
243 | msg = msg
244 | elif mode == "warning":
245 | msg = f"\033[5;31m{msg}\033[0m"
246 | elif mode == "error":
247 | msg = f"\033[1;31m{msg}\033[0m"
248 | else:
249 | raise ValueError(f"{mode} is invalid mode.")
250 | print(msg)
251 |
252 |
253 | class ColoredPrinter:
254 | """
255 | 为不同类型的字符串消息的打印提供一些显示格式的定制
256 | """
257 |
258 | @staticmethod
259 | def info(msg):
260 | print(msg)
261 |
262 | @staticmethod
263 | def warn(msg):
264 | msg = f"\033[5;31m{msg}\033[0m"
265 | print(msg)
266 |
267 | @staticmethod
268 | def error(msg):
269 | msg = f"\033[1;31m{msg}\033[0m"
270 | print(msg)
271 |
272 |
273 | def update_info(source_info: dict, new_info: dict):
274 | for name, info in source_info.items():
275 | if name in new_info:
276 | if isinstance(info, dict):
277 | update_info(source_info=info, new_info=new_info[name])
278 | else: # int, float, list, tuple
279 | info = new_info[name]
280 | source_info[name] = info
281 | return source_info
282 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/print_formatter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from tabulate import tabulate
4 |
5 |
6 | def print_formatter(
7 | results: dict, method_name_length=10, metric_name_length=5, metric_value_length=5
8 | ):
9 | dataset_regions = []
10 | for dataset_name, dataset_metrics in results.items():
11 | dataset_head_row = f" Dataset: {dataset_name} "
12 | dataset_region = [dataset_head_row]
13 | for method_name, metric_info in dataset_metrics.items():
14 | showed_method_name = clip_string(
15 | method_name, max_length=method_name_length, mode="left"
16 | )
17 | method_row_head = f"{showed_method_name} "
18 | method_row_body = []
19 | for metric_name, metric_value in metric_info.items():
20 | showed_metric_name = clip_string(
21 | metric_name, max_length=metric_name_length, mode="right"
22 | )
23 | showed_value_string = clip_string(
24 | str(metric_value), max_length=metric_value_length, mode="left"
25 | )
26 | method_row_body.append(f"{showed_metric_name}: {showed_value_string}")
27 | method_row = method_row_head + ", ".join(method_row_body)
28 | dataset_region.append(method_row)
29 | dataset_region_string = "\n".join(dataset_region)
30 | dataset_regions.append(dataset_region_string)
31 | dividing_line = "\n" + "-" * len(dataset_region[-1]) + "\n" # 直接使用最后一个数据集区域的最后一行数据的长度作为分割线的长度
32 | formatted_string = dividing_line.join(dataset_regions)
33 | return formatted_string
34 |
35 |
36 | def clip_string(string: str, max_length: int, padding_char: str = " ", mode: str = "left"):
37 | assert isinstance(max_length, int), f"{max_length} must be `int` type"
38 |
39 | real_length = len(string)
40 | if real_length <= max_length:
41 | padding_length = max_length - real_length
42 | if mode == "left":
43 | clipped_string = string + f"{padding_char}" * padding_length
44 | elif mode == "center":
45 | left_padding_str = f"{padding_char}" * (padding_length // 2)
46 | right_padding_str = f"{padding_char}" * (padding_length - padding_length // 2)
47 | clipped_string = left_padding_str + string + right_padding_str
48 | elif mode == "right":
49 | clipped_string = f"{padding_char}" * padding_length + string
50 | else:
51 | raise NotImplementedError
52 | else:
53 | clipped_string = string[:max_length]
54 |
55 | return clipped_string
56 |
57 |
58 | def formatter_for_tabulate(
59 | results: dict,
60 | dataset_titlefmt: str = "Dataset: {}",
61 | method_name_length=None,
62 | metric_value_length=None,
63 | tablefmt="github",
64 | ):
65 | """
66 | tabulate format:
67 |
68 | ::
69 |
70 | table = [["spam",42],["eggs",451],["bacon",0]]
71 | headers = ["item", "qty"]
72 | print(tabulate(table, headers, tablefmt="github"))
73 |
74 | | item | qty |
75 | |--------|-------|
76 | | spam | 42 |
77 | | eggs | 451 |
78 | | bacon | 0 |
79 |
80 | 本函数的作用:
81 | 针对不同的数据集各自构造符合tabulate格式的列表并使用换行符间隔串联起来返回
82 | """
83 | all_tables = []
84 | for dataset_name, dataset_metrics in results.items():
85 | all_tables.append(dataset_titlefmt.format(dataset_name))
86 |
87 | table = []
88 | headers = ["methods"]
89 | for method_name, metric_info in dataset_metrics.items():
90 | if method_name_length:
91 | method_name = clip_string(method_name, max_length=method_name_length, mode="left")
92 | method_row = [method_name]
93 | # 保障顺序的一致性,虽然python3中已经实现了字典的有序性,但是为了确保万无一失,(毕竟可能设计到导出和导入)这里直接重新排序
94 | for metric_name, metric_value in sorted(metric_info.items(), key=lambda item: item[0]):
95 | if metric_value_length:
96 | metric_value = clip_string(
97 | str(metric_value), max_length=metric_value_length, mode="center"
98 | )
99 | if metric_name not in headers:
100 | headers.append(metric_name)
101 | method_row.append(metric_value)
102 | table.append(method_row)
103 | all_tables.append(tabulate(table, headers, tablefmt=tablefmt))
104 |
105 | formatted_string = "\n".join(all_tables)
106 | return formatted_string
107 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/recorders/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from .curve_drawer import CurveDrawer
3 | from .excel_recorder import MetricExcelRecorder
4 | from .metric_recorder import (
5 | METRIC_MAPPING,
6 | GroupedMetricRecorder,
7 | MetricRecorder,
8 | MetricRecorder_V2,
9 | )
10 | from .txt_recorder import TxtRecorder
11 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/recorders/curve_drawer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/1/4
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 | import math
6 | import os
7 |
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | class CurveDrawer(object):
12 | def __init__(
13 | self,
14 | row_num,
15 | num_subplots,
16 | style_cfg=None,
17 | ncol_of_legend=1,
18 | separated_legend=False,
19 | sharey=False,
20 | ):
21 | """A better wrapper of matplotlib for me.
22 |
23 | Args:
24 | row_num (int): Number of rows.
25 | num_subplots (int): Number of subplots.
26 | style_cfg (str, optional): Style yaml file path for matplotlib. Defaults to None.
27 | ncol_of_legend (int, optional): Number of columns of the legend. Defaults to 1.
28 | separated_legend (bool, optional): Use the separated legend. Defaults to False.
29 | sharey (bool, optional): Use a shared y-axis. Defaults to False.
30 | """
31 | if style_cfg is not None:
32 | assert os.path.isfile(style_cfg)
33 | plt.style.use(style_cfg)
34 |
35 | self.ncol_of_legend = ncol_of_legend
36 | self.separated_legend = separated_legend
37 | if self.separated_legend:
38 | num_subplots += 1
39 | self.num_subplots = num_subplots
40 | self.sharey = sharey
41 |
42 | fig, axes = plt.subplots(
43 | nrows=row_num, ncols=math.ceil(self.num_subplots / row_num), sharey=self.sharey
44 | )
45 | self.fig = fig
46 | self.axes = axes.flatten()
47 |
48 | self.init_subplots()
49 | self.dummy_data = {}
50 |
51 | def init_subplots(self):
52 | for ax in self.axes:
53 | ax.set_axis_off()
54 |
55 | def plot_at_axis(self, idx, method_curve_setting, x_data, y_data):
56 | """
57 | :param method_curve_setting: {
58 | "line_color": "color"(str),
59 | "line_style": "style"(str),
60 | "line_label": "label"(str),
61 | "line_width": width(int),
62 | }
63 | """
64 | assert isinstance(idx, int) and 0 <= idx < self.num_subplots
65 | self.axes[idx].plot(
66 | x_data,
67 | y_data,
68 | linewidth=method_curve_setting["line_width"],
69 | label=method_curve_setting["line_label"],
70 | color=method_curve_setting["line_color"],
71 | linestyle=method_curve_setting["line_style"],
72 | )
73 |
74 | if self.separated_legend:
75 | self.dummy_data[method_curve_setting["line_label"]] = method_curve_setting
76 |
77 | def set_axis_property(
78 | self, idx, title=None, x_label=None, y_label=None, x_ticks=None, y_ticks=None
79 | ):
80 | ax = self.axes[idx]
81 |
82 | ax.set_axis_on()
83 |
84 | # give plot a title
85 | ax.set_title(title)
86 |
87 | # make axis labels
88 | ax.set_xlabel(x_label)
89 | ax.set_ylabel(y_label)
90 |
91 | # 对坐标刻度的设置
92 | x_ticks = [] if x_ticks is None else x_ticks
93 | y_ticks = [] if y_ticks is None else y_ticks
94 | ax.set_xlim((min(x_ticks), max(x_ticks)))
95 | ax.set_ylim((min(y_ticks), max(x_ticks)))
96 | ax.set_xticks(x_ticks)
97 | ax.set_yticks(y_ticks)
98 | ax.set_xticklabels(labels=[f"{x:.2f}" for x in x_ticks])
99 | ax.set_yticklabels(labels=[f"{y:.2f}" for y in y_ticks])
100 |
101 | def _plot(self):
102 | if self.sharey:
103 | for ax in self.axes[1:]:
104 | ax.set_ylabel(None)
105 | ax.tick_params(bottom=True, top=False, left=False, right=False)
106 |
107 | if self.separated_legend:
108 | # settings for the legend axis
109 | for method_label, method_info in self.dummy_data.items():
110 | self.plot_at_axis(
111 | idx=self.num_subplots - 1, method_curve_setting=method_info, x_data=0, y_data=0
112 | )
113 | ax = self.axes[self.num_subplots - 1]
114 | ax.set_axis_off()
115 | ax.legend(loc=10, ncol=self.ncol_of_legend, facecolor="white", edgecolor="white")
116 | else:
117 | # settings for the legneds of all common subplots.
118 | for ax in self.axes:
119 | # loc=0,自动将位置放在最合适的
120 | ax.legend(loc=3, ncol=self.ncol_of_legend, facecolor="white", edgecolor="white")
121 |
122 | def show(self):
123 | self._plot()
124 | plt.tight_layout()
125 | plt.show()
126 |
127 | def save(self, path):
128 | self._plot()
129 | plt.tight_layout()
130 | plt.savefig(path)
131 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/recorders/excel_recorder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/1/3
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 |
6 | import os
7 | import re
8 |
9 | from openpyxl import Workbook, load_workbook
10 | from openpyxl.utils import get_column_letter
11 | from openpyxl.worksheet.worksheet import Worksheet
12 |
13 |
14 | # Thanks:
15 | # - Python_Openpyxl: https://www.cnblogs.com/programmer-tlh/p/10461353.html
16 | # - Python之re模块: https://www.cnblogs.com/shenjianping/p/11647473.html
17 | class _BaseExcelRecorder(object):
18 | def __init__(self, xlsx_path: str):
19 | """
20 | 提供写xlsx文档功能的基础类。主要基于openpyxl实现了一层更方便的封装。
21 |
22 | :param xlsx_path: xlsx文档的路径。
23 | """
24 | self.xlsx_path = xlsx_path
25 | if not os.path.exists(self.xlsx_path):
26 | print("We have created a new excel file!!!")
27 | self._initial_xlsx()
28 | else:
29 | print("Excel file has existed!")
30 |
31 | def _initial_xlsx(self):
32 | Workbook().save(self.xlsx_path)
33 |
34 | def load_sheet(self, sheet_name: str):
35 | wb = load_workbook(self.xlsx_path)
36 | if sheet_name not in wb.sheetnames:
37 | wb.create_sheet(title=sheet_name, index=0)
38 | sheet = wb[sheet_name]
39 |
40 | return wb, sheet
41 |
42 | def append_row(self, sheet: Worksheet, row_data):
43 | assert isinstance(row_data, (tuple, list))
44 | sheet.append(row_data)
45 |
46 | def insert_row(self, sheet: Worksheet, row_data, row_id, min_col=1, interval=0):
47 | assert isinstance(row_id, int) and isinstance(min_col, int) and row_id > 0 and min_col > 0
48 | assert isinstance(row_data, (tuple, list)), row_data
49 |
50 | num_elements = len(row_data)
51 | row_data = iter(row_data)
52 | for row in sheet.iter_rows(
53 | min_row=row_id,
54 | max_row=row_id,
55 | min_col=min_col,
56 | max_col=min_col + (interval + 1) * (num_elements - 1),
57 | ):
58 | for i, cell in enumerate(row):
59 | if i % (interval + 1) == 0:
60 | sheet.cell(row=row_id, column=cell.column, value=next(row_data))
61 |
62 | @staticmethod
63 | def merge_region(sheet: Worksheet, min_row, max_row, min_col, max_col):
64 | assert max_row >= min_row > 0 and max_col >= min_col > 0
65 |
66 | merged_region = (
67 | f"{get_column_letter(min_col)}{min_row}:{get_column_letter(max_col)}{max_row}"
68 | )
69 | sheet.merge_cells(merged_region)
70 |
71 | @staticmethod
72 | def get_col_id_with_row_id(sheet: Worksheet, col_name: str, row_id):
73 | """
74 | 从指定行中寻找特定的列名,并返回对应的列序号
75 | """
76 | assert isinstance(row_id, int) and row_id > 0
77 |
78 | for cell in sheet[row_id]:
79 | if cell.value == col_name:
80 | return cell.column
81 | raise ValueError(f"In row {row_id}, there is not the column {col_name}!")
82 |
83 | def get_row_id_with_col_name(self, sheet: Worksheet, row_name: str, col_name: str):
84 | """
85 | 从指定列名字的一列中寻找指定行,返回对应的row_id, col_id, is_new_row
86 | """
87 | is_new_row = True
88 | col_id = self.get_col_id_with_row_id(sheet=sheet, col_name=col_name, row_id=1)
89 |
90 | row_id = 0
91 | for cell in sheet[get_column_letter(col_id)]:
92 | row_id = cell.row
93 | if cell.value == row_name:
94 | return (row_id, col_id), not is_new_row
95 | return (row_id + 1, col_id), is_new_row
96 |
97 | @staticmethod
98 | def get_row_id_with_col_id(sheet: Worksheet, row_name: str, col_id: int):
99 | """
100 | 从指定序号的一列中寻找指定行
101 | """
102 | assert isinstance(col_id, int) and col_id > 0
103 |
104 | is_new_row = True
105 | row_id = 0
106 | for cell in sheet[get_column_letter(col_id)]:
107 | row_id = cell.row
108 | if cell.value == row_name:
109 | return row_id, not is_new_row
110 | return row_id + 1, is_new_row
111 |
112 | @staticmethod
113 | def format_string_with_config(string: str, repalce_config: dict = None):
114 | assert repalce_config is not None
115 | if repalce_config.get("lower"):
116 | string = string.lower()
117 | elif repalce_config.get("upper"):
118 | string = string.upper()
119 | elif repalce_config.get("title"):
120 | string = string.title()
121 |
122 | sub_rule = repalce_config.get("replace")
123 | if sub_rule:
124 | string = re.sub(pattern=sub_rule[0], repl=sub_rule[1], string=string)
125 | return string
126 |
127 |
128 | class MetricExcelRecorder(_BaseExcelRecorder):
129 | def __init__(
130 | self,
131 | xlsx_path: str,
132 | sheet_name: str = None,
133 | repalce_config=None,
134 | row_header=None,
135 | dataset_names=None,
136 | metric_names=None,
137 | ):
138 | """
139 | 向xlsx文档写数据的类
140 |
141 | :param xlsx_path: 对应的xlsx文档路径
142 | :param sheet_name: 要写入数据对应的sheet名字
143 | 默认为 `results`
144 | :param repalce_config: 用于替换对应数据字典的键的模式,会被用于re.sub来进行替换
145 | 默认为 dict(lower=True, replace=(r"[_-]", ""))
146 | :param row_header: 用于指定表格工作表左上角的内容,这里默认为 `["methods", "num_data"]`
147 | :param dataset_names: 对应的数据集名称列表
148 | 默认为rgb sod的数据集合 ["pascals", "ecssd", "hkuis", "dutste", "dutomron"]
149 | :param metric_names: 对应指标名称列表
150 | 默认为 ["smeasure","wfmeasure","mae","adpfm","meanfm","maxfm","adpem","meanem","maxem"]
151 | """
152 | super().__init__(xlsx_path=xlsx_path)
153 | if sheet_name is None:
154 | sheet_name = "results"
155 |
156 | if repalce_config is None:
157 | self.repalce_config = dict(lower=True, replace=(r"[_-]", ""))
158 | else:
159 | self.repalce_config = repalce_config
160 |
161 | if row_header is None:
162 | row_header = ["methods", "num_data"]
163 | self.row_header = [
164 | self.format_string_with_config(s, self.repalce_config) for s in row_header
165 | ]
166 | if dataset_names is None:
167 | dataset_names = ["pascals", "ecssd", "hkuis", "dutste", "dutomron"]
168 | self.dataset_names = [
169 | self.format_string_with_config(s, self.repalce_config) for s in dataset_names
170 | ]
171 | if metric_names is None:
172 | metric_names = [
173 | "smeasure",
174 | "wfmeasure",
175 | "mae",
176 | "adpfm",
177 | "meanfm",
178 | "maxfm",
179 | "adpem",
180 | "meanem",
181 | "maxem",
182 | ]
183 | self.metric_names = [
184 | self.format_string_with_config(s, self.repalce_config) for s in metric_names
185 | ]
186 |
187 | self.sheet_name = sheet_name
188 | self._initial_table()
189 |
190 | def _initial_table(self):
191 | wb, sheet = self.load_sheet(sheet_name=self.sheet_name)
192 |
193 | # 插入row_header
194 | self.insert_row(sheet=sheet, row_data=self.row_header, row_id=1, min_col=1)
195 | # 合并row_header的单元格
196 | for col_id in range(len(self.row_header)):
197 | self.merge_region(
198 | sheet=sheet, min_row=1, max_row=2, min_col=col_id + 1, max_col=col_id + 1
199 | )
200 |
201 | # 插入数据集信息
202 | self.insert_row(
203 | sheet=sheet,
204 | row_data=self.dataset_names,
205 | row_id=1,
206 | min_col=len(self.row_header) + 1,
207 | interval=len(self.metric_names) - 1,
208 | )
209 |
210 | # 插入指标信息
211 | start_col = len(self.row_header) + 1
212 | for i in range(len(self.dataset_names)):
213 | self.insert_row(
214 | sheet=sheet,
215 | row_data=self.metric_names,
216 | row_id=2,
217 | min_col=start_col + i * len(self.metric_names),
218 | )
219 | wb.save(self.xlsx_path)
220 |
221 | def _format_row_data(self, row_data: dict) -> list:
222 | row_data = {
223 | self.format_string_with_config(k, self.repalce_config): v for k, v in row_data.items()
224 | }
225 | return [row_data[n] for n in self.metric_names]
226 |
227 | def __call__(self, row_data: dict, dataset_name: str, method_name: str):
228 | dataset_name = self.format_string_with_config(dataset_name, self.repalce_config)
229 | assert (
230 | dataset_name in self.dataset_names
231 | ), f"{dataset_name} is not contained in {self.dataset_names}"
232 |
233 | # 1 载入数据表
234 | wb, sheet = self.load_sheet(sheet_name=self.sheet_name)
235 | # 2 搜索method_name是否存在,如果存在则直接寻找对应的行列坐标,如果不存在则直接使用新行
236 | dataset_col_start_id = self.get_col_id_with_row_id(
237 | sheet=sheet, col_name=dataset_name, row_id=1
238 | )
239 | (method_row_id, method_col_id), is_new_row = self.get_row_id_with_col_name(
240 | sheet=sheet, row_name=method_name, col_name="methods"
241 | )
242 | # 3 插入方法名字到对应的位置
243 | if is_new_row:
244 | sheet.cell(row=method_row_id, column=method_col_id, value=method_name)
245 | # 4 格式化指标数据部分为合理的格式,并插入表中
246 | row_data = self._format_row_data(row_data=row_data)
247 | self.insert_row(
248 | sheet=sheet, row_data=row_data, row_id=method_row_id, min_col=dataset_col_start_id
249 | )
250 | # 4 写入新表
251 | wb.save(self.xlsx_path)
252 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/recorders/metric_recorder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/1/4
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 |
6 | import numpy as np
7 | from py_sod_metrics.sod_metrics import (
8 | MAE,
9 | Emeasure,
10 | Fmeasure,
11 | Smeasure,
12 | WeightedFmeasure,
13 | )
14 |
15 | from metrics.extra_metrics import ExtraSegMeasure
16 |
17 |
18 | def ndarray_to_basetype(data):
19 | """
20 | 将单独的ndarray,或者tuple,list或者dict中的ndarray转化为基本数据类型,
21 | 即列表(.tolist())和python标量
22 | """
23 |
24 | def _to_list_or_scalar(item):
25 | listed_item = item.tolist()
26 | if isinstance(listed_item, list) and len(listed_item) == 1:
27 | listed_item = listed_item[0]
28 | return listed_item
29 |
30 | if isinstance(data, (tuple, list)):
31 | results = [_to_list_or_scalar(item) for item in data]
32 | elif isinstance(data, dict):
33 | results = {k: _to_list_or_scalar(item) for k, item in data.items()}
34 | else:
35 | assert isinstance(data, np.ndarray)
36 | results = _to_list_or_scalar(data)
37 | return results
38 |
39 |
40 | METRIC_MAPPING = {
41 | "mae": MAE,
42 | "fm": Fmeasure,
43 | "em": Emeasure,
44 | "sm": Smeasure,
45 | "wfm": WeightedFmeasure,
46 | "extra": ExtraSegMeasure,
47 | }
48 |
49 |
50 | class MetricRecorder_V2(object):
51 | def __init__(self, metric_names=None):
52 | """
53 | 用于统计各种指标的类
54 | """
55 | if metric_names is None:
56 | metric_names = ("mae", "fm", "em", "sm", "wfm")
57 | self.metric_objs = {}
58 | for metric_name in metric_names:
59 | self.metric_objs[metric_name] = METRIC_MAPPING[metric_name]()
60 |
61 | def update(self, pre: np.ndarray, gt: np.ndarray):
62 | assert pre.shape == gt.shape
63 | assert pre.dtype == np.uint8
64 | assert gt.dtype == np.uint8
65 | for m_name, m_obj in self.metric_objs.items():
66 | m_obj.step(pre, gt)
67 |
68 | def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
69 | """
70 | 返回指标计算结果:
71 |
72 | - 曲线数据(sequential)
73 | - 数值指标(numerical)
74 | """
75 | sequential_results = {}
76 | numerical_results = {}
77 | for m_name, m_obj in self.metric_objs.items():
78 | info = m_obj.get_results()
79 | if m_name == "fm":
80 | fm = info["fm"]
81 | pr = info["pr"]
82 | sequential_results.update(
83 | {
84 | "fm": np.flip(fm["curve"]),
85 | "p": np.flip(pr["p"]),
86 | "r": np.flip(pr["r"]),
87 | }
88 | )
89 | numerical_results.update(
90 | {"maxf": fm["curve"].max(), "avgf": fm["curve"].mean(), "adpf": fm["adp"]}
91 | )
92 | elif m_name == "wfm":
93 | wfm = info["wfm"]
94 | numerical_results["wfm"] = wfm
95 | elif m_name == "sm":
96 | sm = info["sm"]
97 | numerical_results["sm"] = sm
98 | elif m_name == "em":
99 | em = info["em"]
100 | sequential_results["em"] = np.flip(em["curve"])
101 | numerical_results.update(
102 | {"maxe": em["curve"].max(), "avge": em["curve"].mean(), "adpe": em["adp"]}
103 | )
104 | elif m_name == "mae":
105 | mae = info["mae"]
106 | numerical_results["mae"] = mae
107 | elif m_name == "extra":
108 | pre = info["pre"]
109 | sen = info["sen"]
110 | spec = info["spec"]
111 | fm_std = info["fm"]
112 | dice = info["dice"]
113 | iou = info["iou"]
114 | numerical_results.update(
115 | {
116 | "maxpre": pre.max(),
117 | "avgpre": pre.mean(),
118 | "maxsen": sen.max(),
119 | "avgsen": sen.mean(),
120 | "maxspec": spec.max(),
121 | "avgspec": spec.mean(),
122 | "maxfm_std": fm_std.max(),
123 | "avgfm_std": fm_std.mean(),
124 | "maxdice": dice.max(),
125 | "avgdice": dice.mean(),
126 | "maxiou": iou.max(),
127 | "avgiou": iou.mean(),
128 | }
129 | )
130 | else:
131 | raise NotImplementedError
132 |
133 | if num_bits is not None and isinstance(num_bits, int):
134 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
135 | if not return_ndarray:
136 | sequential_results = ndarray_to_basetype(sequential_results)
137 | numerical_results = ndarray_to_basetype(numerical_results)
138 | return {"sequential": sequential_results, "numerical": numerical_results}
139 |
140 |
141 | class MetricRecorder(object):
142 | def __init__(self):
143 | """
144 | 用于统计各种指标的类
145 | """
146 | self.mae = MAE()
147 | self.fm = Fmeasure()
148 | self.sm = Smeasure()
149 | self.em = Emeasure()
150 | self.wfm = WeightedFmeasure()
151 |
152 | def update(self, pre: np.ndarray, gt: np.ndarray):
153 | assert pre.shape == gt.shape
154 | assert pre.dtype == np.uint8
155 | assert gt.dtype == np.uint8
156 |
157 | self.mae.step(pre, gt)
158 | self.sm.step(pre, gt)
159 | self.fm.step(pre, gt)
160 | self.em.step(pre, gt)
161 | self.wfm.step(pre, gt)
162 |
163 | def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
164 | """
165 | 返回指标计算结果:
166 |
167 | - 曲线数据(sequential): fm/em/p/r
168 | - 数值指标(numerical): SM/MAE/maxE/avgE/adpE/maxF/avgF/adpF/wFm
169 | """
170 | fm_info = self.fm.get_results()
171 | fm = fm_info["fm"]
172 | pr = fm_info["pr"]
173 | wfm = self.wfm.get_results()["wfm"]
174 | sm = self.sm.get_results()["sm"]
175 | em = self.em.get_results()["em"]
176 | mae = self.mae.get_results()["mae"]
177 |
178 | sequential_results = {
179 | "fm": np.flip(fm["curve"]),
180 | "em": np.flip(em["curve"]),
181 | "p": np.flip(pr["p"]),
182 | "r": np.flip(pr["r"]),
183 | }
184 | numerical_results = {
185 | "SM": sm,
186 | "MAE": mae,
187 | "maxE": em["curve"].max(),
188 | "avgE": em["curve"].mean(),
189 | "adpE": em["adp"],
190 | "maxF": fm["curve"].max(),
191 | "avgF": fm["curve"].mean(),
192 | "adpF": fm["adp"],
193 | "wFm": wfm,
194 | }
195 | if num_bits is not None and isinstance(num_bits, int):
196 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
197 | if not return_ndarray:
198 | sequential_results = ndarray_to_basetype(sequential_results)
199 | numerical_results = ndarray_to_basetype(numerical_results)
200 | return {"sequential": sequential_results, "numerical": numerical_results}
201 |
202 |
203 | class GroupedMetricRecorder(object):
204 | def __init__(self):
205 | self.metric_recorders = {}
206 | # 这些指标会根据最终所有分组进行平均得到的曲线计算
207 | self.re_cal_metrics = ["maxE", "avgE", "maxF", "avgF"]
208 |
209 | def update(self, group_name: str, pre: np.ndarray, gt: np.ndarray):
210 | if group_name not in self.metric_recorders:
211 | self.metric_recorders[group_name] = MetricRecorder()
212 | self.metric_recorders[group_name].update(pre, gt)
213 |
214 | def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
215 | """
216 | 返回指标计算结果:
217 |
218 | - 曲线数据(sequential): fm/em/p/r
219 | - 数值指标(numerical): SM/MAE/maxE/avgE/adpE/maxF/avgF/adpF/wFm
220 | """
221 | group_metrics = {}
222 | for k, v in self.metric_recorders.items():
223 | group_metric = v.show(num_bits=None, return_ndarray=True)
224 | group_metrics[k] = {
225 | **group_metric["sequential"],
226 | **{
227 | metric_name: metric_value
228 | for metric_name, metric_value in group_metric["numerical"].items()
229 | if metric_name not in self.re_cal_metrics
230 | },
231 | }
232 | avg_results = self.average_group_metrics(group_metrics=group_metrics)
233 |
234 | sequential_results = {
235 | "fm": avg_results["fm"],
236 | "em": avg_results["em"],
237 | "p": avg_results["p"],
238 | "r": avg_results["r"],
239 | }
240 | numerical_results = {
241 | "SM": avg_results["SM"],
242 | "MAE": avg_results["MAE"],
243 | "maxE": avg_results["em"].max(),
244 | "avgE": avg_results["em"].mean(),
245 | "adpE": avg_results["adpE"],
246 | "maxF": avg_results["fm"].max(),
247 | "avgF": avg_results["fm"].mean(),
248 | "adpF": avg_results["adpF"],
249 | "wFm": avg_results["wFm"],
250 | }
251 | if num_bits is not None and isinstance(num_bits, int):
252 | numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
253 | if not return_ndarray:
254 | sequential_results = ndarray_to_basetype(sequential_results)
255 | numerical_results = ndarray_to_basetype(numerical_results)
256 | return {"sequential": sequential_results, "numerical": numerical_results}
257 |
258 | @staticmethod
259 | def average_group_metrics(group_metrics: dict) -> dict:
260 | recorder = defaultdict(list)
261 | for group_name, metrics in group_metrics.items():
262 | for metric_name, metric_array in metrics.items():
263 | recorder[metric_name].append(metric_array)
264 | results = {k: np.mean(np.vstack(v), axis=0) for k, v in recorder.items()}
265 | return results
266 |
--------------------------------------------------------------------------------
/cos_eval_toolbox/utils/recorders/txt_recorder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/1/4
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 |
6 | from datetime import datetime
7 |
8 |
9 | class TxtRecorder:
10 | def __init__(self, txt_path, to_append=True, max_method_name_width=10):
11 | """
12 | 用于向txt文档写数据的类。
13 |
14 | :param txt_path: txt文档路径
15 | :param to_append: 是否要继续使用之前的文档,如果没有就重新创建
16 | :param max_method_name_width: 方法字符串的最大长度
17 | """
18 | self.txt_path = txt_path
19 | self.max_method_name_width = max_method_name_width
20 | mode = "a" if to_append else "w"
21 | with open(txt_path, mode=mode, encoding="utf-8") as f:
22 | f.write(f"\n ========>> Date: {datetime.now()} <<======== \n")
23 |
24 | def add_row(self, row_name, row_data, row_start_str="", row_end_str="\n"):
25 | with open(self.txt_path, mode="a", encoding="utf-8") as f:
26 | f.write(f"{row_start_str} ========>> {row_name}: {row_data} <<======== {row_end_str}")
27 |
28 | def __call__(
29 | self,
30 | method_results: dict,
31 | method_name: str = "",
32 | row_start_str="",
33 | row_end_str="\n",
34 | value_width=6,
35 | ):
36 | msg = row_start_str
37 | if len(method_name) > self.max_method_name_width:
38 | method_name = method_name[: self.max_method_name_width - 3] + "..."
39 | else:
40 | method_name += " " * (self.max_method_name_width - len(method_name))
41 | msg += f"[{method_name}] "
42 | for metric_name, metric_value in method_results.items():
43 | assert isinstance(metric_value, float)
44 | msg += f"{metric_name}: "
45 | real_width = len(str(metric_value))
46 | if value_width > real_width:
47 | # 后补空格
48 | msg += f"{metric_value}" + " " * (value_width - real_width)
49 | else:
50 | # 真实数据长度超过了限定,这时需要近似保留小数
51 | # 保留指定位数,注意,这里由于数据都是0~1之间的数据,所以使用round的时候需要去掉前面的`0.`
52 | msg += f"{round(metric_value, ndigits=value_width - 2)}"
53 | msg += " "
54 | msg += row_end_str
55 | with open(self.txt_path, mode="a", encoding="utf-8") as f:
56 | f.write(msg)
57 |
--------------------------------------------------------------------------------