├── capture
└── capture.jpg
├── requirements.txt
├── embeddings
└── carpet
│ └── embedding.pickle
├── sampling_methods
├── sampling_def.py
└── kcenter_greedy.py
├── .gitignore
├── README.md
├── LICENSE
└── train.py
/capture/capture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hcw-00/PatchCore_anomaly_detection/HEAD/capture/capture.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python==4.5.2.52
2 | scikit-learn==0.24.2
3 | pytorch-lightning==1.3.3
4 | faiss-gpu==1.7.1.post3
--------------------------------------------------------------------------------
/embeddings/carpet/embedding.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hcw-00/PatchCore_anomaly_detection/HEAD/embeddings/carpet/embedding.pickle
--------------------------------------------------------------------------------
/sampling_methods/sampling_def.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Abstract class for sampling methods.
16 |
17 | Provides interface to sampling methods that allow same signature
18 | for select_batch. Each subclass implements select_batch_ with the desired
19 | signature for readability.
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import abc
27 | import numpy as np
28 |
29 | class SamplingMethod(object):
30 | __metaclass__ = abc.ABCMeta
31 |
32 | @abc.abstractmethod
33 | def __init__(self, X, y, seed, **kwargs):
34 | self.X = X
35 | self.y = y
36 | self.seed = seed
37 |
38 | def flatten_X(self):
39 | shape = self.X.shape
40 | flat_X = self.X
41 | if len(shape) > 2:
42 | flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:])))
43 | return flat_X
44 |
45 |
46 | @abc.abstractmethod
47 | def select_batch_(self):
48 | return
49 |
50 | def select_batch(self, **kwargs):
51 | return self.select_batch_(**kwargs)
52 |
53 | def to_dict(self):
54 | return None
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # embedding.pickle
132 | ./embeddings/
133 |
134 | *.png
135 | *.jpg
136 | *.ckpt
137 | test/*
138 | MVTec/*
--------------------------------------------------------------------------------
/sampling_methods/kcenter_greedy.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Returns points that minimizes the maximum distance of any point to a center.
16 |
17 | Implements the k-Center-Greedy method in
18 | Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for
19 | Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017
20 |
21 | Distance metric defaults to l2 distance. Features used to calculate distance
22 | are either raw features or if a model has transform method then uses the output
23 | of model.transform(X).
24 |
25 | Can be extended to a robust k centers algorithm that ignores a certain number of
26 | outlier datapoints. Resulting centers are solution to multiple integer program.
27 | """
28 |
29 | from __future__ import absolute_import
30 | from __future__ import division
31 | from __future__ import print_function
32 |
33 | import numpy as np
34 | from sklearn.metrics import pairwise_distances
35 | from sampling_methods.sampling_def import SamplingMethod
36 |
37 |
38 | class kCenterGreedy(SamplingMethod):
39 |
40 | def __init__(self, X, y, seed, metric='euclidean'):
41 | self.X = X
42 | self.y = y
43 | self.flat_X = self.flatten_X()
44 | self.name = 'kcenter'
45 | self.features = self.flat_X
46 | self.metric = metric
47 | self.min_distances = None
48 | self.n_obs = self.X.shape[0]
49 | self.already_selected = []
50 |
51 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False):
52 | """Update min distances given cluster centers.
53 |
54 | Args:
55 | cluster_centers: indices of cluster centers
56 | only_new: only calculate distance for newly selected points and update
57 | min_distances.
58 | rest_dist: whether to reset min_distances.
59 | """
60 |
61 | if reset_dist:
62 | self.min_distances = None
63 | if only_new:
64 | cluster_centers = [d for d in cluster_centers
65 | if d not in self.already_selected]
66 | if cluster_centers:
67 | # Update min_distances for all examples given new cluster center.
68 | x = self.features[cluster_centers]
69 | dist = pairwise_distances(self.features, x, metric=self.metric)
70 |
71 | if self.min_distances is None:
72 | self.min_distances = np.min(dist, axis=1).reshape(-1,1)
73 | else:
74 | self.min_distances = np.minimum(self.min_distances, dist)
75 |
76 | def select_batch_(self, model, already_selected, N, **kwargs):
77 | """
78 | Diversity promoting active learning method that greedily forms a batch
79 | to minimize the maximum distance to a cluster center among all unlabeled
80 | datapoints.
81 |
82 | Args:
83 | model: model with scikit-like API with decision_function implemented
84 | already_selected: index of datapoints already selected
85 | N: batch size
86 |
87 | Returns:
88 | indices of points selected to minimize distance to cluster centers
89 | """
90 |
91 | try:
92 | # Assumes that the transform function takes in original data and not
93 | # flattened data.
94 | print('Getting transformed features...')
95 | self.features = model.transform(self.X)
96 | print('Calculating distances...')
97 | self.update_distances(already_selected, only_new=False, reset_dist=True)
98 | except:
99 | print('Using flat_X as features.')
100 | self.update_distances(already_selected, only_new=True, reset_dist=False)
101 |
102 | new_batch = []
103 |
104 | for _ in range(N):
105 | if self.already_selected is None:
106 | # Initialize centers with a randomly selected datapoint
107 | ind = np.random.choice(np.arange(self.n_obs))
108 | else:
109 | ind = np.argmax(self.min_distances)
110 | # New examples should not be in already selected since those points
111 | # should have min_distance of zero to a cluster center.
112 | assert ind not in already_selected
113 |
114 | self.update_distances([ind], only_new=True, reset_dist=False)
115 | new_batch.append(ind)
116 | print('Maximum distance from cluster centers is %0.2f'
117 | % max(self.min_distances))
118 |
119 |
120 | self.already_selected = already_selected
121 |
122 | return new_batch
123 |
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PatchCore anomaly detection
2 | Unofficial implementation of PatchCore(new SOTA) anomaly detection model
3 |
4 |
5 | Original Paper :
6 | Towards Total Recall in Industrial Anomaly Detection (Jun 2021)
7 | Karsten Roth, Latha Pemula, Joaquin Zepeda, Bernhard Schölkopf, Thomas Brox, Peter Gehler
8 |
9 |
10 | https://arxiv.org/abs/2106.08265
11 | https://paperswithcode.com/sota/anomaly-detection-on-mvtec-ad
12 |
13 | 
14 |
15 |
16 | updates(21/06/21) :
17 | - I used sklearn's SparseRandomProjection(ep=0.9) for random projection. I'm not confident with this.
18 | - I think exact value of "b nearest patch-features" is not presented in the paper. I just set 9. (args.n_neighbors)
19 | - In terms of NN search, author used "faiss". but not implemented in this code yet.
20 | - sample embeddings/carpet/embedding.pickle => coreset_sampling_ratio=0.001
21 |
22 | updates(21/06/26) :
23 | - A critical [issue](https://github.com/hcw-00/PatchCore_anomaly_detection/issues/3#issue-930229038) related to "locally aware patch" raised and fixed. Score table is updated.
24 |
25 | ### Usage
26 | ~~~
27 | # install python 3.6, torch==1.8.1, torchvision==0.9.1
28 | pip install -r requirements.txt
29 |
30 | python train.py --phase train or test --dataset_path .../mvtec_anomaly_detection --category carpet --project_root_path path/to/save/results --coreset_sampling_ratio 0.01 --n_neighbors 9'
31 |
32 | # for fast try just specify your dataset_path and run
33 | python train.py --phase test --dataset_path .../mvtec_anomaly_detection --project_root_path ./
34 | ~~~
35 |
36 | ### MVTecAD AUROC score (PatchCore-1%, mean of n trials)
37 | | Category | Paper
(image-level) | This code
(image-level) | Paper
(pixel-level) | This code
(pixel-level) |
38 | | :-----: | :-: | :-: | :-: | :-: |
39 | | carpet | 0.980 | 0.991(1) | 0.989 | 0.989(1) |
40 | | grid | 0.986 | 0.975(1) | 0.986 | 0.975(1) |
41 | | leather | 1.000 | 1.000(1) | 0.993 | 0.991(1) |
42 | | tile | 0.994 | 0.994(1) | 0.961 | 0.949(1) |
43 | | wood | 0.992 | 0.989(1) | 0.951 | 0.936(1) |
44 | | bottle | 1.000 | 1.000(1) | 0.985 | 0.981(1) |
45 | | cable | 0.993 | 0.995(1) | 0.982 | 0.983(1) |
46 | | capsule | 0.980 | 0.976(1) | 0.988 | 0.989(1) |
47 | | hazelnut | 1.000 | 1.000(1) | 0.986 | 0.985(1) |
48 | | metal nut | 0.997 | 0.999(1) | 0.984 | 0.984(1) |
49 | | pill | 0.970 | 0.959(1) | 0.971 | 0.977(1) |
50 | | screw | 0.964 | 0.949(1) | 0.992 | 0.977(1) |
51 | | toothbrush | 1.000 | 1.000(1) | 0.985 | 0.986(1) |
52 | | transistor | 0.999 | 1.000(1) | 0.949 | 0.972(1) |
53 | | zipper | 0.992 | 0.995(1) | 0.988 | 0.984(1) |
54 | | mean | 0.990 | 0.988 | 0.980 | 0.977 |
55 |
56 | ### Code Reference
57 | kcenter algorithm :
58 | https://github.com/google/active-learning
59 | embedding concat function :
60 | https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master
61 |
62 |
83 |