├── .gitignore
├── LICENSE
├── README.md
├── [demo] model inference.ipynb
├── [example] anomaly_simulation_strategy.ipynb
├── [example] model overview.ipynb
├── anomaly_mask.json
├── assets
└── memseg.gif
├── configs.yaml
├── data
├── __init__.py
├── dataset.py
├── factory.py
└── perlin.py
├── docker
└── Dockerfile
├── focal_loss.py
├── log.py
├── main.py
├── metrics
├── __init__.py
├── generic_util.py
└── pro_curve_util.py
├── models
├── __init__.py
├── coordatt.py
├── decoder.py
├── memory_module.py
├── memseg.py
└── msff.py
├── requirements.txt
├── saved_model
├── MemSeg-bottle
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-cable
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-capsule
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-carpet
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-grid
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-hazelnut
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-leather
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-metal_nut
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-pill
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-screw
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-tile
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-toothbrush
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-transistor
│ ├── best_score.json
│ └── latest_score.json
├── MemSeg-wood
│ ├── best_score.json
│ └── latest_score.json
└── MemSeg-zipper
│ ├── best_score.json
│ └── latest_score.json
├── scheduler.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Python ###
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # pytype static type analyzer
136 | .pytype/
137 |
138 | # Cython debug symbols
139 | cython_debug/
140 |
141 | # End of https://www.toptal.com/developers/gitignore/api/python
142 |
143 | # results
144 | *.pt
145 | wandb/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Jaehyuk Heo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MemSeg
2 | Unofficial re-implementation for [MemSeg: A semi-supervised method for image surface defect detection using differences and commonalities](https://arxiv.org/abs/2205.00908)
3 |
4 | # Environments
5 |
6 | - Docker image: nvcr.io/nvidia/pytorch:20.12-py3
7 |
8 | ```
9 | einops==0.5.0
10 | timm==0.5.4
11 | wandb==0.12.17
12 | omegaconf
13 | imgaug==0.4.0
14 | ```
15 |
16 |
17 | # Process
18 |
19 | ## 1. Anomaly Simulation Strategy
20 |
21 | - [notebook](https://github.com/TooTouch/MemSeg/blob/main/%5Bexample%5D%20anomaly_simulation_strategy.ipynb)
22 | - Describable Textures Dataset(DTD) [ [download](https://www.google.com/search?q=dtd+texture+dataset&rlz=1C5CHFA_enKR999KR999&oq=dtd+texture+dataset&aqs=chrome..69i57j69i60.2253j0j7&sourceid=chrome&ie=UTF-8) ]
23 |
24 |
25 |
26 |
27 |
28 | ## 2. Model Process
29 |
30 | - [notebook](https://github.com/TooTouch/MemSeg/blob/main/%5Bexample%5D%20model%20overview.ipynb)
31 |
32 |
33 |
34 |
35 |
36 |
37 | # Run
38 |
39 | **Example**
40 |
41 | ```bash
42 | python main.py configs=configs.yaml DATASET.target=bottle
43 | ```
44 |
45 | ## Demo
46 |
47 | ```
48 | voila "[demo] model inference.ipynb" --port ${port} --Voila.ip ${ip}
49 | ```
50 |
51 | 
52 |
53 | # Results
54 |
55 | - **Backbone**: ResNet18
56 |
57 | | target | AUROC-image | AUROC-pixel | AUPRO-pixel |
58 | |:-----------|--------------:|--------------:|--------------:|
59 | | leather | 100 | 98.83 | 99.09 |
60 | | pill | 97.05 | 98.29 | 97.96 |
61 | | carpet | 99.12 | 97.54 | 97.02 |
62 | | hazelnut | 100 | 97.78 | 99 |
63 | | tile | 99.86 | 99.38 | 98.81 |
64 | | cable | 92.5 | 82.3 | 87.31 |
65 | | toothbrush | 100 | 99.28 | 98.56 |
66 | | transistor | 96.5 | 76.29 | 86.06 |
67 | | zipper | 99.95 | 97.94 | 97.26 |
68 | | metal_nut | 99.46 | 88.48 | 95 |
69 | | grid | 99.83 | 98.37 | 98.53 |
70 | | bottle | 100 | 98.79 | 98.36 |
71 | | capsule | 95.41 | 98.43 | 97.73 |
72 | | screw | 94.86 | 95.08 | 94 |
73 | | wood | 100 | 97.54 | 97.62 |
74 | | **Average** | 98.3 | 94.96 | 96.15 |
75 |
76 | # Citation
77 |
78 | ```
79 | @article{DBLP:journals/corr/abs-2205-00908,
80 | author = {Minghui Yang and
81 | Peng Wu and
82 | Jing Liu and
83 | Hui Feng},
84 | title = {MemSeg: {A} semi-supervised method for image surface defect detection
85 | using differences and commonalities},
86 | journal = {CoRR},
87 | volume = {abs/2205.00908},
88 | year = {2022},
89 | url = {https://doi.org/10.48550/arXiv.2205.00908},
90 | doi = {10.48550/arXiv.2205.00908},
91 | eprinttype = {arXiv},
92 | eprint = {2205.00908},
93 | timestamp = {Tue, 03 May 2022 15:52:06 +0200},
94 | biburl = {https://dblp.org/rec/journals/corr/abs-2205-00908.bib},
95 | bibsource = {dblp computer science bibliography, https://dblp.org}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/[demo] model inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Model Select"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 9,
13 | "metadata": {
14 | "scrolled": false
15 | },
16 | "outputs": [
17 | {
18 | "data": {
19 | "application/vnd.jupyter.widget-view+json": {
20 | "model_id": "91564dc2a09d465baa8420dd651407c7",
21 | "version_major": 2,
22 | "version_minor": 0
23 | },
24 | "text/plain": [
25 | "HBox(children=(Dropdown(description='Model:', index=11, options=('leather', 'pill', 'carpet', 'hazelnut', 'til…"
26 | ]
27 | },
28 | "metadata": {},
29 | "output_type": "display_data"
30 | },
31 | {
32 | "data": {
33 | "application/vnd.jupyter.widget-view+json": {
34 | "model_id": "76944799ab664d25b738c192c8268f1c",
35 | "version_major": 2,
36 | "version_minor": 0
37 | },
38 | "text/plain": [
39 | "Output()"
40 | ]
41 | },
42 | "metadata": {},
43 | "output_type": "display_data"
44 | }
45 | ],
46 | "source": [
47 | "from data import create_dataset, create_dataloader\n",
48 | "\n",
49 | "import os\n",
50 | "import torch\n",
51 | "from omegaconf import OmegaConf\n",
52 | "from timm import create_model\n",
53 | "from models import MemSeg\n",
54 | "\n",
55 | "import ipywidgets as widgets\n",
56 | "from IPython.display import display, clear_output\n",
57 | "import matplotlib.pyplot as plt\n",
58 | "\n",
59 | "cfg = OmegaConf.load('./configs.yaml')\n",
60 | "\n",
61 | "# ====================================\n",
62 | "# Select Model\n",
63 | "# ====================================\n",
64 | "\n",
65 | "def load_model(target_name):\n",
66 | " global model\n",
67 | " global testset \n",
68 | " \n",
69 | " testset = create_dataset(\n",
70 | " datadir = cfg['DATASET']['datadir'],\n",
71 | " target = target_name, \n",
72 | " is_train = False,\n",
73 | " resize = cfg['DATASET']['resize'],\n",
74 | " texture_source_dir = cfg['DATASET']['texture_source_dir'],\n",
75 | " structure_grid_size = cfg['DATASET']['structure_grid_size'],\n",
76 | " transparency_range = cfg['DATASET']['transparency_range'],\n",
77 | " perlin_scale = cfg['DATASET']['perlin_scale'], \n",
78 | " min_perlin_scale = cfg['DATASET']['min_perlin_scale'], \n",
79 | " perlin_noise_threshold = cfg['DATASET']['perlin_noise_threshold']\n",
80 | " )\n",
81 | " \n",
82 | " memory_bank = torch.load(f'./saved_model/MemSeg-{target_name}/memory_bank.pt')\n",
83 | " memory_bank.device = 'cpu'\n",
84 | " for k in memory_bank.memory_information.keys():\n",
85 | " memory_bank.memory_information[k] = memory_bank.memory_information[k].cpu()\n",
86 | "\n",
87 | " feature_extractor = feature_extractor = create_model(\n",
88 | " cfg['MODEL']['feature_extractor_name'], \n",
89 | " pretrained = True, \n",
90 | " features_only = True\n",
91 | " )\n",
92 | " model = MemSeg(\n",
93 | " memory_bank = memory_bank,\n",
94 | " feature_extractor = feature_extractor\n",
95 | " )\n",
96 | "\n",
97 | " model.load_state_dict(torch.load(f'./saved_model/MemSeg-{target_name}/best_model.pt'))\n",
98 | " model.eval()\n",
99 | "\n",
100 | "# ====================================\n",
101 | "# Visualization\n",
102 | "# ====================================\n",
103 | "\n",
104 | "\n",
105 | "def result_plot(idx):\n",
106 | " input_i, mask_i, target_i = testset[idx]\n",
107 | "\n",
108 | " output_i = model(input_i.unsqueeze(0)).detach()\n",
109 | " output_i = torch.nn.functional.softmax(output_i, dim=1)\n",
110 | "\n",
111 | " def minmax_scaling(img):\n",
112 | " return (((img - img.min()) / (img.max() - img.min())) * 255).to(torch.uint8)\n",
113 | "\n",
114 | " fig, ax = plt.subplots(1,4, figsize=(15,10))\n",
115 | " \n",
116 | " ax[0].imshow(minmax_scaling(input_i.permute(1,2,0)))\n",
117 | " ax[0].set_title('Input: {}'.format('Normal' if target_i == 0 else 'Abnormal'))\n",
118 | " ax[1].imshow(mask_i, cmap='gray')\n",
119 | " ax[1].set_title('Ground Truth')\n",
120 | " ax[2].imshow(output_i[0][1], cmap='gray')\n",
121 | " ax[2].set_title('Predicted Mask')\n",
122 | " ax[3].imshow(minmax_scaling(input_i.permute(1,2,0)), alpha=1)\n",
123 | " ax[3].imshow(output_i[0][1], cmap='gray', alpha=0.5)\n",
124 | " ax[3].set_title(f'Input X Predicted Mask')\n",
125 | " \n",
126 | " plt.show()\n",
127 | " \n",
128 | "\n",
129 | "\n",
130 | "\n",
131 | "\n",
132 | "# ====================================\n",
133 | "# widgets\n",
134 | "# ====================================\n",
135 | "\n",
136 | "\n",
137 | "target_list = widgets.Dropdown(\n",
138 | " options=[m.split('-')[1] for m in os.listdir('./saved_model')],\n",
139 | " value='capsule',\n",
140 | " description='Model:',\n",
141 | " disabled=False,\n",
142 | ")\n",
143 | "button = widgets.Button(description=\"Model Change\")\n",
144 | "output = widgets.Output()\n",
145 | "\n",
146 | "\n",
147 | "@output.capture()\n",
148 | "def on_button_clicked(b):\n",
149 | " clear_output(wait=True)\n",
150 | " load_model(target_name=target_list.value)\n",
151 | " \n",
152 | " # vizualization\n",
153 | " file_list = widgets.Dropdown(\n",
154 | " options=[(file_path, i) for i, file_path in enumerate(testset.file_list)],\n",
155 | " value=0,\n",
156 | " description='image:',\n",
157 | " )\n",
158 | "\n",
159 | " \n",
160 | " widgets.interact(result_plot, idx=file_list)\n",
161 | "\n",
162 | "button.on_click(on_button_clicked)\n",
163 | "\n",
164 | "\n",
165 | "display(widgets.HBox([target_list, button]), output)\n",
166 | "\n"
167 | ]
168 | }
169 | ],
170 | "metadata": {
171 | "kernelspec": {
172 | "display_name": "Python 3 (ipykernel)",
173 | "language": "python",
174 | "name": "python3"
175 | },
176 | "language_info": {
177 | "codemirror_mode": {
178 | "name": "ipython",
179 | "version": 3
180 | },
181 | "file_extension": ".py",
182 | "mimetype": "text/x-python",
183 | "name": "python",
184 | "nbconvert_exporter": "python",
185 | "pygments_lexer": "ipython3",
186 | "version": "3.8.10"
187 | }
188 | },
189 | "nbformat": 4,
190 | "nbformat_minor": 4
191 | }
192 |
--------------------------------------------------------------------------------
/anomaly_mask.json:
--------------------------------------------------------------------------------
1 | {
2 | "hazelnut": {
3 | "use_mask" : true,
4 | "bg_threshold" : 50,
5 | "bg_reverse" : true
6 | },
7 | "leather": {
8 | "use_mask" : false,
9 | "bg_threshold" : null,
10 | "bg_reverse" : null
11 | },
12 | "metal_nut": {
13 | "use_mask" : true,
14 | "bg_threshold" : 40,
15 | "bg_reverse" : true
16 | },
17 | "tile": {
18 | "use_mask" : false,
19 | "bg_threshold" : null,
20 | "bg_reverse" : null
21 | },
22 | "wood": {
23 | "use_mask" : false,
24 | "bg_threshold" : null,
25 | "bg_reverse" : null
26 | },
27 | "grid": {
28 | "use_mask" : false,
29 | "bg_threshold" : null,
30 | "bg_reverse" : null
31 | },
32 | "cable": {
33 | "use_mask" : false,
34 | "bg_threshold" : 150,
35 | "bg_reverse" : true
36 | },
37 | "capsule": {
38 | "use_mask" : true,
39 | "bg_threshold" : 120,
40 | "bg_reverse" : false
41 | },
42 | "transistor": {
43 | "use_mask" : true,
44 | "bg_threshold" : 90,
45 | "bg_reverse" : false
46 | },
47 | "carpet": {
48 | "use_mask" : false,
49 | "bg_threshold" : null,
50 | "bg_reverse" : null
51 | },
52 | "bottle": {
53 | "use_mask" : true,
54 | "bg_threshold" : 250,
55 | "bg_reverse" : false
56 | },
57 | "screw": {
58 | "use_mask" : true,
59 | "bg_threshold" : 110,
60 | "bg_reverse" : false
61 | },
62 | "zipper": {
63 | "use_mask" : true,
64 | "bg_threshold" : 100,
65 | "bg_reverse" : false
66 | },
67 | "pill": {
68 | "use_mask" : true,
69 | "bg_threshold" : 100,
70 | "bg_reverse" : true
71 | },
72 | "toothbrush": {
73 | "use_mask" : true,
74 | "bg_threshold" : 30,
75 | "bg_reverse" : true
76 | }
77 | }
--------------------------------------------------------------------------------
/assets/memseg.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TooTouch/MemSeg/836bd465a9b14422f92666dc29dc36edce2692d0/assets/memseg.gif
--------------------------------------------------------------------------------
/configs.yaml:
--------------------------------------------------------------------------------
1 | EXP_NAME: MemSeg
2 | SEED: 42
3 |
4 | DATASET:
5 | datadir: /datasets/MVTec
6 | texture_source_dir: /datasets/dtd/images
7 | anomaly_mask_info: ./anomaly_mask.json
8 | target: null
9 | resize:
10 | - 288
11 | - 288
12 | imagesize: 256
13 | structure_grid_size: 8
14 | transparency_range:
15 | - 0.15 # under bound
16 | - 1. # upper bound
17 | perlin_scale: 6
18 | min_perlin_scale: 0
19 | perlin_noise_threshold: 0.5
20 |
21 | DATALOADER:
22 | batch_size: 8
23 | num_workers: 0
24 |
25 | MEMORYBANK:
26 | nb_memory_sample: 30
27 |
28 | MODEL:
29 | feature_extractor_name: resnet18
30 |
31 | TRAIN:
32 | batch_size: 8
33 | num_training_steps: 5000
34 | l1_weight: 0.6
35 | focal_weight: 0.4
36 | focal_alpha: null
37 | focal_gamma: 4
38 | use_wandb: True
39 |
40 | OPTIMIZER:
41 | lr: 0.003
42 | weight_decay: 0.0005
43 |
44 | SCHEDULER:
45 | min_lr: 0.0001
46 | warmup_ratio: 0.1
47 | use_scheduler: True
48 |
49 | LOG:
50 | log_interval: 1
51 | eval_interval: 100
52 |
53 | RESULT:
54 | savedir: ./saved_model
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .perlin import *
2 | from .factory import *
3 | from .dataset import *
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | from glob import glob
6 | from einops import rearrange
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 | import torchvision.transforms as transforms
11 | import imgaug.augmenters as iaa
12 |
13 | from data import rand_perlin_2d_np
14 |
15 | from typing import List
16 |
17 | import cv2
18 | import os
19 | import numpy as np
20 | from glob import glob
21 | from einops import rearrange
22 |
23 | import torch
24 | from torchvision import transforms
25 | from torch.utils.data import Dataset
26 | import imgaug.augmenters as iaa
27 |
28 | from .perlin import rand_perlin_2d_np
29 |
30 | from typing import List
31 |
32 | IMAGENET_MEAN = [0.485, 0.456, 0.406]
33 | IMAGENET_STD = [0.229, 0.224, 0.225]
34 |
35 | class MemSegDataset(Dataset):
36 | def __init__(
37 | self, datadir: str, target: str, is_train: bool, to_memory: bool = False,
38 | resize: List[int] = [256, 256], imagesize: int = 224,
39 | texture_source_dir: str = None, structure_grid_size: str = 8,
40 | transparency_range: List[float] = [0.15, 1.],
41 | perlin_scale: int = 6, min_perlin_scale: int = 0, perlin_noise_threshold: float = 0.5,
42 | use_mask: bool = True, bg_threshold: float = 100, bg_reverse: bool = False
43 | ):
44 | # mode
45 | self.is_train = is_train
46 | self.to_memory = to_memory
47 |
48 | # load image file list
49 | self.datadir = datadir
50 | self.target = target
51 | self.file_list = glob(os.path.join(self.datadir, self.target, 'train/*/*' if is_train else 'test/*/*'))
52 |
53 | # synthetic anomaly
54 | if self.is_train and not self.to_memory:
55 | # load texture image file list
56 | self.texture_source_file_list = glob(os.path.join(texture_source_dir,'*/*')) if texture_source_dir else None
57 |
58 | # perlin noise
59 | self.perlin_scale = perlin_scale
60 | self.min_perlin_scale = min_perlin_scale
61 | self.perlin_noise_threshold = perlin_noise_threshold
62 |
63 | # structure
64 | self.structure_grid_size = structure_grid_size
65 |
66 | # anomaly mixing
67 | self.transparency_range = transparency_range
68 |
69 | # mask setting
70 | self.use_mask = use_mask
71 | self.bg_threshold = bg_threshold
72 | self.bg_reverse = bg_reverse
73 |
74 | # transform
75 | self.resize = list(resize)
76 | self.transform_img = [
77 | transforms.ToPILImage(),
78 | transforms.CenterCrop(imagesize),
79 | transforms.ToTensor(),
80 | transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
81 | ]
82 | self.transform_img = transforms.Compose(self.transform_img)
83 |
84 | self.transform_mask = [
85 | transforms.ToPILImage(),
86 | transforms.CenterCrop(imagesize),
87 | transforms.ToTensor(),
88 | ]
89 | self.transform_mask = transforms.Compose(self.transform_mask)
90 |
91 | # sythetic anomaly switch
92 | self.anomaly_switch = False
93 |
94 | def __getitem__(self, idx):
95 |
96 | file_path = self.file_list[idx]
97 |
98 | # image
99 | img = Image.open(file_path).convert("RGB").resize(self.resize)
100 | img = np.array(img)
101 |
102 | # target
103 | target = 0 if 'good' in self.file_list[idx] else 1
104 |
105 | # mask
106 | if 'good' in file_path:
107 | mask = np.zeros(self.resize, dtype=np.float32)
108 | else:
109 | mask = Image.open(file_path.replace('test','ground_truth').replace('.png','_mask.png')).resize(self.resize)
110 | mask = np.array(mask)
111 |
112 | ## anomaly source
113 | if self.is_train and not self.to_memory:
114 | if self.anomaly_switch:
115 | img, mask = self.generate_anomaly(img=img, texture_img_list=self.texture_source_file_list)
116 | target = 1
117 | self.anomaly_switch = False
118 |
119 | mask = torch.Tensor(mask)
120 | else:
121 | self.anomaly_switch = True
122 |
123 | img = self.transform_img(img)
124 | mask = self.transform_mask(mask).squeeze()
125 |
126 | return img, mask, target
127 |
128 |
129 | def rand_augment(self):
130 | augmenters = [
131 | iaa.GammaContrast((0.5,2.0),per_channel=True),
132 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)),
133 | iaa.pillike.EnhanceSharpness(),
134 | iaa.AddToHueAndSaturation((-50,50),per_channel=True),
135 | iaa.Solarize(0.5, threshold=(32,128)),
136 | iaa.Posterize(),
137 | iaa.Invert(),
138 | iaa.pillike.Autocontrast(),
139 | iaa.pillike.Equalize(),
140 | iaa.Affine(rotate=(-45, 45))
141 | ]
142 |
143 | aug_idx = np.random.choice(np.arange(len(augmenters)), 3, replace=False)
144 | aug = iaa.Sequential([
145 | augmenters[aug_idx[0]],
146 | augmenters[aug_idx[1]],
147 | augmenters[aug_idx[2]]
148 | ])
149 |
150 | return aug
151 |
152 | def generate_anomaly(self, img: np.ndarray, texture_img_list: list = None) -> List[np.ndarray]:
153 | '''
154 | step 1. generate mask
155 | - target foreground mask
156 | - perlin noise mask
157 |
158 | step 2. generate texture or structure anomaly
159 | - texture: load DTD
160 | - structure: we first perform random adjustment of mirror symmetry, rotation, brightness, saturation,
161 | and hue on the input image 𝐼 . Then the preliminary processed image is uniformly divided into a 4×8 grid
162 | and randomly arranged to obtain the disordered image 𝐼
163 |
164 | step 3. blending image and anomaly source
165 | '''
166 |
167 | # step 1. generate mask
168 | img_size = img.shape[:-1] # H x W
169 |
170 | ## target foreground mask
171 | if self.use_mask:
172 | target_foreground_mask = self.generate_target_foreground_mask(img=img)
173 | else:
174 | target_foreground_mask = np.ones(self.resize)
175 |
176 | ## perlin noise mask
177 | perlin_noise_mask = self.generate_perlin_noise_mask(img_size=img_size)
178 |
179 | ## mask
180 | mask = perlin_noise_mask * target_foreground_mask
181 | mask_expanded = np.expand_dims(mask, axis=2)
182 |
183 | # step 2. generate texture or structure anomaly
184 |
185 | ## anomaly source
186 | anomaly_source_img = self.anomaly_source(img=img, texture_img_list=texture_img_list)
187 |
188 | ## mask anomaly parts
189 | factor = np.random.uniform(*self.transparency_range, size=1)[0]
190 | anomaly_source_img = factor * (mask_expanded * anomaly_source_img) + (1 - factor) * (mask_expanded * img)
191 |
192 | # step 3. blending image and anomaly source
193 | anomaly_source_img = ((- mask_expanded + 1) * img) + anomaly_source_img
194 |
195 | return (anomaly_source_img.astype(np.uint8), mask)
196 |
197 | def generate_target_foreground_mask(self, img: np.ndarray) -> np.ndarray:
198 | # convert RGB into GRAY scale
199 | img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
200 |
201 | # generate binary mask of gray scale image
202 | _, target_background_mask = cv2.threshold(img_gray, self.bg_threshold, 255, cv2.THRESH_BINARY)
203 | target_background_mask = target_background_mask.astype(np.bool).astype(np.int)
204 |
205 | # invert mask for foreground mask
206 | if self.bg_reverse:
207 | target_foreground_mask = target_background_mask
208 | else:
209 | target_foreground_mask = -(target_background_mask - 1)
210 |
211 | return target_foreground_mask
212 |
213 | def generate_perlin_noise_mask(self, img_size: tuple) -> np.ndarray:
214 | # define perlin noise scale
215 | perlin_scalex = 2 ** (torch.randint(self.min_perlin_scale, self.perlin_scale, (1,)).numpy()[0])
216 | perlin_scaley = 2 ** (torch.randint(self.min_perlin_scale, self.perlin_scale, (1,)).numpy()[0])
217 |
218 | # generate perlin noise
219 | perlin_noise = rand_perlin_2d_np(img_size, (perlin_scalex, perlin_scaley))
220 |
221 | # apply affine transform
222 | rot = iaa.Affine(rotate=(-90, 90))
223 | perlin_noise = rot(image=perlin_noise)
224 |
225 | # make a mask by applying threshold
226 | mask_noise = np.where(
227 | perlin_noise > self.perlin_noise_threshold,
228 | np.ones_like(perlin_noise),
229 | np.zeros_like(perlin_noise)
230 | )
231 |
232 | return mask_noise
233 |
234 | def anomaly_source(self, img: np.ndarray, texture_img_list: list = None) -> np.ndarray:
235 | p = np.random.uniform() if texture_img_list else 1.0
236 | if p < 0.5:
237 | idx = np.random.choice(len(texture_img_list))
238 | img_size = img.shape[:-1] # H x W
239 | anomaly_source_img = self._texture_source(img_size=img_size, texture_img_path=texture_img_list[idx])
240 | else:
241 | anomaly_source_img = self._structure_source(img=img)
242 |
243 | return anomaly_source_img
244 |
245 | def _texture_source(self, img_size: tuple, texture_img_path: str) -> np.ndarray:
246 | texture_source_img = cv2.imread(texture_img_path)
247 | texture_source_img = cv2.cvtColor(texture_source_img, cv2.COLOR_BGR2RGB)
248 | texture_source_img = cv2.resize(texture_source_img, dsize=img_size).astype(np.float32)
249 |
250 | return texture_source_img
251 |
252 | def _structure_source(self, img: np.ndarray) -> np.ndarray:
253 | structure_source_img = self.rand_augment()(image=img)
254 |
255 | img_size = img.shape[:-1] # H x W
256 |
257 | assert img_size[0] % self.structure_grid_size == 0, 'structure should be devided by grid size accurately'
258 | grid_w = img_size[1] // self.structure_grid_size
259 | grid_h = img_size[0] // self.structure_grid_size
260 |
261 | structure_source_img = rearrange(
262 | tensor = structure_source_img,
263 | pattern = '(h gh) (w gw) c -> (h w) gw gh c',
264 | gw = grid_w,
265 | gh = grid_h
266 | )
267 | disordered_idx = np.arange(structure_source_img.shape[0])
268 | np.random.shuffle(disordered_idx)
269 |
270 | structure_source_img = rearrange(
271 | tensor = structure_source_img[disordered_idx],
272 | pattern = '(h w) gw gh c -> (h gh) (w gw) c',
273 | h = self.structure_grid_size,
274 | w = self.structure_grid_size
275 | ).astype(np.float32)
276 |
277 | return structure_source_img
278 |
279 | def __len__(self):
280 | return len(self.file_list)
281 |
282 |
283 |
--------------------------------------------------------------------------------
/data/factory.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from typing import List
3 | from .dataset import MemSegDataset
4 |
5 | def create_dataset(
6 | datadir: str, target: str, is_train: bool, to_memory: bool = False,
7 | resize: List[int] = [256, 256], imagesize: int = 224,
8 | texture_source_dir: str = None, structure_grid_size: str = 8,
9 | transparency_range: List[float] = [0.15, 1.],
10 | perlin_scale: int = 6, min_perlin_scale: int = 0, perlin_noise_threshold: float = 0.5,
11 | use_mask: bool = True, bg_threshold: float = 100, bg_reverse: bool = False
12 | ):
13 | dataset = MemSegDataset(
14 | datadir = datadir,
15 | target = target,
16 | is_train = is_train,
17 | to_memory = to_memory,
18 | resize = resize,
19 | imagesize = imagesize,
20 | texture_source_dir = texture_source_dir,
21 | structure_grid_size = structure_grid_size,
22 | transparency_range = transparency_range,
23 | perlin_scale = perlin_scale,
24 | min_perlin_scale = min_perlin_scale,
25 | perlin_noise_threshold = perlin_noise_threshold,
26 | use_mask = use_mask,
27 | bg_threshold = bg_threshold,
28 | bg_reverse = bg_reverse
29 | )
30 |
31 | return dataset
32 |
33 |
34 | def create_dataloader(dataset, train: bool, batch_size: int = 16, num_workers: int = 1):
35 | dataloader = DataLoader(
36 | dataset,
37 | shuffle = train,
38 | batch_size = batch_size,
39 | num_workers = num_workers
40 | )
41 |
42 | return dataloader
--------------------------------------------------------------------------------
/data/perlin.py:
--------------------------------------------------------------------------------
1 | # https://github.com/VitjanZ/DRAEM/blob/main/perlin.py
2 |
3 | import torch
4 | import math
5 | import numpy as np
6 |
7 | def lerp_np(x,y,w):
8 | fin_out = (y-x)*w + x
9 | return fin_out
10 |
11 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5):
12 | noise = np.zeros(shape)
13 | frequency = 1
14 | amplitude = 1
15 | for _ in range(octaves):
16 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1]))
17 | frequency *= 2
18 | amplitude *= persistence
19 | return noise
20 |
21 |
22 | def generate_perlin_noise_2d(shape, res):
23 | def f(t):
24 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
25 |
26 | delta = (res[0] / shape[0], res[1] / shape[1])
27 | d = (shape[0] // res[0], shape[1] // res[1])
28 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
29 | # Gradients
30 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1)
31 | gradients = np.dstack((np.cos(angles), np.sin(angles)))
32 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
33 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
34 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1)
35 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1)
36 | # Ramps
37 | n00 = np.sum(grid * g00, 2)
38 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
39 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
40 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
41 | # Interpolation
42 | t = f(grid)
43 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
44 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
45 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)
46 |
47 |
48 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
49 | delta = (res[0] / shape[0], res[1] / shape[1])
50 | d = (shape[0] // res[0], shape[1] // res[1])
51 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
52 |
53 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1)
54 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
55 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1)
56 |
57 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1)
58 | dot = lambda grad, shift: (
59 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
60 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
61 |
62 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
63 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
64 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
65 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
66 | t = fade(grid[:shape[0], :shape[1]])
67 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])
68 |
69 |
70 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
71 | delta = (res[0] / shape[0], res[1] / shape[1])
72 | d = (shape[0] // res[0], shape[1] // res[1])
73 |
74 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
75 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
76 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
77 |
78 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
79 | 0).repeat_interleave(
80 | d[1], 1)
81 | dot = lambda grad, shift: (
82 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
83 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
84 |
85 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
86 |
87 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
88 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
89 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
90 | t = fade(grid[:shape[0], :shape[1]])
91 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
92 |
93 |
94 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
95 | noise = torch.zeros(shape)
96 | frequency = 1
97 | amplitude = 1
98 | for _ in range(octaves):
99 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
100 | frequency *= 2
101 | amplitude *= persistence
102 | return noise
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:20.12-py3
2 |
3 | RUN apt-get update && apt-get upgrade -y && apt-get install -y vim && apt-get install -y git
4 | RUN pip install --upgrade pip
5 |
6 | WORKDIR /
7 |
8 | ARG UNAME
9 | ARG UID
10 | ARG GID
11 | RUN groupadd -g $GID -o $UNAME
12 | RUN useradd -m -u $UID -g $GID -o -s /bin/bash $UNAME
13 | USER $UNAME
14 |
15 |
16 |
--------------------------------------------------------------------------------
/focal_loss.py:
--------------------------------------------------------------------------------
1 | # https://github.com/mbsariyildiz/focal-loss.pytorch/blob/master/focalloss.py
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | class FocalLoss(nn.Module):
7 |
8 | def __init__(self, smooth=1e-5, gamma=0, alpha=None, size_average=True):
9 | super(FocalLoss, self).__init__()
10 | self.gamma = gamma
11 | self.alpha = alpha
12 | self.smooth = smooth
13 | if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
14 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
15 | self.size_average = size_average
16 |
17 | def forward(self, input, target):
18 | if input.dim()>2:
19 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W
20 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
21 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
22 | target = target.view(-1, 1)
23 |
24 | pt = input
25 | logpt = (pt + 1e-5).log()
26 |
27 | # add label smoothing
28 | num_class = input.shape[1]
29 | idx = target.cpu().long()
30 |
31 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
32 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
33 | if one_hot_key.device != input.device:
34 | one_hot_key = one_hot_key.to(input.device)
35 |
36 | if self.smooth:
37 | one_hot_key = torch.clamp(
38 | one_hot_key, self.smooth, 1.0 - self.smooth)
39 | logpt = logpt * one_hot_key
40 |
41 | if self.alpha is not None:
42 | if self.alpha.type() != input.data.type():
43 | self.alpha = self.alpha.type_as(input.data)
44 | at = self.alpha.gather(0, target.data.view(-1))
45 | logpt = logpt * at
46 |
47 | loss = (-1 * (1 - pt)**self.gamma * logpt).sum(1)
48 | if self.size_average: return loss.mean()
49 | else: return loss.sum()
--------------------------------------------------------------------------------
/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.handlers
3 |
4 |
5 | class FormatterNoInfo(logging.Formatter):
6 | def __init__(self, fmt='%(levelname)s: %(message)s'):
7 | logging.Formatter.__init__(self, fmt)
8 |
9 | def format(self, record):
10 | if record.levelno == logging.INFO:
11 | return str(record.getMessage())
12 | return logging.Formatter.format(self, record)
13 |
14 |
15 | def setup_default_logging(default_level=logging.INFO, log_path=''):
16 | console_handler = logging.StreamHandler()
17 | console_handler.setFormatter(FormatterNoInfo())
18 | logging.root.addHandler(console_handler)
19 | logging.root.setLevel(default_level)
20 | if log_path:
21 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3)
22 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s")
23 | file_handler.setFormatter(file_formatter)
24 | logging.root.addHandler(file_handler)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | import logging
3 | import os
4 | import torch
5 | import torch.nn as nn
6 | import argparse
7 |
8 | from omegaconf import OmegaConf
9 | from timm import create_model
10 | from data import create_dataset, create_dataloader
11 | from models import MemSeg, MemoryBank
12 | from focal_loss import FocalLoss
13 | from train import training
14 | from log import setup_default_logging
15 | from utils import torch_seed
16 | from scheduler import CosineAnnealingWarmupRestarts
17 |
18 |
19 | _logger = logging.getLogger('train')
20 |
21 |
22 | def run(cfg):
23 |
24 | # setting seed and device
25 | setup_default_logging()
26 | torch_seed(cfg.SEED)
27 |
28 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
29 | _logger.info('Device: {}'.format(device))
30 |
31 | # savedir
32 | cfg.EXP_NAME = cfg.EXP_NAME + f"-{cfg.DATASET.target}"
33 | savedir = os.path.join(cfg.RESULT.savedir, cfg.EXP_NAME)
34 | os.makedirs(savedir, exist_ok=True)
35 |
36 |
37 | # wandb
38 | if cfg.TRAIN.use_wandb:
39 | wandb.init(name=cfg.EXP_NAME, project='MemSeg', config=OmegaConf.to_container(cfg))
40 |
41 | # build datasets
42 | trainset = create_dataset(
43 | datadir = cfg.DATASET.datadir,
44 | target = cfg.DATASET.target,
45 | is_train = True,
46 | resize = cfg.DATASET.resize,
47 | imagesize = cfg.DATASET.imagesize,
48 | texture_source_dir = cfg.DATASET.texture_source_dir,
49 | structure_grid_size = cfg.DATASET.structure_grid_size,
50 | transparency_range = cfg.DATASET.transparency_range,
51 | perlin_scale = cfg.DATASET.perlin_scale,
52 | min_perlin_scale = cfg.DATASET.min_perlin_scale,
53 | perlin_noise_threshold = cfg.DATASET.perlin_noise_threshold,
54 | use_mask = cfg.DATASET.use_mask,
55 | bg_threshold = cfg.DATASET.bg_threshold,
56 | bg_reverse = cfg.DATASET.bg_reverse
57 | )
58 |
59 | memoryset = create_dataset(
60 | datadir = cfg.DATASET.datadir,
61 | target = cfg.DATASET.target,
62 | is_train = True,
63 | to_memory = True,
64 | resize = cfg.DATASET.resize,
65 | imagesize = cfg.DATASET.imagesize,
66 | )
67 |
68 | testset = create_dataset(
69 | datadir = cfg.DATASET.datadir,
70 | target = cfg.DATASET.target,
71 | is_train = False,
72 | resize = cfg.DATASET.resize,
73 | imagesize = cfg.DATASET.imagesize,
74 | )
75 |
76 | # build dataloader
77 | trainloader = create_dataloader(
78 | dataset = trainset,
79 | train = True,
80 | batch_size = cfg.DATALOADER.batch_size,
81 | num_workers = cfg.DATALOADER.num_workers
82 | )
83 |
84 | testloader = create_dataloader(
85 | dataset = testset,
86 | train = False,
87 | batch_size = cfg.DATALOADER.batch_size,
88 | num_workers = cfg.DATALOADER.num_workers
89 | )
90 |
91 |
92 | # build feature extractor
93 | feature_extractor = create_model(
94 | cfg.MODEL.feature_extractor_name,
95 | pretrained = True,
96 | features_only = True
97 | ).to(device)
98 | ## freeze weight of layer1,2,3
99 | for l in ['layer1','layer2','layer3']:
100 | for p in feature_extractor[l].parameters():
101 | p.requires_grad = False
102 |
103 | # build memory bank
104 | memory_bank = MemoryBank(
105 | normal_dataset = memoryset,
106 | nb_memory_sample = cfg.MEMORYBANK.nb_memory_sample,
107 | device = device
108 | )
109 | ## update normal samples and save
110 | memory_bank.update(feature_extractor=feature_extractor)
111 | torch.save(memory_bank, os.path.join(savedir, f'memory_bank.pt'))
112 | _logger.info('Update {} normal samples in memory bank'.format(cfg.MEMORYBANK.nb_memory_sample))
113 |
114 | # build MemSeg
115 | model = MemSeg(
116 | memory_bank = memory_bank,
117 | feature_extractor = feature_extractor
118 | ).to(device)
119 |
120 | # Set training
121 | l1_criterion = nn.L1Loss()
122 | f_criterion = FocalLoss(
123 | gamma = cfg.TRAIN.focal_gamma,
124 | alpha = cfg.TRAIN.focal_alpha
125 | )
126 |
127 | optimizer = torch.optim.AdamW(
128 | params = filter(lambda p: p.requires_grad, model.parameters()),
129 | lr = cfg.OPTIMIZER.lr,
130 | weight_decay = cfg.OPTIMIZER.weight_decay
131 | )
132 |
133 | if cfg['SCHEDULER']['use_scheduler']:
134 | scheduler = CosineAnnealingWarmupRestarts(
135 | optimizer,
136 | first_cycle_steps = cfg.TRAIN.num_training_steps,
137 | max_lr = cfg.OPTIMIZER.lr,
138 | min_lr = cfg.SCHEDULER.min_lr,
139 | warmup_steps = int(cfg.TRAIN.num_training_steps * cfg.SCHEDULER.warmup_ratio)
140 | )
141 | else:
142 | scheduler = None
143 |
144 | # Fitting model
145 | training(
146 | model = model,
147 | num_training_steps = cfg.TRAIN.num_training_steps,
148 | trainloader = trainloader,
149 | validloader = testloader,
150 | criterion = [l1_criterion, f_criterion],
151 | loss_weights = [cfg.TRAIN.l1_weight, cfg.TRAIN.focal_weight],
152 | optimizer = optimizer,
153 | scheduler = scheduler,
154 | log_interval = cfg.LOG.log_interval,
155 | eval_interval = cfg.LOG.eval_interval,
156 | savedir = savedir,
157 | device = device,
158 | use_wandb = cfg.TRAIN.use_wandb
159 | )
160 |
161 |
162 |
163 | if __name__=='__main__':
164 | args = OmegaConf.from_cli()
165 | # load default config
166 | cfg = OmegaConf.load(args.configs)
167 | del args['configs']
168 |
169 | # merge config with new keys
170 | cfg = OmegaConf.merge(cfg, args)
171 |
172 | # target cfg
173 | target_cfg = OmegaConf.load(cfg.DATASET.anomaly_mask_info)
174 | cfg.DATASET = OmegaConf.merge(cfg.DATASET, target_cfg[cfg.DATASET.target])
175 |
176 | print(OmegaConf.to_yaml(cfg))
177 |
178 | run(cfg)
179 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .pro_curve_util import compute_pro
2 | from .generic_util import trapezoid
--------------------------------------------------------------------------------
/metrics/generic_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utility functions for:
3 | - parsing user arguments.
4 | - computing the area under a curve.
5 | - generating a toy dataset to test the evaluation script.
6 | """
7 | from bisect import bisect
8 | import numpy as np
9 |
10 |
11 | def trapezoid(x, y, x_max=None):
12 | """
13 | This function calculates the definit integral of a curve given by
14 | x- and corresponding y-values. In contrast to, e.g., 'numpy.trapz()',
15 | this function allows to define an upper bound to the integration range by
16 | setting a value x_max.
17 |
18 | Points that do not have a finite x or y value will be ignored with a
19 | warning.
20 |
21 | Args:
22 | x: Samples from the domain of the function to integrate
23 | Need to be sorted in ascending order. May contain the same value
24 | multiple times. In that case, the order of the corresponding
25 | y values will affect the integration with the trapezoidal rule.
26 | y: Values of the function corresponding to x values.
27 | x_max: Upper limit of the integration. The y value at max_x will be
28 | determined by interpolating between its neighbors. Must not lie
29 | outside of the range of x.
30 |
31 | Returns:
32 | Area under the curve.
33 | """
34 |
35 | x = np.asarray(x)
36 | y = np.asarray(y)
37 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y))
38 | if not finite_mask.all():
39 | print("WARNING: Not all x and y values passed to trapezoid(...)"
40 | " are finite. Will continue with only the finite values.")
41 | x = x[finite_mask]
42 | y = y[finite_mask]
43 |
44 | # Introduce a correction term if max_x is not an element of x.
45 | correction = 0.
46 | if x_max is not None:
47 | if x_max not in x:
48 | # Get the insertion index that would keep x sorted after
49 | # np.insert(x, ins, x_max).
50 | ins = bisect(x, x_max)
51 | # x_max must be between the minimum and the maximum, so the
52 | # insertion_point cannot be zero or len(x).
53 | assert 0 < ins < len(x)
54 |
55 | # Calculate the correction term which is the integral between
56 | # the last x[ins-1] and x_max. Since we do not know the exact value
57 | # of y at x_max, we interpolate between y[ins] and y[ins-1].
58 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) *
59 | (x_max - x[ins - 1]) /
60 | (x[ins] - x[ins - 1]))
61 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1])
62 |
63 | # Cut off at x_max.
64 | mask = x <= x_max
65 | x = x[mask]
66 | y = y[mask]
67 |
68 | # Return area under the curve using the trapezoidal rule.
69 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction
70 |
71 | def generate_toy_dataset(num_images, image_width, image_height, gt_size):
72 | """Generate a toy dataset to test the evaluation script.
73 |
74 | Args:
75 | num_images: Number of images that the toy dataset contains.
76 | image_width: Width of the dataset images in pixels.
77 | image_height: Height of the dataset images in pixels.
78 | gt_size: Size of rectangular ground truth regions that are
79 | artificially generated on the dataset images.
80 |
81 | Returns:
82 | anomaly_maps: List of numpy arrays that contain random anomaly maps.
83 | ground_truth_map: Corresponding list of numpy arrays that specify a
84 | rectangular ground truth region.
85 | """
86 | # Fix a random seed for reproducibility.
87 | np.random.seed(1338)
88 |
89 | # Create synthetic evaluation data with random anomaly scores and
90 | # simple ground truth maps.
91 | anomaly_maps = []
92 | ground_truth_maps = []
93 | for _ in range(num_images):
94 | # Sample a random anomaly map.
95 | anomaly_map = np.random.random((image_height, image_width))
96 |
97 | # Construct a fixed ground truth map.
98 | ground_truth_map = np.zeros((image_height, image_width))
99 | ground_truth_map[0:gt_size, 0:gt_size] = 1
100 |
101 | anomaly_maps.append(anomaly_map)
102 | ground_truth_maps.append(ground_truth_map)
103 |
104 | return anomaly_maps, ground_truth_maps
--------------------------------------------------------------------------------
/metrics/pro_curve_util.py:
--------------------------------------------------------------------------------
1 | """Utility function that computes a PRO curve, given pairs of anomaly and ground
2 | truth maps.
3 |
4 | The PRO curve can also be integrated up to a constant integration limit.
5 | """
6 | import numpy as np
7 | from scipy.ndimage.measurements import label
8 |
9 |
10 | def compute_pro(anomaly_maps, ground_truth_maps):
11 | """Compute the PRO curve for a set of anomaly maps with corresponding ground
12 | truth maps.
13 |
14 | Args:
15 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a
16 | real-valued anomaly score at each pixel.
17 |
18 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that
19 | contain binary-valued ground truth labels for each pixel.
20 | 0 indicates that a pixel is anomaly-free.
21 | 1 indicates that a pixel contains an anomaly.
22 |
23 | Returns:
24 | fprs: numpy array of false positive rates.
25 | pros: numpy array of corresponding PRO values.
26 | """
27 |
28 | print("Compute PRO curve...")
29 |
30 | # Structuring element for computing connected components.
31 | structure = np.ones((3, 3), dtype=int)
32 |
33 | num_ok_pixels = 0
34 | num_gt_regions = 0
35 |
36 | shape = (len(anomaly_maps),
37 | anomaly_maps[0].shape[0],
38 | anomaly_maps[0].shape[1])
39 | fp_changes = np.zeros(shape, dtype=np.uint32)
40 | assert shape[0] * shape[1] * shape[2] < np.iinfo(fp_changes.dtype).max, \
41 | 'Potential overflow when using np.cumsum(), consider using np.uint64.'
42 |
43 | pro_changes = np.zeros(shape, dtype=np.float64)
44 |
45 | for gt_ind, gt_map in enumerate(ground_truth_maps):
46 |
47 | # Compute the connected components in the ground truth map.
48 | labeled, n_components = label(gt_map, structure)
49 | num_gt_regions += n_components
50 |
51 | # Compute the mask that gives us all ok pixels.
52 | ok_mask = labeled == 0
53 | num_ok_pixels_in_map = np.sum(ok_mask)
54 | num_ok_pixels += num_ok_pixels_in_map
55 |
56 | # Compute by how much the FPR changes when each anomaly score is
57 | # added to the set of positives.
58 | # fp_change needs to be normalized later when we know the final value
59 | # of num_ok_pixels -> right now it is only the change in the number of
60 | # false positives
61 | fp_change = np.zeros_like(gt_map, dtype=fp_changes.dtype)
62 | fp_change[ok_mask] = 1
63 |
64 | # Compute by how much the PRO changes when each anomaly score is
65 | # added to the set of positives.
66 | # pro_change needs to be normalized later when we know the final value
67 | # of num_gt_regions.
68 | pro_change = np.zeros_like(gt_map, dtype=np.float64)
69 | for k in range(n_components):
70 | region_mask = labeled == (k + 1)
71 | region_size = np.sum(region_mask)
72 | pro_change[region_mask] = 1. / region_size
73 |
74 | fp_changes[gt_ind, :, :] = fp_change
75 | pro_changes[gt_ind, :, :] = pro_change
76 |
77 | # Flatten the numpy arrays before sorting.
78 | anomaly_scores_flat = np.array(anomaly_maps).ravel()
79 | fp_changes_flat = fp_changes.ravel()
80 | pro_changes_flat = pro_changes.ravel()
81 |
82 | # Sort all anomaly scores.
83 | print(f"Sort {len(anomaly_scores_flat)} anomaly scores...")
84 | sort_idxs = np.argsort(anomaly_scores_flat).astype(np.uint32)[::-1]
85 |
86 | # Info: np.take(a, ind, out=a) followed by b=a instead of
87 | # b=a[ind] showed to be more memory efficient.
88 | np.take(anomaly_scores_flat, sort_idxs, out=anomaly_scores_flat)
89 | anomaly_scores_sorted = anomaly_scores_flat
90 | np.take(fp_changes_flat, sort_idxs, out=fp_changes_flat)
91 | fp_changes_sorted = fp_changes_flat
92 | np.take(pro_changes_flat, sort_idxs, out=pro_changes_flat)
93 | pro_changes_sorted = pro_changes_flat
94 |
95 | del sort_idxs
96 |
97 | # Get the (FPR, PRO) curve values.
98 | np.cumsum(fp_changes_sorted, out=fp_changes_sorted)
99 | fp_changes_sorted = fp_changes_sorted.astype(np.float32, copy=False)
100 | np.divide(fp_changes_sorted, num_ok_pixels, out=fp_changes_sorted)
101 | fprs = fp_changes_sorted
102 |
103 | np.cumsum(pro_changes_sorted, out=pro_changes_sorted)
104 | np.divide(pro_changes_sorted, num_gt_regions, out=pro_changes_sorted)
105 | pros = pro_changes_sorted
106 |
107 | # Merge (FPR, PRO) points that occur together at the same threshold.
108 | # For those points, only the final (FPR, PRO) point should be kept.
109 | # That is because that point is the one that takes all changes
110 | # to the FPR and the PRO at the respective threshold into account.
111 | # -> keep_mask is True if the subsequent score is different from the
112 | # score at the respective position.
113 | # anomaly_scores_sorted = [7, 4, 4, 4, 3, 1, 1]
114 | # -> keep_mask = [T, F, F, T, T, F]
115 | keep_mask = np.append(np.diff(anomaly_scores_sorted) != 0, np.True_)
116 | del anomaly_scores_sorted
117 |
118 | fprs = fprs[keep_mask]
119 | pros = pros[keep_mask]
120 | del keep_mask
121 |
122 | # To mitigate the adding up of numerical errors during the np.cumsum calls,
123 | # make sure that the curve ends at (1, 1) and does not contain values > 1.
124 | np.clip(fprs, a_min=None, a_max=1., out=fprs)
125 | np.clip(pros, a_min=None, a_max=1., out=pros)
126 |
127 | # Make the fprs and pros start at 0 and end at 1.
128 | zero = np.array([0.])
129 | one = np.array([1.])
130 |
131 | return np.concatenate((zero, fprs, one)), np.concatenate((zero, pros, one))
132 |
133 |
134 | def main():
135 | """
136 | Compute the area under the PRO curve for a toy dataset and an algorithm
137 | that randomly assigns anomaly scores to each pixel. The integration
138 | limit can be specified.
139 | """
140 |
141 | from generic_util import trapezoid, generate_toy_dataset
142 |
143 | integration_limit = 0.3
144 |
145 | # Generate a toy dataset.
146 | anomaly_maps, ground_truth_maps = generate_toy_dataset(
147 | num_images=200, image_width=500, image_height=300, gt_size=10)
148 |
149 | # Compute the PRO curve for this dataset.
150 | all_fprs, all_pros = compute_pro(
151 | anomaly_maps=anomaly_maps,
152 | ground_truth_maps=ground_truth_maps)
153 |
154 | au_pro = trapezoid(all_fprs, all_pros, x_max=integration_limit)
155 | au_pro /= integration_limit
156 | print(f"AU-PRO (FPR limit: {integration_limit}): {au_pro}")
157 |
158 |
159 | if __name__ == "__main__":
160 | main()
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .coordatt import *
2 | from .decoder import *
3 | from .memseg import *
4 | from .msff import *
5 | from .memory_module import *
--------------------------------------------------------------------------------
/models/coordatt.py:
--------------------------------------------------------------------------------
1 | # https://github.com/houqb/CoordAttention/blob/main/coordatt.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import math
6 | import torch.nn.functional as F
7 |
8 | class h_sigmoid(nn.Module):
9 | def __init__(self, inplace=True):
10 | super(h_sigmoid, self).__init__()
11 | self.relu = nn.ReLU6(inplace=inplace)
12 |
13 | def forward(self, x):
14 | return self.relu(x + 3) / 6
15 |
16 | class h_swish(nn.Module):
17 | def __init__(self, inplace=True):
18 | super(h_swish, self).__init__()
19 | self.sigmoid = h_sigmoid(inplace=inplace)
20 |
21 | def forward(self, x):
22 | return x * self.sigmoid(x)
23 |
24 |
25 | class CoordAtt(nn.Module):
26 | def __init__(self, inp, oup, reduction=32):
27 | super(CoordAtt, self).__init__()
28 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
29 | self.pool_w = nn.AdaptiveAvgPool2d((1, None))
30 |
31 | mip = max(8, inp // reduction)
32 |
33 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
34 | self.bn1 = nn.BatchNorm2d(mip)
35 | self.act = h_swish()
36 |
37 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
38 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
39 |
40 |
41 | def forward(self, x):
42 | identity = x
43 |
44 | n,c,h,w = x.size()
45 | x_h = self.pool_h(x)
46 | x_w = self.pool_w(x).permute(0, 1, 3, 2)
47 |
48 | y = torch.cat([x_h, x_w], dim=2)
49 | y = self.conv1(y)
50 | y = self.bn1(y)
51 | y = self.act(y)
52 |
53 | x_h, x_w = torch.split(y, [h, w], dim=2)
54 | x_w = x_w.permute(0, 1, 3, 2)
55 |
56 | a_h = self.conv_h(x_h).sigmoid()
57 | a_w = self.conv_w(x_w).sigmoid()
58 |
59 | out = identity * a_w * a_h
60 |
61 | return out
--------------------------------------------------------------------------------
/models/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class UpConvBlock(nn.Module):
6 | def __init__(self, in_channel, out_channel):
7 | super(UpConvBlock, self).__init__()
8 | self.blk = nn.Sequential(
9 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
10 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
11 | nn.BatchNorm2d(out_channel),
12 | nn.ReLU()
13 | )
14 |
15 | def forward(self, x):
16 | return self.blk(x)
17 |
18 |
19 | class Decoder(nn.Module):
20 | def __init__(self):
21 | super(Decoder, self).__init__()
22 |
23 | self.conv = nn.Conv2d(64, 48, kernel_size=3, stride=1, padding=1)
24 |
25 | self.upconv3 = UpConvBlock(512, 256)
26 | self.upconv2 = UpConvBlock(512, 128)
27 | self.upconv1 = UpConvBlock(256, 64)
28 | self.upconv0 = UpConvBlock(128, 48)
29 | self.upconv2mask = UpConvBlock(96, 48)
30 |
31 | self.final_conv = nn.Conv2d(48, 2, kernel_size=3, stride=1, padding=1)
32 |
33 | def forward(self, encoder_output, concat_features):
34 | # concat_features = [level0, level1, level2, level3]
35 | f0, f1, f2, f3 = concat_features
36 |
37 | # 512 x 8 x 8 -> 512 x 16 x 16
38 | x_up3 = self.upconv3(encoder_output)
39 | x_up3 = torch.cat([x_up3, f3], dim=1)
40 |
41 | # 512 x 16 x 16 -> 256 x 32 x 32
42 | x_up2 = self.upconv2(x_up3)
43 | x_up2 = torch.cat([x_up2, f2], dim=1)
44 |
45 | # 256 x 32 x 32 -> 128 x 64 x 64
46 | x_up1 = self.upconv1(x_up2)
47 | x_up1 = torch.cat([x_up1, f1], dim=1)
48 |
49 | # 128 x 64 x 64 -> 96 x 128 x 128
50 | x_up0 = self.upconv0(x_up1)
51 | f0 = self.conv(f0)
52 | x_up2mask = torch.cat([x_up0, f0], dim=1)
53 |
54 | # 96 x 128 x 128 -> 48 x 256 x 256
55 | x_mask = self.upconv2mask(x_up2mask)
56 |
57 | # 48 x 256 x 256 -> 1 x 256 x 256
58 | x_mask = self.final_conv(x_mask)
59 |
60 | return x_mask
--------------------------------------------------------------------------------
/models/memory_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import numpy as np
5 | from typing import List
6 |
7 |
8 | class MemoryBank:
9 | def __init__(self, normal_dataset, nb_memory_sample: int = 30, device='cpu'):
10 | self.device = device
11 |
12 | # memory bank
13 | self.memory_information = {}
14 |
15 | # normal dataset
16 | self.normal_dataset = normal_dataset
17 |
18 | # the number of samples saved in memory bank
19 | self.nb_memory_sample = nb_memory_sample
20 |
21 |
22 | def update(self, feature_extractor):
23 | feature_extractor.eval()
24 |
25 | # define sample index
26 | samples_idx = np.arange(len(self.normal_dataset))
27 | np.random.shuffle(samples_idx)
28 |
29 | # extract features and save features into memory bank
30 | with torch.no_grad():
31 | for i in range(self.nb_memory_sample):
32 | # select image
33 | input_normal, _, _ = self.normal_dataset[samples_idx[i]]
34 | input_normal = input_normal.to(self.device)
35 |
36 | # extract features
37 | features = feature_extractor(input_normal.unsqueeze(0))
38 |
39 | # save features into memoery bank
40 | for i, features_l in enumerate(features[1:-1]):
41 | if f'level{i}' not in self.memory_information.keys():
42 | self.memory_information[f'level{i}'] = features_l
43 | else:
44 | self.memory_information[f'level{i}'] = torch.cat([self.memory_information[f'level{i}'], features_l], dim=0)
45 |
46 |
47 | def _calc_diff(self, features: List[torch.Tensor]) -> torch.Tensor:
48 | # batch size X the number of samples saved in memory
49 | diff_bank = torch.zeros(features[0].size(0), self.nb_memory_sample).to(self.device)
50 |
51 | # level
52 | for l, level in enumerate(self.memory_information.keys()):
53 | # batch
54 | for b_idx, features_b in enumerate(features[l]):
55 | # calculate l2 loss
56 | diff = F.mse_loss(
57 | input = torch.repeat_interleave(features_b.unsqueeze(0), repeats=self.nb_memory_sample, dim=0),
58 | target = self.memory_information[level],
59 | reduction ='none'
60 | ).mean(dim=[1,2,3])
61 |
62 | # sum loss
63 | diff_bank[b_idx] += diff
64 |
65 | return diff_bank
66 |
67 |
68 | def select(self, features: List[torch.Tensor]) -> torch.Tensor:
69 | # calculate difference between features and normal features of memory bank
70 | diff_bank = self._calc_diff(features=features)
71 |
72 | # concatenate features with minimum difference features of memory bank
73 | for l, level in enumerate(self.memory_information.keys()):
74 |
75 | selected_features = torch.index_select(self.memory_information[level], dim=0, index=diff_bank.argmin(dim=1))
76 | diff_features = F.mse_loss(selected_features, features[l], reduction='none')
77 | features[l] = torch.cat([features[l], diff_features], dim=1)
78 |
79 | return features
80 |
--------------------------------------------------------------------------------
/models/memseg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .decoder import Decoder
3 | from .msff import MSFF
4 |
5 | class MemSeg(nn.Module):
6 | def __init__(self, memory_bank, feature_extractor):
7 | super(MemSeg, self).__init__()
8 |
9 | self.memory_bank = memory_bank
10 | self.feature_extractor = feature_extractor
11 | self.msff = MSFF()
12 | self.decoder = Decoder()
13 |
14 | def forward(self, inputs):
15 | # extract features
16 | features = self.feature_extractor(inputs)
17 | f_in = features[0]
18 | f_out = features[-1]
19 | f_ii = features[1:-1]
20 |
21 | # extract concatenated information(CI)
22 | concat_features = self.memory_bank.select(features = f_ii)
23 |
24 | # Multi-scale Feature Fusion(MSFF) Module
25 | msff_outputs = self.msff(features = concat_features)
26 |
27 | # decoder
28 | predicted_mask = self.decoder(
29 | encoder_output = f_out,
30 | concat_features = [f_in] + msff_outputs
31 | )
32 |
33 | return predicted_mask
34 |
--------------------------------------------------------------------------------
/models/msff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import torch.nn.functional as F
5 | from .coordatt import CoordAtt
6 |
7 | class MSFFBlock(nn.Module):
8 | def __init__(self, in_channel):
9 | super(MSFFBlock, self).__init__()
10 | self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1)
11 | self.attn = CoordAtt(in_channel, in_channel)
12 | self.conv2 = nn.Sequential(
13 | nn.Conv2d(in_channel, in_channel // 2, kernel_size=3, stride=1, padding=1),
14 | nn.Conv2d(in_channel // 2, in_channel // 2, kernel_size=3, stride=1, padding=1)
15 | )
16 |
17 | def forward(self, x):
18 | x_conv = self.conv1(x)
19 | x_att = self.attn(x)
20 |
21 | x = x_conv * x_att
22 | x = self.conv2(x)
23 | return x
24 |
25 |
26 | class MSFF(nn.Module):
27 | def __init__(self):
28 | super(MSFF, self).__init__()
29 | self.blk1 = MSFFBlock(128)
30 | self.blk2 = MSFFBlock(256)
31 | self.blk3 = MSFFBlock(512)
32 |
33 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
34 |
35 | self.upconv32 = nn.Sequential(
36 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
37 | nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
38 | )
39 | self.upconv21 = nn.Sequential(
40 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
41 | nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
42 | )
43 |
44 | def forward(self, features):
45 | # features = [level1, level2, level3]
46 | f1, f2, f3 = features
47 |
48 | # MSFF Module
49 | f1_k = self.blk1(f1)
50 | f2_k = self.blk2(f2)
51 | f3_k = self.blk3(f3)
52 |
53 | f2_f = f2_k + self.upconv32(f3_k)
54 | f1_f = f1_k + self.upconv21(f2_f)
55 |
56 | # spatial attention
57 |
58 | # mask
59 | m3 = f3[:,256:,...].mean(dim=1, keepdim=True)
60 | m2 = f2[:,128:,...].mean(dim=1, keepdim=True) * self.upsample(m3)
61 | m1 = f1[:,64:,...].mean(dim=1, keepdim=True) * self.upsample(m2)
62 |
63 | f1_out = f1_f * m1
64 | f2_out = f2_f * m2
65 | f3_out = f3_k * m3
66 |
67 | return [f1_out, f2_out, f3_out]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.5.0
2 | timm==0.5.4
3 | wandb==0.12.17
4 | omegaconf
5 | imgaug==0.4.0
6 | ipywidgets==8.1.1
7 | voila==0.5.5
--------------------------------------------------------------------------------
/saved_model/MemSeg-bottle/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 4599,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.9878949119786599,
5 | "eval_AUPRO-pixel": 0.9835623941012489
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-bottle/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.9810207179490311,
5 | "eval_AUPRO-pixel": 0.9756835825062008
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-cable/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 2999,
3 | "eval_AUROC-image": 0.9250374812593704,
4 | "eval_AUROC-pixel": 0.8230074651568189,
5 | "eval_AUPRO-pixel": 0.8731255778978297
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-cable/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.8967391304347826,
4 | "eval_AUROC-pixel": 0.7628859940117106,
5 | "eval_AUPRO-pixel": 0.8339052111761991
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-capsule/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 4499,
3 | "eval_AUROC-image": 0.9541284403669725,
4 | "eval_AUROC-pixel": 0.9842618218962903,
5 | "eval_AUPRO-pixel": 0.9772760890870018
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-capsule/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9509373753490227,
4 | "eval_AUROC-pixel": 0.9818320950550422,
5 | "eval_AUPRO-pixel": 0.9762463379799712
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-carpet/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 999,
3 | "eval_AUROC-image": 0.9911717495987158,
4 | "eval_AUROC-pixel": 0.9754415202769691,
5 | "eval_AUPRO-pixel": 0.9702277436742398
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-carpet/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9225521669341894,
4 | "eval_AUROC-pixel": 0.9521577575948166,
5 | "eval_AUPRO-pixel": 0.9304082803819657
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-grid/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 2299,
3 | "eval_AUROC-image": 0.9983291562238931,
4 | "eval_AUROC-pixel": 0.9836655532727973,
5 | "eval_AUPRO-pixel": 0.9853033944553282
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-grid/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9949874686716792,
4 | "eval_AUROC-pixel": 0.9839286697331651,
5 | "eval_AUPRO-pixel": 0.9811414483214054
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-hazelnut/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 2799,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.977846774892756,
5 | "eval_AUPRO-pixel": 0.9899949505024724
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-hazelnut/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.9530533678779027,
5 | "eval_AUPRO-pixel": 0.9864527018411502
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-leather/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 2499,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.9882677082218918,
5 | "eval_AUPRO-pixel": 0.9908964012921226
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-leather/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.9806788880434785,
5 | "eval_AUPRO-pixel": 0.9771980890520273
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-metal_nut/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 3599,
3 | "eval_AUROC-image": 0.9946236559139785,
4 | "eval_AUROC-pixel": 0.8848478801101993,
5 | "eval_AUPRO-pixel": 0.9500360784835721
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-metal_nut/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9912023460410557,
4 | "eval_AUROC-pixel": 0.8702097697878765,
5 | "eval_AUPRO-pixel": 0.9522523369617932
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-pill/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 3899,
3 | "eval_AUROC-image": 0.9705400981996726,
4 | "eval_AUROC-pixel": 0.9829100041961665,
5 | "eval_AUPRO-pixel": 0.9796020986582216
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-pill/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9735406437534097,
4 | "eval_AUROC-pixel": 0.9837363970095163,
5 | "eval_AUPRO-pixel": 0.9726783368058662
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-screw/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 2099,
3 | "eval_AUROC-image": 0.9485550317688051,
4 | "eval_AUROC-pixel": 0.9508182505179927,
5 | "eval_AUPRO-pixel": 0.9400450867663136
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-screw/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9321582291453167,
4 | "eval_AUROC-pixel": 0.9421926937096599,
5 | "eval_AUPRO-pixel": 0.9239317290959074
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-tile/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 2599,
3 | "eval_AUROC-image": 0.9985569985569985,
4 | "eval_AUROC-pixel": 0.9938479608847095,
5 | "eval_AUPRO-pixel": 0.9880707645224925
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-tile/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9967532467532467,
4 | "eval_AUROC-pixel": 0.9954440823852293,
5 | "eval_AUPRO-pixel": 0.9881914197355043
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-toothbrush/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 3499,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.9928249709185054,
5 | "eval_AUPRO-pixel": 0.985574884437306
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-toothbrush/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.9930413807016041,
5 | "eval_AUPRO-pixel": 0.9831682986954344
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-transistor/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 4199,
3 | "eval_AUROC-image": 0.965,
4 | "eval_AUROC-pixel": 0.7629001885386938,
5 | "eval_AUPRO-pixel": 0.8606269093140095
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-transistor/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9566666666666667,
4 | "eval_AUROC-pixel": 0.7424347738574396,
5 | "eval_AUPRO-pixel": 0.8319123925254669
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-wood/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 3799,
3 | "eval_AUROC-image": 1.0,
4 | "eval_AUROC-pixel": 0.9753957059774258,
5 | "eval_AUPRO-pixel": 0.9762236672467449
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-wood/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.999122807017544,
4 | "eval_AUROC-pixel": 0.972050242335017,
5 | "eval_AUPRO-pixel": 0.9735705883386424
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-zipper/best_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "best_step": 4899,
3 | "eval_AUROC-image": 0.9994747899159664,
4 | "eval_AUROC-pixel": 0.9794035590859894,
5 | "eval_AUPRO-pixel": 0.9725967686762359
6 | }
--------------------------------------------------------------------------------
/saved_model/MemSeg-zipper/latest_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "latest_step": 5000,
3 | "eval_AUROC-image": 0.9994747899159664,
4 | "eval_AUROC-pixel": 0.9773868547463116,
5 | "eval_AUPRO-pixel": 0.9692512286487627
6 | }
--------------------------------------------------------------------------------
/scheduler.py:
--------------------------------------------------------------------------------
1 | # https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/blob/master/cosine_annealing_warmup/scheduler.py
2 |
3 | import math
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 | class CosineAnnealingWarmupRestarts(_LRScheduler):
8 | """
9 | optimizer (Optimizer): Wrapped optimizer.
10 | first_cycle_steps (int): First cycle step size.
11 | cycle_mult(float): Cycle steps magnification. Default: -1.
12 | max_lr(float): First cycle's max learning rate. Default: 0.1.
13 | min_lr(float): Min learning rate. Default: 0.001.
14 | warmup_steps(int): Linear warmup step size. Default: 0.
15 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
16 | last_epoch (int): The index of last epoch. Default: -1.
17 | """
18 |
19 | def __init__(self,
20 | optimizer : torch.optim.Optimizer,
21 | first_cycle_steps : int,
22 | cycle_mult : float = 1.,
23 | max_lr : float = 0.1,
24 | min_lr : float = 0.001,
25 | warmup_steps : int = 0,
26 | gamma : float = 1.,
27 | last_epoch : int = -1
28 | ):
29 | assert warmup_steps < first_cycle_steps
30 |
31 | self.first_cycle_steps = first_cycle_steps # first cycle step size
32 | self.cycle_mult = cycle_mult # cycle steps magnification
33 | self.base_max_lr = max_lr # first max learning rate
34 | self.max_lr = max_lr # max learning rate in the current cycle
35 | self.min_lr = min_lr # min learning rate
36 | self.warmup_steps = warmup_steps # warmup step size
37 | self.gamma = gamma # decrease rate of max learning rate by cycle
38 |
39 | self.cur_cycle_steps = first_cycle_steps # first cycle step size
40 | self.cycle = 0 # cycle count
41 | self.step_in_cycle = last_epoch # step size of the current cycle
42 |
43 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
44 |
45 | # set learning rate min_lr
46 | self.init_lr()
47 |
48 | def init_lr(self):
49 | self.base_lrs = []
50 | for param_group in self.optimizer.param_groups:
51 | param_group['lr'] = self.min_lr
52 | self.base_lrs.append(self.min_lr)
53 |
54 | def get_lr(self):
55 | if self.step_in_cycle == -1:
56 | return self.base_lrs
57 | elif self.step_in_cycle < self.warmup_steps:
58 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
59 | else:
60 | return [base_lr + (self.max_lr - base_lr) \
61 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
62 | / (self.cur_cycle_steps - self.warmup_steps))) / 2
63 | for base_lr in self.base_lrs]
64 |
65 | def step(self, epoch=None):
66 | if epoch is None:
67 | epoch = self.last_epoch + 1
68 | self.step_in_cycle = self.step_in_cycle + 1
69 | if self.step_in_cycle >= self.cur_cycle_steps:
70 | self.cycle += 1
71 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
72 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
73 | else:
74 | if epoch >= self.first_cycle_steps:
75 | if self.cycle_mult == 1.:
76 | self.step_in_cycle = epoch % self.first_cycle_steps
77 | self.cycle = epoch // self.first_cycle_steps
78 | else:
79 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
80 | self.cycle = n
81 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
82 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
83 | else:
84 | self.cur_cycle_steps = self.first_cycle_steps
85 | self.step_in_cycle = epoch
86 |
87 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
88 | self.last_epoch = math.floor(epoch)
89 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
90 | param_group['lr'] = lr
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import json
3 | import os
4 | import wandb
5 | import logging
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import numpy as np
10 | from typing import List
11 | from sklearn.metrics import roc_auc_score
12 | from metrics import compute_pro, trapezoid
13 |
14 | _logger = logging.getLogger('train')
15 |
16 | class AverageMeter:
17 | """Computes and stores the average and current value"""
18 | def __init__(self):
19 | self.reset()
20 |
21 | def reset(self):
22 | self.val = 0
23 | self.avg = 0
24 | self.sum = 0
25 | self.count = 0
26 |
27 | def update(self, val, n=1):
28 | self.val = val
29 | self.sum += val * n
30 | self.count += n
31 | self.avg = self.sum / self.count
32 |
33 |
34 |
35 | def training(model, trainloader, validloader, criterion, optimizer, scheduler, num_training_steps: int = 1000, loss_weights: List[float] = [0.6, 0.4],
36 | log_interval: int = 1, eval_interval: int = 1, savedir: str = None, use_wandb: bool = False, device: str ='cpu') -> dict:
37 |
38 | batch_time_m = AverageMeter()
39 | data_time_m = AverageMeter()
40 | losses_m = AverageMeter()
41 | l1_losses_m = AverageMeter()
42 | focal_losses_m = AverageMeter()
43 |
44 | # criterion
45 | l1_criterion, focal_criterion = criterion
46 | l1_weight, focal_weight = loss_weights
47 |
48 | # set train mode
49 | model.train()
50 |
51 | # set optimizer
52 | optimizer.zero_grad()
53 |
54 | # training
55 | best_score = 0
56 | step = 0
57 | train_mode = True
58 | while train_mode:
59 |
60 | end = time.time()
61 | for inputs, masks, targets in trainloader:
62 | # batch
63 | inputs, masks, targets = inputs.to(device), masks.to(device), targets.to(device)
64 |
65 | data_time_m.update(time.time() - end)
66 |
67 | # predict
68 | outputs = model(inputs)
69 | outputs = F.softmax(outputs, dim=1)
70 | l1_loss = l1_criterion(outputs[:,1,:], masks)
71 | focal_loss = focal_criterion(outputs, masks)
72 | loss = (l1_weight * l1_loss) + (focal_weight * focal_loss)
73 |
74 | loss.backward()
75 |
76 | # update weight
77 | optimizer.step()
78 | optimizer.zero_grad()
79 |
80 | # log loss
81 | l1_losses_m.update(l1_loss.item())
82 | focal_losses_m.update(focal_loss.item())
83 | losses_m.update(loss.item())
84 |
85 | batch_time_m.update(time.time() - end)
86 |
87 | # wandb
88 | if use_wandb:
89 | wandb.log({
90 | 'lr':optimizer.param_groups[0]['lr'],
91 | 'train_focal_loss':focal_losses_m.val,
92 | 'train_l1_loss':l1_losses_m.val,
93 | 'train_loss':losses_m.val
94 | },
95 | step=step)
96 |
97 | if (step+1) % log_interval == 0 or step == 0:
98 | _logger.info('TRAIN [{:>4d}/{}] '
99 | 'Loss: {loss.val:>6.4f} ({loss.avg:>6.4f}) '
100 | 'L1 Loss: {l1_loss.val:>6.4f} ({l1_loss.avg:>6.4f}) '
101 | 'Focal Loss: {focal_loss.val:>6.4f} ({focal_loss.avg:>6.4f}) '
102 | 'LR: {lr:.3e} '
103 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
104 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
105 | step+1, num_training_steps,
106 | loss = losses_m,
107 | l1_loss = l1_losses_m,
108 | focal_loss = focal_losses_m,
109 | lr = optimizer.param_groups[0]['lr'],
110 | batch_time = batch_time_m,
111 | rate = inputs.size(0) / batch_time_m.val,
112 | rate_avg = inputs.size(0) / batch_time_m.avg,
113 | data_time = data_time_m))
114 |
115 |
116 | if ((step+1) % eval_interval == 0 and step != 0) or (step+1) == num_training_steps:
117 | eval_metrics = evaluate(
118 | model = model,
119 | dataloader = validloader,
120 | device = device
121 | )
122 | model.train()
123 |
124 | eval_log = dict([(f'eval_{k}', v) for k, v in eval_metrics.items()])
125 |
126 | # wandb
127 | if use_wandb:
128 | wandb.log(eval_log, step=step)
129 |
130 | # checkpoint
131 | if best_score < np.mean(list(eval_metrics.values())):
132 | # save best score
133 | state = {'best_step':step}
134 | state.update(eval_log)
135 | json.dump(state, open(os.path.join(savedir, 'best_score.json'),'w'), indent='\t')
136 |
137 | # save best model
138 | torch.save(model.state_dict(), os.path.join(savedir, f'best_model.pt'))
139 |
140 | _logger.info('Best Score {0:.3%} to {1:.3%}'.format(best_score, np.mean(list(eval_metrics.values()))))
141 |
142 | best_score = np.mean(list(eval_metrics.values()))
143 |
144 | # scheduler
145 | if scheduler:
146 | scheduler.step()
147 |
148 | end = time.time()
149 |
150 | step += 1
151 |
152 | if step == num_training_steps:
153 | train_mode = False
154 | break
155 |
156 | # print best score and step
157 | _logger.info('Best Metric: {0:.3%} (step {1:})'.format(best_score, state['best_step']))
158 |
159 | # save latest model
160 | torch.save(model.state_dict(), os.path.join(savedir, f'latest_model.pt'))
161 |
162 | # save latest score
163 | state = {'latest_step':step}
164 | state.update(eval_log)
165 | json.dump(state, open(os.path.join(savedir, 'latest_score.json'),'w'), indent='\t')
166 |
167 |
168 |
169 |
170 | def evaluate(model, dataloader, device: str = 'cpu'):
171 | # targets and outputs
172 | image_targets = []
173 | image_masks = []
174 | anomaly_score = []
175 | anomaly_map = []
176 |
177 | model.eval()
178 | with torch.no_grad():
179 | for idx, (inputs, masks, targets) in enumerate(dataloader):
180 | inputs, masks, targets = inputs.to(device), masks.to(device), targets.to(device)
181 |
182 | # predict
183 | outputs = model(inputs)
184 | outputs = F.softmax(outputs, dim=1)
185 | anomaly_score_i = torch.topk(torch.flatten(outputs[:,1,:], start_dim=1), 100)[0].mean(dim=1)
186 |
187 | # stack targets and outputs
188 | image_targets.extend(targets.cpu().tolist())
189 | image_masks.extend(masks.cpu().numpy())
190 |
191 | anomaly_score.extend(anomaly_score_i.cpu().tolist())
192 | anomaly_map.extend(outputs[:,1,:].cpu().numpy())
193 |
194 | # metrics
195 | image_masks = np.array(image_masks)
196 | anomaly_map = np.array(anomaly_map)
197 |
198 | auroc_image = roc_auc_score(image_targets, anomaly_score)
199 | auroc_pixel = roc_auc_score(image_masks.reshape(-1).astype(int), anomaly_map.reshape(-1))
200 | all_fprs, all_pros = compute_pro(
201 | anomaly_maps = anomaly_map,
202 | ground_truth_maps = image_masks
203 | )
204 | aupro = trapezoid(all_fprs, all_pros)
205 |
206 | metrics = {
207 | 'AUROC-image':auroc_image,
208 | 'AUROC-pixel':auroc_pixel,
209 | 'AUPRO-pixel':aupro
210 |
211 | }
212 |
213 | _logger.info('TEST: AUROC-image: %.3f%% | AUROC-pixel: %.3f%% | AUPRO-pixel: %.3f%%' %
214 | (metrics['AUROC-image'], metrics['AUROC-pixel'], metrics['AUPRO-pixel']))
215 |
216 |
217 | return metrics
218 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import os
5 |
6 | def torch_seed(random_seed):
7 | torch.manual_seed(random_seed)
8 | torch.cuda.manual_seed(random_seed)
9 | torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
10 | # CUDA randomness
11 | torch.backends.cudnn.deterministic = True
12 | torch.backends.cudnn.benchmark = False
13 |
14 | np.random.seed(random_seed)
15 | random.seed(random_seed)
16 | os.environ['PYTHONHASHSEED'] = str(random_seed)
--------------------------------------------------------------------------------