├── .gitignore
├── LICENSE
├── README.md
├── assets
├── localization_bottle_broken_large_000_p0100_k01_s067.png
├── localization_cable_bent_wire_000_p0100_k01_s073.png
├── localization_capsule_crack_000_p0100_k01_s020.png
├── localization_carpet_color_000_p0100_k01_s070.png
├── localization_grid_bent_000_p0100_k01_s048.png
├── localization_hazelnut_crack_000_p0100_k01_s047.png
├── localization_leather_color_000_p0100_k01_s058.png
├── localization_metal_nut_bent_000_p0100_k01_s082.png
├── localization_pill_color_000_p0100_k01_s036.png
├── localization_screw_manipulated_front_000_p0100_k01_s047.png
├── localization_tile_crack_000_p0100_k01_s059.png
├── localization_toothbrush_defective_000_p0100_k01_s079.png
├── localization_transistor_bent_lead_000_p0100_k01_s043.png
├── localization_wood_color_000_p0100_k01_s063.png
├── localization_zipper_broken_teeth_000_p0100_k01_s029.png
├── pred-dist_bottle_p0100_k01_r1000.png
├── pred-dist_cable_p0100_k01_r0999.png
├── pred-dist_capsule_p0100_k01_r0978.png
├── pred-dist_carpet_p0100_k01_r0987.png
├── pred-dist_grid_p0100_k01_r0985.png
├── pred-dist_hazelnut_p0100_k01_r1000.png
├── pred-dist_leather_p0100_k01_r1000.png
├── pred-dist_metal_nut_p0100_k01_r1000.png
├── pred-dist_pill_p0100_k01_r0962.png
├── pred-dist_screw_p0100_k01_r0985.png
├── pred-dist_tile_p0100_k01_r0994.png
├── pred-dist_toothbrush_p0100_k01_r0994.png
├── pred-dist_transistor_p0100_k01_r1000.png
├── pred-dist_wood_p0100_k01_r0990.png
├── pred-dist_zipper_p0100_k01_r0996.png
├── roc-curve_p0010_k01_rim0990_rpm0979.png
├── roc-curve_p0100_k01_rim0991_rpm0982.png
├── roc-curve_p0250_k01_rim0991_rpm0981.png
└── total_recall.jpg
├── datasets
├── mvtec_dataset.py
└── shared_memory.py
├── infer.py
├── models
├── feat_extract.py
└── patchcore.py
├── requirements.txt
├── traintest.py
└── utils
├── config.py
├── metrics.py
├── tictoc.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .idea
131 | result
132 | trained
133 | memo.txt
134 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PatchCore with some "ex"
2 | This is an unofficial implementation of the paper [Towards Total Recall in Industrial Anomaly Detection](https://arxiv.org/pdf/2106.08265.pdf).
3 |
4 |
5 | We measured accuracy and speed for percentage_coreset=0.01, 0.1 and 0.25.
6 |
7 | This code was implimented with [patchcore-inspection](https://github.com/amazon-science/patchcore-inspection), thanks.
8 |
9 | Some "ex" are **ex**plainability, **ex**press delivery, fl**ex**ibility, **ex**tra algolithm and so on.
10 |
11 |
12 |
13 | ## Prerequisites
14 |
15 | - faiss-gpu (easy to install with conda : [ref](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md))
16 | - torch
17 | - torchvision
18 | - numpy
19 | - opencv-python
20 | - scipy
21 | - argparse
22 | - matplotlib
23 | - scikit-learn
24 | - torchinfo
25 | - tqdm
26 |
27 |
28 | Install prerequisites with:
29 | ```
30 | conda install --file requirements.txt
31 | ```
32 |
33 |
34 |
35 | Please download [`MVTec AD`](https://www.mvtec.com/company/research/datasets/mvtec-ad/) dataset.
36 |
37 | After downloading, place the data as follows:
38 | ```
39 | ./
40 | ├── main.py
41 | └── mvtec_anomaly_detection
42 | ├── bottle
43 | │ ├── ground_truth # pixel annotation of test data other than good
44 | │ │ ├── broken_large
45 | │ │ ├── broken_small
46 | │ │ └── contamination
47 | │ ├── test # other than good is anomaly data
48 | │ │ ├── broken_large
49 | │ │ ├── broken_small
50 | │ │ ├── contamination
51 | │ │ └── good
52 | │ └── train # good only
53 | │ └── good
54 | ├── cable
55 | ├── cable
56 | ├── capsule
57 | ├── carpet
58 | ├── grid
59 | ├── hazelnut
60 | ├── leather
61 | ├── metal_nut
62 | ├── pill
63 | ├── screw
64 | ├── tile
65 | ├── toothbrush
66 | ├── transistor
67 | ├── wood
68 | └── zipper
69 | ```
70 |
71 |
72 |
73 | When using custom dataset, place the data as follows:
74 | ```
75 | ./
76 | ├── main.py
77 | └── custom_data
78 | └── theme
79 | ├── ground_truth # pixel annotation of test data other than good
80 | │ ├── anomaly_type_a
81 | │ └── anomaly_type_b
82 | ├── test # other than good is anomaly data
83 | │ ├── anomaly_type_a
84 | │ ├── anomaly_type_b
85 | │ └── good
86 | └── train # good only
87 | └── good
88 | ```
89 |
90 |
91 |
92 | ## Usage
93 |
94 | To test **PatchCore** on `MVTec AD` dataset:
95 | ```
96 | python main.py
97 | ```
98 |
99 | After running the code above, you can see the ROCAUC results in `result/roc_curve.png`
100 |
101 |
102 |
103 | To test **PatchCore** on custom dataset:
104 | ```
105 | python main.py --path_data ./custom_data
106 | ```
107 |
108 |
109 |
110 | ## Results
111 |
112 | Below is the implementation result of the test set ROCAUC on the `MVTec AD` dataset.
113 |
114 | ### 1. Image-level anomaly detection accuracy (ROCAUC %)
115 |
116 | | | Paper
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.01 | This Repo
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.25 |
117 | | - | - | - | - | - |
118 | | bottle | 100.0 | 100.0 | 100.0 | 100.0 |
119 | | cable | 99.4 | 99.6 | 99.9 | 99.6 |
120 | | capsule | 97.8 | 97.6 | 97.8 | 97.8 |
121 | | carpet | 98.7 | 98.3 | 98.7 | 98.5 |
122 | | grid | 97.9 | 97.8 | 98.5 | 98.4 |
123 | | hazelnut | 100.0 | 100.0 | 100.0 | 100.0 |
124 | | leather | 100.0 | 100.0 | 100.0 | 100.0 |
125 | | metal_nut | 100.0 | 100.0 | 100.0 | 99.9 |
126 | | pill | 96.0 | 96.3 | 96.2 | 96.4 |
127 | | screw | 97.0 | 97.9 | 98.5 | 98.1 |
128 | | tile | 98.9 | 99.0 | 99.4 | 99.2 |
129 | | toothbrush | 99.7 | 99.4 | 99.4 | 100.0 |
130 | | transistor | 100.0 | 99.8 | 100.0 | 100.0 |
131 | | wood | 99.0 | 99.1 | 99.0 | 98.9 |
132 | | zipper | 99.5 | 99.7 | 99.6 | 99.5 |
133 | | Average | 99.0 | 99.0 | 99.1 | 99.1 |
134 |
135 |
136 |
137 | ### 2. Pixel-level anomaly detection accuracy (ROCAUC %)
138 |
139 | | | Paper
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.01 | This Repo
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.25 |
140 | | - | - | - | - | - |
141 | | bottle | 98.6 | 98.5 | 98.6 | 98.6 |
142 | | cable | 98.5 | 98.2 | 98.4 | 98.4 |
143 | | capsule | 98.9 | 98.8 | 98.9 | 98.9 |
144 | | carpet | 99.1 | 99.0 | 99.1 | 99.0 |
145 | | grid | 98.7 | 98.2 | 98.7 | 98.7 |
146 | | hazelnut | 98.7 | 98.6 | 98.7 | 98.6 |
147 | | leather | 99.3 | 99.3 | 99.3 | 99.3 |
148 | | metal_nut | 98.4 | 98.5 | 98.7 | 98.7 |
149 | | pill | 97.6 | 97.4 | 97.6 | 97.3 |
150 | | screw | 99.4 | 98.8 | 99.4 | 99.4 |
151 | | tile | 95.9 | 96.2 | 96.1 | 95.9 |
152 | | toothbrush | 98.7 | 98.6 | 98.7 | 98.7 |
153 | | transistor | 96.4 | 94.2 | 96.0 | 95.8 |
154 | | wood | 95.1 | 95.7 | 95.5 | 95.4 |
155 | | zipper | 98.9 | 98.9 | 98.9 | 98.9 |
156 | | Average | 98.1 | 97.9 | 98.2 | 98.1 |
157 |
158 |
159 |
160 | ### 3. Processing time (sec)
161 |
162 | | | This Repo
$\\%_{core}$=0.01 | This Repo
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.25 |
163 | | - | - | - | - |
164 | | bottle | 14.2 | 21.3 | 33.0 |
165 | | cable | 22.9 | 30.3 | 44.6 |
166 | | capsule | 20.4 | 27.3 | 39.8 |
167 | | carpet | 22.4 | 32.6 | 52.7 |
168 | | grid | 16.3 | 26.3 | 44.3 |
169 | | hazelnut | 25.9 | 46.4 | 82.1 |
170 | | leather | 20.8 | 28.5 | 44.0 |
171 | | metal_nut | 16.8 | 23.8 | 36.3 |
172 | | pill | 23.3 | 34.2 | 52.6 |
173 | | screw | 22.4 | 37.7 | 63.9 |
174 | | tile | 18.4 | 26.2 | 40.7 |
175 | | toothbrush | 5.8 | 6.7 | 8.0 |
176 | | transistor | 17.2 | 23.9 | 35.9 |
177 | | wood | 18.0 | 26.5 | 42.1 |
178 | | zipper | 20.5 | 29.0 | 44.2 |
179 |
180 | ```
181 | CPU : Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
182 | GPU : Tesla V100 SXM2
183 | ```
184 |
185 |
186 |
187 | ### ROC Curve
188 |
189 | - percentage_coreset = 0.01
190 | 
191 |
192 |
193 |
194 | - percentage_coreset = 0.1
195 | 
196 |
197 |
198 |
199 | - percentage_coreset = 0.25
200 | 
201 |
202 |
203 |
204 | ### Prediction Distribution (percentage_coreset = 0.1)
205 |
206 | - bottle
207 | 
208 |
209 | - cable
210 | 
211 |
212 | - capsule
213 | 
214 |
215 | - carpet
216 | 
217 |
218 | - grid
219 | 
220 |
221 | - hazelnut
222 | 
223 |
224 | - leather
225 | 
226 |
227 | - metal_nut
228 | 
229 |
230 | - pill
231 | 
232 |
233 | - screw
234 | 
235 |
236 | - tile
237 | 
238 |
239 | - toothbrush
240 | 
241 |
242 | - transistor
243 | 
244 |
245 | - wood
246 | 
247 |
248 | - zipper
249 | 
250 |
251 |
252 |
253 | ### Localization : percentage_coreset = 0.1
254 |
255 | - bottle (test case : broken_large)
256 | 
257 |
258 | - cable (test case : bent_wire)
259 | 
260 |
261 | - capsule (test case : crack)
262 | 
263 |
264 | - carpet (test case : color)
265 | 
266 |
267 | - grid (test case : bent)
268 | 
269 |
270 | - hazelnut (test case : crack)
271 | 
272 |
273 | - leather (test case : color)
274 | 
275 |
276 | - metal_nut (test case : bent)
277 | 
278 |
279 | - pill (test case : color)
280 | 
281 |
282 | - screw (test case : manipulated_front)
283 | 
284 |
285 | - tile (test case : crack)
286 | 
287 |
288 | - toothbrush (test case : defective)
289 | 
290 |
291 | - transistor (test case : bent_lead)
292 | 
293 |
294 | - wood (test case : color)
295 | 
296 |
297 | - zipper (test case : broken_teeth)
298 | 
299 |
300 |
301 |
302 | ### For your infomation
303 |
304 | We also implement a similar algorithm, SPADE.
305 | https://github.com/any-tech/SPADE-fast/tree/main
306 |
307 | There is an explanatory article.
308 | https://tech.anytech.co.jp/entry/2023/03/24/100000
309 |
--------------------------------------------------------------------------------
/assets/localization_bottle_broken_large_000_p0100_k01_s067.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_bottle_broken_large_000_p0100_k01_s067.png
--------------------------------------------------------------------------------
/assets/localization_cable_bent_wire_000_p0100_k01_s073.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_cable_bent_wire_000_p0100_k01_s073.png
--------------------------------------------------------------------------------
/assets/localization_capsule_crack_000_p0100_k01_s020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_capsule_crack_000_p0100_k01_s020.png
--------------------------------------------------------------------------------
/assets/localization_carpet_color_000_p0100_k01_s070.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_carpet_color_000_p0100_k01_s070.png
--------------------------------------------------------------------------------
/assets/localization_grid_bent_000_p0100_k01_s048.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_grid_bent_000_p0100_k01_s048.png
--------------------------------------------------------------------------------
/assets/localization_hazelnut_crack_000_p0100_k01_s047.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_hazelnut_crack_000_p0100_k01_s047.png
--------------------------------------------------------------------------------
/assets/localization_leather_color_000_p0100_k01_s058.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_leather_color_000_p0100_k01_s058.png
--------------------------------------------------------------------------------
/assets/localization_metal_nut_bent_000_p0100_k01_s082.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_metal_nut_bent_000_p0100_k01_s082.png
--------------------------------------------------------------------------------
/assets/localization_pill_color_000_p0100_k01_s036.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_pill_color_000_p0100_k01_s036.png
--------------------------------------------------------------------------------
/assets/localization_screw_manipulated_front_000_p0100_k01_s047.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_screw_manipulated_front_000_p0100_k01_s047.png
--------------------------------------------------------------------------------
/assets/localization_tile_crack_000_p0100_k01_s059.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_tile_crack_000_p0100_k01_s059.png
--------------------------------------------------------------------------------
/assets/localization_toothbrush_defective_000_p0100_k01_s079.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_toothbrush_defective_000_p0100_k01_s079.png
--------------------------------------------------------------------------------
/assets/localization_transistor_bent_lead_000_p0100_k01_s043.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_transistor_bent_lead_000_p0100_k01_s043.png
--------------------------------------------------------------------------------
/assets/localization_wood_color_000_p0100_k01_s063.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_wood_color_000_p0100_k01_s063.png
--------------------------------------------------------------------------------
/assets/localization_zipper_broken_teeth_000_p0100_k01_s029.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_zipper_broken_teeth_000_p0100_k01_s029.png
--------------------------------------------------------------------------------
/assets/pred-dist_bottle_p0100_k01_r1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_bottle_p0100_k01_r1000.png
--------------------------------------------------------------------------------
/assets/pred-dist_cable_p0100_k01_r0999.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_cable_p0100_k01_r0999.png
--------------------------------------------------------------------------------
/assets/pred-dist_capsule_p0100_k01_r0978.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_capsule_p0100_k01_r0978.png
--------------------------------------------------------------------------------
/assets/pred-dist_carpet_p0100_k01_r0987.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_carpet_p0100_k01_r0987.png
--------------------------------------------------------------------------------
/assets/pred-dist_grid_p0100_k01_r0985.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_grid_p0100_k01_r0985.png
--------------------------------------------------------------------------------
/assets/pred-dist_hazelnut_p0100_k01_r1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_hazelnut_p0100_k01_r1000.png
--------------------------------------------------------------------------------
/assets/pred-dist_leather_p0100_k01_r1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_leather_p0100_k01_r1000.png
--------------------------------------------------------------------------------
/assets/pred-dist_metal_nut_p0100_k01_r1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_metal_nut_p0100_k01_r1000.png
--------------------------------------------------------------------------------
/assets/pred-dist_pill_p0100_k01_r0962.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_pill_p0100_k01_r0962.png
--------------------------------------------------------------------------------
/assets/pred-dist_screw_p0100_k01_r0985.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_screw_p0100_k01_r0985.png
--------------------------------------------------------------------------------
/assets/pred-dist_tile_p0100_k01_r0994.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_tile_p0100_k01_r0994.png
--------------------------------------------------------------------------------
/assets/pred-dist_toothbrush_p0100_k01_r0994.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_toothbrush_p0100_k01_r0994.png
--------------------------------------------------------------------------------
/assets/pred-dist_transistor_p0100_k01_r1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_transistor_p0100_k01_r1000.png
--------------------------------------------------------------------------------
/assets/pred-dist_wood_p0100_k01_r0990.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_wood_p0100_k01_r0990.png
--------------------------------------------------------------------------------
/assets/pred-dist_zipper_p0100_k01_r0996.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_zipper_p0100_k01_r0996.png
--------------------------------------------------------------------------------
/assets/roc-curve_p0010_k01_rim0990_rpm0979.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/roc-curve_p0010_k01_rim0990_rpm0979.png
--------------------------------------------------------------------------------
/assets/roc-curve_p0100_k01_rim0991_rpm0982.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/roc-curve_p0100_k01_rim0991_rpm0982.png
--------------------------------------------------------------------------------
/assets/roc-curve_p0250_k01_rim0991_rpm0981.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/roc-curve_p0250_k01_rim0991_rpm0981.png
--------------------------------------------------------------------------------
/assets/total_recall.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/total_recall.jpg
--------------------------------------------------------------------------------
/datasets/mvtec_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | from tqdm import tqdm
5 | from utils.config import ConfigData
6 | from datasets.shared_memory import SharedMemory
7 |
8 |
9 | class MVTecDataset:
10 | @classmethod
11 | def __init__(cls, type_data):
12 | cls.type_data = type_data
13 | cls.SHAPE_INPUT = ConfigData.SHAPE_INPUT
14 |
15 | cls.imgs_train = None
16 | cls.imgs_test = {}
17 | cls.gts_test = {}
18 |
19 | # read train data
20 | desc = 'read images for train (case:good)'
21 | path = '%s/%s/train/good' % (ConfigData.path_data, type_data)
22 | files = ['%s/%s' % (path, f) for f in os.listdir(path)
23 | if (os.path.isfile('%s/%s' % (path, f)) & (('.png' in f) | ('.jpg' in f)))]
24 | cls.files_train = np.sort(np.array(files))
25 | cls.imgs_train = SharedMemory.read_img_parallel(files=cls.files_train,
26 | imgs=cls.imgs_train,
27 | desc=desc)
28 | if ConfigData.shuffle:
29 | cls.imgs_train = np.random.permutation(cls.imgs_train)
30 | if ConfigData.flip_horz:
31 | cls.imgs_train = np.concatenate([cls.imgs_train,
32 | cls.imgs_train[:, :, ::-1]], axis=0)
33 | if ConfigData.flip_vert:
34 | cls.imgs_train = np.concatenate([cls.imgs_train,
35 | cls.imgs_train[:, ::-1]], axis=0)
36 |
37 | # read test data
38 | cls.files_test = {}
39 | cls.types_test = os.listdir('%s/%s/test' % (ConfigData.path_data, type_data))
40 | cls.types_test = np.array(sorted(cls.types_test))
41 | for type_test in cls.types_test:
42 | desc = 'read images for test (case:%s)' % type_test
43 | path = '%s/%s/test/%s' % (ConfigData.path_data, type_data, type_test)
44 | files = [('%s/%s' % (path, f)) for f in os.listdir(path)
45 | if (os.path.isfile('%s/%s' % (path, f)) & (('.png' in f) | ('.jpg' in f)))]
46 | cls.files_test[type_test] = np.sort(np.array(files))
47 | cls.imgs_test[type_test] = None
48 | cls.imgs_test[type_test] = SharedMemory.read_img_parallel(files=cls.files_test[type_test],
49 | imgs=cls.imgs_test[type_test],
50 | desc=desc)
51 |
52 | # read ground truth of test data
53 | for type_test in cls.types_test:
54 | # create memory shared variable
55 | if type_test == 'good':
56 | cls.gts_test[type_test] = np.zeros([len(cls.files_test[type_test]),
57 | ConfigData.SHAPE_INPUT[0],
58 | ConfigData.SHAPE_INPUT[1]], dtype=np.uint8)
59 | else:
60 | desc = 'read ground-truths for test (case:%s)' % type_test
61 | cls.gts_test[type_test] = None
62 | cls.gts_test[type_test] = SharedMemory.read_img_parallel(files=cls.files_test[type_test],
63 | imgs=cls.gts_test[type_test],
64 | is_gt=True,
65 | desc=desc)
66 |
67 |
68 | class MVTecDatasetOnlyTest:
69 | @classmethod
70 | def __init__(cls, type_data):
71 | if ConfigData.path_data is not None:
72 | cls.type_data = type_data
73 | cls.SHAPE_INPUT = ConfigData.SHAPE_INPUT
74 |
75 | # read test data
76 | cls.imgs_test = {}
77 | cls.files_test = {}
78 | cls.types_test = os.listdir('%s/%s/test' % (ConfigData.path_data, type_data))
79 | cls.types_test = np.array(sorted(cls.types_test))
80 | for type_test in cls.types_test:
81 | desc = 'read images for test (case:%s)' % type_test
82 | path = '%s/%s/test/%s' % (ConfigData.path_data, type_data, type_test)
83 | files = [('%s/%s' % (path, f)) for f in os.listdir(path)
84 | if (os.path.isfile('%s/%s' % (path, f)) & (('.png' in f) | ('.jpg' in f)))]
85 | cls.files_test[type_test] = np.sort(np.array(files))
86 | cls.imgs_test[type_test] = None
87 | cls.imgs_test[type_test] = SharedMemory.read_img_parallel(files=cls.files_test[type_test],
88 | imgs=cls.imgs_test[type_test],
89 | desc=desc)
90 | else:
91 | cls.type_data = type_data
92 | cls.SHAPE_INPUT = ConfigData.SHAPE_INPUT
93 |
94 | # read test data
95 | cls.imgs_test = {}
96 | cls.files_test = {}
97 | cls.types_test = np.array(['video'])
98 | type_test = cls.types_test[0]
99 | cls.imgs_test[type_test] = []
100 | cls.files_test[type_test] = []
101 | capture = cv2.VideoCapture(ConfigData.path_video)
102 | num_frame = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
103 | desc = 'read images from video for test'
104 | for i_frame in tqdm(range(num_frame), desc=desc):
105 | # frame read
106 | capture.grab()
107 | ret, frame = capture.retrieve()
108 | if ret:
109 | frame = frame[..., ::-1] # BGR2RGB
110 | frame = cv2.resize(frame, (ConfigData.SHAPE_MIDDLE[1],
111 | ConfigData.SHAPE_MIDDLE[0]),
112 | interpolation=cv2.INTER_AREA)
113 | frame = frame[ConfigData.pixel_cut[0]:(ConfigData.SHAPE_INPUT[0] +
114 | ConfigData.pixel_cut[0]),
115 | ConfigData.pixel_cut[1]:(ConfigData.SHAPE_INPUT[1] +
116 | ConfigData.pixel_cut[1])]
117 | cls.imgs_test[type_test].append(frame)
118 | cls.files_test[type_test].append('%s/frame/%05d' %
119 | (os.path.basename(ConfigData.path_video),
120 | i_frame))
121 | cls.imgs_test[type_test] = np.array(cls.imgs_test[type_test])
122 | cls.files_test[type_test] = np.array(cls.files_test[type_test])
123 |
--------------------------------------------------------------------------------
/datasets/shared_memory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from tqdm import tqdm
4 | import multiprocessing as mp
5 | from multiprocessing.sharedctypes import RawArray
6 | from utils.config import ConfigData
7 |
8 |
9 | def read_and_resize(file):
10 | img = cv2.imread(file)[..., ::-1] # BGR2RGB
11 | img = cv2.resize(img, (ConfigData.SHAPE_MIDDLE[1], ConfigData.SHAPE_MIDDLE[0]),
12 | interpolation=cv2.INTER_AREA)
13 | img = img[ConfigData.pixel_cut[0]:(ConfigData.SHAPE_INPUT[0] + ConfigData.pixel_cut[0]),
14 | ConfigData.pixel_cut[1]:(ConfigData.SHAPE_INPUT[1] + ConfigData.pixel_cut[1])]
15 | SharedMemory.shared_array[np.where(SharedMemory.files == file)[0]] = img
16 |
17 |
18 | def read_and_resize_ground_truth(file):
19 | file_gt = file.replace('/test/', '/ground_truth/')
20 | file_gt = file_gt.replace('.png', '_mask.png')
21 |
22 | gt = cv2.imread(file_gt, cv2.IMREAD_GRAYSCALE)
23 | gt = cv2.resize(gt, (ConfigData.SHAPE_MIDDLE[1], ConfigData.SHAPE_MIDDLE[0]),
24 | interpolation=cv2.INTER_NEAREST)
25 | gt = gt[ConfigData.pixel_cut[0]:(ConfigData.SHAPE_INPUT[0] + ConfigData.pixel_cut[0]),
26 | ConfigData.pixel_cut[1]:(ConfigData.SHAPE_INPUT[1] + ConfigData.pixel_cut[1])]
27 |
28 | if np.max(gt) != 0:
29 | gt = (gt / np.max(gt)).astype(np.uint8)
30 |
31 | SharedMemory.shared_array[np.where(SharedMemory.files == file)[0]] = gt
32 |
33 |
34 | class SharedMemory:
35 | @classmethod
36 | def read_img_parallel(cls, files, imgs, is_gt=False, desc='read images'):
37 | cls.files = files
38 | cls.shared_array = imgs
39 |
40 | if not is_gt:
41 | shape = (len(cls.files), ConfigData.SHAPE_INPUT[0], ConfigData.SHAPE_INPUT[1], 3)
42 | num_elm = shape[0] * shape[1] * shape[2] * shape[3]
43 | else:
44 | shape = (len(cls.files), ConfigData.SHAPE_INPUT[0], ConfigData.SHAPE_INPUT[1])
45 | num_elm = shape[0] * shape[1] * shape[2]
46 |
47 | ctype = np.ctypeslib.as_ctypes_type(np.dtype(np.uint8))
48 | data = np.ctypeslib.as_array(RawArray(ctype, num_elm))
49 | data.shape = shape
50 | cls.shared_array = data.view(np.uint8)
51 |
52 | # exec imread and imresize on multiprocess
53 | mp.set_start_method('fork', force=True)
54 | p = mp.Pool(min(mp.cpu_count(), ConfigData.num_cpu_max))
55 |
56 | func = read_and_resize
57 | if is_gt:
58 | func = read_and_resize_ground_truth
59 |
60 | for _ in tqdm(p.imap_unordered(func, cls.files), total=len(cls.files), desc=desc):
61 | pass
62 |
63 | p.close()
64 |
65 | return cls.shared_array
66 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from argparse import ArgumentParser
4 | import json
5 |
6 | from utils.config import ConfigData, ConfigFeat, ConfigPatchCore, ConfigDraw
7 | from utils.tictoc import tic, toc
8 | from utils.visualize import draw_heatmap
9 |
10 | from datasets.mvtec_dataset import MVTecDatasetOnlyTest
11 |
12 | from models.feat_extract import FeatExtract
13 | from models.patchcore import PatchCore
14 |
15 |
16 | def arg_parser():
17 | parser = ArgumentParser()
18 | # environment related
19 | parser.add_argument('-n', '--num_cpu_max', default=4, type=int,
20 | help='number of CPUs for parallel reading input images')
21 | parser.add_argument('-c', '--cpu', action='store_true', help='use cpu')
22 |
23 | # I/O and visualization related
24 | parser.add_argument('-pd', '--path_data', type=str, default='./mvtec_anomaly_detection',
25 | help='parent path of input data path')
26 | parser.add_argument('-pv', '--path_video', type=str, default=None,
27 | help='path of input video path (.mp4 or .avi or .mov)')
28 | parser.add_argument('-pt', '--path_trained', type=str, default='./trained',
29 | help='output path of trained products')
30 | parser.add_argument('-pr', '--path_result', type=str, default='./result',
31 | help='output path of figure image as the evaluation result')
32 | parser.add_argument('-v', '--verbose', action='store_true',
33 | help='save visualization of localization')
34 | parser.add_argument('-sm', '--score_max', type=float, default=None,
35 | help='value for normalization of visualizing')
36 |
37 | # data loader related
38 | parser.add_argument('-bs', '--batch_size', type=int, default=16,
39 | help='batch-size for feature extraction by ImageNet model')
40 | parser.add_argument('-tt', '--types_data', nargs='*', type=str, default=None)
41 |
42 | # Nearest-Neighbor related
43 | parser.add_argument('-k', '--k', type=int, default=5,
44 | help='nearest neighbor\'s k for coreset searching')
45 | # post precessing related
46 | parser.add_argument('-pod', '--pixel_outer_decay', type=int, default=0,
47 | help='number of outer pixels to decay anomaly score')
48 |
49 | args = parser.parse_args()
50 |
51 | # adjust...
52 | if args.path_video is not None:
53 | args.path_data = None
54 |
55 | print('args =\n', args)
56 | return args
57 |
58 |
59 | def check_args(args):
60 | assert 0 < args.num_cpu_max < os.cpu_count()
61 | if args.path_video is None:
62 | assert args.path_data is not None
63 | assert os.path.isdir(args.path_data)
64 | else:
65 | assert os.path.isfile(args.path_video)
66 | assert ((args.path_video.split('.')[-1].lower() == 'mp4') |
67 | (args.path_video.split('.')[-1].lower() == 'avi') |
68 | (args.path_video.split('.')[-1].lower() == 'mov'))
69 | assert len(args.types_data) == 1
70 | if args.score_max is not None:
71 | assert args.score_max > 0
72 | assert args.batch_size > 0
73 | if args.types_data is not None:
74 | if args.path_video is None:
75 | for type_data in args.types_data:
76 | assert os.path.exists('%s/%s' % (args.path_data, type_data))
77 | assert args.k > 0
78 | assert args.pixel_outer_decay >= 0
79 |
80 |
81 | def summary_result(type_data, D, files_test, thresh):
82 | result = ['data-type filename anomaly-score threshold abnormal-judgement']
83 | for type_test in D.keys():
84 | for i_file in range(len(D[type_test])):
85 | D_max = np.max(D[type_test][i_file])
86 | flg_abnormal = int(thresh <= D_max)
87 |
88 | result.append('%s %s %.3f %.3f %d' %
89 | (type_data, files_test[type_test][i_file],
90 | D_max, thresh, flg_abnormal))
91 |
92 | filename_txt = '%s/%s/%s_result.txt' % (args.path_result, type_data, type_data)
93 | np.savetxt(filename_txt, result, fmt='%s')
94 |
95 |
96 | def apply_patchcore_inference(args, type_data, feat_ext, patchcore, cfg_draw):
97 | print('\n----> inference-only PatchCore processing in %s start' % type_data)
98 | tic()
99 |
100 | # read images
101 | MVTecDatasetOnlyTest(type_data)
102 |
103 | # load neighbor
104 | patchcore.reset_faiss_index()
105 | patchcore.load_faiss_index(type_data)
106 |
107 | # extract features
108 | feat_test = {}
109 | for type_test in MVTecDatasetOnlyTest.imgs_test.keys():
110 | feat_test[type_test] = feat_ext.extract(MVTecDatasetOnlyTest.imgs_test[type_test],
111 | case='test (case:%s)' % type_test)
112 |
113 | # Sub-Image Anomaly Detection with Deep Pyramid Correspondences
114 | D, D_max, I = patchcore.localization(feat_test, feat_ext.HW_map())
115 |
116 | toc(tag=('----> inference-only PatchCore processing in %s end, elapsed time' % type_data))
117 |
118 | if args.verbose:
119 | imgs_coreset = patchcore.load_coreset(type_data)
120 |
121 | draw_heatmap(type_data, cfg_draw, D, None, D_max, I,
122 | MVTecDatasetOnlyTest.imgs_test, MVTecDatasetOnlyTest.files_test,
123 | imgs_coreset, feat_ext.HW_map())
124 |
125 | # load optimal threshold
126 | thresh = np.loadtxt('%s/%s_thr.txt' % (args.path_trained, type_data))
127 |
128 | # summary test result
129 | summary_result(type_data, D, MVTecDatasetOnlyTest.files_test, thresh)
130 |
131 |
132 | def main(args):
133 | ConfigData(args, mode_train=False) # static define for speed-up
134 | cfg_feat = ConfigFeat(args, mode_train=False)
135 | cfg_patchcore = ConfigPatchCore(args, mode_train=False)
136 | cfg_draw = ConfigDraw(args, mode_train=False)
137 |
138 | with open('%s/args.json' % args.path_trained, mode='r') as f:
139 | args_trained = json.load(f)
140 | ConfigData.follow(args_trained)
141 | cfg_feat.follow(args_trained)
142 | cfg_patchcore.follow(args_trained)
143 | cfg_draw.follow(args_trained)
144 |
145 | feat_ext = FeatExtract(cfg_feat)
146 | patchcore = PatchCore(cfg_patchcore)
147 |
148 | os.makedirs(args.path_result, exist_ok=True)
149 | for type_data in ConfigData.types_data:
150 | os.makedirs('%s/%s' % (args.path_result, type_data), exist_ok=True)
151 |
152 | # loop for types of data
153 | for type_data in ConfigData.types_data:
154 | apply_patchcore_inference(args, type_data, feat_ext, patchcore, cfg_draw)
155 |
156 |
157 | if __name__ == '__main__':
158 | args = arg_parser()
159 | main(args)
160 |
--------------------------------------------------------------------------------
/models/feat_extract.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | import torch
5 | import torch.nn.functional as F
6 | import torchvision
7 | from torchinfo import summary
8 |
9 |
10 | def print_weight_device(backbone):
11 | for name, module in backbone.named_modules():
12 | if hasattr(module, 'weight') and getattr(module, 'weight') is not None:
13 | weight_device = module.weight.device
14 | weight_shape = module.weight.shape
15 | print(f"{name}.weight - Device: {weight_device}, Shape: {weight_shape}")
16 |
17 | if hasattr(module, 'bias') and getattr(module, 'bias') is not None:
18 | bias_device = module.bias.device
19 | print(f"{name}.bias - Device: {bias_device}")
20 |
21 |
22 | class FeatExtract:
23 | def __init__(self, cfg_feat):
24 | self.device = cfg_feat.device
25 | self.batch_size = cfg_feat.batch_size
26 | self.shape_input = cfg_feat.SHAPE_INPUT
27 | self.layer_map = cfg_feat.layer_map
28 | self.layer_weight = cfg_feat.layer_weight
29 | self.size_patch = cfg_feat.size_patch
30 | self.dim_each_feat = cfg_feat.dim_each_feat
31 | self.dim_merge_feat = cfg_feat.dim_merge_feat
32 | self.MEAN = cfg_feat.MEAN
33 | self.STD = cfg_feat.STD
34 | self.merge_dst_index = self.layer_map.index(cfg_feat.layer_merge_ref)
35 |
36 | self.padding = int((self.size_patch - 1) / 2)
37 | self.stride = 1 # fixed temporarily...
38 |
39 | code = 'self.backbone = %s(weights=%s)' % (cfg_feat.backbone, cfg_feat.weight)
40 | exec(code)
41 | summary(self.backbone, input_size=(1, 3, *self.shape_input))
42 | self.backbone.eval()
43 |
44 | # Executing summary will force the backbone device to be changed to cuda, so do to(device) after summary
45 | self.backbone.to(self.device)
46 |
47 | self.feat = []
48 | for layer_map in self.layer_map:
49 | code = 'self.backbone.%s.register_forward_hook(self.hook)' % layer_map
50 | exec(code)
51 |
52 | # dummy forward
53 | x = torch.zeros(1, 3, self.shape_input[0], self.shape_input[1]) # RGB
54 | x = x.to(self.device)
55 | self.feat = []
56 | with torch.no_grad():
57 | _ = self.backbone(x)
58 |
59 | # https://github.com/amazon-science/patchcore-inspection/blob/main/src/patchcore/patchcore.py#L295
60 | self.unfolder = torch.nn.Unfold(kernel_size=self.size_patch, stride=self.stride,
61 | padding=self.padding, dilation=1)
62 | self.patch_shapes = []
63 | for i in range(len(self.feat)):
64 | number_of_total_patches = []
65 | for s in self.feat[i].shape[-2:]:
66 | n_patches = (
67 | s + 2 * self.padding - 1 * (self.size_patch - 1) - 1
68 | ) / self.stride + 1
69 | number_of_total_patches.append(int(n_patches))
70 | self.patch_shapes.append(number_of_total_patches)
71 |
72 | def hook(self, module, input, output):
73 | self.feat.append(output.detach().cpu())
74 |
75 | def HW_map(self):
76 | return self.patch_shapes[self.merge_dst_index]
77 |
78 | def normalize(self, input):
79 | x = torch.from_numpy(input.astype(np.float32))
80 | x = x.to(self.device)
81 | x = x / 255
82 | x = x - self.MEAN
83 | x = x / self.STD
84 | x = x.unsqueeze(0).permute(0, 3, 1, 2)
85 | return x
86 |
87 | # return : is_train=True->torch.Tensor, is_train=False->dict
88 | def extract(self, imgs, case='train (case:good)', batch_size_patchfy=50, show_progress=True):
89 | # feature extract for train and aggregate split-image for explain
90 | x_batch = []
91 | self.feat = []
92 | for i_img in tqdm(range(len(imgs)), desc='extract feature for %s' % case, disable=not show_progress):
93 | img = imgs[i_img]
94 | x = self.normalize(img)
95 | x_batch.append(x)
96 |
97 | if (len(x_batch) == self.batch_size) | (i_img == (len(imgs) - 1)):
98 | with torch.no_grad():
99 | _ = self.backbone(torch.vstack(x_batch))
100 | x_batch = []
101 |
102 | # adjust
103 | feat = []
104 | num_layer = len(self.layer_map)
105 | for i_layer_map in range(num_layer):
106 | feat.append(torch.vstack(self.feat[i_layer_map::num_layer]))
107 |
108 | # patchfy (consider out of memory)
109 | num_patch_per_image = self.HW_map()[0] * self.HW_map()[1]
110 | num_patch = len(imgs) * num_patch_per_image
111 | feat_patchfy = torch.zeros(num_patch, self.dim_merge_feat)
112 |
113 | num_patchfy_process = (len(feat) * 3) + 1
114 | num_iter = np.ceil(len(imgs) / batch_size_patchfy)
115 | pbar = tqdm(total=int(num_patchfy_process * num_iter), desc='patchfy feature', disable=not show_progress)
116 | for i_batch in range(0, len(imgs), batch_size_patchfy):
117 | feat_tmp = []
118 | for feat_layer in feat:
119 | feat_tmp.append(feat_layer[i_batch:(i_batch + batch_size_patchfy)])
120 | i_from = i_batch * num_patch_per_image
121 | i_to = (i_batch + batch_size_patchfy) * num_patch_per_image
122 | feat_patchfy[i_from:i_to] = self.patchfy(feat_tmp, pbar)
123 | pbar.close()
124 |
125 | return feat_patchfy
126 |
127 | def patchfy(self, feat, pbar, batch_size_interp=2000):
128 |
129 | with torch.no_grad():
130 | # unfold
131 | for i in range(len(feat)):
132 | _feat = feat[i]
133 | BC_before_unfold = _feat.shape[:2]
134 | # (B, C, H, W) -> (B, CPHPW, HW)
135 | _feat = self.unfolder(_feat)
136 | # (B, CPHPW, HW) -> (B, C, PH, PW, HW)
137 | _feat = _feat.reshape(*BC_before_unfold,
138 | self.size_patch, self.size_patch, -1)
139 | # (B, C, PH, PW, HW) -> (B, HW, C, PW, HW)
140 | _feat = _feat.permute(0, 4, 1, 2, 3)
141 | feat[i] = _feat
142 | pbar.update(1)
143 |
144 | # expand small feat to fit large features
145 | for i in range(0, len(feat)):
146 | if i == self.merge_dst_index:
147 | continue
148 |
149 | _feat = feat[i]
150 | patch_dims = self.patch_shapes[i]
151 | # (B, HW, C, PW, HW) -> (B, H, W, C, PH, PW)
152 | _feat = _feat.reshape(_feat.shape[0], patch_dims[0],
153 | patch_dims[1], *_feat.shape[2:])
154 | # (B, H, W, C, PH, PW) -> (B, C, PH, PW, H, W)
155 | _feat = _feat.permute(0, -3, -2, -1, 1, 2)
156 | perm_base_shape = _feat.shape
157 | # (B, C, PH, PW, H, W) -> (BCPHPW, H, W)
158 | _feat = _feat.reshape(-1, *_feat.shape[-2:])
159 | # (BCPHPW, H, W) -> (BCPHPW, H_max, W_max)
160 | feat_dst = torch.zeros([len(_feat), self.HW_map()[0], self.HW_map()[1]])
161 | for i_batch in range(0, len(_feat), batch_size_interp):
162 | feat_tmp = _feat[i_batch:(i_batch + batch_size_interp)]
163 | feat_tmp = feat_tmp.unsqueeze(1).to(self.device)
164 | feat_tmp = F.interpolate(feat_tmp,
165 | size=(self.HW_map()[0], self.HW_map()[1]),
166 | mode="bilinear", align_corners=False)
167 | feat_dst[i_batch:(i_batch + batch_size_interp)] = feat_tmp.squeeze(1).cpu()
168 | _feat = feat_dst
169 | # _feat = F.interpolate(_feat.unsqueeze(1),
170 | # size=(self.HW_map()[0], self.HW_map()[1]),
171 | # mode="bilinear", align_corners=False)
172 | # _feat = _feat.squeeze(1)
173 | # for i_batch in range(0, len(_feat), 10000):
174 | # print(torch.sum(torch.abs(_feat[i_batch:(i_batch + 10000)] - __feat[i_batch:(i_batch + 10000)])))
175 | # (BCPHPW, H_max, W_max) -> (B, C, PH, PW, H_max, W_max)
176 | _feat = _feat.reshape(*perm_base_shape[:-2],
177 | self.HW_map()[0], self.HW_map()[1])
178 | # (B, C, PH, PW, H_max, W_max) -> (B, H_max, W_max, C, PH, PW)
179 | _feat = _feat.permute(0, -2, -1, 1, 2, 3)
180 | # (B, H_max, W_max, C, PH, PW) -> (B, H_maxW_max, C, PH, PW)
181 | _feat = _feat.reshape(len(_feat), -1, *_feat.shape[-3:])
182 | feat[i] = _feat
183 | pbar.update(1)
184 |
185 | # aggregate feature vectors
186 | # (B, H, W, C, PH, PW) -> (BHW, C, PH, PW)
187 | feat = [x.reshape(-1, *x.shape[-3:]) for x in feat]
188 | pbar.update(1)
189 |
190 | # adaptive average pooling for each feature vector
191 | for i in range(len(feat)):
192 | _feat = feat[i] * self.layer_weight[i]
193 |
194 | # (BHW, C, PH, PW) -> (BHW, 1, CPHPW)
195 | _feat = _feat.reshape(len(_feat), 1, -1)
196 |
197 | # (BHW, 1, CPHPW) -> (BHW, D_e)
198 | _feat = F.adaptive_avg_pool1d(_feat, self.dim_each_feat).squeeze(1)
199 | feat[i] = _feat
200 | pbar.update(1)
201 |
202 | # concat the two feature vectors and adaptive average pooling
203 | # (BHW, D_e) -> (BHW, D_e*2)
204 | feat = torch.stack(feat, dim=1)
205 | """Returns reshaped and average pooled feat."""
206 | # batchsize x number_of_layers x input_dim -> batchsize x target_dim
207 | # (BHW, D_e*2) -> (BHW, D_m)
208 | feat = feat.reshape(len(feat), 1, -1)
209 | feat = F.adaptive_avg_pool1d(feat, self.dim_merge_feat)
210 | feat = feat.reshape(len(feat), -1)
211 | pbar.update(1)
212 |
213 | return feat
214 |
--------------------------------------------------------------------------------
/models/patchcore.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import cv2
5 | import torch
6 | from tqdm import tqdm
7 | import faiss
8 | from scipy.ndimage import gaussian_filter
9 |
10 |
11 | class PatchCore:
12 | def __init__(self, cfg_patchcore):
13 | self.device = cfg_patchcore.device
14 | self.k = cfg_patchcore.k
15 | self.dim_coreset_feat = cfg_patchcore.dim_coreset_feat
16 | self.num_split_seq = cfg_patchcore.num_split_seq
17 | self.percentage_coreset = cfg_patchcore.percentage_coreset
18 | self.dim_sampling = cfg_patchcore.dim_sampling
19 | self.num_initial_coreset = cfg_patchcore.num_initial_coreset
20 | self.shape_stretch = cfg_patchcore.shape_stretch
21 | self.pixel_outer_decay = cfg_patchcore.pixel_outer_decay
22 |
23 | # prep knn idx
24 | if self.device.type != 'cuda':
25 | self.idx_feat = faiss.IndexFlatL2(self.dim_coreset_feat)
26 | else:
27 | self.idx_feat = faiss.GpuIndexFlatL2(faiss.StandardGpuResources(),
28 | self.dim_coreset_feat,
29 | faiss.GpuIndexFlatConfig())
30 |
31 | if self.dim_sampling is not None:
32 | # prep mapper
33 | self.mapper = torch.nn.Linear(self.dim_coreset_feat, self.dim_sampling,
34 | bias=False).to(self.device)
35 |
36 | self.path_trained = cfg_patchcore.path_trained
37 | os.makedirs(self.path_trained, exist_ok=True)
38 |
39 | self.path_trained = cfg_patchcore.path_trained
40 | os.makedirs(self.path_trained, exist_ok=True)
41 |
42 | def compute_greedy_coreset_idx(self, feat):
43 | feat = feat.to(self.device)
44 | with torch.no_grad():
45 | feat_proj = self.mapper(feat)
46 |
47 | _num_initial_coreset = np.clip(self.num_initial_coreset,
48 | None, len(feat_proj))
49 | start_points = np.random.choice(len(feat_proj), _num_initial_coreset,
50 | replace=False).tolist()
51 |
52 | # computes batchwise Euclidean distances using PyTorch
53 | mat_A = feat_proj
54 | mat_B = feat_proj[start_points]
55 | A_x_A = mat_A.unsqueeze(1).bmm(mat_A.unsqueeze(2)).reshape(-1, 1)
56 | B_x_B = mat_B.unsqueeze(1).bmm(mat_B.unsqueeze(2)).reshape(1, -1)
57 | A_x_B = mat_A.mm(mat_B.T)
58 | # not need sqrt
59 | mat_dist = (-2 * A_x_B + A_x_A + B_x_B).clamp(0, None)
60 |
61 | dist_coreset_anchor = torch.mean(mat_dist, axis=-1, keepdims=True)
62 |
63 | idx_coreset = []
64 | num_coreset_samples = int(len(feat_proj) * self.percentage_coreset)
65 |
66 | with torch.no_grad():
67 | for _ in tqdm(range(num_coreset_samples), desc="sampling"):
68 | idx_select = torch.argmax(dist_coreset_anchor).item()
69 | idx_coreset.append(idx_select)
70 |
71 | mat_A = feat_proj
72 | mat_B = feat_proj[[idx_select]]
73 | # computes batchwise Euclidean distances using PyTorch
74 | A_x_A = mat_A.unsqueeze(1).bmm(mat_A.unsqueeze(2)).reshape(-1, 1)
75 | B_x_B = mat_B.unsqueeze(1).bmm(mat_B.unsqueeze(2)).reshape(1, -1)
76 | A_x_B = mat_A.mm(mat_B.T)
77 | # not need sqrt
78 | mat_select_dist = (-2 * A_x_B + A_x_A + B_x_B).clamp(0, None)
79 |
80 | dist_coreset_anchor = torch.cat([dist_coreset_anchor,
81 | mat_select_dist], dim=-1)
82 | dist_coreset_anchor = torch.min(dist_coreset_anchor,
83 | dim=1).values.reshape(-1, 1)
84 |
85 | idx_coreset = np.array(idx_coreset)
86 | return idx_coreset
87 |
88 | def reset_faiss_index(self):
89 | self.idx_feat.reset()
90 |
91 | def add_neighbor(self, feat_train):
92 | self.idx_feat.add(feat_train.numpy())
93 |
94 | def localization(self, feat_test, HW_map, show_progress=True):
95 | D = {}
96 | D_max = -9999
97 | I = {}
98 | # loop for test cases
99 | for type_test in feat_test.keys():
100 | D[type_test] = []
101 | I[type_test] = []
102 |
103 | # loop for test data
104 | _feat_test = feat_test[type_test]
105 | _feat_test = _feat_test.reshape(-1, (HW_map[0] * HW_map[1]),
106 | self.dim_coreset_feat)
107 | _feat_test = _feat_test.numpy()
108 | num_data = len(_feat_test)
109 | for i in tqdm(range(num_data), desc='localization (case:%s)' % type_test,
110 | disable=not show_progress):
111 | # measure distance pixelwise
112 | score_map, I_tmp = self.measure_dist_pixelwise(_feat_test[i], HW_map)
113 | # adjust score of outer-pixel (provisional heuristic algorithm)
114 | if self.pixel_outer_decay > 0:
115 | score_map[:self.pixel_outer_decay, :] *= 0.6
116 | score_map[-self.pixel_outer_decay:, :] *= 0.6
117 | score_map[:, :self.pixel_outer_decay] *= 0.6
118 | score_map[:, -self.pixel_outer_decay:] *= 0.6
119 | # stock score map
120 | D[type_test].append(score_map)
121 | D_max = max(D_max, np.max(score_map))
122 | I[type_test].append(I_tmp)
123 |
124 | # cast list to numpy array
125 | D[type_test] = np.array(D[type_test])
126 | I[type_test] = np.array(I[type_test])
127 |
128 | return D, D_max, I
129 |
130 | def measure_dist_pixelwise(self, feat_test, HW_map):
131 | # k nearest neighbor
132 | D, I = self.idx_feat.search(feat_test, self.k)
133 | D = np.mean(D, axis=-1)
134 |
135 | # transform to scoremap
136 | score_map = D.reshape(*HW_map)
137 | score_map = cv2.resize(score_map, (self.shape_stretch[1], self.shape_stretch[0]))
138 |
139 | # apply gaussian smoothing on the score map
140 | score_map_smooth = gaussian_filter(score_map, sigma=4)
141 |
142 | return score_map_smooth, I
143 |
144 | def save_faiss_index(self, type_data):
145 | path_faiss_idx = '%s/%s.idx' % (self.path_trained, type_data)
146 | idx_feat_cpu = faiss.index_gpu_to_cpu(self.idx_feat)
147 | faiss.write_index(idx_feat_cpu, path_faiss_idx)
148 |
149 | def load_faiss_index(self, type_data):
150 | path_faiss_idx = '%s/%s.idx' % (self.path_trained, type_data)
151 | self.idx_feat = faiss.read_index(path_faiss_idx)
152 |
153 | def load_faiss_index_direct(self, path_faiss_idx):
154 | self.idx_feat = faiss.read_index(path_faiss_idx)
155 |
156 | def pickup_patch(self, idx_patch, imgs, HW_map, size_receptive):
157 | h = imgs.shape[-3]
158 | w = imgs.shape[-2]
159 |
160 | # calculate half size for split
161 | h_half = int((size_receptive - 1) / 2)
162 | w_half = int((size_receptive - 1) / 2)
163 |
164 | # calculate center-coordinates of split-image
165 | y_pitch = np.arange(0, (h - 1 + 1e-10), ((h - 1) / (HW_map[0] - 1)))
166 | y_pitch = np.round(y_pitch).astype(np.int16)
167 | y_pitch = y_pitch + h_half
168 | x_pitch = np.arange(0, (w - 1 + 1e-10), ((w - 1) / (HW_map[1] - 1)))
169 | x_pitch = np.round(x_pitch).astype(np.int16)
170 | x_pitch = x_pitch + w_half
171 | # padding to normal images
172 | imgs = np.pad(imgs, ((0, 0), (h_half, h_half), (w_half, w_half), (0, 0)))
173 |
174 | # collect piece image
175 | img_piece_list = []
176 | for i_patch in idx_patch:
177 | i_img = i_patch // (HW_map[0] * HW_map[1])
178 | i_HW = i_patch % (HW_map[0] * HW_map[1])
179 | i_H = i_HW // HW_map[1]
180 | i_W = i_HW % HW_map[1]
181 |
182 | img = imgs[i_img]
183 | y = y_pitch[i_H]
184 | x = x_pitch[i_W]
185 | img_piece = img[(y - h_half):(y + h_half + 1), (x - w_half):(x + w_half + 1)]
186 | img_piece_list.append(img_piece)
187 |
188 | img_piece_array = np.stack(img_piece_list)
189 |
190 | return img_piece_array
191 |
192 | def save_coreset(self, idx_coreset, type_data, imgs_train, HW_map, size_receptive):
193 | # save patch images of coreset
194 | imgs_coreset = self.pickup_patch(idx_coreset, imgs_train, HW_map, size_receptive)
195 | path_coreset_img = '%s/%s_img_coreset.npy' % (self.path_trained, type_data)
196 | np.save(path_coreset_img, imgs_coreset)
197 |
198 | return imgs_coreset
199 |
200 | def load_coreset(self, type_data):
201 | # load patch images of coreset
202 | path_coreset_img = '%s/%s_img_coreset.npy' % (self.path_trained, type_data)
203 | imgs_coreset = np.load(path_coreset_img)
204 |
205 | return imgs_coreset
206 |
207 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | faiss-gpu
2 | matplotlib
3 | numpy
4 | numba
5 | opencv-python
6 | scikit-learn
7 | scipy
8 | torch
9 | torchinfo
10 | torchvision
11 | tqdm
12 |
--------------------------------------------------------------------------------
/traintest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | import json
4 |
5 | import numpy as np
6 | import torch
7 | import multiprocessing as mp
8 |
9 | from utils.config import ConfigData, ConfigFeat, ConfigPatchCore, ConfigDraw
10 | from utils.tictoc import tic, toc
11 | from utils.metrics import calc_imagewise_metrics, calc_pixelwise_metrics
12 | from utils.visualize import draw_curve, draw_distance_graph, draw_heatmap
13 | from datasets.mvtec_dataset import MVTecDataset
14 | from models.feat_extract import FeatExtract
15 | from models.patchcore import PatchCore
16 |
17 |
18 | def arg_parser():
19 | parser = ArgumentParser()
20 | # environment related
21 | parser.add_argument('-n', '--num_cpu_max', default=4, type=int,
22 | help='number of CPUs for parallel reading input images')
23 | parser.add_argument('-c', '--cpu', action='store_true', help='use cpu')
24 |
25 | # I/O and visualization related
26 | parser.add_argument('-pd', '--path_data', type=str, default='./mvtec_anomaly_detection',
27 | help='parent path of input data path')
28 | parser.add_argument('-pt', '--path_trained', type=str, default='./trained',
29 | help='output path of trained products')
30 | parser.add_argument('-pr', '--path_result', type=str, default='./result',
31 | help='output path of figure image as the evaluation result')
32 | parser.add_argument('-v', '--verbose', action='store_true',
33 | help='save visualization of localization')
34 | parser.add_argument('-srf', '--size_receptive', type=int, default=9,
35 | help='estimate and specify receptive field size (odd number)')
36 | parser.add_argument('-mv', '--mode_visualize', type=str, default='eval',
37 | choices=['eval', 'infer'], help='set mode, [eval] or [infer]')
38 | parser.add_argument('-sm', '--score_max', type=float, default=None,
39 | help='value for normalization of visualizing')
40 |
41 | # data loader related
42 | parser.add_argument('-bs', '--batch_size', type=int, default=16,
43 | help='batch-size for feature extraction by ImageNet model')
44 | parser.add_argument('-sr', '--size_resize', nargs=2, type=int, default=[256, 256],
45 | help='size of resizing input image')
46 | parser.add_argument('-sc', '--size_crop', nargs=2, type=int, default=[224, 224],
47 | help='size of cropping after resize')
48 | parser.add_argument('-fh', '--flip_horz', action='store_true', help='flip horizontal')
49 | parser.add_argument('-fv', '--flip_vert', action='store_true', help='flip vertical')
50 | parser.add_argument('-tt', '--types_data', nargs='*', type=str, default=None)
51 | # feature extraction related
52 | parser.add_argument('-b', '--backbone', type=str,
53 | default='torchvision.models.wide_resnet50_2',
54 | help='specify torchvision model with the full path')
55 | parser.add_argument('-w', '--weight', type=str,
56 | default='torchvision.models.Wide_ResNet50_2_Weights.IMAGENET1K_V1',
57 | help='specify the trained weights of torchvision model with the full path')
58 | parser.add_argument('-lm', '--layer_map', nargs='+', type=str,
59 | default=['layer2[-1]', 'layer3[-1]'],
60 | help='specify layers to extract feature map')
61 | parser.add_argument('-lw', '--layer_weight', nargs='+', type=float,
62 | default=[1.0, 1.0],
63 | help='specify layers weights for merge of feature map')
64 | parser.add_argument('-lmr', '--layer_merge_ref', type=str, default='layer2[-1]',
65 | help='specify the layer to use as a reference for spatial size when merging feature maps')
66 |
67 | # patchification related
68 | parser.add_argument('-sp', '--size_patch', type=int, default=3,
69 | help='patch pixel of feature map for increasing receptive field size')
70 | parser.add_argument('-de', '--dim_each_feat', type=int, default=1024,
71 | help='dimension of extract feature (at 1st adaptive average pooling)')
72 | parser.add_argument('-dm', '--dim_merge_feat', type=int, default=1024,
73 | help='dimension after layer feature merging (at 2nd adaptive average pooling)')
74 |
75 | # coreset related
76 | parser.add_argument('-s', '--seed', type=int, default=0,
77 | help='specify a random-seed for k-center-greedy')
78 | parser.add_argument('-ns', '--num_split_seq', type=int, default=1,
79 | help='percentage of coreset to all patch features')
80 | parser.add_argument('-pc', '--percentage_coreset', type=float, default=0.01,
81 | help='percentage of coreset to all patch features')
82 | parser.add_argument('-ds', '--dim_sampling', type=int, default=128,
83 | help='dimension to project features for sampling')
84 | parser.add_argument('-ni', '--num_initial_coreset', type=int, default=10,
85 | help='number of samples to initially randomly select coreset')
86 | # Nearest-Neighbor related
87 | parser.add_argument('-k', '--k', type=int, default=5,
88 | help='nearest neighbor\'s k for coreset searching')
89 | # post precessing related
90 | parser.add_argument('-pod', '--pixel_outer_decay', type=int, default=0,
91 | help='number of outer pixels to decay anomaly score')
92 |
93 | args = parser.parse_args()
94 |
95 | print('args =\n', args)
96 | return args
97 |
98 |
99 | def check_args(args):
100 | assert 0 < args.num_cpu_max < os.cpu_count()
101 | assert os.path.isdir(args.path_data)
102 | assert (args.size_receptive % 2) == 1
103 | assert args.size_receptive > 0
104 | if args.score_max is not None:
105 | assert args.score_max > 0
106 | assert args.batch_size > 0
107 | assert args.size_resize[0] > 0
108 | assert args.size_resize[1] > 0
109 | assert args.size_crop[0] > 0
110 | assert args.size_crop[1] > 0
111 | if args.types_data is not None:
112 | for type_data in args.types_data:
113 | assert os.path.isdir('%s/%s' % (args.path_data, type_data))
114 | assert len(args.layer_map) == len(args.layer_weight)
115 | assert args.layer_merge_ref in args.layer_map
116 | assert args.size_patch > 0
117 | assert args.dim_each_feat > 0
118 | assert args.dim_merge_feat > 0
119 | assert args.num_split_seq > 0
120 | assert 0.0 < args.percentage_coreset <= 1.0
121 | assert args.dim_sampling > 0
122 | assert args.num_initial_coreset > 0
123 | assert args.k > 0
124 | assert args.pixel_outer_decay >= 0
125 |
126 |
127 | def set_seed(seed, gpu=True):
128 | np.random.seed(seed)
129 | torch.manual_seed(seed)
130 | if gpu:
131 | torch.cuda.manual_seed(seed)
132 | torch.cuda.manual_seed_all(seed)
133 |
134 |
135 | def apply_patchcore(type_data, feat_ext, patchcore, cfg_draw):
136 | print('\n----> PatchCore processing in %s start' % type_data)
137 | tic()
138 |
139 | # read images
140 | MVTecDataset(type_data)
141 |
142 | # reset neighbor
143 | patchcore.reset_faiss_index()
144 |
145 | # reset total index
146 | idx_coreset_total = []
147 |
148 | # loop of split-sequential to apply k-center-greedy
149 | num_pitch = int(np.ceil(len(MVTecDataset.imgs_train) / patchcore.num_split_seq))
150 | for i_split in range(patchcore.num_split_seq):
151 | # extract features
152 | i_from = i_split * num_pitch
153 | i_to = min(((i_split + 1) * num_pitch), len(MVTecDataset.imgs_train))
154 | if patchcore.num_split_seq > 1:
155 | print('[split%02d] image index range is %d~%d' % (i_split, i_from, (i_to - 1)))
156 | feat_train = feat_ext.extract(MVTecDataset.imgs_train[i_from:i_to])
157 |
158 | # coreset-reduced patch-feature memory bank
159 | idx_coreset = patchcore.compute_greedy_coreset_idx(feat_train)
160 | feat_train = feat_train[idx_coreset]
161 |
162 | # add feature as neighbor
163 | patchcore.add_neighbor(feat_train)
164 |
165 | # stock index
166 | offset_split = i_from * feat_ext.HW_map()[0] * feat_ext.HW_map()[1]
167 | idx_coreset_total.append(idx_coreset + offset_split)
168 |
169 | # save faiss index
170 | patchcore.save_faiss_index(type_data)
171 |
172 | # concat index
173 | idx_coreset_total = np.hstack(idx_coreset_total)
174 |
175 | # save and get images of coreset
176 | imgs_coreset = patchcore.save_coreset(idx_coreset_total, type_data,
177 | MVTecDataset.imgs_train, feat_ext.HW_map(),
178 | cfg_draw.size_receptive)
179 |
180 | # extract features
181 | feat_test = {}
182 | for type_test in MVTecDataset.imgs_test.keys():
183 | feat_test[type_test] = feat_ext.extract(MVTecDataset.imgs_test[type_test],
184 | case='test (case:%s)' % type_test)
185 |
186 | # Sub-Image Anomaly Detection with Deep Pyramid Correspondences
187 | D, D_max, I = patchcore.localization(feat_test, feat_ext.HW_map())
188 |
189 | # measure per image
190 | fpr_img, tpr_img, rocauc_img, pre_img, rec_img, prauc_img = calc_imagewise_metrics(D)
191 | print('%s imagewise ROCAUC: %.3f' % (type_data, rocauc_img))
192 |
193 | # measure per pixel
194 | (fpr_pix, tpr_pix, rocauc_pix,
195 | pre_pix, rec_pix, prauc_pix, thresh_opt) = calc_pixelwise_metrics(D, MVTecDataset.gts_test)
196 | print('%s pixelwise ROCAUC: %.3f' % (type_data, rocauc_pix))
197 |
198 | # save optimal threshold
199 | np.savetxt('%s/%s_thr.txt' % (args.path_trained, type_data),
200 | np.array([thresh_opt]), fmt='%.3f')
201 |
202 | toc(tag=('----> PatchCore processing in %s end, elapsed time' % type_data))
203 |
204 | draw_distance_graph(type_data, cfg_draw, D, rocauc_img)
205 | if cfg_draw.verbose:
206 | draw_heatmap(type_data, cfg_draw, D, MVTecDataset.gts_test, D_max, I,
207 | MVTecDataset.imgs_test, MVTecDataset.files_test,
208 | imgs_coreset, feat_ext.HW_map())
209 |
210 | return [fpr_img, tpr_img, rocauc_img, pre_img, rec_img, prauc_img,
211 | fpr_pix, tpr_pix, rocauc_pix, pre_pix, rec_pix, prauc_pix]
212 |
213 |
214 | def main(args):
215 | ConfigData(args) # static define for speed-up
216 | cfg_feat = ConfigFeat(args)
217 | cfg_patchcore = ConfigPatchCore(args)
218 | cfg_draw = ConfigDraw(args)
219 |
220 | feat_ext = FeatExtract(cfg_feat)
221 | patchcore = PatchCore(cfg_patchcore)
222 |
223 | os.makedirs(args.path_result, exist_ok=True)
224 | for type_data in ConfigData.types_data:
225 | os.makedirs('%s/%s' % (args.path_result, type_data), exist_ok=True)
226 |
227 | fpr_img = {}
228 | tpr_img = {}
229 | rocauc_img = {}
230 | pre_img = {}
231 | rec_img = {}
232 | prauc_img = {}
233 | fpr_pix = {}
234 | tpr_pix = {}
235 | rocauc_pix = {}
236 | pre_pix = {}
237 | rec_pix = {}
238 | prauc_pix = {}
239 |
240 | # loop for types of data
241 | for type_data in ConfigData.types_data:
242 | set_seed(seed=args.seed, gpu=(not args.cpu))
243 |
244 | result = apply_patchcore(type_data, feat_ext, patchcore, cfg_draw)
245 |
246 | fpr_img[type_data] = result[0]
247 | tpr_img[type_data] = result[1]
248 | rocauc_img[type_data] = result[2]
249 | pre_img[type_data] = result[3]
250 | rec_img[type_data] = result[4]
251 | prauc_img[type_data] = result[5]
252 |
253 | fpr_pix[type_data] = result[6]
254 | tpr_pix[type_data] = result[7]
255 | rocauc_pix[type_data] = result[8]
256 | pre_pix[type_data] = result[9]
257 | rec_pix[type_data] = result[10]
258 | prauc_pix[type_data] = result[11]
259 |
260 | rocauc_img_mean = np.array([rocauc_img[type_data] for type_data in ConfigData.types_data])
261 | rocauc_img_mean = np.mean(rocauc_img_mean)
262 | prauc_img_mean = np.array([prauc_img[type_data] for type_data in ConfigData.types_data])
263 | prauc_img_mean = np.mean(prauc_img_mean)
264 | rocauc_pix_mean = np.array([rocauc_pix[type_data] for type_data in ConfigData.types_data])
265 | rocauc_pix_mean = np.mean(rocauc_pix_mean)
266 | prauc_pix_mean = np.array([prauc_pix[type_data] for type_data in ConfigData.types_data])
267 | prauc_pix_mean = np.mean(prauc_pix_mean)
268 |
269 | draw_curve(cfg_draw, fpr_img, tpr_img, rocauc_img, rocauc_img_mean,
270 | fpr_pix, tpr_pix, rocauc_pix, rocauc_pix_mean)
271 | draw_curve(cfg_draw, rec_img, pre_img, prauc_img, prauc_img_mean,
272 | rec_pix, pre_pix, prauc_pix, prauc_pix_mean, False)
273 |
274 | for type_data in ConfigData.types_data:
275 | print('rocauc_img[%s] = %.3f' % (type_data, rocauc_img[type_data]))
276 | print('rocauc_img[mean] = %.3f' % rocauc_img_mean)
277 | for type_data in ConfigData.types_data:
278 | print('prauc_img[%s] = %.3f' % (type_data, prauc_img[type_data]))
279 | print('prauc_img[mean] = %.3f' % prauc_img_mean)
280 | for type_data in ConfigData.types_data:
281 | print('rocauc_pix[%s] = %.3f' % (type_data, rocauc_pix[type_data]))
282 | print('rocauc_pix[mean] = %.3f' % rocauc_pix_mean)
283 | for type_data in ConfigData.types_data:
284 | print('prauc_pix[%s] = %.3f' % (type_data, prauc_pix[type_data]))
285 | print('prauc_pix[mean] = %.3f' % prauc_pix_mean)
286 |
287 | if __name__ == '__main__':
288 | args = arg_parser()
289 | check_args(args)
290 | main(args)
291 |
292 | with open('%s/args.json' % args.path_trained, mode='w') as f:
293 | json.dump(args.__dict__, f, indent=4)
294 |
295 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import torch
5 |
6 | # https://pytorch.org/vision/main/models/generated/torchvision.models.wide_resnet50_2.html
7 | MEAN = torch.FloatTensor([[[0.485, 0.456, 0.406]]])
8 | STD = torch.FloatTensor([[[0.229, 0.224, 0.225]]])
9 |
10 |
11 | class ConfigData:
12 | @classmethod
13 | def __init__(cls, args, mode_train=True):
14 | # file reading related
15 | if mode_train:
16 | cls.path_data = args.path_data
17 | else:
18 | cls.path_data = args.path_data
19 | cls.path_video = args.path_video
20 | cls.num_cpu_max = args.num_cpu_max
21 | if mode_train:
22 | cls.shuffle = (args.num_split_seq > 1) # for k-center-greedy split
23 | else:
24 | cls.shuffle = None
25 |
26 | if mode_train:
27 | # input format related
28 | cls.SHAPE_MIDDLE = (args.size_resize[0], args.size_resize[1]) # (H, W)
29 | cls.SHAPE_INPUT = (args.size_crop[0], args.size_crop[1]) # (H, W)
30 | cls.pixel_cut = (int((cls.SHAPE_MIDDLE[0] - cls.SHAPE_INPUT[0]) / 2),
31 | int((cls.SHAPE_MIDDLE[1] - cls.SHAPE_INPUT[1]) / 2)) # (H, W)
32 | else:
33 | cls.SHAPE_MIDDLE = None
34 | cls.SHAPE_INPUT = None
35 | cls.pixel_cut = None
36 |
37 | if mode_train:
38 | # augmantation related
39 | cls.flip_horz = args.flip_horz
40 | cls.flip_vert = args.flip_vert
41 | else:
42 | cls.flip_horz = None
43 | cls.flip_vert = None
44 |
45 | # collect types of data
46 | if args.types_data is None:
47 | types_data = [d for d in os.listdir(args.path_data)
48 | if os.path.isdir('%s/%s' % (args.path_data, d))]
49 | cls.types_data = np.sort(np.array(types_data))
50 | else:
51 | cls.types_data = np.sort(np.array(args.types_data))
52 |
53 | @classmethod
54 | def follow(cls, args_trained):
55 | # input format related
56 | cls.SHAPE_MIDDLE = (args_trained['size_resize'][0],
57 | args_trained['size_resize'][1]) # (H, W)
58 | cls.SHAPE_INPUT = (args_trained['size_crop'][0],
59 | args_trained['size_crop'][1]) # (H, W)
60 | cls.pixel_cut = (int((cls.SHAPE_MIDDLE[0] - cls.SHAPE_INPUT[0]) / 2),
61 | int((cls.SHAPE_MIDDLE[1] - cls.SHAPE_INPUT[1]) / 2)) # (H, W)
62 |
63 |
64 | class ConfigFeat:
65 | def __init__(self, args, mode_train=True):
66 | # adjsut to environment
67 | if args.cpu:
68 | self.device = torch.device('cpu')
69 | else:
70 | self.device = torch.device('cuda:0')
71 |
72 | # batch-size for feature extraction by ImageNet model
73 | self.batch_size = args.batch_size
74 |
75 | if mode_train:
76 | # input format related
77 | self.SHAPE_INPUT = (args.size_crop[0], args.size_crop[1]) # (H, W)
78 |
79 | # base network
80 | self.backbone = args.backbone
81 | self.weight = args.weight
82 |
83 | # layer specification
84 | self.layer_map = args.layer_map
85 | self.layer_weight = args.layer_weight
86 | self.layer_merge_ref = args.layer_merge_ref
87 |
88 | # patch pixel of feature map for increasing receptive field size
89 | self.size_patch = args.size_patch
90 |
91 | # dimension of each layer feature (at 1st adaptive average pooling)
92 | self.dim_each_feat = args.dim_each_feat
93 | # dimension after layer feature merging (at 2nd adaptive average pooling)
94 | self.dim_merge_feat = args.dim_merge_feat
95 | else:
96 | self.SHAPE_INPUT = None
97 | self.backbone = None
98 | self.weight = None
99 | self.layer_map = None
100 | self.layer_weight = None
101 | self.layer_merge_ref = None
102 | self.size_patch = None
103 | self.dim_each_feat = None
104 | self.dim_merge_feat = None
105 |
106 | # adjust to the network's learning policy and the data conditions
107 | self.MEAN = MEAN.to(self.device)
108 | self.STD = STD.to(self.device)
109 |
110 | def follow(self, args_trained):
111 | # input format related
112 | self.SHAPE_INPUT = (args_trained['size_crop'][0],
113 | args_trained['size_crop'][1]) # (H, W)
114 | # base network
115 | self.backbone = args_trained['backbone']
116 | self.weight = args_trained['weight']
117 | # layer specification
118 | self.layer_map = args_trained['layer_map']
119 | self.layer_weight = args_trained['layer_weight']
120 | self.layer_merge_ref = args_trained['layer_merge_ref']
121 | # patch pixel of feature map for increasing receptive field size
122 | self.size_patch = args_trained['size_patch']
123 | # dimension of each layer feature (at 1st adaptive average pooling)
124 | self.dim_each_feat = args_trained['dim_each_feat']
125 | # dimension after layer feature merging (at 2nd adaptive average pooling)
126 | self.dim_merge_feat = args_trained['dim_merge_feat']
127 |
128 |
129 | class ConfigPatchCore:
130 | def __init__(self, args, mode_train=True):
131 | # adjsut to environment
132 | if args.cpu:
133 | self.device = torch.device('cpu')
134 | else:
135 | self.device = torch.device('cuda:0')
136 |
137 | if mode_train:
138 | # dimension after layer feature merging (at 2nd adaptive average pooling)
139 | self.dim_coreset_feat = args.dim_merge_feat
140 |
141 | # number split-sequential to apply k-center-greedy
142 | self.num_split_seq = args.num_split_seq
143 | # percentage of coreset to all patch features
144 | self.percentage_coreset = args.percentage_coreset
145 | # dimension to project features for sampling
146 | self.dim_sampling = args.dim_sampling
147 | # number of samples to initially randomly select coreset
148 | self.num_initial_coreset = args.num_initial_coreset
149 |
150 | # input format related
151 | self.shape_stretch = (args.size_crop[0], args.size_crop[1]) # (H, W)
152 | else:
153 | self.dim_coreset_feat = None
154 | self.num_split_seq = None
155 | self.percentage_coreset = None
156 | self.dim_sampling = None
157 | self.num_initial_coreset = None
158 | self.shape_stretch = None
159 |
160 | # number of nearest neighbor to get patch images
161 | self.k = args.k
162 |
163 | # consideration for the outer edge
164 | self.pixel_outer_decay = args.pixel_outer_decay
165 |
166 | # output path of trained something
167 | self.path_trained = args.path_trained
168 |
169 | def follow(self, args_trained):
170 | # dimension after layer feature merging (at 2nd adaptive average pooling)
171 | self.dim_coreset_feat = args_trained['dim_merge_feat']
172 | # input format related
173 | self.shape_stretch = (args_trained['size_crop'][0],
174 | args_trained['size_crop'][1]) # (H, W)
175 |
176 |
177 | class ConfigDraw:
178 | def __init__(self, args, mode_train=True):
179 | # output detail or not (take a long time...)
180 | self.verbose = args.verbose
181 |
182 | # value for normalization of visualizing
183 | self.score_max = args.score_max
184 |
185 | # output filename related
186 | self.k = args.k
187 |
188 | if mode_train:
189 | # output filename related
190 | self.percentage_coreset = args.percentage_coreset
191 |
192 | # receptive field size
193 | self.size_receptive = args.size_receptive
194 |
195 | # aspect_ratio of output figure
196 | self.aspect_figure = args.size_crop[1] / args.size_crop[0] # W / H
197 | self.aspect_figure = np.round(self.aspect_figure, decimals=1)
198 |
199 | # visualize mode
200 | self.mode_visualize = args.mode_visualize
201 | self.mode_video = False
202 | else:
203 | self.percentage_coreset = None
204 | self.size_receptive = None
205 | self.aspect_figure = None
206 | self.mode_visualize = 'infer'
207 | if args.path_video is None:
208 | self.mode_video = False
209 | else:
210 | self.mode_video = True
211 | capture = cv2.VideoCapture(args.path_video)
212 | self.fps_video = capture.get(cv2.CAP_PROP_FPS)
213 | capture.release()
214 |
215 | # output path of figure
216 | self.path_result = args.path_result
217 |
218 | def follow(self, args_trained):
219 | # output filename related
220 | self.percentage_coreset = args_trained['percentage_coreset']
221 | # receptive field size
222 | self.size_receptive = args_trained['size_receptive']
223 | # aspect_ratio of output figure
224 | self.aspect_figure = args_trained['size_crop'][1] / args_trained['size_crop'][0] # W / H
225 | self.aspect_figure = np.round(self.aspect_figure, decimals=1)
226 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | from sklearn.metrics import auc, roc_curve, precision_recall_curve
4 |
5 |
6 | def calc_imagewise_metrics(D, type_normal='good'):
7 | D_list = []
8 | y_list = []
9 |
10 | pbar = tqdm(total=int(len(D.keys()) + 5),
11 | desc='calculate imagewise metrics')
12 | for type_test in D.keys():
13 | for i in range(len(D[type_test])):
14 | D_tmp = np.max(D[type_test][i])
15 | y_tmp = int(type_test != type_normal)
16 |
17 | D_list.append(D_tmp)
18 | y_list.append(y_tmp)
19 | pbar.update(1)
20 |
21 | D_flat_list = np.array(D_list).reshape(-1)
22 | y_flat_list = np.array(y_list).reshape(-1)
23 | pbar.update(1)
24 |
25 | fpr, tpr, _ = roc_curve(y_flat_list, D_flat_list)
26 | pbar.update(1)
27 |
28 | rocauc = auc(fpr, tpr)
29 | pbar.update(1)
30 |
31 | pre, rec, _ = precision_recall_curve(y_flat_list, D_flat_list)
32 | pbar.update(1)
33 |
34 | # https://sinyi-chou.github.io/python-sklearn-precision-recall/
35 | prauc = auc(rec, pre)
36 | pbar.update(1)
37 | pbar.close()
38 |
39 | return fpr, tpr, rocauc, pre, rec, prauc
40 |
41 |
42 | def calc_pixelwise_metrics(D, y):
43 | D_list = []
44 | y_list = []
45 |
46 | pbar = tqdm(total=int(len(D.keys()) + 6),
47 | desc='calculate pixelwise metrics')
48 | for type_test in D.keys():
49 | for i in range(len(D[type_test])):
50 | D_tmp = D[type_test][i]
51 | y_tmp = y[type_test][i]
52 |
53 | D_list.append(D_tmp)
54 | y_list.append(y_tmp)
55 | pbar.update(1)
56 |
57 | D_flat_list = np.array(D_list).reshape(-1)
58 | y_flat_list = np.array(y_list).reshape(-1)
59 | pbar.update(1)
60 |
61 | fpr, tpr, _ = roc_curve(y_flat_list, D_flat_list)
62 | pbar.update(1)
63 |
64 | rocauc = auc(fpr, tpr)
65 | pbar.update(1)
66 |
67 | pre, rec, thresh = precision_recall_curve(y_flat_list, D_flat_list)
68 | pbar.update(1)
69 |
70 | prauc = auc(rec, pre)
71 | pbar.update(1)
72 |
73 | # https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master/blob/main/main.py#L193C1-L200C1
74 | # get optimal threshold
75 | a = 2 * pre * rec
76 | b = pre + rec
77 | f1 = np.divide(a, b, out=np.zeros_like(a), where=(b != 0))
78 | i_opt = np.argmax(f1)
79 | thresh_opt = thresh[i_opt]
80 | pbar.update(1)
81 | pbar.close()
82 |
83 | print('pixelwise optimal threshold:%.3f (precision:%.3f, recall:%.3f)' %
84 | (thresh_opt, pre[i_opt], rec[i_opt]))
85 |
86 | return fpr, tpr, rocauc, pre, rec, prauc, thresh_opt
87 |
--------------------------------------------------------------------------------
/utils/tictoc.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | def tic():
5 | # require to import time
6 | global start_time_tictoc
7 | start_time_tictoc = time.time()
8 |
9 |
10 | def toc(tag="elapsed time"):
11 | if "start_time_tictoc" in globals():
12 | print("{}: {:.1f} [sec]".format(tag, time.time() - start_time_tictoc))
13 | else:
14 | print("tic has not been called")
15 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import matplotlib.pyplot as plt
5 | from tqdm import tqdm
6 |
7 |
8 | # https://github.com/gsurma/cnn_explainer/blob/main/utils.py
9 | def overlay_heatmap_on_image(img, heatmap, ratio_img=0.5):
10 | img = img.astype(np.float32)
11 |
12 | heatmap = 1 - np.clip(heatmap, 0, 1)
13 | heatmap = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
14 | heatmap = heatmap.astype(np.float32)
15 |
16 | overlay = (img * ratio_img) + (heatmap * (1 - ratio_img))
17 | overlay = np.clip(overlay, 0, 255)
18 | overlay = overlay.astype(np.uint8)
19 | return overlay
20 |
21 |
22 | def draw_distance_graph(type_data, cfg_draw, D, rocauc_img):
23 | D_list = {}
24 | for type_test in D.keys():
25 | D_list[type_test] = []
26 | for i_D in range(len(D[type_test])):
27 | D_tmp = np.max(D[type_test][i_D])
28 | D_list[type_test].append(D_tmp)
29 | D_list[type_test] = np.array(D_list[type_test])
30 |
31 | plt.figure(figsize=(10, 8), dpi=100, facecolor='white')
32 |
33 | # 'good' 1st
34 | N_test = 0
35 | type_test = 'good'
36 | plt.subplot(2, 1, 1)
37 | plt.scatter((np.arange(len(D_list[type_test])) + N_test), D_list[type_test],
38 | alpha=0.5, label=type_test)
39 | plt.subplot(2, 1, 2)
40 | plt.hist(D_list[type_test], alpha=0.5, label=type_test, bins=10)
41 |
42 | # other than 'good'
43 | N_test += len(D_list[type_test])
44 | types_test = np.array([k for k in D_list.keys() if k != 'good'])
45 | for type_test in types_test:
46 | plt.subplot(2, 1, 1)
47 | plt.scatter((np.arange(len(D_list[type_test])) + N_test), D_list[type_test], alpha=0.5, label=type_test)
48 | plt.subplot(2, 1, 2)
49 | plt.hist(D_list[type_test], alpha=0.5, label=type_test, bins=10)
50 | N_test += len(D_list[type_test])
51 |
52 | plt.subplot(2, 1, 1)
53 | plt.title('imagewise ROCAUC %% : %.3f' % rocauc_img)
54 | plt.grid()
55 | plt.legend(loc='upper left')
56 | plt.subplot(2, 1, 2)
57 | plt.grid()
58 | plt.legend(loc='upper right')
59 | plt.gcf().tight_layout()
60 | plt.gcf().savefig(('%s/%s/pred-dist_%s_p%04d_k%02d_rocaucimg%04d.png' %
61 | (cfg_draw.path_result, type_data,
62 | type_data, (cfg_draw.percentage_coreset * 1000),
63 | cfg_draw.k, round(rocauc_img * 1000))))
64 | plt.clf()
65 | plt.close()
66 |
67 |
68 | def draw_heatmap(type_data, cfg_draw, D, y, D_max, I, imgs_test, files_test,
69 | imgs_coreset, HW_map):
70 |
71 | if cfg_draw.mode_visualize == 'eval':
72 | fig_width = 10 * max(1, cfg_draw.aspect_figure)
73 | fig_height = 18
74 | pixel_cut=[160, 140, 60, 60] # [top, bottom, left right]
75 | else:
76 | fig_width = 20 * max(1, cfg_draw.aspect_figure)
77 | fig_height = 16
78 | pixel_cut=[140, 120, 180, 180] # [top, bottom, left right]
79 | dpi = 100
80 |
81 | for type_test in D.keys():
82 |
83 | if cfg_draw.mode_video:
84 | filename_out = ('%s/%s/localization_%s_%s_p%04d_k%02d.mp4' %
85 | (cfg_draw.path_result, type_data,
86 | type_data, files_test[type_test][0].split('.')[0],
87 | (cfg_draw.percentage_coreset * 1000), cfg_draw.k))
88 | # build writer
89 | codecs = 'mp4v'
90 | fourcc = cv2.VideoWriter_fourcc(*codecs)
91 | width = (fig_width * dpi) - pixel_cut[2] - pixel_cut[3]
92 | height = (fig_height * dpi) - pixel_cut[0] - pixel_cut[1]
93 | writer = cv2.VideoWriter(filename_out, fourcc, cfg_draw.fps_video,
94 | (width, height), True)
95 |
96 | desc = '[verbose mode] visualize localization (case:%s)' % type_test
97 | for i_D in tqdm(range(len(D[type_test])), desc=desc):
98 | file = files_test[type_test][i_D]
99 | img = imgs_test[type_test][i_D]
100 | score_map = D[type_test][i_D]
101 | if cfg_draw.score_max is None:
102 | score_max = D_max
103 | else:
104 | score_max = cfg_draw.score_max
105 | if y is not None:
106 | gt = y[type_test][i_D]
107 |
108 | I_tmp = I[type_test][i_D, :, 0]
109 | img_patch = assemble_patch(I_tmp, imgs_coreset, HW_map)
110 |
111 | plt.figure(figsize=(fig_width, fig_height), dpi=dpi, facecolor='white')
112 | plt.rcParams['font.size'] = 10
113 |
114 | score_map_reg = score_map / score_max
115 |
116 | if cfg_draw.mode_visualize == 'eval':
117 | plt.subplot2grid((7, 3), (0, 0), rowspan=1, colspan=1)
118 | plt.imshow(img)
119 | plt.title('%s : %s' % (file.split('/')[-2], file.split('/')[-1]))
120 |
121 | plt.subplot2grid((7, 3), (0, 1), rowspan=1, colspan=1)
122 | plt.imshow(gt)
123 |
124 | plt.subplot2grid((7, 3), (0, 2), rowspan=1, colspan=1)
125 | plt.imshow(score_map)
126 | plt.colorbar()
127 | plt.title('max score : %.2f' % score_max)
128 |
129 | plt.subplot2grid((42, 2), (7, 0), rowspan=10, colspan=1)
130 | plt.imshow(overlay_heatmap_on_image(img, score_map_reg))
131 |
132 | plt.subplot2grid((42, 2), (7, 1), rowspan=10, colspan=1)
133 | plt.imshow((img.astype(np.float32) * score_map_reg[..., None]).astype(np.uint8))
134 |
135 | plt.subplot2grid((21, 1), (10, 0), rowspan=11, colspan=1)
136 | plt.imshow(img_patch, interpolation='none')
137 | plt.title('patch images created with top1-NN')
138 |
139 | elif cfg_draw.mode_visualize == 'infer':
140 | plt.subplot(2, 2, 1)
141 | plt.imshow(img)
142 | plt.title('%s : %s' % (file.split('/')[-2], file.split('/')[-1]))
143 |
144 | plt.subplot(2, 2, 2)
145 | plt.imshow(score_map)
146 | plt.colorbar()
147 | plt.title('max score : %.2f' % score_max)
148 |
149 | plt.subplot(2, 2, 3)
150 | plt.imshow(overlay_heatmap_on_image(img, score_map_reg))
151 |
152 | plt.subplot(2, 2, 4)
153 | plt.imshow(img_patch, interpolation='none')
154 | plt.title('patch images created with top1-NN')
155 |
156 | score_tmp = np.max(score_map) / score_max * 100
157 | plt.gcf().canvas.draw()
158 | img_figure = np.fromstring(plt.gcf().canvas.tostring_rgb(), dtype='uint8')
159 | img_figure = img_figure.reshape(fig_height * dpi, -1, 3)
160 | img_figure = img_figure[pixel_cut[0]:(img_figure.shape[0] - pixel_cut[1]),
161 | pixel_cut[2]:(img_figure.shape[1] - pixel_cut[3])]
162 |
163 | if not cfg_draw.mode_video:
164 | filename_out = ('%s/%s/localization_%s_%s_%s_p%04d_k%02d_s%03d.png' %
165 | (cfg_draw.path_result, type_data,
166 | type_data, type_test, os.path.basename(file).split('.')[0],
167 | (cfg_draw.percentage_coreset * 1000), cfg_draw.k,
168 | round(score_tmp)))
169 | cv2.imwrite(filename_out, img_figure[..., ::-1])
170 | else:
171 | writer.write(img_figure[..., ::-1])
172 |
173 | plt.clf()
174 | plt.close()
175 |
176 | if cfg_draw.mode_video:
177 | writer.release()
178 |
179 |
180 | def assemble_patch(idx_patch, imgs_coreset, HW_map):
181 | size_receptive = imgs_coreset.shape[1]
182 | img_patch = np.zeros([(HW_map[0] * size_receptive),
183 | (HW_map[1] * size_receptive), 3], dtype=np.uint8)
184 |
185 | # reset counter
186 | i_y = 0
187 | i_x = 0
188 |
189 | # loop of patch feature index
190 | for i_patch in idx_patch:
191 | # tile...
192 | img_piece = imgs_coreset[i_patch]
193 | y = i_y * size_receptive
194 | x = i_x * size_receptive
195 | img_patch[y:(y + size_receptive), x:(x + size_receptive)] = img_piece
196 |
197 | # count-up
198 | i_x += 1
199 | if i_x == HW_map[1]:
200 | i_x = 0
201 | i_y += 1
202 |
203 | return img_patch
204 |
205 |
206 | def draw_curve(cfg_draw, x_img, y_img, auc_img, auc_img_mean,
207 | x_pix, y_pix, auc_pix, auc_pix_mean, flg_roc=True):
208 | if flg_roc:
209 | idx = 'ROC'
210 | lbl_x = 'False Positive Rate'
211 | lbl_y = 'True Positive Rate'
212 | else:
213 | idx = 'PR'
214 | lbl_x = 'Recall'
215 | lbl_y = 'Precision'
216 |
217 | plt.figure(figsize=(15, 6), dpi=100, facecolor='white')
218 | for type_data in x_img.keys():
219 | plt.subplot(1, 2, 1)
220 | plt.plot(x_img[type_data], y_img[type_data],
221 | label='%s %sAUC: %.3f' % (type_data, idx, auc_img[type_data]))
222 | plt.subplot(1, 2, 2)
223 | plt.plot(x_pix[type_data], y_pix[type_data],
224 | label='%s %sAUC: %.3f' % (type_data, idx, auc_pix[type_data]))
225 |
226 | plt.subplot(1, 2, 1)
227 | plt.title('imagewise %sAUC %% : mean %.3f' % (idx, auc_img_mean))
228 | plt.grid()
229 | plt.legend(loc='lower right')
230 | plt.xlabel(lbl_x)
231 | plt.ylabel(lbl_y)
232 | plt.subplot(1, 2, 2)
233 | plt.title('pixelwise %sAUC %% : mean %.3f' % (idx, auc_pix_mean))
234 | plt.grid()
235 | plt.legend(loc='lower right')
236 | plt.xlabel(lbl_x)
237 | plt.ylabel(lbl_y)
238 | plt.gcf().tight_layout()
239 | plt.gcf().savefig('%s/%s-curve_p%04d_k%02d_aucimg%04d_aucpix%04d.png' %
240 | (cfg_draw.path_result, idx.lower(),
241 | (cfg_draw.percentage_coreset * 1000), cfg_draw.k,
242 | round(auc_img_mean * 1000), round(auc_pix_mean * 1000)))
243 | plt.clf()
244 | plt.close()
245 |
--------------------------------------------------------------------------------