├── .gitignore ├── LICENSE ├── README.md ├── assets ├── localization_bottle_broken_large_000_p0100_k01_s067.png ├── localization_cable_bent_wire_000_p0100_k01_s073.png ├── localization_capsule_crack_000_p0100_k01_s020.png ├── localization_carpet_color_000_p0100_k01_s070.png ├── localization_grid_bent_000_p0100_k01_s048.png ├── localization_hazelnut_crack_000_p0100_k01_s047.png ├── localization_leather_color_000_p0100_k01_s058.png ├── localization_metal_nut_bent_000_p0100_k01_s082.png ├── localization_pill_color_000_p0100_k01_s036.png ├── localization_screw_manipulated_front_000_p0100_k01_s047.png ├── localization_tile_crack_000_p0100_k01_s059.png ├── localization_toothbrush_defective_000_p0100_k01_s079.png ├── localization_transistor_bent_lead_000_p0100_k01_s043.png ├── localization_wood_color_000_p0100_k01_s063.png ├── localization_zipper_broken_teeth_000_p0100_k01_s029.png ├── pred-dist_bottle_p0100_k01_r1000.png ├── pred-dist_cable_p0100_k01_r0999.png ├── pred-dist_capsule_p0100_k01_r0978.png ├── pred-dist_carpet_p0100_k01_r0987.png ├── pred-dist_grid_p0100_k01_r0985.png ├── pred-dist_hazelnut_p0100_k01_r1000.png ├── pred-dist_leather_p0100_k01_r1000.png ├── pred-dist_metal_nut_p0100_k01_r1000.png ├── pred-dist_pill_p0100_k01_r0962.png ├── pred-dist_screw_p0100_k01_r0985.png ├── pred-dist_tile_p0100_k01_r0994.png ├── pred-dist_toothbrush_p0100_k01_r0994.png ├── pred-dist_transistor_p0100_k01_r1000.png ├── pred-dist_wood_p0100_k01_r0990.png ├── pred-dist_zipper_p0100_k01_r0996.png ├── roc-curve_p0010_k01_rim0990_rpm0979.png ├── roc-curve_p0100_k01_rim0991_rpm0982.png ├── roc-curve_p0250_k01_rim0991_rpm0981.png └── total_recall.jpg ├── datasets ├── mvtec_dataset.py └── shared_memory.py ├── infer.py ├── models ├── feat_extract.py └── patchcore.py ├── requirements.txt ├── traintest.py └── utils ├── config.py ├── metrics.py ├── tictoc.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea 131 | result 132 | trained 133 | memo.txt 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PatchCore with some "ex" 2 | This is an unofficial implementation of the paper [Towards Total Recall in Industrial Anomaly Detection](https://arxiv.org/pdf/2106.08265.pdf). 3 | 4 | 5 | We measured accuracy and speed for percentage_coreset=0.01, 0.1 and 0.25. 6 | 7 | This code was implimented with [patchcore-inspection](https://github.com/amazon-science/patchcore-inspection), thanks. 8 | 9 | Some "ex" are **ex**plainability, **ex**press delivery, fl**ex**ibility, **ex**tra algolithm and so on. 10 | 11 |
12 | 13 | ## Prerequisites 14 | 15 | - faiss-gpu (easy to install with conda : [ref](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md)) 16 | - torch 17 | - torchvision 18 | - numpy 19 | - opencv-python 20 | - scipy 21 | - argparse 22 | - matplotlib 23 | - scikit-learn 24 | - torchinfo 25 | - tqdm 26 | 27 | 28 | Install prerequisites with: 29 | ``` 30 | conda install --file requirements.txt 31 | ``` 32 | 33 |
34 | 35 | Please download [`MVTec AD`](https://www.mvtec.com/company/research/datasets/mvtec-ad/) dataset. 36 | 37 | After downloading, place the data as follows: 38 | ``` 39 | ./ 40 | ├── main.py 41 | └── mvtec_anomaly_detection 42 | ├── bottle 43 | │   ├── ground_truth # pixel annotation of test data other than good 44 | │   │   ├── broken_large 45 | │   │   ├── broken_small 46 | │   │   └── contamination 47 | │   ├── test # other than good is anomaly data 48 | │   │   ├── broken_large 49 | │   │   ├── broken_small 50 | │   │   ├── contamination 51 | │   │   └── good 52 | │   └── train # good only 53 | │   └── good 54 | ├── cable 55 | ├── cable 56 | ├── capsule 57 | ├── carpet 58 | ├── grid 59 | ├── hazelnut 60 | ├── leather 61 | ├── metal_nut 62 | ├── pill 63 | ├── screw 64 | ├── tile 65 | ├── toothbrush 66 | ├── transistor 67 | ├── wood 68 | └── zipper 69 | ``` 70 | 71 |
72 | 73 | When using custom dataset, place the data as follows: 74 | ``` 75 | ./ 76 | ├── main.py 77 | └── custom_data 78 | └── theme 79 |    ├── ground_truth # pixel annotation of test data other than good 80 |    │   ├── anomaly_type_a 81 |    │   └── anomaly_type_b 82 |    ├── test # other than good is anomaly data 83 |    │   ├── anomaly_type_a 84 |    │   ├── anomaly_type_b 85 |    │   └── good 86 |    └── train # good only 87 |    └── good 88 | ``` 89 | 90 |
91 | 92 | ## Usage 93 | 94 | To test **PatchCore** on `MVTec AD` dataset: 95 | ``` 96 | python main.py 97 | ``` 98 | 99 | After running the code above, you can see the ROCAUC results in `result/roc_curve.png` 100 | 101 |
102 | 103 | To test **PatchCore** on custom dataset: 104 | ``` 105 | python main.py --path_data ./custom_data 106 | ``` 107 | 108 |
109 | 110 | ## Results 111 | 112 | Below is the implementation result of the test set ROCAUC on the `MVTec AD` dataset. 113 | 114 | ### 1. Image-level anomaly detection accuracy (ROCAUC %) 115 | 116 | | | Paper
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.01 | This Repo
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.25 | 117 | | - | - | - | - | - | 118 | | bottle | 100.0 | 100.0 | 100.0 | 100.0 | 119 | | cable | 99.4 | 99.6 | 99.9 | 99.6 | 120 | | capsule | 97.8 | 97.6 | 97.8 | 97.8 | 121 | | carpet | 98.7 | 98.3 | 98.7 | 98.5 | 122 | | grid | 97.9 | 97.8 | 98.5 | 98.4 | 123 | | hazelnut | 100.0 | 100.0 | 100.0 | 100.0 | 124 | | leather | 100.0 | 100.0 | 100.0 | 100.0 | 125 | | metal_nut | 100.0 | 100.0 | 100.0 | 99.9 | 126 | | pill | 96.0 | 96.3 | 96.2 | 96.4 | 127 | | screw | 97.0 | 97.9 | 98.5 | 98.1 | 128 | | tile | 98.9 | 99.0 | 99.4 | 99.2 | 129 | | toothbrush | 99.7 | 99.4 | 99.4 | 100.0 | 130 | | transistor | 100.0 | 99.8 | 100.0 | 100.0 | 131 | | wood | 99.0 | 99.1 | 99.0 | 98.9 | 132 | | zipper | 99.5 | 99.7 | 99.6 | 99.5 | 133 | | Average | 99.0 | 99.0 | 99.1 | 99.1 | 134 | 135 |
136 | 137 | ### 2. Pixel-level anomaly detection accuracy (ROCAUC %) 138 | 139 | | | Paper
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.01 | This Repo
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.25 | 140 | | - | - | - | - | - | 141 | | bottle | 98.6 | 98.5 | 98.6 | 98.6 | 142 | | cable | 98.5 | 98.2 | 98.4 | 98.4 | 143 | | capsule | 98.9 | 98.8 | 98.9 | 98.9 | 144 | | carpet | 99.1 | 99.0 | 99.1 | 99.0 | 145 | | grid | 98.7 | 98.2 | 98.7 | 98.7 | 146 | | hazelnut | 98.7 | 98.6 | 98.7 | 98.6 | 147 | | leather | 99.3 | 99.3 | 99.3 | 99.3 | 148 | | metal_nut | 98.4 | 98.5 | 98.7 | 98.7 | 149 | | pill | 97.6 | 97.4 | 97.6 | 97.3 | 150 | | screw | 99.4 | 98.8 | 99.4 | 99.4 | 151 | | tile | 95.9 | 96.2 | 96.1 | 95.9 | 152 | | toothbrush | 98.7 | 98.6 | 98.7 | 98.7 | 153 | | transistor | 96.4 | 94.2 | 96.0 | 95.8 | 154 | | wood | 95.1 | 95.7 | 95.5 | 95.4 | 155 | | zipper | 98.9 | 98.9 | 98.9 | 98.9 | 156 | | Average | 98.1 | 97.9 | 98.2 | 98.1 | 157 | 158 |
159 | 160 | ### 3. Processing time (sec) 161 | 162 | | | This Repo
$\\%_{core}$=0.01 | This Repo
$\\%_{core}$=0.1 | This Repo
$\\%_{core}$=0.25 | 163 | | - | - | - | - | 164 | | bottle | 14.2 | 21.3 | 33.0 | 165 | | cable | 22.9 | 30.3 | 44.6 | 166 | | capsule | 20.4 | 27.3 | 39.8 | 167 | | carpet | 22.4 | 32.6 | 52.7 | 168 | | grid | 16.3 | 26.3 | 44.3 | 169 | | hazelnut | 25.9 | 46.4 | 82.1 | 170 | | leather | 20.8 | 28.5 | 44.0 | 171 | | metal_nut | 16.8 | 23.8 | 36.3 | 172 | | pill | 23.3 | 34.2 | 52.6 | 173 | | screw | 22.4 | 37.7 | 63.9 | 174 | | tile | 18.4 | 26.2 | 40.7 | 175 | | toothbrush | 5.8 | 6.7 | 8.0 | 176 | | transistor | 17.2 | 23.9 | 35.9 | 177 | | wood | 18.0 | 26.5 | 42.1 | 178 | | zipper | 20.5 | 29.0 | 44.2 | 179 | 180 | ``` 181 | CPU : Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz 182 | GPU : Tesla V100 SXM2 183 | ``` 184 | 185 |
186 | 187 | ### ROC Curve 188 | 189 | - percentage_coreset = 0.01 190 | ![roc](./assets/roc-curve_p0010_k01_rim0990_rpm0979.png) 191 | 192 |
193 | 194 | - percentage_coreset = 0.1 195 | ![roc](./assets/roc-curve_p0100_k01_rim0991_rpm0982.png) 196 | 197 |
198 | 199 | - percentage_coreset = 0.25 200 | ![roc](./assets/roc-curve_p0250_k01_rim0991_rpm0981.png) 201 | 202 |
203 | 204 | ### Prediction Distribution (percentage_coreset = 0.1) 205 | 206 | - bottle 207 | ![bottle](./assets/pred-dist_bottle_p0100_k01_r1000.png) 208 | 209 | - cable 210 | ![cable](./assets/pred-dist_cable_p0100_k01_r0999.png) 211 | 212 | - capsule 213 | ![capsule](./assets/pred-dist_capsule_p0100_k01_r0978.png) 214 | 215 | - carpet 216 | ![carpet](./assets/pred-dist_carpet_p0100_k01_r0987.png) 217 | 218 | - grid 219 | ![grid](./assets/pred-dist_grid_p0100_k01_r0985.png) 220 | 221 | - hazelnut 222 | ![hazelnut](./assets/pred-dist_hazelnut_p0100_k01_r1000.png) 223 | 224 | - leather 225 | ![leather](./assets/pred-dist_leather_p0100_k01_r1000.png) 226 | 227 | - metal_nut 228 | ![metal_nut](./assets/pred-dist_metal_nut_p0100_k01_r1000.png) 229 | 230 | - pill 231 | ![pill](./assets/pred-dist_pill_p0100_k01_r0962.png) 232 | 233 | - screw 234 | ![screw](./assets/pred-dist_screw_p0100_k01_r0985.png) 235 | 236 | - tile 237 | ![tile](./assets/pred-dist_tile_p0100_k01_r0994.png) 238 | 239 | - toothbrush 240 | ![toothbrush](./assets/pred-dist_toothbrush_p0100_k01_r0994.png) 241 | 242 | - transistor 243 | ![transistor](./assets/pred-dist_transistor_p0100_k01_r1000.png) 244 | 245 | - wood 246 | ![wood](./assets/pred-dist_wood_p0100_k01_r0990.png) 247 | 248 | - zipper 249 | ![zipper](./assets/pred-dist_zipper_p0100_k01_r0996.png) 250 | 251 |
252 | 253 | ### Localization : percentage_coreset = 0.1 254 | 255 | - bottle (test case : broken_large) 256 | ![bottle](./assets/localization_bottle_broken_large_000_p0100_k01_s067.png) 257 | 258 | - cable (test case : bent_wire) 259 | ![cable](./assets/localization_cable_bent_wire_000_p0100_k01_s073.png) 260 | 261 | - capsule (test case : crack) 262 | ![capsule](./assets/localization_capsule_crack_000_p0100_k01_s020.png) 263 | 264 | - carpet (test case : color) 265 | ![carpet](./assets/localization_carpet_color_000_p0100_k01_s070.png) 266 | 267 | - grid (test case : bent) 268 | ![grid](./assets/localization_grid_bent_000_p0100_k01_s048.png) 269 | 270 | - hazelnut (test case : crack) 271 | ![hazelnut](./assets/localization_hazelnut_crack_000_p0100_k01_s047.png) 272 | 273 | - leather (test case : color) 274 | ![leather](./assets/localization_leather_color_000_p0100_k01_s058.png) 275 | 276 | - metal_nut (test case : bent) 277 | ![metal_nut](./assets/localization_metal_nut_bent_000_p0100_k01_s082.png) 278 | 279 | - pill (test case : color) 280 | ![pill](./assets/localization_pill_color_000_p0100_k01_s036.png) 281 | 282 | - screw (test case : manipulated_front) 283 | ![screw](./assets/localization_screw_manipulated_front_000_p0100_k01_s047.png) 284 | 285 | - tile (test case : crack) 286 | ![tile](./assets/localization_tile_crack_000_p0100_k01_s059.png) 287 | 288 | - toothbrush (test case : defective) 289 | ![toothbrush](./assets/localization_toothbrush_defective_000_p0100_k01_s079.png) 290 | 291 | - transistor (test case : bent_lead) 292 | ![transistor](./assets/localization_transistor_bent_lead_000_p0100_k01_s043.png) 293 | 294 | - wood (test case : color) 295 | ![wood](./assets/localization_wood_color_000_p0100_k01_s063.png) 296 | 297 | - zipper (test case : broken_teeth) 298 | ![zipper](./assets/localization_zipper_broken_teeth_000_p0100_k01_s029.png) 299 | 300 |
301 | 302 | ### For your infomation 303 | 304 | We also implement a similar algorithm, SPADE.
305 | https://github.com/any-tech/SPADE-fast/tree/main 306 | 307 | There is an explanatory article.
308 | https://tech.anytech.co.jp/entry/2023/03/24/100000 309 | -------------------------------------------------------------------------------- /assets/localization_bottle_broken_large_000_p0100_k01_s067.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_bottle_broken_large_000_p0100_k01_s067.png -------------------------------------------------------------------------------- /assets/localization_cable_bent_wire_000_p0100_k01_s073.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_cable_bent_wire_000_p0100_k01_s073.png -------------------------------------------------------------------------------- /assets/localization_capsule_crack_000_p0100_k01_s020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_capsule_crack_000_p0100_k01_s020.png -------------------------------------------------------------------------------- /assets/localization_carpet_color_000_p0100_k01_s070.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_carpet_color_000_p0100_k01_s070.png -------------------------------------------------------------------------------- /assets/localization_grid_bent_000_p0100_k01_s048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_grid_bent_000_p0100_k01_s048.png -------------------------------------------------------------------------------- /assets/localization_hazelnut_crack_000_p0100_k01_s047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_hazelnut_crack_000_p0100_k01_s047.png -------------------------------------------------------------------------------- /assets/localization_leather_color_000_p0100_k01_s058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_leather_color_000_p0100_k01_s058.png -------------------------------------------------------------------------------- /assets/localization_metal_nut_bent_000_p0100_k01_s082.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_metal_nut_bent_000_p0100_k01_s082.png -------------------------------------------------------------------------------- /assets/localization_pill_color_000_p0100_k01_s036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_pill_color_000_p0100_k01_s036.png -------------------------------------------------------------------------------- /assets/localization_screw_manipulated_front_000_p0100_k01_s047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_screw_manipulated_front_000_p0100_k01_s047.png -------------------------------------------------------------------------------- /assets/localization_tile_crack_000_p0100_k01_s059.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_tile_crack_000_p0100_k01_s059.png -------------------------------------------------------------------------------- /assets/localization_toothbrush_defective_000_p0100_k01_s079.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_toothbrush_defective_000_p0100_k01_s079.png -------------------------------------------------------------------------------- /assets/localization_transistor_bent_lead_000_p0100_k01_s043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_transistor_bent_lead_000_p0100_k01_s043.png -------------------------------------------------------------------------------- /assets/localization_wood_color_000_p0100_k01_s063.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_wood_color_000_p0100_k01_s063.png -------------------------------------------------------------------------------- /assets/localization_zipper_broken_teeth_000_p0100_k01_s029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/localization_zipper_broken_teeth_000_p0100_k01_s029.png -------------------------------------------------------------------------------- /assets/pred-dist_bottle_p0100_k01_r1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_bottle_p0100_k01_r1000.png -------------------------------------------------------------------------------- /assets/pred-dist_cable_p0100_k01_r0999.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_cable_p0100_k01_r0999.png -------------------------------------------------------------------------------- /assets/pred-dist_capsule_p0100_k01_r0978.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_capsule_p0100_k01_r0978.png -------------------------------------------------------------------------------- /assets/pred-dist_carpet_p0100_k01_r0987.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_carpet_p0100_k01_r0987.png -------------------------------------------------------------------------------- /assets/pred-dist_grid_p0100_k01_r0985.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_grid_p0100_k01_r0985.png -------------------------------------------------------------------------------- /assets/pred-dist_hazelnut_p0100_k01_r1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_hazelnut_p0100_k01_r1000.png -------------------------------------------------------------------------------- /assets/pred-dist_leather_p0100_k01_r1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_leather_p0100_k01_r1000.png -------------------------------------------------------------------------------- /assets/pred-dist_metal_nut_p0100_k01_r1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_metal_nut_p0100_k01_r1000.png -------------------------------------------------------------------------------- /assets/pred-dist_pill_p0100_k01_r0962.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_pill_p0100_k01_r0962.png -------------------------------------------------------------------------------- /assets/pred-dist_screw_p0100_k01_r0985.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_screw_p0100_k01_r0985.png -------------------------------------------------------------------------------- /assets/pred-dist_tile_p0100_k01_r0994.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_tile_p0100_k01_r0994.png -------------------------------------------------------------------------------- /assets/pred-dist_toothbrush_p0100_k01_r0994.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_toothbrush_p0100_k01_r0994.png -------------------------------------------------------------------------------- /assets/pred-dist_transistor_p0100_k01_r1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_transistor_p0100_k01_r1000.png -------------------------------------------------------------------------------- /assets/pred-dist_wood_p0100_k01_r0990.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_wood_p0100_k01_r0990.png -------------------------------------------------------------------------------- /assets/pred-dist_zipper_p0100_k01_r0996.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/pred-dist_zipper_p0100_k01_r0996.png -------------------------------------------------------------------------------- /assets/roc-curve_p0010_k01_rim0990_rpm0979.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/roc-curve_p0010_k01_rim0990_rpm0979.png -------------------------------------------------------------------------------- /assets/roc-curve_p0100_k01_rim0991_rpm0982.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/roc-curve_p0100_k01_rim0991_rpm0982.png -------------------------------------------------------------------------------- /assets/roc-curve_p0250_k01_rim0991_rpm0981.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/roc-curve_p0250_k01_rim0991_rpm0981.png -------------------------------------------------------------------------------- /assets/total_recall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/any-tech/PatchCore-ex/15a46f13f38a6ddd1e615c8bc7a19433c03fee1f/assets/total_recall.jpg -------------------------------------------------------------------------------- /datasets/mvtec_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | from tqdm import tqdm 5 | from utils.config import ConfigData 6 | from datasets.shared_memory import SharedMemory 7 | 8 | 9 | class MVTecDataset: 10 | @classmethod 11 | def __init__(cls, type_data): 12 | cls.type_data = type_data 13 | cls.SHAPE_INPUT = ConfigData.SHAPE_INPUT 14 | 15 | cls.imgs_train = None 16 | cls.imgs_test = {} 17 | cls.gts_test = {} 18 | 19 | # read train data 20 | desc = 'read images for train (case:good)' 21 | path = '%s/%s/train/good' % (ConfigData.path_data, type_data) 22 | files = ['%s/%s' % (path, f) for f in os.listdir(path) 23 | if (os.path.isfile('%s/%s' % (path, f)) & (('.png' in f) | ('.jpg' in f)))] 24 | cls.files_train = np.sort(np.array(files)) 25 | cls.imgs_train = SharedMemory.read_img_parallel(files=cls.files_train, 26 | imgs=cls.imgs_train, 27 | desc=desc) 28 | if ConfigData.shuffle: 29 | cls.imgs_train = np.random.permutation(cls.imgs_train) 30 | if ConfigData.flip_horz: 31 | cls.imgs_train = np.concatenate([cls.imgs_train, 32 | cls.imgs_train[:, :, ::-1]], axis=0) 33 | if ConfigData.flip_vert: 34 | cls.imgs_train = np.concatenate([cls.imgs_train, 35 | cls.imgs_train[:, ::-1]], axis=0) 36 | 37 | # read test data 38 | cls.files_test = {} 39 | cls.types_test = os.listdir('%s/%s/test' % (ConfigData.path_data, type_data)) 40 | cls.types_test = np.array(sorted(cls.types_test)) 41 | for type_test in cls.types_test: 42 | desc = 'read images for test (case:%s)' % type_test 43 | path = '%s/%s/test/%s' % (ConfigData.path_data, type_data, type_test) 44 | files = [('%s/%s' % (path, f)) for f in os.listdir(path) 45 | if (os.path.isfile('%s/%s' % (path, f)) & (('.png' in f) | ('.jpg' in f)))] 46 | cls.files_test[type_test] = np.sort(np.array(files)) 47 | cls.imgs_test[type_test] = None 48 | cls.imgs_test[type_test] = SharedMemory.read_img_parallel(files=cls.files_test[type_test], 49 | imgs=cls.imgs_test[type_test], 50 | desc=desc) 51 | 52 | # read ground truth of test data 53 | for type_test in cls.types_test: 54 | # create memory shared variable 55 | if type_test == 'good': 56 | cls.gts_test[type_test] = np.zeros([len(cls.files_test[type_test]), 57 | ConfigData.SHAPE_INPUT[0], 58 | ConfigData.SHAPE_INPUT[1]], dtype=np.uint8) 59 | else: 60 | desc = 'read ground-truths for test (case:%s)' % type_test 61 | cls.gts_test[type_test] = None 62 | cls.gts_test[type_test] = SharedMemory.read_img_parallel(files=cls.files_test[type_test], 63 | imgs=cls.gts_test[type_test], 64 | is_gt=True, 65 | desc=desc) 66 | 67 | 68 | class MVTecDatasetOnlyTest: 69 | @classmethod 70 | def __init__(cls, type_data): 71 | if ConfigData.path_data is not None: 72 | cls.type_data = type_data 73 | cls.SHAPE_INPUT = ConfigData.SHAPE_INPUT 74 | 75 | # read test data 76 | cls.imgs_test = {} 77 | cls.files_test = {} 78 | cls.types_test = os.listdir('%s/%s/test' % (ConfigData.path_data, type_data)) 79 | cls.types_test = np.array(sorted(cls.types_test)) 80 | for type_test in cls.types_test: 81 | desc = 'read images for test (case:%s)' % type_test 82 | path = '%s/%s/test/%s' % (ConfigData.path_data, type_data, type_test) 83 | files = [('%s/%s' % (path, f)) for f in os.listdir(path) 84 | if (os.path.isfile('%s/%s' % (path, f)) & (('.png' in f) | ('.jpg' in f)))] 85 | cls.files_test[type_test] = np.sort(np.array(files)) 86 | cls.imgs_test[type_test] = None 87 | cls.imgs_test[type_test] = SharedMemory.read_img_parallel(files=cls.files_test[type_test], 88 | imgs=cls.imgs_test[type_test], 89 | desc=desc) 90 | else: 91 | cls.type_data = type_data 92 | cls.SHAPE_INPUT = ConfigData.SHAPE_INPUT 93 | 94 | # read test data 95 | cls.imgs_test = {} 96 | cls.files_test = {} 97 | cls.types_test = np.array(['video']) 98 | type_test = cls.types_test[0] 99 | cls.imgs_test[type_test] = [] 100 | cls.files_test[type_test] = [] 101 | capture = cv2.VideoCapture(ConfigData.path_video) 102 | num_frame = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) 103 | desc = 'read images from video for test' 104 | for i_frame in tqdm(range(num_frame), desc=desc): 105 | # frame read 106 | capture.grab() 107 | ret, frame = capture.retrieve() 108 | if ret: 109 | frame = frame[..., ::-1] # BGR2RGB 110 | frame = cv2.resize(frame, (ConfigData.SHAPE_MIDDLE[1], 111 | ConfigData.SHAPE_MIDDLE[0]), 112 | interpolation=cv2.INTER_AREA) 113 | frame = frame[ConfigData.pixel_cut[0]:(ConfigData.SHAPE_INPUT[0] + 114 | ConfigData.pixel_cut[0]), 115 | ConfigData.pixel_cut[1]:(ConfigData.SHAPE_INPUT[1] + 116 | ConfigData.pixel_cut[1])] 117 | cls.imgs_test[type_test].append(frame) 118 | cls.files_test[type_test].append('%s/frame/%05d' % 119 | (os.path.basename(ConfigData.path_video), 120 | i_frame)) 121 | cls.imgs_test[type_test] = np.array(cls.imgs_test[type_test]) 122 | cls.files_test[type_test] = np.array(cls.files_test[type_test]) 123 | -------------------------------------------------------------------------------- /datasets/shared_memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from tqdm import tqdm 4 | import multiprocessing as mp 5 | from multiprocessing.sharedctypes import RawArray 6 | from utils.config import ConfigData 7 | 8 | 9 | def read_and_resize(file): 10 | img = cv2.imread(file)[..., ::-1] # BGR2RGB 11 | img = cv2.resize(img, (ConfigData.SHAPE_MIDDLE[1], ConfigData.SHAPE_MIDDLE[0]), 12 | interpolation=cv2.INTER_AREA) 13 | img = img[ConfigData.pixel_cut[0]:(ConfigData.SHAPE_INPUT[0] + ConfigData.pixel_cut[0]), 14 | ConfigData.pixel_cut[1]:(ConfigData.SHAPE_INPUT[1] + ConfigData.pixel_cut[1])] 15 | SharedMemory.shared_array[np.where(SharedMemory.files == file)[0]] = img 16 | 17 | 18 | def read_and_resize_ground_truth(file): 19 | file_gt = file.replace('/test/', '/ground_truth/') 20 | file_gt = file_gt.replace('.png', '_mask.png') 21 | 22 | gt = cv2.imread(file_gt, cv2.IMREAD_GRAYSCALE) 23 | gt = cv2.resize(gt, (ConfigData.SHAPE_MIDDLE[1], ConfigData.SHAPE_MIDDLE[0]), 24 | interpolation=cv2.INTER_NEAREST) 25 | gt = gt[ConfigData.pixel_cut[0]:(ConfigData.SHAPE_INPUT[0] + ConfigData.pixel_cut[0]), 26 | ConfigData.pixel_cut[1]:(ConfigData.SHAPE_INPUT[1] + ConfigData.pixel_cut[1])] 27 | 28 | if np.max(gt) != 0: 29 | gt = (gt / np.max(gt)).astype(np.uint8) 30 | 31 | SharedMemory.shared_array[np.where(SharedMemory.files == file)[0]] = gt 32 | 33 | 34 | class SharedMemory: 35 | @classmethod 36 | def read_img_parallel(cls, files, imgs, is_gt=False, desc='read images'): 37 | cls.files = files 38 | cls.shared_array = imgs 39 | 40 | if not is_gt: 41 | shape = (len(cls.files), ConfigData.SHAPE_INPUT[0], ConfigData.SHAPE_INPUT[1], 3) 42 | num_elm = shape[0] * shape[1] * shape[2] * shape[3] 43 | else: 44 | shape = (len(cls.files), ConfigData.SHAPE_INPUT[0], ConfigData.SHAPE_INPUT[1]) 45 | num_elm = shape[0] * shape[1] * shape[2] 46 | 47 | ctype = np.ctypeslib.as_ctypes_type(np.dtype(np.uint8)) 48 | data = np.ctypeslib.as_array(RawArray(ctype, num_elm)) 49 | data.shape = shape 50 | cls.shared_array = data.view(np.uint8) 51 | 52 | # exec imread and imresize on multiprocess 53 | mp.set_start_method('fork', force=True) 54 | p = mp.Pool(min(mp.cpu_count(), ConfigData.num_cpu_max)) 55 | 56 | func = read_and_resize 57 | if is_gt: 58 | func = read_and_resize_ground_truth 59 | 60 | for _ in tqdm(p.imap_unordered(func, cls.files), total=len(cls.files), desc=desc): 61 | pass 62 | 63 | p.close() 64 | 65 | return cls.shared_array 66 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from argparse import ArgumentParser 4 | import json 5 | 6 | from utils.config import ConfigData, ConfigFeat, ConfigPatchCore, ConfigDraw 7 | from utils.tictoc import tic, toc 8 | from utils.visualize import draw_heatmap 9 | 10 | from datasets.mvtec_dataset import MVTecDatasetOnlyTest 11 | 12 | from models.feat_extract import FeatExtract 13 | from models.patchcore import PatchCore 14 | 15 | 16 | def arg_parser(): 17 | parser = ArgumentParser() 18 | # environment related 19 | parser.add_argument('-n', '--num_cpu_max', default=4, type=int, 20 | help='number of CPUs for parallel reading input images') 21 | parser.add_argument('-c', '--cpu', action='store_true', help='use cpu') 22 | 23 | # I/O and visualization related 24 | parser.add_argument('-pd', '--path_data', type=str, default='./mvtec_anomaly_detection', 25 | help='parent path of input data path') 26 | parser.add_argument('-pv', '--path_video', type=str, default=None, 27 | help='path of input video path (.mp4 or .avi or .mov)') 28 | parser.add_argument('-pt', '--path_trained', type=str, default='./trained', 29 | help='output path of trained products') 30 | parser.add_argument('-pr', '--path_result', type=str, default='./result', 31 | help='output path of figure image as the evaluation result') 32 | parser.add_argument('-v', '--verbose', action='store_true', 33 | help='save visualization of localization') 34 | parser.add_argument('-sm', '--score_max', type=float, default=None, 35 | help='value for normalization of visualizing') 36 | 37 | # data loader related 38 | parser.add_argument('-bs', '--batch_size', type=int, default=16, 39 | help='batch-size for feature extraction by ImageNet model') 40 | parser.add_argument('-tt', '--types_data', nargs='*', type=str, default=None) 41 | 42 | # Nearest-Neighbor related 43 | parser.add_argument('-k', '--k', type=int, default=5, 44 | help='nearest neighbor\'s k for coreset searching') 45 | # post precessing related 46 | parser.add_argument('-pod', '--pixel_outer_decay', type=int, default=0, 47 | help='number of outer pixels to decay anomaly score') 48 | 49 | args = parser.parse_args() 50 | 51 | # adjust... 52 | if args.path_video is not None: 53 | args.path_data = None 54 | 55 | print('args =\n', args) 56 | return args 57 | 58 | 59 | def check_args(args): 60 | assert 0 < args.num_cpu_max < os.cpu_count() 61 | if args.path_video is None: 62 | assert args.path_data is not None 63 | assert os.path.isdir(args.path_data) 64 | else: 65 | assert os.path.isfile(args.path_video) 66 | assert ((args.path_video.split('.')[-1].lower() == 'mp4') | 67 | (args.path_video.split('.')[-1].lower() == 'avi') | 68 | (args.path_video.split('.')[-1].lower() == 'mov')) 69 | assert len(args.types_data) == 1 70 | if args.score_max is not None: 71 | assert args.score_max > 0 72 | assert args.batch_size > 0 73 | if args.types_data is not None: 74 | if args.path_video is None: 75 | for type_data in args.types_data: 76 | assert os.path.exists('%s/%s' % (args.path_data, type_data)) 77 | assert args.k > 0 78 | assert args.pixel_outer_decay >= 0 79 | 80 | 81 | def summary_result(type_data, D, files_test, thresh): 82 | result = ['data-type filename anomaly-score threshold abnormal-judgement'] 83 | for type_test in D.keys(): 84 | for i_file in range(len(D[type_test])): 85 | D_max = np.max(D[type_test][i_file]) 86 | flg_abnormal = int(thresh <= D_max) 87 | 88 | result.append('%s %s %.3f %.3f %d' % 89 | (type_data, files_test[type_test][i_file], 90 | D_max, thresh, flg_abnormal)) 91 | 92 | filename_txt = '%s/%s/%s_result.txt' % (args.path_result, type_data, type_data) 93 | np.savetxt(filename_txt, result, fmt='%s') 94 | 95 | 96 | def apply_patchcore_inference(args, type_data, feat_ext, patchcore, cfg_draw): 97 | print('\n----> inference-only PatchCore processing in %s start' % type_data) 98 | tic() 99 | 100 | # read images 101 | MVTecDatasetOnlyTest(type_data) 102 | 103 | # load neighbor 104 | patchcore.reset_faiss_index() 105 | patchcore.load_faiss_index(type_data) 106 | 107 | # extract features 108 | feat_test = {} 109 | for type_test in MVTecDatasetOnlyTest.imgs_test.keys(): 110 | feat_test[type_test] = feat_ext.extract(MVTecDatasetOnlyTest.imgs_test[type_test], 111 | case='test (case:%s)' % type_test) 112 | 113 | # Sub-Image Anomaly Detection with Deep Pyramid Correspondences 114 | D, D_max, I = patchcore.localization(feat_test, feat_ext.HW_map()) 115 | 116 | toc(tag=('----> inference-only PatchCore processing in %s end, elapsed time' % type_data)) 117 | 118 | if args.verbose: 119 | imgs_coreset = patchcore.load_coreset(type_data) 120 | 121 | draw_heatmap(type_data, cfg_draw, D, None, D_max, I, 122 | MVTecDatasetOnlyTest.imgs_test, MVTecDatasetOnlyTest.files_test, 123 | imgs_coreset, feat_ext.HW_map()) 124 | 125 | # load optimal threshold 126 | thresh = np.loadtxt('%s/%s_thr.txt' % (args.path_trained, type_data)) 127 | 128 | # summary test result 129 | summary_result(type_data, D, MVTecDatasetOnlyTest.files_test, thresh) 130 | 131 | 132 | def main(args): 133 | ConfigData(args, mode_train=False) # static define for speed-up 134 | cfg_feat = ConfigFeat(args, mode_train=False) 135 | cfg_patchcore = ConfigPatchCore(args, mode_train=False) 136 | cfg_draw = ConfigDraw(args, mode_train=False) 137 | 138 | with open('%s/args.json' % args.path_trained, mode='r') as f: 139 | args_trained = json.load(f) 140 | ConfigData.follow(args_trained) 141 | cfg_feat.follow(args_trained) 142 | cfg_patchcore.follow(args_trained) 143 | cfg_draw.follow(args_trained) 144 | 145 | feat_ext = FeatExtract(cfg_feat) 146 | patchcore = PatchCore(cfg_patchcore) 147 | 148 | os.makedirs(args.path_result, exist_ok=True) 149 | for type_data in ConfigData.types_data: 150 | os.makedirs('%s/%s' % (args.path_result, type_data), exist_ok=True) 151 | 152 | # loop for types of data 153 | for type_data in ConfigData.types_data: 154 | apply_patchcore_inference(args, type_data, feat_ext, patchcore, cfg_draw) 155 | 156 | 157 | if __name__ == '__main__': 158 | args = arg_parser() 159 | main(args) 160 | -------------------------------------------------------------------------------- /models/feat_extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision 7 | from torchinfo import summary 8 | 9 | 10 | def print_weight_device(backbone): 11 | for name, module in backbone.named_modules(): 12 | if hasattr(module, 'weight') and getattr(module, 'weight') is not None: 13 | weight_device = module.weight.device 14 | weight_shape = module.weight.shape 15 | print(f"{name}.weight - Device: {weight_device}, Shape: {weight_shape}") 16 | 17 | if hasattr(module, 'bias') and getattr(module, 'bias') is not None: 18 | bias_device = module.bias.device 19 | print(f"{name}.bias - Device: {bias_device}") 20 | 21 | 22 | class FeatExtract: 23 | def __init__(self, cfg_feat): 24 | self.device = cfg_feat.device 25 | self.batch_size = cfg_feat.batch_size 26 | self.shape_input = cfg_feat.SHAPE_INPUT 27 | self.layer_map = cfg_feat.layer_map 28 | self.layer_weight = cfg_feat.layer_weight 29 | self.size_patch = cfg_feat.size_patch 30 | self.dim_each_feat = cfg_feat.dim_each_feat 31 | self.dim_merge_feat = cfg_feat.dim_merge_feat 32 | self.MEAN = cfg_feat.MEAN 33 | self.STD = cfg_feat.STD 34 | self.merge_dst_index = self.layer_map.index(cfg_feat.layer_merge_ref) 35 | 36 | self.padding = int((self.size_patch - 1) / 2) 37 | self.stride = 1 # fixed temporarily... 38 | 39 | code = 'self.backbone = %s(weights=%s)' % (cfg_feat.backbone, cfg_feat.weight) 40 | exec(code) 41 | summary(self.backbone, input_size=(1, 3, *self.shape_input)) 42 | self.backbone.eval() 43 | 44 | # Executing summary will force the backbone device to be changed to cuda, so do to(device) after summary 45 | self.backbone.to(self.device) 46 | 47 | self.feat = [] 48 | for layer_map in self.layer_map: 49 | code = 'self.backbone.%s.register_forward_hook(self.hook)' % layer_map 50 | exec(code) 51 | 52 | # dummy forward 53 | x = torch.zeros(1, 3, self.shape_input[0], self.shape_input[1]) # RGB 54 | x = x.to(self.device) 55 | self.feat = [] 56 | with torch.no_grad(): 57 | _ = self.backbone(x) 58 | 59 | # https://github.com/amazon-science/patchcore-inspection/blob/main/src/patchcore/patchcore.py#L295 60 | self.unfolder = torch.nn.Unfold(kernel_size=self.size_patch, stride=self.stride, 61 | padding=self.padding, dilation=1) 62 | self.patch_shapes = [] 63 | for i in range(len(self.feat)): 64 | number_of_total_patches = [] 65 | for s in self.feat[i].shape[-2:]: 66 | n_patches = ( 67 | s + 2 * self.padding - 1 * (self.size_patch - 1) - 1 68 | ) / self.stride + 1 69 | number_of_total_patches.append(int(n_patches)) 70 | self.patch_shapes.append(number_of_total_patches) 71 | 72 | def hook(self, module, input, output): 73 | self.feat.append(output.detach().cpu()) 74 | 75 | def HW_map(self): 76 | return self.patch_shapes[self.merge_dst_index] 77 | 78 | def normalize(self, input): 79 | x = torch.from_numpy(input.astype(np.float32)) 80 | x = x.to(self.device) 81 | x = x / 255 82 | x = x - self.MEAN 83 | x = x / self.STD 84 | x = x.unsqueeze(0).permute(0, 3, 1, 2) 85 | return x 86 | 87 | # return : is_train=True->torch.Tensor, is_train=False->dict 88 | def extract(self, imgs, case='train (case:good)', batch_size_patchfy=50, show_progress=True): 89 | # feature extract for train and aggregate split-image for explain 90 | x_batch = [] 91 | self.feat = [] 92 | for i_img in tqdm(range(len(imgs)), desc='extract feature for %s' % case, disable=not show_progress): 93 | img = imgs[i_img] 94 | x = self.normalize(img) 95 | x_batch.append(x) 96 | 97 | if (len(x_batch) == self.batch_size) | (i_img == (len(imgs) - 1)): 98 | with torch.no_grad(): 99 | _ = self.backbone(torch.vstack(x_batch)) 100 | x_batch = [] 101 | 102 | # adjust 103 | feat = [] 104 | num_layer = len(self.layer_map) 105 | for i_layer_map in range(num_layer): 106 | feat.append(torch.vstack(self.feat[i_layer_map::num_layer])) 107 | 108 | # patchfy (consider out of memory) 109 | num_patch_per_image = self.HW_map()[0] * self.HW_map()[1] 110 | num_patch = len(imgs) * num_patch_per_image 111 | feat_patchfy = torch.zeros(num_patch, self.dim_merge_feat) 112 | 113 | num_patchfy_process = (len(feat) * 3) + 1 114 | num_iter = np.ceil(len(imgs) / batch_size_patchfy) 115 | pbar = tqdm(total=int(num_patchfy_process * num_iter), desc='patchfy feature', disable=not show_progress) 116 | for i_batch in range(0, len(imgs), batch_size_patchfy): 117 | feat_tmp = [] 118 | for feat_layer in feat: 119 | feat_tmp.append(feat_layer[i_batch:(i_batch + batch_size_patchfy)]) 120 | i_from = i_batch * num_patch_per_image 121 | i_to = (i_batch + batch_size_patchfy) * num_patch_per_image 122 | feat_patchfy[i_from:i_to] = self.patchfy(feat_tmp, pbar) 123 | pbar.close() 124 | 125 | return feat_patchfy 126 | 127 | def patchfy(self, feat, pbar, batch_size_interp=2000): 128 | 129 | with torch.no_grad(): 130 | # unfold 131 | for i in range(len(feat)): 132 | _feat = feat[i] 133 | BC_before_unfold = _feat.shape[:2] 134 | # (B, C, H, W) -> (B, CPHPW, HW) 135 | _feat = self.unfolder(_feat) 136 | # (B, CPHPW, HW) -> (B, C, PH, PW, HW) 137 | _feat = _feat.reshape(*BC_before_unfold, 138 | self.size_patch, self.size_patch, -1) 139 | # (B, C, PH, PW, HW) -> (B, HW, C, PW, HW) 140 | _feat = _feat.permute(0, 4, 1, 2, 3) 141 | feat[i] = _feat 142 | pbar.update(1) 143 | 144 | # expand small feat to fit large features 145 | for i in range(0, len(feat)): 146 | if i == self.merge_dst_index: 147 | continue 148 | 149 | _feat = feat[i] 150 | patch_dims = self.patch_shapes[i] 151 | # (B, HW, C, PW, HW) -> (B, H, W, C, PH, PW) 152 | _feat = _feat.reshape(_feat.shape[0], patch_dims[0], 153 | patch_dims[1], *_feat.shape[2:]) 154 | # (B, H, W, C, PH, PW) -> (B, C, PH, PW, H, W) 155 | _feat = _feat.permute(0, -3, -2, -1, 1, 2) 156 | perm_base_shape = _feat.shape 157 | # (B, C, PH, PW, H, W) -> (BCPHPW, H, W) 158 | _feat = _feat.reshape(-1, *_feat.shape[-2:]) 159 | # (BCPHPW, H, W) -> (BCPHPW, H_max, W_max) 160 | feat_dst = torch.zeros([len(_feat), self.HW_map()[0], self.HW_map()[1]]) 161 | for i_batch in range(0, len(_feat), batch_size_interp): 162 | feat_tmp = _feat[i_batch:(i_batch + batch_size_interp)] 163 | feat_tmp = feat_tmp.unsqueeze(1).to(self.device) 164 | feat_tmp = F.interpolate(feat_tmp, 165 | size=(self.HW_map()[0], self.HW_map()[1]), 166 | mode="bilinear", align_corners=False) 167 | feat_dst[i_batch:(i_batch + batch_size_interp)] = feat_tmp.squeeze(1).cpu() 168 | _feat = feat_dst 169 | # _feat = F.interpolate(_feat.unsqueeze(1), 170 | # size=(self.HW_map()[0], self.HW_map()[1]), 171 | # mode="bilinear", align_corners=False) 172 | # _feat = _feat.squeeze(1) 173 | # for i_batch in range(0, len(_feat), 10000): 174 | # print(torch.sum(torch.abs(_feat[i_batch:(i_batch + 10000)] - __feat[i_batch:(i_batch + 10000)]))) 175 | # (BCPHPW, H_max, W_max) -> (B, C, PH, PW, H_max, W_max) 176 | _feat = _feat.reshape(*perm_base_shape[:-2], 177 | self.HW_map()[0], self.HW_map()[1]) 178 | # (B, C, PH, PW, H_max, W_max) -> (B, H_max, W_max, C, PH, PW) 179 | _feat = _feat.permute(0, -2, -1, 1, 2, 3) 180 | # (B, H_max, W_max, C, PH, PW) -> (B, H_maxW_max, C, PH, PW) 181 | _feat = _feat.reshape(len(_feat), -1, *_feat.shape[-3:]) 182 | feat[i] = _feat 183 | pbar.update(1) 184 | 185 | # aggregate feature vectors 186 | # (B, H, W, C, PH, PW) -> (BHW, C, PH, PW) 187 | feat = [x.reshape(-1, *x.shape[-3:]) for x in feat] 188 | pbar.update(1) 189 | 190 | # adaptive average pooling for each feature vector 191 | for i in range(len(feat)): 192 | _feat = feat[i] * self.layer_weight[i] 193 | 194 | # (BHW, C, PH, PW) -> (BHW, 1, CPHPW) 195 | _feat = _feat.reshape(len(_feat), 1, -1) 196 | 197 | # (BHW, 1, CPHPW) -> (BHW, D_e) 198 | _feat = F.adaptive_avg_pool1d(_feat, self.dim_each_feat).squeeze(1) 199 | feat[i] = _feat 200 | pbar.update(1) 201 | 202 | # concat the two feature vectors and adaptive average pooling 203 | # (BHW, D_e) -> (BHW, D_e*2) 204 | feat = torch.stack(feat, dim=1) 205 | """Returns reshaped and average pooled feat.""" 206 | # batchsize x number_of_layers x input_dim -> batchsize x target_dim 207 | # (BHW, D_e*2) -> (BHW, D_m) 208 | feat = feat.reshape(len(feat), 1, -1) 209 | feat = F.adaptive_avg_pool1d(feat, self.dim_merge_feat) 210 | feat = feat.reshape(len(feat), -1) 211 | pbar.update(1) 212 | 213 | return feat 214 | -------------------------------------------------------------------------------- /models/patchcore.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import cv2 5 | import torch 6 | from tqdm import tqdm 7 | import faiss 8 | from scipy.ndimage import gaussian_filter 9 | 10 | 11 | class PatchCore: 12 | def __init__(self, cfg_patchcore): 13 | self.device = cfg_patchcore.device 14 | self.k = cfg_patchcore.k 15 | self.dim_coreset_feat = cfg_patchcore.dim_coreset_feat 16 | self.num_split_seq = cfg_patchcore.num_split_seq 17 | self.percentage_coreset = cfg_patchcore.percentage_coreset 18 | self.dim_sampling = cfg_patchcore.dim_sampling 19 | self.num_initial_coreset = cfg_patchcore.num_initial_coreset 20 | self.shape_stretch = cfg_patchcore.shape_stretch 21 | self.pixel_outer_decay = cfg_patchcore.pixel_outer_decay 22 | 23 | # prep knn idx 24 | if self.device.type != 'cuda': 25 | self.idx_feat = faiss.IndexFlatL2(self.dim_coreset_feat) 26 | else: 27 | self.idx_feat = faiss.GpuIndexFlatL2(faiss.StandardGpuResources(), 28 | self.dim_coreset_feat, 29 | faiss.GpuIndexFlatConfig()) 30 | 31 | if self.dim_sampling is not None: 32 | # prep mapper 33 | self.mapper = torch.nn.Linear(self.dim_coreset_feat, self.dim_sampling, 34 | bias=False).to(self.device) 35 | 36 | self.path_trained = cfg_patchcore.path_trained 37 | os.makedirs(self.path_trained, exist_ok=True) 38 | 39 | self.path_trained = cfg_patchcore.path_trained 40 | os.makedirs(self.path_trained, exist_ok=True) 41 | 42 | def compute_greedy_coreset_idx(self, feat): 43 | feat = feat.to(self.device) 44 | with torch.no_grad(): 45 | feat_proj = self.mapper(feat) 46 | 47 | _num_initial_coreset = np.clip(self.num_initial_coreset, 48 | None, len(feat_proj)) 49 | start_points = np.random.choice(len(feat_proj), _num_initial_coreset, 50 | replace=False).tolist() 51 | 52 | # computes batchwise Euclidean distances using PyTorch 53 | mat_A = feat_proj 54 | mat_B = feat_proj[start_points] 55 | A_x_A = mat_A.unsqueeze(1).bmm(mat_A.unsqueeze(2)).reshape(-1, 1) 56 | B_x_B = mat_B.unsqueeze(1).bmm(mat_B.unsqueeze(2)).reshape(1, -1) 57 | A_x_B = mat_A.mm(mat_B.T) 58 | # not need sqrt 59 | mat_dist = (-2 * A_x_B + A_x_A + B_x_B).clamp(0, None) 60 | 61 | dist_coreset_anchor = torch.mean(mat_dist, axis=-1, keepdims=True) 62 | 63 | idx_coreset = [] 64 | num_coreset_samples = int(len(feat_proj) * self.percentage_coreset) 65 | 66 | with torch.no_grad(): 67 | for _ in tqdm(range(num_coreset_samples), desc="sampling"): 68 | idx_select = torch.argmax(dist_coreset_anchor).item() 69 | idx_coreset.append(idx_select) 70 | 71 | mat_A = feat_proj 72 | mat_B = feat_proj[[idx_select]] 73 | # computes batchwise Euclidean distances using PyTorch 74 | A_x_A = mat_A.unsqueeze(1).bmm(mat_A.unsqueeze(2)).reshape(-1, 1) 75 | B_x_B = mat_B.unsqueeze(1).bmm(mat_B.unsqueeze(2)).reshape(1, -1) 76 | A_x_B = mat_A.mm(mat_B.T) 77 | # not need sqrt 78 | mat_select_dist = (-2 * A_x_B + A_x_A + B_x_B).clamp(0, None) 79 | 80 | dist_coreset_anchor = torch.cat([dist_coreset_anchor, 81 | mat_select_dist], dim=-1) 82 | dist_coreset_anchor = torch.min(dist_coreset_anchor, 83 | dim=1).values.reshape(-1, 1) 84 | 85 | idx_coreset = np.array(idx_coreset) 86 | return idx_coreset 87 | 88 | def reset_faiss_index(self): 89 | self.idx_feat.reset() 90 | 91 | def add_neighbor(self, feat_train): 92 | self.idx_feat.add(feat_train.numpy()) 93 | 94 | def localization(self, feat_test, HW_map, show_progress=True): 95 | D = {} 96 | D_max = -9999 97 | I = {} 98 | # loop for test cases 99 | for type_test in feat_test.keys(): 100 | D[type_test] = [] 101 | I[type_test] = [] 102 | 103 | # loop for test data 104 | _feat_test = feat_test[type_test] 105 | _feat_test = _feat_test.reshape(-1, (HW_map[0] * HW_map[1]), 106 | self.dim_coreset_feat) 107 | _feat_test = _feat_test.numpy() 108 | num_data = len(_feat_test) 109 | for i in tqdm(range(num_data), desc='localization (case:%s)' % type_test, 110 | disable=not show_progress): 111 | # measure distance pixelwise 112 | score_map, I_tmp = self.measure_dist_pixelwise(_feat_test[i], HW_map) 113 | # adjust score of outer-pixel (provisional heuristic algorithm) 114 | if self.pixel_outer_decay > 0: 115 | score_map[:self.pixel_outer_decay, :] *= 0.6 116 | score_map[-self.pixel_outer_decay:, :] *= 0.6 117 | score_map[:, :self.pixel_outer_decay] *= 0.6 118 | score_map[:, -self.pixel_outer_decay:] *= 0.6 119 | # stock score map 120 | D[type_test].append(score_map) 121 | D_max = max(D_max, np.max(score_map)) 122 | I[type_test].append(I_tmp) 123 | 124 | # cast list to numpy array 125 | D[type_test] = np.array(D[type_test]) 126 | I[type_test] = np.array(I[type_test]) 127 | 128 | return D, D_max, I 129 | 130 | def measure_dist_pixelwise(self, feat_test, HW_map): 131 | # k nearest neighbor 132 | D, I = self.idx_feat.search(feat_test, self.k) 133 | D = np.mean(D, axis=-1) 134 | 135 | # transform to scoremap 136 | score_map = D.reshape(*HW_map) 137 | score_map = cv2.resize(score_map, (self.shape_stretch[1], self.shape_stretch[0])) 138 | 139 | # apply gaussian smoothing on the score map 140 | score_map_smooth = gaussian_filter(score_map, sigma=4) 141 | 142 | return score_map_smooth, I 143 | 144 | def save_faiss_index(self, type_data): 145 | path_faiss_idx = '%s/%s.idx' % (self.path_trained, type_data) 146 | idx_feat_cpu = faiss.index_gpu_to_cpu(self.idx_feat) 147 | faiss.write_index(idx_feat_cpu, path_faiss_idx) 148 | 149 | def load_faiss_index(self, type_data): 150 | path_faiss_idx = '%s/%s.idx' % (self.path_trained, type_data) 151 | self.idx_feat = faiss.read_index(path_faiss_idx) 152 | 153 | def load_faiss_index_direct(self, path_faiss_idx): 154 | self.idx_feat = faiss.read_index(path_faiss_idx) 155 | 156 | def pickup_patch(self, idx_patch, imgs, HW_map, size_receptive): 157 | h = imgs.shape[-3] 158 | w = imgs.shape[-2] 159 | 160 | # calculate half size for split 161 | h_half = int((size_receptive - 1) / 2) 162 | w_half = int((size_receptive - 1) / 2) 163 | 164 | # calculate center-coordinates of split-image 165 | y_pitch = np.arange(0, (h - 1 + 1e-10), ((h - 1) / (HW_map[0] - 1))) 166 | y_pitch = np.round(y_pitch).astype(np.int16) 167 | y_pitch = y_pitch + h_half 168 | x_pitch = np.arange(0, (w - 1 + 1e-10), ((w - 1) / (HW_map[1] - 1))) 169 | x_pitch = np.round(x_pitch).astype(np.int16) 170 | x_pitch = x_pitch + w_half 171 | # padding to normal images 172 | imgs = np.pad(imgs, ((0, 0), (h_half, h_half), (w_half, w_half), (0, 0))) 173 | 174 | # collect piece image 175 | img_piece_list = [] 176 | for i_patch in idx_patch: 177 | i_img = i_patch // (HW_map[0] * HW_map[1]) 178 | i_HW = i_patch % (HW_map[0] * HW_map[1]) 179 | i_H = i_HW // HW_map[1] 180 | i_W = i_HW % HW_map[1] 181 | 182 | img = imgs[i_img] 183 | y = y_pitch[i_H] 184 | x = x_pitch[i_W] 185 | img_piece = img[(y - h_half):(y + h_half + 1), (x - w_half):(x + w_half + 1)] 186 | img_piece_list.append(img_piece) 187 | 188 | img_piece_array = np.stack(img_piece_list) 189 | 190 | return img_piece_array 191 | 192 | def save_coreset(self, idx_coreset, type_data, imgs_train, HW_map, size_receptive): 193 | # save patch images of coreset 194 | imgs_coreset = self.pickup_patch(idx_coreset, imgs_train, HW_map, size_receptive) 195 | path_coreset_img = '%s/%s_img_coreset.npy' % (self.path_trained, type_data) 196 | np.save(path_coreset_img, imgs_coreset) 197 | 198 | return imgs_coreset 199 | 200 | def load_coreset(self, type_data): 201 | # load patch images of coreset 202 | path_coreset_img = '%s/%s_img_coreset.npy' % (self.path_trained, type_data) 203 | imgs_coreset = np.load(path_coreset_img) 204 | 205 | return imgs_coreset 206 | 207 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss-gpu 2 | matplotlib 3 | numpy 4 | numba 5 | opencv-python 6 | scikit-learn 7 | scipy 8 | torch 9 | torchinfo 10 | torchvision 11 | tqdm 12 | -------------------------------------------------------------------------------- /traintest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import json 4 | 5 | import numpy as np 6 | import torch 7 | import multiprocessing as mp 8 | 9 | from utils.config import ConfigData, ConfigFeat, ConfigPatchCore, ConfigDraw 10 | from utils.tictoc import tic, toc 11 | from utils.metrics import calc_imagewise_metrics, calc_pixelwise_metrics 12 | from utils.visualize import draw_curve, draw_distance_graph, draw_heatmap 13 | from datasets.mvtec_dataset import MVTecDataset 14 | from models.feat_extract import FeatExtract 15 | from models.patchcore import PatchCore 16 | 17 | 18 | def arg_parser(): 19 | parser = ArgumentParser() 20 | # environment related 21 | parser.add_argument('-n', '--num_cpu_max', default=4, type=int, 22 | help='number of CPUs for parallel reading input images') 23 | parser.add_argument('-c', '--cpu', action='store_true', help='use cpu') 24 | 25 | # I/O and visualization related 26 | parser.add_argument('-pd', '--path_data', type=str, default='./mvtec_anomaly_detection', 27 | help='parent path of input data path') 28 | parser.add_argument('-pt', '--path_trained', type=str, default='./trained', 29 | help='output path of trained products') 30 | parser.add_argument('-pr', '--path_result', type=str, default='./result', 31 | help='output path of figure image as the evaluation result') 32 | parser.add_argument('-v', '--verbose', action='store_true', 33 | help='save visualization of localization') 34 | parser.add_argument('-srf', '--size_receptive', type=int, default=9, 35 | help='estimate and specify receptive field size (odd number)') 36 | parser.add_argument('-mv', '--mode_visualize', type=str, default='eval', 37 | choices=['eval', 'infer'], help='set mode, [eval] or [infer]') 38 | parser.add_argument('-sm', '--score_max', type=float, default=None, 39 | help='value for normalization of visualizing') 40 | 41 | # data loader related 42 | parser.add_argument('-bs', '--batch_size', type=int, default=16, 43 | help='batch-size for feature extraction by ImageNet model') 44 | parser.add_argument('-sr', '--size_resize', nargs=2, type=int, default=[256, 256], 45 | help='size of resizing input image') 46 | parser.add_argument('-sc', '--size_crop', nargs=2, type=int, default=[224, 224], 47 | help='size of cropping after resize') 48 | parser.add_argument('-fh', '--flip_horz', action='store_true', help='flip horizontal') 49 | parser.add_argument('-fv', '--flip_vert', action='store_true', help='flip vertical') 50 | parser.add_argument('-tt', '--types_data', nargs='*', type=str, default=None) 51 | # feature extraction related 52 | parser.add_argument('-b', '--backbone', type=str, 53 | default='torchvision.models.wide_resnet50_2', 54 | help='specify torchvision model with the full path') 55 | parser.add_argument('-w', '--weight', type=str, 56 | default='torchvision.models.Wide_ResNet50_2_Weights.IMAGENET1K_V1', 57 | help='specify the trained weights of torchvision model with the full path') 58 | parser.add_argument('-lm', '--layer_map', nargs='+', type=str, 59 | default=['layer2[-1]', 'layer3[-1]'], 60 | help='specify layers to extract feature map') 61 | parser.add_argument('-lw', '--layer_weight', nargs='+', type=float, 62 | default=[1.0, 1.0], 63 | help='specify layers weights for merge of feature map') 64 | parser.add_argument('-lmr', '--layer_merge_ref', type=str, default='layer2[-1]', 65 | help='specify the layer to use as a reference for spatial size when merging feature maps') 66 | 67 | # patchification related 68 | parser.add_argument('-sp', '--size_patch', type=int, default=3, 69 | help='patch pixel of feature map for increasing receptive field size') 70 | parser.add_argument('-de', '--dim_each_feat', type=int, default=1024, 71 | help='dimension of extract feature (at 1st adaptive average pooling)') 72 | parser.add_argument('-dm', '--dim_merge_feat', type=int, default=1024, 73 | help='dimension after layer feature merging (at 2nd adaptive average pooling)') 74 | 75 | # coreset related 76 | parser.add_argument('-s', '--seed', type=int, default=0, 77 | help='specify a random-seed for k-center-greedy') 78 | parser.add_argument('-ns', '--num_split_seq', type=int, default=1, 79 | help='percentage of coreset to all patch features') 80 | parser.add_argument('-pc', '--percentage_coreset', type=float, default=0.01, 81 | help='percentage of coreset to all patch features') 82 | parser.add_argument('-ds', '--dim_sampling', type=int, default=128, 83 | help='dimension to project features for sampling') 84 | parser.add_argument('-ni', '--num_initial_coreset', type=int, default=10, 85 | help='number of samples to initially randomly select coreset') 86 | # Nearest-Neighbor related 87 | parser.add_argument('-k', '--k', type=int, default=5, 88 | help='nearest neighbor\'s k for coreset searching') 89 | # post precessing related 90 | parser.add_argument('-pod', '--pixel_outer_decay', type=int, default=0, 91 | help='number of outer pixels to decay anomaly score') 92 | 93 | args = parser.parse_args() 94 | 95 | print('args =\n', args) 96 | return args 97 | 98 | 99 | def check_args(args): 100 | assert 0 < args.num_cpu_max < os.cpu_count() 101 | assert os.path.isdir(args.path_data) 102 | assert (args.size_receptive % 2) == 1 103 | assert args.size_receptive > 0 104 | if args.score_max is not None: 105 | assert args.score_max > 0 106 | assert args.batch_size > 0 107 | assert args.size_resize[0] > 0 108 | assert args.size_resize[1] > 0 109 | assert args.size_crop[0] > 0 110 | assert args.size_crop[1] > 0 111 | if args.types_data is not None: 112 | for type_data in args.types_data: 113 | assert os.path.isdir('%s/%s' % (args.path_data, type_data)) 114 | assert len(args.layer_map) == len(args.layer_weight) 115 | assert args.layer_merge_ref in args.layer_map 116 | assert args.size_patch > 0 117 | assert args.dim_each_feat > 0 118 | assert args.dim_merge_feat > 0 119 | assert args.num_split_seq > 0 120 | assert 0.0 < args.percentage_coreset <= 1.0 121 | assert args.dim_sampling > 0 122 | assert args.num_initial_coreset > 0 123 | assert args.k > 0 124 | assert args.pixel_outer_decay >= 0 125 | 126 | 127 | def set_seed(seed, gpu=True): 128 | np.random.seed(seed) 129 | torch.manual_seed(seed) 130 | if gpu: 131 | torch.cuda.manual_seed(seed) 132 | torch.cuda.manual_seed_all(seed) 133 | 134 | 135 | def apply_patchcore(type_data, feat_ext, patchcore, cfg_draw): 136 | print('\n----> PatchCore processing in %s start' % type_data) 137 | tic() 138 | 139 | # read images 140 | MVTecDataset(type_data) 141 | 142 | # reset neighbor 143 | patchcore.reset_faiss_index() 144 | 145 | # reset total index 146 | idx_coreset_total = [] 147 | 148 | # loop of split-sequential to apply k-center-greedy 149 | num_pitch = int(np.ceil(len(MVTecDataset.imgs_train) / patchcore.num_split_seq)) 150 | for i_split in range(patchcore.num_split_seq): 151 | # extract features 152 | i_from = i_split * num_pitch 153 | i_to = min(((i_split + 1) * num_pitch), len(MVTecDataset.imgs_train)) 154 | if patchcore.num_split_seq > 1: 155 | print('[split%02d] image index range is %d~%d' % (i_split, i_from, (i_to - 1))) 156 | feat_train = feat_ext.extract(MVTecDataset.imgs_train[i_from:i_to]) 157 | 158 | # coreset-reduced patch-feature memory bank 159 | idx_coreset = patchcore.compute_greedy_coreset_idx(feat_train) 160 | feat_train = feat_train[idx_coreset] 161 | 162 | # add feature as neighbor 163 | patchcore.add_neighbor(feat_train) 164 | 165 | # stock index 166 | offset_split = i_from * feat_ext.HW_map()[0] * feat_ext.HW_map()[1] 167 | idx_coreset_total.append(idx_coreset + offset_split) 168 | 169 | # save faiss index 170 | patchcore.save_faiss_index(type_data) 171 | 172 | # concat index 173 | idx_coreset_total = np.hstack(idx_coreset_total) 174 | 175 | # save and get images of coreset 176 | imgs_coreset = patchcore.save_coreset(idx_coreset_total, type_data, 177 | MVTecDataset.imgs_train, feat_ext.HW_map(), 178 | cfg_draw.size_receptive) 179 | 180 | # extract features 181 | feat_test = {} 182 | for type_test in MVTecDataset.imgs_test.keys(): 183 | feat_test[type_test] = feat_ext.extract(MVTecDataset.imgs_test[type_test], 184 | case='test (case:%s)' % type_test) 185 | 186 | # Sub-Image Anomaly Detection with Deep Pyramid Correspondences 187 | D, D_max, I = patchcore.localization(feat_test, feat_ext.HW_map()) 188 | 189 | # measure per image 190 | fpr_img, tpr_img, rocauc_img, pre_img, rec_img, prauc_img = calc_imagewise_metrics(D) 191 | print('%s imagewise ROCAUC: %.3f' % (type_data, rocauc_img)) 192 | 193 | # measure per pixel 194 | (fpr_pix, tpr_pix, rocauc_pix, 195 | pre_pix, rec_pix, prauc_pix, thresh_opt) = calc_pixelwise_metrics(D, MVTecDataset.gts_test) 196 | print('%s pixelwise ROCAUC: %.3f' % (type_data, rocauc_pix)) 197 | 198 | # save optimal threshold 199 | np.savetxt('%s/%s_thr.txt' % (args.path_trained, type_data), 200 | np.array([thresh_opt]), fmt='%.3f') 201 | 202 | toc(tag=('----> PatchCore processing in %s end, elapsed time' % type_data)) 203 | 204 | draw_distance_graph(type_data, cfg_draw, D, rocauc_img) 205 | if cfg_draw.verbose: 206 | draw_heatmap(type_data, cfg_draw, D, MVTecDataset.gts_test, D_max, I, 207 | MVTecDataset.imgs_test, MVTecDataset.files_test, 208 | imgs_coreset, feat_ext.HW_map()) 209 | 210 | return [fpr_img, tpr_img, rocauc_img, pre_img, rec_img, prauc_img, 211 | fpr_pix, tpr_pix, rocauc_pix, pre_pix, rec_pix, prauc_pix] 212 | 213 | 214 | def main(args): 215 | ConfigData(args) # static define for speed-up 216 | cfg_feat = ConfigFeat(args) 217 | cfg_patchcore = ConfigPatchCore(args) 218 | cfg_draw = ConfigDraw(args) 219 | 220 | feat_ext = FeatExtract(cfg_feat) 221 | patchcore = PatchCore(cfg_patchcore) 222 | 223 | os.makedirs(args.path_result, exist_ok=True) 224 | for type_data in ConfigData.types_data: 225 | os.makedirs('%s/%s' % (args.path_result, type_data), exist_ok=True) 226 | 227 | fpr_img = {} 228 | tpr_img = {} 229 | rocauc_img = {} 230 | pre_img = {} 231 | rec_img = {} 232 | prauc_img = {} 233 | fpr_pix = {} 234 | tpr_pix = {} 235 | rocauc_pix = {} 236 | pre_pix = {} 237 | rec_pix = {} 238 | prauc_pix = {} 239 | 240 | # loop for types of data 241 | for type_data in ConfigData.types_data: 242 | set_seed(seed=args.seed, gpu=(not args.cpu)) 243 | 244 | result = apply_patchcore(type_data, feat_ext, patchcore, cfg_draw) 245 | 246 | fpr_img[type_data] = result[0] 247 | tpr_img[type_data] = result[1] 248 | rocauc_img[type_data] = result[2] 249 | pre_img[type_data] = result[3] 250 | rec_img[type_data] = result[4] 251 | prauc_img[type_data] = result[5] 252 | 253 | fpr_pix[type_data] = result[6] 254 | tpr_pix[type_data] = result[7] 255 | rocauc_pix[type_data] = result[8] 256 | pre_pix[type_data] = result[9] 257 | rec_pix[type_data] = result[10] 258 | prauc_pix[type_data] = result[11] 259 | 260 | rocauc_img_mean = np.array([rocauc_img[type_data] for type_data in ConfigData.types_data]) 261 | rocauc_img_mean = np.mean(rocauc_img_mean) 262 | prauc_img_mean = np.array([prauc_img[type_data] for type_data in ConfigData.types_data]) 263 | prauc_img_mean = np.mean(prauc_img_mean) 264 | rocauc_pix_mean = np.array([rocauc_pix[type_data] for type_data in ConfigData.types_data]) 265 | rocauc_pix_mean = np.mean(rocauc_pix_mean) 266 | prauc_pix_mean = np.array([prauc_pix[type_data] for type_data in ConfigData.types_data]) 267 | prauc_pix_mean = np.mean(prauc_pix_mean) 268 | 269 | draw_curve(cfg_draw, fpr_img, tpr_img, rocauc_img, rocauc_img_mean, 270 | fpr_pix, tpr_pix, rocauc_pix, rocauc_pix_mean) 271 | draw_curve(cfg_draw, rec_img, pre_img, prauc_img, prauc_img_mean, 272 | rec_pix, pre_pix, prauc_pix, prauc_pix_mean, False) 273 | 274 | for type_data in ConfigData.types_data: 275 | print('rocauc_img[%s] = %.3f' % (type_data, rocauc_img[type_data])) 276 | print('rocauc_img[mean] = %.3f' % rocauc_img_mean) 277 | for type_data in ConfigData.types_data: 278 | print('prauc_img[%s] = %.3f' % (type_data, prauc_img[type_data])) 279 | print('prauc_img[mean] = %.3f' % prauc_img_mean) 280 | for type_data in ConfigData.types_data: 281 | print('rocauc_pix[%s] = %.3f' % (type_data, rocauc_pix[type_data])) 282 | print('rocauc_pix[mean] = %.3f' % rocauc_pix_mean) 283 | for type_data in ConfigData.types_data: 284 | print('prauc_pix[%s] = %.3f' % (type_data, prauc_pix[type_data])) 285 | print('prauc_pix[mean] = %.3f' % prauc_pix_mean) 286 | 287 | if __name__ == '__main__': 288 | args = arg_parser() 289 | check_args(args) 290 | main(args) 291 | 292 | with open('%s/args.json' % args.path_trained, mode='w') as f: 293 | json.dump(args.__dict__, f, indent=4) 294 | 295 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import torch 5 | 6 | # https://pytorch.org/vision/main/models/generated/torchvision.models.wide_resnet50_2.html 7 | MEAN = torch.FloatTensor([[[0.485, 0.456, 0.406]]]) 8 | STD = torch.FloatTensor([[[0.229, 0.224, 0.225]]]) 9 | 10 | 11 | class ConfigData: 12 | @classmethod 13 | def __init__(cls, args, mode_train=True): 14 | # file reading related 15 | if mode_train: 16 | cls.path_data = args.path_data 17 | else: 18 | cls.path_data = args.path_data 19 | cls.path_video = args.path_video 20 | cls.num_cpu_max = args.num_cpu_max 21 | if mode_train: 22 | cls.shuffle = (args.num_split_seq > 1) # for k-center-greedy split 23 | else: 24 | cls.shuffle = None 25 | 26 | if mode_train: 27 | # input format related 28 | cls.SHAPE_MIDDLE = (args.size_resize[0], args.size_resize[1]) # (H, W) 29 | cls.SHAPE_INPUT = (args.size_crop[0], args.size_crop[1]) # (H, W) 30 | cls.pixel_cut = (int((cls.SHAPE_MIDDLE[0] - cls.SHAPE_INPUT[0]) / 2), 31 | int((cls.SHAPE_MIDDLE[1] - cls.SHAPE_INPUT[1]) / 2)) # (H, W) 32 | else: 33 | cls.SHAPE_MIDDLE = None 34 | cls.SHAPE_INPUT = None 35 | cls.pixel_cut = None 36 | 37 | if mode_train: 38 | # augmantation related 39 | cls.flip_horz = args.flip_horz 40 | cls.flip_vert = args.flip_vert 41 | else: 42 | cls.flip_horz = None 43 | cls.flip_vert = None 44 | 45 | # collect types of data 46 | if args.types_data is None: 47 | types_data = [d for d in os.listdir(args.path_data) 48 | if os.path.isdir('%s/%s' % (args.path_data, d))] 49 | cls.types_data = np.sort(np.array(types_data)) 50 | else: 51 | cls.types_data = np.sort(np.array(args.types_data)) 52 | 53 | @classmethod 54 | def follow(cls, args_trained): 55 | # input format related 56 | cls.SHAPE_MIDDLE = (args_trained['size_resize'][0], 57 | args_trained['size_resize'][1]) # (H, W) 58 | cls.SHAPE_INPUT = (args_trained['size_crop'][0], 59 | args_trained['size_crop'][1]) # (H, W) 60 | cls.pixel_cut = (int((cls.SHAPE_MIDDLE[0] - cls.SHAPE_INPUT[0]) / 2), 61 | int((cls.SHAPE_MIDDLE[1] - cls.SHAPE_INPUT[1]) / 2)) # (H, W) 62 | 63 | 64 | class ConfigFeat: 65 | def __init__(self, args, mode_train=True): 66 | # adjsut to environment 67 | if args.cpu: 68 | self.device = torch.device('cpu') 69 | else: 70 | self.device = torch.device('cuda:0') 71 | 72 | # batch-size for feature extraction by ImageNet model 73 | self.batch_size = args.batch_size 74 | 75 | if mode_train: 76 | # input format related 77 | self.SHAPE_INPUT = (args.size_crop[0], args.size_crop[1]) # (H, W) 78 | 79 | # base network 80 | self.backbone = args.backbone 81 | self.weight = args.weight 82 | 83 | # layer specification 84 | self.layer_map = args.layer_map 85 | self.layer_weight = args.layer_weight 86 | self.layer_merge_ref = args.layer_merge_ref 87 | 88 | # patch pixel of feature map for increasing receptive field size 89 | self.size_patch = args.size_patch 90 | 91 | # dimension of each layer feature (at 1st adaptive average pooling) 92 | self.dim_each_feat = args.dim_each_feat 93 | # dimension after layer feature merging (at 2nd adaptive average pooling) 94 | self.dim_merge_feat = args.dim_merge_feat 95 | else: 96 | self.SHAPE_INPUT = None 97 | self.backbone = None 98 | self.weight = None 99 | self.layer_map = None 100 | self.layer_weight = None 101 | self.layer_merge_ref = None 102 | self.size_patch = None 103 | self.dim_each_feat = None 104 | self.dim_merge_feat = None 105 | 106 | # adjust to the network's learning policy and the data conditions 107 | self.MEAN = MEAN.to(self.device) 108 | self.STD = STD.to(self.device) 109 | 110 | def follow(self, args_trained): 111 | # input format related 112 | self.SHAPE_INPUT = (args_trained['size_crop'][0], 113 | args_trained['size_crop'][1]) # (H, W) 114 | # base network 115 | self.backbone = args_trained['backbone'] 116 | self.weight = args_trained['weight'] 117 | # layer specification 118 | self.layer_map = args_trained['layer_map'] 119 | self.layer_weight = args_trained['layer_weight'] 120 | self.layer_merge_ref = args_trained['layer_merge_ref'] 121 | # patch pixel of feature map for increasing receptive field size 122 | self.size_patch = args_trained['size_patch'] 123 | # dimension of each layer feature (at 1st adaptive average pooling) 124 | self.dim_each_feat = args_trained['dim_each_feat'] 125 | # dimension after layer feature merging (at 2nd adaptive average pooling) 126 | self.dim_merge_feat = args_trained['dim_merge_feat'] 127 | 128 | 129 | class ConfigPatchCore: 130 | def __init__(self, args, mode_train=True): 131 | # adjsut to environment 132 | if args.cpu: 133 | self.device = torch.device('cpu') 134 | else: 135 | self.device = torch.device('cuda:0') 136 | 137 | if mode_train: 138 | # dimension after layer feature merging (at 2nd adaptive average pooling) 139 | self.dim_coreset_feat = args.dim_merge_feat 140 | 141 | # number split-sequential to apply k-center-greedy 142 | self.num_split_seq = args.num_split_seq 143 | # percentage of coreset to all patch features 144 | self.percentage_coreset = args.percentage_coreset 145 | # dimension to project features for sampling 146 | self.dim_sampling = args.dim_sampling 147 | # number of samples to initially randomly select coreset 148 | self.num_initial_coreset = args.num_initial_coreset 149 | 150 | # input format related 151 | self.shape_stretch = (args.size_crop[0], args.size_crop[1]) # (H, W) 152 | else: 153 | self.dim_coreset_feat = None 154 | self.num_split_seq = None 155 | self.percentage_coreset = None 156 | self.dim_sampling = None 157 | self.num_initial_coreset = None 158 | self.shape_stretch = None 159 | 160 | # number of nearest neighbor to get patch images 161 | self.k = args.k 162 | 163 | # consideration for the outer edge 164 | self.pixel_outer_decay = args.pixel_outer_decay 165 | 166 | # output path of trained something 167 | self.path_trained = args.path_trained 168 | 169 | def follow(self, args_trained): 170 | # dimension after layer feature merging (at 2nd adaptive average pooling) 171 | self.dim_coreset_feat = args_trained['dim_merge_feat'] 172 | # input format related 173 | self.shape_stretch = (args_trained['size_crop'][0], 174 | args_trained['size_crop'][1]) # (H, W) 175 | 176 | 177 | class ConfigDraw: 178 | def __init__(self, args, mode_train=True): 179 | # output detail or not (take a long time...) 180 | self.verbose = args.verbose 181 | 182 | # value for normalization of visualizing 183 | self.score_max = args.score_max 184 | 185 | # output filename related 186 | self.k = args.k 187 | 188 | if mode_train: 189 | # output filename related 190 | self.percentage_coreset = args.percentage_coreset 191 | 192 | # receptive field size 193 | self.size_receptive = args.size_receptive 194 | 195 | # aspect_ratio of output figure 196 | self.aspect_figure = args.size_crop[1] / args.size_crop[0] # W / H 197 | self.aspect_figure = np.round(self.aspect_figure, decimals=1) 198 | 199 | # visualize mode 200 | self.mode_visualize = args.mode_visualize 201 | self.mode_video = False 202 | else: 203 | self.percentage_coreset = None 204 | self.size_receptive = None 205 | self.aspect_figure = None 206 | self.mode_visualize = 'infer' 207 | if args.path_video is None: 208 | self.mode_video = False 209 | else: 210 | self.mode_video = True 211 | capture = cv2.VideoCapture(args.path_video) 212 | self.fps_video = capture.get(cv2.CAP_PROP_FPS) 213 | capture.release() 214 | 215 | # output path of figure 216 | self.path_result = args.path_result 217 | 218 | def follow(self, args_trained): 219 | # output filename related 220 | self.percentage_coreset = args_trained['percentage_coreset'] 221 | # receptive field size 222 | self.size_receptive = args_trained['size_receptive'] 223 | # aspect_ratio of output figure 224 | self.aspect_figure = args_trained['size_crop'][1] / args_trained['size_crop'][0] # W / H 225 | self.aspect_figure = np.round(self.aspect_figure, decimals=1) 226 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from sklearn.metrics import auc, roc_curve, precision_recall_curve 4 | 5 | 6 | def calc_imagewise_metrics(D, type_normal='good'): 7 | D_list = [] 8 | y_list = [] 9 | 10 | pbar = tqdm(total=int(len(D.keys()) + 5), 11 | desc='calculate imagewise metrics') 12 | for type_test in D.keys(): 13 | for i in range(len(D[type_test])): 14 | D_tmp = np.max(D[type_test][i]) 15 | y_tmp = int(type_test != type_normal) 16 | 17 | D_list.append(D_tmp) 18 | y_list.append(y_tmp) 19 | pbar.update(1) 20 | 21 | D_flat_list = np.array(D_list).reshape(-1) 22 | y_flat_list = np.array(y_list).reshape(-1) 23 | pbar.update(1) 24 | 25 | fpr, tpr, _ = roc_curve(y_flat_list, D_flat_list) 26 | pbar.update(1) 27 | 28 | rocauc = auc(fpr, tpr) 29 | pbar.update(1) 30 | 31 | pre, rec, _ = precision_recall_curve(y_flat_list, D_flat_list) 32 | pbar.update(1) 33 | 34 | # https://sinyi-chou.github.io/python-sklearn-precision-recall/ 35 | prauc = auc(rec, pre) 36 | pbar.update(1) 37 | pbar.close() 38 | 39 | return fpr, tpr, rocauc, pre, rec, prauc 40 | 41 | 42 | def calc_pixelwise_metrics(D, y): 43 | D_list = [] 44 | y_list = [] 45 | 46 | pbar = tqdm(total=int(len(D.keys()) + 6), 47 | desc='calculate pixelwise metrics') 48 | for type_test in D.keys(): 49 | for i in range(len(D[type_test])): 50 | D_tmp = D[type_test][i] 51 | y_tmp = y[type_test][i] 52 | 53 | D_list.append(D_tmp) 54 | y_list.append(y_tmp) 55 | pbar.update(1) 56 | 57 | D_flat_list = np.array(D_list).reshape(-1) 58 | y_flat_list = np.array(y_list).reshape(-1) 59 | pbar.update(1) 60 | 61 | fpr, tpr, _ = roc_curve(y_flat_list, D_flat_list) 62 | pbar.update(1) 63 | 64 | rocauc = auc(fpr, tpr) 65 | pbar.update(1) 66 | 67 | pre, rec, thresh = precision_recall_curve(y_flat_list, D_flat_list) 68 | pbar.update(1) 69 | 70 | prauc = auc(rec, pre) 71 | pbar.update(1) 72 | 73 | # https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master/blob/main/main.py#L193C1-L200C1 74 | # get optimal threshold 75 | a = 2 * pre * rec 76 | b = pre + rec 77 | f1 = np.divide(a, b, out=np.zeros_like(a), where=(b != 0)) 78 | i_opt = np.argmax(f1) 79 | thresh_opt = thresh[i_opt] 80 | pbar.update(1) 81 | pbar.close() 82 | 83 | print('pixelwise optimal threshold:%.3f (precision:%.3f, recall:%.3f)' % 84 | (thresh_opt, pre[i_opt], rec[i_opt])) 85 | 86 | return fpr, tpr, rocauc, pre, rec, prauc, thresh_opt 87 | -------------------------------------------------------------------------------- /utils/tictoc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | def tic(): 5 | # require to import time 6 | global start_time_tictoc 7 | start_time_tictoc = time.time() 8 | 9 | 10 | def toc(tag="elapsed time"): 11 | if "start_time_tictoc" in globals(): 12 | print("{}: {:.1f} [sec]".format(tag, time.time() - start_time_tictoc)) 13 | else: 14 | print("tic has not been called") 15 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm 6 | 7 | 8 | # https://github.com/gsurma/cnn_explainer/blob/main/utils.py 9 | def overlay_heatmap_on_image(img, heatmap, ratio_img=0.5): 10 | img = img.astype(np.float32) 11 | 12 | heatmap = 1 - np.clip(heatmap, 0, 1) 13 | heatmap = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET) 14 | heatmap = heatmap.astype(np.float32) 15 | 16 | overlay = (img * ratio_img) + (heatmap * (1 - ratio_img)) 17 | overlay = np.clip(overlay, 0, 255) 18 | overlay = overlay.astype(np.uint8) 19 | return overlay 20 | 21 | 22 | def draw_distance_graph(type_data, cfg_draw, D, rocauc_img): 23 | D_list = {} 24 | for type_test in D.keys(): 25 | D_list[type_test] = [] 26 | for i_D in range(len(D[type_test])): 27 | D_tmp = np.max(D[type_test][i_D]) 28 | D_list[type_test].append(D_tmp) 29 | D_list[type_test] = np.array(D_list[type_test]) 30 | 31 | plt.figure(figsize=(10, 8), dpi=100, facecolor='white') 32 | 33 | # 'good' 1st 34 | N_test = 0 35 | type_test = 'good' 36 | plt.subplot(2, 1, 1) 37 | plt.scatter((np.arange(len(D_list[type_test])) + N_test), D_list[type_test], 38 | alpha=0.5, label=type_test) 39 | plt.subplot(2, 1, 2) 40 | plt.hist(D_list[type_test], alpha=0.5, label=type_test, bins=10) 41 | 42 | # other than 'good' 43 | N_test += len(D_list[type_test]) 44 | types_test = np.array([k for k in D_list.keys() if k != 'good']) 45 | for type_test in types_test: 46 | plt.subplot(2, 1, 1) 47 | plt.scatter((np.arange(len(D_list[type_test])) + N_test), D_list[type_test], alpha=0.5, label=type_test) 48 | plt.subplot(2, 1, 2) 49 | plt.hist(D_list[type_test], alpha=0.5, label=type_test, bins=10) 50 | N_test += len(D_list[type_test]) 51 | 52 | plt.subplot(2, 1, 1) 53 | plt.title('imagewise ROCAUC %% : %.3f' % rocauc_img) 54 | plt.grid() 55 | plt.legend(loc='upper left') 56 | plt.subplot(2, 1, 2) 57 | plt.grid() 58 | plt.legend(loc='upper right') 59 | plt.gcf().tight_layout() 60 | plt.gcf().savefig(('%s/%s/pred-dist_%s_p%04d_k%02d_rocaucimg%04d.png' % 61 | (cfg_draw.path_result, type_data, 62 | type_data, (cfg_draw.percentage_coreset * 1000), 63 | cfg_draw.k, round(rocauc_img * 1000)))) 64 | plt.clf() 65 | plt.close() 66 | 67 | 68 | def draw_heatmap(type_data, cfg_draw, D, y, D_max, I, imgs_test, files_test, 69 | imgs_coreset, HW_map): 70 | 71 | if cfg_draw.mode_visualize == 'eval': 72 | fig_width = 10 * max(1, cfg_draw.aspect_figure) 73 | fig_height = 18 74 | pixel_cut=[160, 140, 60, 60] # [top, bottom, left right] 75 | else: 76 | fig_width = 20 * max(1, cfg_draw.aspect_figure) 77 | fig_height = 16 78 | pixel_cut=[140, 120, 180, 180] # [top, bottom, left right] 79 | dpi = 100 80 | 81 | for type_test in D.keys(): 82 | 83 | if cfg_draw.mode_video: 84 | filename_out = ('%s/%s/localization_%s_%s_p%04d_k%02d.mp4' % 85 | (cfg_draw.path_result, type_data, 86 | type_data, files_test[type_test][0].split('.')[0], 87 | (cfg_draw.percentage_coreset * 1000), cfg_draw.k)) 88 | # build writer 89 | codecs = 'mp4v' 90 | fourcc = cv2.VideoWriter_fourcc(*codecs) 91 | width = (fig_width * dpi) - pixel_cut[2] - pixel_cut[3] 92 | height = (fig_height * dpi) - pixel_cut[0] - pixel_cut[1] 93 | writer = cv2.VideoWriter(filename_out, fourcc, cfg_draw.fps_video, 94 | (width, height), True) 95 | 96 | desc = '[verbose mode] visualize localization (case:%s)' % type_test 97 | for i_D in tqdm(range(len(D[type_test])), desc=desc): 98 | file = files_test[type_test][i_D] 99 | img = imgs_test[type_test][i_D] 100 | score_map = D[type_test][i_D] 101 | if cfg_draw.score_max is None: 102 | score_max = D_max 103 | else: 104 | score_max = cfg_draw.score_max 105 | if y is not None: 106 | gt = y[type_test][i_D] 107 | 108 | I_tmp = I[type_test][i_D, :, 0] 109 | img_patch = assemble_patch(I_tmp, imgs_coreset, HW_map) 110 | 111 | plt.figure(figsize=(fig_width, fig_height), dpi=dpi, facecolor='white') 112 | plt.rcParams['font.size'] = 10 113 | 114 | score_map_reg = score_map / score_max 115 | 116 | if cfg_draw.mode_visualize == 'eval': 117 | plt.subplot2grid((7, 3), (0, 0), rowspan=1, colspan=1) 118 | plt.imshow(img) 119 | plt.title('%s : %s' % (file.split('/')[-2], file.split('/')[-1])) 120 | 121 | plt.subplot2grid((7, 3), (0, 1), rowspan=1, colspan=1) 122 | plt.imshow(gt) 123 | 124 | plt.subplot2grid((7, 3), (0, 2), rowspan=1, colspan=1) 125 | plt.imshow(score_map) 126 | plt.colorbar() 127 | plt.title('max score : %.2f' % score_max) 128 | 129 | plt.subplot2grid((42, 2), (7, 0), rowspan=10, colspan=1) 130 | plt.imshow(overlay_heatmap_on_image(img, score_map_reg)) 131 | 132 | plt.subplot2grid((42, 2), (7, 1), rowspan=10, colspan=1) 133 | plt.imshow((img.astype(np.float32) * score_map_reg[..., None]).astype(np.uint8)) 134 | 135 | plt.subplot2grid((21, 1), (10, 0), rowspan=11, colspan=1) 136 | plt.imshow(img_patch, interpolation='none') 137 | plt.title('patch images created with top1-NN') 138 | 139 | elif cfg_draw.mode_visualize == 'infer': 140 | plt.subplot(2, 2, 1) 141 | plt.imshow(img) 142 | plt.title('%s : %s' % (file.split('/')[-2], file.split('/')[-1])) 143 | 144 | plt.subplot(2, 2, 2) 145 | plt.imshow(score_map) 146 | plt.colorbar() 147 | plt.title('max score : %.2f' % score_max) 148 | 149 | plt.subplot(2, 2, 3) 150 | plt.imshow(overlay_heatmap_on_image(img, score_map_reg)) 151 | 152 | plt.subplot(2, 2, 4) 153 | plt.imshow(img_patch, interpolation='none') 154 | plt.title('patch images created with top1-NN') 155 | 156 | score_tmp = np.max(score_map) / score_max * 100 157 | plt.gcf().canvas.draw() 158 | img_figure = np.fromstring(plt.gcf().canvas.tostring_rgb(), dtype='uint8') 159 | img_figure = img_figure.reshape(fig_height * dpi, -1, 3) 160 | img_figure = img_figure[pixel_cut[0]:(img_figure.shape[0] - pixel_cut[1]), 161 | pixel_cut[2]:(img_figure.shape[1] - pixel_cut[3])] 162 | 163 | if not cfg_draw.mode_video: 164 | filename_out = ('%s/%s/localization_%s_%s_%s_p%04d_k%02d_s%03d.png' % 165 | (cfg_draw.path_result, type_data, 166 | type_data, type_test, os.path.basename(file).split('.')[0], 167 | (cfg_draw.percentage_coreset * 1000), cfg_draw.k, 168 | round(score_tmp))) 169 | cv2.imwrite(filename_out, img_figure[..., ::-1]) 170 | else: 171 | writer.write(img_figure[..., ::-1]) 172 | 173 | plt.clf() 174 | plt.close() 175 | 176 | if cfg_draw.mode_video: 177 | writer.release() 178 | 179 | 180 | def assemble_patch(idx_patch, imgs_coreset, HW_map): 181 | size_receptive = imgs_coreset.shape[1] 182 | img_patch = np.zeros([(HW_map[0] * size_receptive), 183 | (HW_map[1] * size_receptive), 3], dtype=np.uint8) 184 | 185 | # reset counter 186 | i_y = 0 187 | i_x = 0 188 | 189 | # loop of patch feature index 190 | for i_patch in idx_patch: 191 | # tile... 192 | img_piece = imgs_coreset[i_patch] 193 | y = i_y * size_receptive 194 | x = i_x * size_receptive 195 | img_patch[y:(y + size_receptive), x:(x + size_receptive)] = img_piece 196 | 197 | # count-up 198 | i_x += 1 199 | if i_x == HW_map[1]: 200 | i_x = 0 201 | i_y += 1 202 | 203 | return img_patch 204 | 205 | 206 | def draw_curve(cfg_draw, x_img, y_img, auc_img, auc_img_mean, 207 | x_pix, y_pix, auc_pix, auc_pix_mean, flg_roc=True): 208 | if flg_roc: 209 | idx = 'ROC' 210 | lbl_x = 'False Positive Rate' 211 | lbl_y = 'True Positive Rate' 212 | else: 213 | idx = 'PR' 214 | lbl_x = 'Recall' 215 | lbl_y = 'Precision' 216 | 217 | plt.figure(figsize=(15, 6), dpi=100, facecolor='white') 218 | for type_data in x_img.keys(): 219 | plt.subplot(1, 2, 1) 220 | plt.plot(x_img[type_data], y_img[type_data], 221 | label='%s %sAUC: %.3f' % (type_data, idx, auc_img[type_data])) 222 | plt.subplot(1, 2, 2) 223 | plt.plot(x_pix[type_data], y_pix[type_data], 224 | label='%s %sAUC: %.3f' % (type_data, idx, auc_pix[type_data])) 225 | 226 | plt.subplot(1, 2, 1) 227 | plt.title('imagewise %sAUC %% : mean %.3f' % (idx, auc_img_mean)) 228 | plt.grid() 229 | plt.legend(loc='lower right') 230 | plt.xlabel(lbl_x) 231 | plt.ylabel(lbl_y) 232 | plt.subplot(1, 2, 2) 233 | plt.title('pixelwise %sAUC %% : mean %.3f' % (idx, auc_pix_mean)) 234 | plt.grid() 235 | plt.legend(loc='lower right') 236 | plt.xlabel(lbl_x) 237 | plt.ylabel(lbl_y) 238 | plt.gcf().tight_layout() 239 | plt.gcf().savefig('%s/%s-curve_p%04d_k%02d_aucimg%04d_aucpix%04d.png' % 240 | (cfg_draw.path_result, idx.lower(), 241 | (cfg_draw.percentage_coreset * 1000), cfg_draw.k, 242 | round(auc_img_mean * 1000), round(auc_pix_mean * 1000))) 243 | plt.clf() 244 | plt.close() 245 | --------------------------------------------------------------------------------