├── .gitignore ├── LICENSE ├── README.md ├── [demo] model inference.ipynb ├── [example] anomaly_simulation_strategy.ipynb ├── [example] model overview.ipynb ├── anomaly_mask.json ├── assets └── memseg.gif ├── configs.yaml ├── data ├── __init__.py ├── dataset.py ├── factory.py └── perlin.py ├── docker └── Dockerfile ├── focal_loss.py ├── log.py ├── main.py ├── metrics ├── __init__.py ├── generic_util.py └── pro_curve_util.py ├── models ├── __init__.py ├── coordatt.py ├── decoder.py ├── memory_module.py ├── memseg.py └── msff.py ├── requirements.txt ├── saved_model ├── MemSeg-bottle │ ├── best_score.json │ └── latest_score.json ├── MemSeg-cable │ ├── best_score.json │ └── latest_score.json ├── MemSeg-capsule │ ├── best_score.json │ └── latest_score.json ├── MemSeg-carpet │ ├── best_score.json │ └── latest_score.json ├── MemSeg-grid │ ├── best_score.json │ └── latest_score.json ├── MemSeg-hazelnut │ ├── best_score.json │ └── latest_score.json ├── MemSeg-leather │ ├── best_score.json │ └── latest_score.json ├── MemSeg-metal_nut │ ├── best_score.json │ └── latest_score.json ├── MemSeg-pill │ ├── best_score.json │ └── latest_score.json ├── MemSeg-screw │ ├── best_score.json │ └── latest_score.json ├── MemSeg-tile │ ├── best_score.json │ └── latest_score.json ├── MemSeg-toothbrush │ ├── best_score.json │ └── latest_score.json ├── MemSeg-transistor │ ├── best_score.json │ └── latest_score.json ├── MemSeg-wood │ ├── best_score.json │ └── latest_score.json └── MemSeg-zipper │ ├── best_score.json │ └── latest_score.json ├── scheduler.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python ### 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | # End of https://www.toptal.com/developers/gitignore/api/python 142 | 143 | # results 144 | *.pt 145 | wandb/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jaehyuk Heo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MemSeg 2 | Unofficial re-implementation for [MemSeg: A semi-supervised method for image surface defect detection using differences and commonalities](https://arxiv.org/abs/2205.00908) 3 | 4 | # Environments 5 | 6 | - Docker image: nvcr.io/nvidia/pytorch:20.12-py3 7 | 8 | ``` 9 | einops==0.5.0 10 | timm==0.5.4 11 | wandb==0.12.17 12 | omegaconf 13 | imgaug==0.4.0 14 | ``` 15 | 16 | 17 | # Process 18 | 19 | ## 1. Anomaly Simulation Strategy 20 | 21 | - [notebook](https://github.com/TooTouch/MemSeg/blob/main/%5Bexample%5D%20anomaly_simulation_strategy.ipynb) 22 | - Describable Textures Dataset(DTD) [ [download](https://www.google.com/search?q=dtd+texture+dataset&rlz=1C5CHFA_enKR999KR999&oq=dtd+texture+dataset&aqs=chrome..69i57j69i60.2253j0j7&sourceid=chrome&ie=UTF-8) ] 23 | 24 |

25 | 26 |

27 | 28 | ## 2. Model Process 29 | 30 | - [notebook](https://github.com/TooTouch/MemSeg/blob/main/%5Bexample%5D%20model%20overview.ipynb) 31 | 32 |

33 | 34 |

35 | 36 | 37 | # Run 38 | 39 | **Example** 40 | 41 | ```bash 42 | python main.py configs=configs.yaml DATASET.target=bottle 43 | ``` 44 | 45 | ## Demo 46 | 47 | ``` 48 | voila "[demo] model inference.ipynb" --port ${port} --Voila.ip ${ip} 49 | ``` 50 | 51 | ![](https://github.com/TooTouch/MemSeg/blob/main/assets/memseg.gif) 52 | 53 | # Results 54 | 55 | - **Backbone**: ResNet18 56 | 57 | | target | AUROC-image | AUROC-pixel | AUPRO-pixel | 58 | |:-----------|--------------:|--------------:|--------------:| 59 | | leather | 100 | 98.83 | 99.09 | 60 | | pill | 97.05 | 98.29 | 97.96 | 61 | | carpet | 99.12 | 97.54 | 97.02 | 62 | | hazelnut | 100 | 97.78 | 99 | 63 | | tile | 99.86 | 99.38 | 98.81 | 64 | | cable | 92.5 | 82.3 | 87.31 | 65 | | toothbrush | 100 | 99.28 | 98.56 | 66 | | transistor | 96.5 | 76.29 | 86.06 | 67 | | zipper | 99.95 | 97.94 | 97.26 | 68 | | metal_nut | 99.46 | 88.48 | 95 | 69 | | grid | 99.83 | 98.37 | 98.53 | 70 | | bottle | 100 | 98.79 | 98.36 | 71 | | capsule | 95.41 | 98.43 | 97.73 | 72 | | screw | 94.86 | 95.08 | 94 | 73 | | wood | 100 | 97.54 | 97.62 | 74 | | **Average** | 98.3 | 94.96 | 96.15 | 75 | 76 | # Citation 77 | 78 | ``` 79 | @article{DBLP:journals/corr/abs-2205-00908, 80 | author = {Minghui Yang and 81 | Peng Wu and 82 | Jing Liu and 83 | Hui Feng}, 84 | title = {MemSeg: {A} semi-supervised method for image surface defect detection 85 | using differences and commonalities}, 86 | journal = {CoRR}, 87 | volume = {abs/2205.00908}, 88 | year = {2022}, 89 | url = {https://doi.org/10.48550/arXiv.2205.00908}, 90 | doi = {10.48550/arXiv.2205.00908}, 91 | eprinttype = {arXiv}, 92 | eprint = {2205.00908}, 93 | timestamp = {Tue, 03 May 2022 15:52:06 +0200}, 94 | biburl = {https://dblp.org/rec/journals/corr/abs-2205-00908.bib}, 95 | bibsource = {dblp computer science bibliography, https://dblp.org} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /[demo] model inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Model Select" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 9, 13 | "metadata": { 14 | "scrolled": false 15 | }, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "application/vnd.jupyter.widget-view+json": { 20 | "model_id": "91564dc2a09d465baa8420dd651407c7", 21 | "version_major": 2, 22 | "version_minor": 0 23 | }, 24 | "text/plain": [ 25 | "HBox(children=(Dropdown(description='Model:', index=11, options=('leather', 'pill', 'carpet', 'hazelnut', 'til…" 26 | ] 27 | }, 28 | "metadata": {}, 29 | "output_type": "display_data" 30 | }, 31 | { 32 | "data": { 33 | "application/vnd.jupyter.widget-view+json": { 34 | "model_id": "76944799ab664d25b738c192c8268f1c", 35 | "version_major": 2, 36 | "version_minor": 0 37 | }, 38 | "text/plain": [ 39 | "Output()" 40 | ] 41 | }, 42 | "metadata": {}, 43 | "output_type": "display_data" 44 | } 45 | ], 46 | "source": [ 47 | "from data import create_dataset, create_dataloader\n", 48 | "\n", 49 | "import os\n", 50 | "import torch\n", 51 | "from omegaconf import OmegaConf\n", 52 | "from timm import create_model\n", 53 | "from models import MemSeg\n", 54 | "\n", 55 | "import ipywidgets as widgets\n", 56 | "from IPython.display import display, clear_output\n", 57 | "import matplotlib.pyplot as plt\n", 58 | "\n", 59 | "cfg = OmegaConf.load('./configs.yaml')\n", 60 | "\n", 61 | "# ====================================\n", 62 | "# Select Model\n", 63 | "# ====================================\n", 64 | "\n", 65 | "def load_model(target_name):\n", 66 | " global model\n", 67 | " global testset \n", 68 | " \n", 69 | " testset = create_dataset(\n", 70 | " datadir = cfg['DATASET']['datadir'],\n", 71 | " target = target_name, \n", 72 | " is_train = False,\n", 73 | " resize = cfg['DATASET']['resize'],\n", 74 | " texture_source_dir = cfg['DATASET']['texture_source_dir'],\n", 75 | " structure_grid_size = cfg['DATASET']['structure_grid_size'],\n", 76 | " transparency_range = cfg['DATASET']['transparency_range'],\n", 77 | " perlin_scale = cfg['DATASET']['perlin_scale'], \n", 78 | " min_perlin_scale = cfg['DATASET']['min_perlin_scale'], \n", 79 | " perlin_noise_threshold = cfg['DATASET']['perlin_noise_threshold']\n", 80 | " )\n", 81 | " \n", 82 | " memory_bank = torch.load(f'./saved_model/MemSeg-{target_name}/memory_bank.pt')\n", 83 | " memory_bank.device = 'cpu'\n", 84 | " for k in memory_bank.memory_information.keys():\n", 85 | " memory_bank.memory_information[k] = memory_bank.memory_information[k].cpu()\n", 86 | "\n", 87 | " feature_extractor = feature_extractor = create_model(\n", 88 | " cfg['MODEL']['feature_extractor_name'], \n", 89 | " pretrained = True, \n", 90 | " features_only = True\n", 91 | " )\n", 92 | " model = MemSeg(\n", 93 | " memory_bank = memory_bank,\n", 94 | " feature_extractor = feature_extractor\n", 95 | " )\n", 96 | "\n", 97 | " model.load_state_dict(torch.load(f'./saved_model/MemSeg-{target_name}/best_model.pt'))\n", 98 | " model.eval()\n", 99 | "\n", 100 | "# ====================================\n", 101 | "# Visualization\n", 102 | "# ====================================\n", 103 | "\n", 104 | "\n", 105 | "def result_plot(idx):\n", 106 | " input_i, mask_i, target_i = testset[idx]\n", 107 | "\n", 108 | " output_i = model(input_i.unsqueeze(0)).detach()\n", 109 | " output_i = torch.nn.functional.softmax(output_i, dim=1)\n", 110 | "\n", 111 | " def minmax_scaling(img):\n", 112 | " return (((img - img.min()) / (img.max() - img.min())) * 255).to(torch.uint8)\n", 113 | "\n", 114 | " fig, ax = plt.subplots(1,4, figsize=(15,10))\n", 115 | " \n", 116 | " ax[0].imshow(minmax_scaling(input_i.permute(1,2,0)))\n", 117 | " ax[0].set_title('Input: {}'.format('Normal' if target_i == 0 else 'Abnormal'))\n", 118 | " ax[1].imshow(mask_i, cmap='gray')\n", 119 | " ax[1].set_title('Ground Truth')\n", 120 | " ax[2].imshow(output_i[0][1], cmap='gray')\n", 121 | " ax[2].set_title('Predicted Mask')\n", 122 | " ax[3].imshow(minmax_scaling(input_i.permute(1,2,0)), alpha=1)\n", 123 | " ax[3].imshow(output_i[0][1], cmap='gray', alpha=0.5)\n", 124 | " ax[3].set_title(f'Input X Predicted Mask')\n", 125 | " \n", 126 | " plt.show()\n", 127 | " \n", 128 | "\n", 129 | "\n", 130 | "\n", 131 | "\n", 132 | "# ====================================\n", 133 | "# widgets\n", 134 | "# ====================================\n", 135 | "\n", 136 | "\n", 137 | "target_list = widgets.Dropdown(\n", 138 | " options=[m.split('-')[1] for m in os.listdir('./saved_model')],\n", 139 | " value='capsule',\n", 140 | " description='Model:',\n", 141 | " disabled=False,\n", 142 | ")\n", 143 | "button = widgets.Button(description=\"Model Change\")\n", 144 | "output = widgets.Output()\n", 145 | "\n", 146 | "\n", 147 | "@output.capture()\n", 148 | "def on_button_clicked(b):\n", 149 | " clear_output(wait=True)\n", 150 | " load_model(target_name=target_list.value)\n", 151 | " \n", 152 | " # vizualization\n", 153 | " file_list = widgets.Dropdown(\n", 154 | " options=[(file_path, i) for i, file_path in enumerate(testset.file_list)],\n", 155 | " value=0,\n", 156 | " description='image:',\n", 157 | " )\n", 158 | "\n", 159 | " \n", 160 | " widgets.interact(result_plot, idx=file_list)\n", 161 | "\n", 162 | "button.on_click(on_button_clicked)\n", 163 | "\n", 164 | "\n", 165 | "display(widgets.HBox([target_list, button]), output)\n", 166 | "\n" 167 | ] 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "Python 3 (ipykernel)", 173 | "language": "python", 174 | "name": "python3" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.8.10" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 4 191 | } 192 | -------------------------------------------------------------------------------- /anomaly_mask.json: -------------------------------------------------------------------------------- 1 | { 2 | "hazelnut": { 3 | "use_mask" : true, 4 | "bg_threshold" : 50, 5 | "bg_reverse" : true 6 | }, 7 | "leather": { 8 | "use_mask" : false, 9 | "bg_threshold" : null, 10 | "bg_reverse" : null 11 | }, 12 | "metal_nut": { 13 | "use_mask" : true, 14 | "bg_threshold" : 40, 15 | "bg_reverse" : true 16 | }, 17 | "tile": { 18 | "use_mask" : false, 19 | "bg_threshold" : null, 20 | "bg_reverse" : null 21 | }, 22 | "wood": { 23 | "use_mask" : false, 24 | "bg_threshold" : null, 25 | "bg_reverse" : null 26 | }, 27 | "grid": { 28 | "use_mask" : false, 29 | "bg_threshold" : null, 30 | "bg_reverse" : null 31 | }, 32 | "cable": { 33 | "use_mask" : false, 34 | "bg_threshold" : 150, 35 | "bg_reverse" : true 36 | }, 37 | "capsule": { 38 | "use_mask" : true, 39 | "bg_threshold" : 120, 40 | "bg_reverse" : false 41 | }, 42 | "transistor": { 43 | "use_mask" : true, 44 | "bg_threshold" : 90, 45 | "bg_reverse" : false 46 | }, 47 | "carpet": { 48 | "use_mask" : false, 49 | "bg_threshold" : null, 50 | "bg_reverse" : null 51 | }, 52 | "bottle": { 53 | "use_mask" : true, 54 | "bg_threshold" : 250, 55 | "bg_reverse" : false 56 | }, 57 | "screw": { 58 | "use_mask" : true, 59 | "bg_threshold" : 110, 60 | "bg_reverse" : false 61 | }, 62 | "zipper": { 63 | "use_mask" : true, 64 | "bg_threshold" : 100, 65 | "bg_reverse" : false 66 | }, 67 | "pill": { 68 | "use_mask" : true, 69 | "bg_threshold" : 100, 70 | "bg_reverse" : true 71 | }, 72 | "toothbrush": { 73 | "use_mask" : true, 74 | "bg_threshold" : 30, 75 | "bg_reverse" : true 76 | } 77 | } -------------------------------------------------------------------------------- /assets/memseg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TooTouch/MemSeg/836bd465a9b14422f92666dc29dc36edce2692d0/assets/memseg.gif -------------------------------------------------------------------------------- /configs.yaml: -------------------------------------------------------------------------------- 1 | EXP_NAME: MemSeg 2 | SEED: 42 3 | 4 | DATASET: 5 | datadir: /datasets/MVTec 6 | texture_source_dir: /datasets/dtd/images 7 | anomaly_mask_info: ./anomaly_mask.json 8 | target: null 9 | resize: 10 | - 288 11 | - 288 12 | imagesize: 256 13 | structure_grid_size: 8 14 | transparency_range: 15 | - 0.15 # under bound 16 | - 1. # upper bound 17 | perlin_scale: 6 18 | min_perlin_scale: 0 19 | perlin_noise_threshold: 0.5 20 | 21 | DATALOADER: 22 | batch_size: 8 23 | num_workers: 0 24 | 25 | MEMORYBANK: 26 | nb_memory_sample: 30 27 | 28 | MODEL: 29 | feature_extractor_name: resnet18 30 | 31 | TRAIN: 32 | batch_size: 8 33 | num_training_steps: 5000 34 | l1_weight: 0.6 35 | focal_weight: 0.4 36 | focal_alpha: null 37 | focal_gamma: 4 38 | use_wandb: True 39 | 40 | OPTIMIZER: 41 | lr: 0.003 42 | weight_decay: 0.0005 43 | 44 | SCHEDULER: 45 | min_lr: 0.0001 46 | warmup_ratio: 0.1 47 | use_scheduler: True 48 | 49 | LOG: 50 | log_interval: 1 51 | eval_interval: 100 52 | 53 | RESULT: 54 | savedir: ./saved_model -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .perlin import * 2 | from .factory import * 3 | from .dataset import * -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from glob import glob 6 | from einops import rearrange 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | import torchvision.transforms as transforms 11 | import imgaug.augmenters as iaa 12 | 13 | from data import rand_perlin_2d_np 14 | 15 | from typing import List 16 | 17 | import cv2 18 | import os 19 | import numpy as np 20 | from glob import glob 21 | from einops import rearrange 22 | 23 | import torch 24 | from torchvision import transforms 25 | from torch.utils.data import Dataset 26 | import imgaug.augmenters as iaa 27 | 28 | from .perlin import rand_perlin_2d_np 29 | 30 | from typing import List 31 | 32 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 33 | IMAGENET_STD = [0.229, 0.224, 0.225] 34 | 35 | class MemSegDataset(Dataset): 36 | def __init__( 37 | self, datadir: str, target: str, is_train: bool, to_memory: bool = False, 38 | resize: List[int] = [256, 256], imagesize: int = 224, 39 | texture_source_dir: str = None, structure_grid_size: str = 8, 40 | transparency_range: List[float] = [0.15, 1.], 41 | perlin_scale: int = 6, min_perlin_scale: int = 0, perlin_noise_threshold: float = 0.5, 42 | use_mask: bool = True, bg_threshold: float = 100, bg_reverse: bool = False 43 | ): 44 | # mode 45 | self.is_train = is_train 46 | self.to_memory = to_memory 47 | 48 | # load image file list 49 | self.datadir = datadir 50 | self.target = target 51 | self.file_list = glob(os.path.join(self.datadir, self.target, 'train/*/*' if is_train else 'test/*/*')) 52 | 53 | # synthetic anomaly 54 | if self.is_train and not self.to_memory: 55 | # load texture image file list 56 | self.texture_source_file_list = glob(os.path.join(texture_source_dir,'*/*')) if texture_source_dir else None 57 | 58 | # perlin noise 59 | self.perlin_scale = perlin_scale 60 | self.min_perlin_scale = min_perlin_scale 61 | self.perlin_noise_threshold = perlin_noise_threshold 62 | 63 | # structure 64 | self.structure_grid_size = structure_grid_size 65 | 66 | # anomaly mixing 67 | self.transparency_range = transparency_range 68 | 69 | # mask setting 70 | self.use_mask = use_mask 71 | self.bg_threshold = bg_threshold 72 | self.bg_reverse = bg_reverse 73 | 74 | # transform 75 | self.resize = list(resize) 76 | self.transform_img = [ 77 | transforms.ToPILImage(), 78 | transforms.CenterCrop(imagesize), 79 | transforms.ToTensor(), 80 | transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), 81 | ] 82 | self.transform_img = transforms.Compose(self.transform_img) 83 | 84 | self.transform_mask = [ 85 | transforms.ToPILImage(), 86 | transforms.CenterCrop(imagesize), 87 | transforms.ToTensor(), 88 | ] 89 | self.transform_mask = transforms.Compose(self.transform_mask) 90 | 91 | # sythetic anomaly switch 92 | self.anomaly_switch = False 93 | 94 | def __getitem__(self, idx): 95 | 96 | file_path = self.file_list[idx] 97 | 98 | # image 99 | img = Image.open(file_path).convert("RGB").resize(self.resize) 100 | img = np.array(img) 101 | 102 | # target 103 | target = 0 if 'good' in self.file_list[idx] else 1 104 | 105 | # mask 106 | if 'good' in file_path: 107 | mask = np.zeros(self.resize, dtype=np.float32) 108 | else: 109 | mask = Image.open(file_path.replace('test','ground_truth').replace('.png','_mask.png')).resize(self.resize) 110 | mask = np.array(mask) 111 | 112 | ## anomaly source 113 | if self.is_train and not self.to_memory: 114 | if self.anomaly_switch: 115 | img, mask = self.generate_anomaly(img=img, texture_img_list=self.texture_source_file_list) 116 | target = 1 117 | self.anomaly_switch = False 118 | 119 | mask = torch.Tensor(mask) 120 | else: 121 | self.anomaly_switch = True 122 | 123 | img = self.transform_img(img) 124 | mask = self.transform_mask(mask).squeeze() 125 | 126 | return img, mask, target 127 | 128 | 129 | def rand_augment(self): 130 | augmenters = [ 131 | iaa.GammaContrast((0.5,2.0),per_channel=True), 132 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)), 133 | iaa.pillike.EnhanceSharpness(), 134 | iaa.AddToHueAndSaturation((-50,50),per_channel=True), 135 | iaa.Solarize(0.5, threshold=(32,128)), 136 | iaa.Posterize(), 137 | iaa.Invert(), 138 | iaa.pillike.Autocontrast(), 139 | iaa.pillike.Equalize(), 140 | iaa.Affine(rotate=(-45, 45)) 141 | ] 142 | 143 | aug_idx = np.random.choice(np.arange(len(augmenters)), 3, replace=False) 144 | aug = iaa.Sequential([ 145 | augmenters[aug_idx[0]], 146 | augmenters[aug_idx[1]], 147 | augmenters[aug_idx[2]] 148 | ]) 149 | 150 | return aug 151 | 152 | def generate_anomaly(self, img: np.ndarray, texture_img_list: list = None) -> List[np.ndarray]: 153 | ''' 154 | step 1. generate mask 155 | - target foreground mask 156 | - perlin noise mask 157 | 158 | step 2. generate texture or structure anomaly 159 | - texture: load DTD 160 | - structure: we first perform random adjustment of mirror symmetry, rotation, brightness, saturation, 161 | and hue on the input image 𝐼 . Then the preliminary processed image is uniformly divided into a 4×8 grid 162 | and randomly arranged to obtain the disordered image 𝐼 163 | 164 | step 3. blending image and anomaly source 165 | ''' 166 | 167 | # step 1. generate mask 168 | img_size = img.shape[:-1] # H x W 169 | 170 | ## target foreground mask 171 | if self.use_mask: 172 | target_foreground_mask = self.generate_target_foreground_mask(img=img) 173 | else: 174 | target_foreground_mask = np.ones(self.resize) 175 | 176 | ## perlin noise mask 177 | perlin_noise_mask = self.generate_perlin_noise_mask(img_size=img_size) 178 | 179 | ## mask 180 | mask = perlin_noise_mask * target_foreground_mask 181 | mask_expanded = np.expand_dims(mask, axis=2) 182 | 183 | # step 2. generate texture or structure anomaly 184 | 185 | ## anomaly source 186 | anomaly_source_img = self.anomaly_source(img=img, texture_img_list=texture_img_list) 187 | 188 | ## mask anomaly parts 189 | factor = np.random.uniform(*self.transparency_range, size=1)[0] 190 | anomaly_source_img = factor * (mask_expanded * anomaly_source_img) + (1 - factor) * (mask_expanded * img) 191 | 192 | # step 3. blending image and anomaly source 193 | anomaly_source_img = ((- mask_expanded + 1) * img) + anomaly_source_img 194 | 195 | return (anomaly_source_img.astype(np.uint8), mask) 196 | 197 | def generate_target_foreground_mask(self, img: np.ndarray) -> np.ndarray: 198 | # convert RGB into GRAY scale 199 | img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 200 | 201 | # generate binary mask of gray scale image 202 | _, target_background_mask = cv2.threshold(img_gray, self.bg_threshold, 255, cv2.THRESH_BINARY) 203 | target_background_mask = target_background_mask.astype(np.bool).astype(np.int) 204 | 205 | # invert mask for foreground mask 206 | if self.bg_reverse: 207 | target_foreground_mask = target_background_mask 208 | else: 209 | target_foreground_mask = -(target_background_mask - 1) 210 | 211 | return target_foreground_mask 212 | 213 | def generate_perlin_noise_mask(self, img_size: tuple) -> np.ndarray: 214 | # define perlin noise scale 215 | perlin_scalex = 2 ** (torch.randint(self.min_perlin_scale, self.perlin_scale, (1,)).numpy()[0]) 216 | perlin_scaley = 2 ** (torch.randint(self.min_perlin_scale, self.perlin_scale, (1,)).numpy()[0]) 217 | 218 | # generate perlin noise 219 | perlin_noise = rand_perlin_2d_np(img_size, (perlin_scalex, perlin_scaley)) 220 | 221 | # apply affine transform 222 | rot = iaa.Affine(rotate=(-90, 90)) 223 | perlin_noise = rot(image=perlin_noise) 224 | 225 | # make a mask by applying threshold 226 | mask_noise = np.where( 227 | perlin_noise > self.perlin_noise_threshold, 228 | np.ones_like(perlin_noise), 229 | np.zeros_like(perlin_noise) 230 | ) 231 | 232 | return mask_noise 233 | 234 | def anomaly_source(self, img: np.ndarray, texture_img_list: list = None) -> np.ndarray: 235 | p = np.random.uniform() if texture_img_list else 1.0 236 | if p < 0.5: 237 | idx = np.random.choice(len(texture_img_list)) 238 | img_size = img.shape[:-1] # H x W 239 | anomaly_source_img = self._texture_source(img_size=img_size, texture_img_path=texture_img_list[idx]) 240 | else: 241 | anomaly_source_img = self._structure_source(img=img) 242 | 243 | return anomaly_source_img 244 | 245 | def _texture_source(self, img_size: tuple, texture_img_path: str) -> np.ndarray: 246 | texture_source_img = cv2.imread(texture_img_path) 247 | texture_source_img = cv2.cvtColor(texture_source_img, cv2.COLOR_BGR2RGB) 248 | texture_source_img = cv2.resize(texture_source_img, dsize=img_size).astype(np.float32) 249 | 250 | return texture_source_img 251 | 252 | def _structure_source(self, img: np.ndarray) -> np.ndarray: 253 | structure_source_img = self.rand_augment()(image=img) 254 | 255 | img_size = img.shape[:-1] # H x W 256 | 257 | assert img_size[0] % self.structure_grid_size == 0, 'structure should be devided by grid size accurately' 258 | grid_w = img_size[1] // self.structure_grid_size 259 | grid_h = img_size[0] // self.structure_grid_size 260 | 261 | structure_source_img = rearrange( 262 | tensor = structure_source_img, 263 | pattern = '(h gh) (w gw) c -> (h w) gw gh c', 264 | gw = grid_w, 265 | gh = grid_h 266 | ) 267 | disordered_idx = np.arange(structure_source_img.shape[0]) 268 | np.random.shuffle(disordered_idx) 269 | 270 | structure_source_img = rearrange( 271 | tensor = structure_source_img[disordered_idx], 272 | pattern = '(h w) gw gh c -> (h gh) (w gw) c', 273 | h = self.structure_grid_size, 274 | w = self.structure_grid_size 275 | ).astype(np.float32) 276 | 277 | return structure_source_img 278 | 279 | def __len__(self): 280 | return len(self.file_list) 281 | 282 | 283 | -------------------------------------------------------------------------------- /data/factory.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from typing import List 3 | from .dataset import MemSegDataset 4 | 5 | def create_dataset( 6 | datadir: str, target: str, is_train: bool, to_memory: bool = False, 7 | resize: List[int] = [256, 256], imagesize: int = 224, 8 | texture_source_dir: str = None, structure_grid_size: str = 8, 9 | transparency_range: List[float] = [0.15, 1.], 10 | perlin_scale: int = 6, min_perlin_scale: int = 0, perlin_noise_threshold: float = 0.5, 11 | use_mask: bool = True, bg_threshold: float = 100, bg_reverse: bool = False 12 | ): 13 | dataset = MemSegDataset( 14 | datadir = datadir, 15 | target = target, 16 | is_train = is_train, 17 | to_memory = to_memory, 18 | resize = resize, 19 | imagesize = imagesize, 20 | texture_source_dir = texture_source_dir, 21 | structure_grid_size = structure_grid_size, 22 | transparency_range = transparency_range, 23 | perlin_scale = perlin_scale, 24 | min_perlin_scale = min_perlin_scale, 25 | perlin_noise_threshold = perlin_noise_threshold, 26 | use_mask = use_mask, 27 | bg_threshold = bg_threshold, 28 | bg_reverse = bg_reverse 29 | ) 30 | 31 | return dataset 32 | 33 | 34 | def create_dataloader(dataset, train: bool, batch_size: int = 16, num_workers: int = 1): 35 | dataloader = DataLoader( 36 | dataset, 37 | shuffle = train, 38 | batch_size = batch_size, 39 | num_workers = num_workers 40 | ) 41 | 42 | return dataloader -------------------------------------------------------------------------------- /data/perlin.py: -------------------------------------------------------------------------------- 1 | # https://github.com/VitjanZ/DRAEM/blob/main/perlin.py 2 | 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | def lerp_np(x,y,w): 8 | fin_out = (y-x)*w + x 9 | return fin_out 10 | 11 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5): 12 | noise = np.zeros(shape) 13 | frequency = 1 14 | amplitude = 1 15 | for _ in range(octaves): 16 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1])) 17 | frequency *= 2 18 | amplitude *= persistence 19 | return noise 20 | 21 | 22 | def generate_perlin_noise_2d(shape, res): 23 | def f(t): 24 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3 25 | 26 | delta = (res[0] / shape[0], res[1] / shape[1]) 27 | d = (shape[0] // res[0], shape[1] // res[1]) 28 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 29 | # Gradients 30 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) 31 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 32 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 33 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 34 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) 35 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) 36 | # Ramps 37 | n00 = np.sum(grid * g00, 2) 38 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) 39 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) 40 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) 41 | # Interpolation 42 | t = f(grid) 43 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 44 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 45 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) 46 | 47 | 48 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 49 | delta = (res[0] / shape[0], res[1] / shape[1]) 50 | d = (shape[0] // res[0], shape[1] // res[1]) 51 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 52 | 53 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) 54 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) 55 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1) 56 | 57 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1) 58 | dot = lambda grad, shift: ( 59 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 60 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) 61 | 62 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 63 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 64 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 65 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 66 | t = fade(grid[:shape[0], :shape[1]]) 67 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) 68 | 69 | 70 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 71 | delta = (res[0] / shape[0], res[1] / shape[1]) 72 | d = (shape[0] // res[0], shape[1] // res[1]) 73 | 74 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 75 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) 76 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) 77 | 78 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 79 | 0).repeat_interleave( 80 | d[1], 1) 81 | dot = lambda grad, shift: ( 82 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 83 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) 84 | 85 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 86 | 87 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 88 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 89 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 90 | t = fade(grid[:shape[0], :shape[1]]) 91 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) 92 | 93 | 94 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): 95 | noise = torch.zeros(shape) 96 | frequency = 1 97 | amplitude = 1 98 | for _ in range(octaves): 99 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) 100 | frequency *= 2 101 | amplitude *= persistence 102 | return noise -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:20.12-py3 2 | 3 | RUN apt-get update && apt-get upgrade -y && apt-get install -y vim && apt-get install -y git 4 | RUN pip install --upgrade pip 5 | 6 | WORKDIR / 7 | 8 | ARG UNAME 9 | ARG UID 10 | ARG GID 11 | RUN groupadd -g $GID -o $UNAME 12 | RUN useradd -m -u $UID -g $GID -o -s /bin/bash $UNAME 13 | USER $UNAME 14 | 15 | 16 | -------------------------------------------------------------------------------- /focal_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/mbsariyildiz/focal-loss.pytorch/blob/master/focalloss.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class FocalLoss(nn.Module): 7 | 8 | def __init__(self, smooth=1e-5, gamma=0, alpha=None, size_average=True): 9 | super(FocalLoss, self).__init__() 10 | self.gamma = gamma 11 | self.alpha = alpha 12 | self.smooth = smooth 13 | if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) 14 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) 15 | self.size_average = size_average 16 | 17 | def forward(self, input, target): 18 | if input.dim()>2: 19 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 20 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 21 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 22 | target = target.view(-1, 1) 23 | 24 | pt = input 25 | logpt = (pt + 1e-5).log() 26 | 27 | # add label smoothing 28 | num_class = input.shape[1] 29 | idx = target.cpu().long() 30 | 31 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 32 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 33 | if one_hot_key.device != input.device: 34 | one_hot_key = one_hot_key.to(input.device) 35 | 36 | if self.smooth: 37 | one_hot_key = torch.clamp( 38 | one_hot_key, self.smooth, 1.0 - self.smooth) 39 | logpt = logpt * one_hot_key 40 | 41 | if self.alpha is not None: 42 | if self.alpha.type() != input.data.type(): 43 | self.alpha = self.alpha.type_as(input.data) 44 | at = self.alpha.gather(0, target.data.view(-1)) 45 | logpt = logpt * at 46 | 47 | loss = (-1 * (1 - pt)**self.gamma * logpt).sum(1) 48 | if self.size_average: return loss.mean() 49 | else: return loss.sum() -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.handlers 3 | 4 | 5 | class FormatterNoInfo(logging.Formatter): 6 | def __init__(self, fmt='%(levelname)s: %(message)s'): 7 | logging.Formatter.__init__(self, fmt) 8 | 9 | def format(self, record): 10 | if record.levelno == logging.INFO: 11 | return str(record.getMessage()) 12 | return logging.Formatter.format(self, record) 13 | 14 | 15 | def setup_default_logging(default_level=logging.INFO, log_path=''): 16 | console_handler = logging.StreamHandler() 17 | console_handler.setFormatter(FormatterNoInfo()) 18 | logging.root.addHandler(console_handler) 19 | logging.root.setLevel(default_level) 20 | if log_path: 21 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 22 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 23 | file_handler.setFormatter(file_formatter) 24 | logging.root.addHandler(file_handler) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import logging 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import argparse 7 | 8 | from omegaconf import OmegaConf 9 | from timm import create_model 10 | from data import create_dataset, create_dataloader 11 | from models import MemSeg, MemoryBank 12 | from focal_loss import FocalLoss 13 | from train import training 14 | from log import setup_default_logging 15 | from utils import torch_seed 16 | from scheduler import CosineAnnealingWarmupRestarts 17 | 18 | 19 | _logger = logging.getLogger('train') 20 | 21 | 22 | def run(cfg): 23 | 24 | # setting seed and device 25 | setup_default_logging() 26 | torch_seed(cfg.SEED) 27 | 28 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 29 | _logger.info('Device: {}'.format(device)) 30 | 31 | # savedir 32 | cfg.EXP_NAME = cfg.EXP_NAME + f"-{cfg.DATASET.target}" 33 | savedir = os.path.join(cfg.RESULT.savedir, cfg.EXP_NAME) 34 | os.makedirs(savedir, exist_ok=True) 35 | 36 | 37 | # wandb 38 | if cfg.TRAIN.use_wandb: 39 | wandb.init(name=cfg.EXP_NAME, project='MemSeg', config=OmegaConf.to_container(cfg)) 40 | 41 | # build datasets 42 | trainset = create_dataset( 43 | datadir = cfg.DATASET.datadir, 44 | target = cfg.DATASET.target, 45 | is_train = True, 46 | resize = cfg.DATASET.resize, 47 | imagesize = cfg.DATASET.imagesize, 48 | texture_source_dir = cfg.DATASET.texture_source_dir, 49 | structure_grid_size = cfg.DATASET.structure_grid_size, 50 | transparency_range = cfg.DATASET.transparency_range, 51 | perlin_scale = cfg.DATASET.perlin_scale, 52 | min_perlin_scale = cfg.DATASET.min_perlin_scale, 53 | perlin_noise_threshold = cfg.DATASET.perlin_noise_threshold, 54 | use_mask = cfg.DATASET.use_mask, 55 | bg_threshold = cfg.DATASET.bg_threshold, 56 | bg_reverse = cfg.DATASET.bg_reverse 57 | ) 58 | 59 | memoryset = create_dataset( 60 | datadir = cfg.DATASET.datadir, 61 | target = cfg.DATASET.target, 62 | is_train = True, 63 | to_memory = True, 64 | resize = cfg.DATASET.resize, 65 | imagesize = cfg.DATASET.imagesize, 66 | ) 67 | 68 | testset = create_dataset( 69 | datadir = cfg.DATASET.datadir, 70 | target = cfg.DATASET.target, 71 | is_train = False, 72 | resize = cfg.DATASET.resize, 73 | imagesize = cfg.DATASET.imagesize, 74 | ) 75 | 76 | # build dataloader 77 | trainloader = create_dataloader( 78 | dataset = trainset, 79 | train = True, 80 | batch_size = cfg.DATALOADER.batch_size, 81 | num_workers = cfg.DATALOADER.num_workers 82 | ) 83 | 84 | testloader = create_dataloader( 85 | dataset = testset, 86 | train = False, 87 | batch_size = cfg.DATALOADER.batch_size, 88 | num_workers = cfg.DATALOADER.num_workers 89 | ) 90 | 91 | 92 | # build feature extractor 93 | feature_extractor = create_model( 94 | cfg.MODEL.feature_extractor_name, 95 | pretrained = True, 96 | features_only = True 97 | ).to(device) 98 | ## freeze weight of layer1,2,3 99 | for l in ['layer1','layer2','layer3']: 100 | for p in feature_extractor[l].parameters(): 101 | p.requires_grad = False 102 | 103 | # build memory bank 104 | memory_bank = MemoryBank( 105 | normal_dataset = memoryset, 106 | nb_memory_sample = cfg.MEMORYBANK.nb_memory_sample, 107 | device = device 108 | ) 109 | ## update normal samples and save 110 | memory_bank.update(feature_extractor=feature_extractor) 111 | torch.save(memory_bank, os.path.join(savedir, f'memory_bank.pt')) 112 | _logger.info('Update {} normal samples in memory bank'.format(cfg.MEMORYBANK.nb_memory_sample)) 113 | 114 | # build MemSeg 115 | model = MemSeg( 116 | memory_bank = memory_bank, 117 | feature_extractor = feature_extractor 118 | ).to(device) 119 | 120 | # Set training 121 | l1_criterion = nn.L1Loss() 122 | f_criterion = FocalLoss( 123 | gamma = cfg.TRAIN.focal_gamma, 124 | alpha = cfg.TRAIN.focal_alpha 125 | ) 126 | 127 | optimizer = torch.optim.AdamW( 128 | params = filter(lambda p: p.requires_grad, model.parameters()), 129 | lr = cfg.OPTIMIZER.lr, 130 | weight_decay = cfg.OPTIMIZER.weight_decay 131 | ) 132 | 133 | if cfg['SCHEDULER']['use_scheduler']: 134 | scheduler = CosineAnnealingWarmupRestarts( 135 | optimizer, 136 | first_cycle_steps = cfg.TRAIN.num_training_steps, 137 | max_lr = cfg.OPTIMIZER.lr, 138 | min_lr = cfg.SCHEDULER.min_lr, 139 | warmup_steps = int(cfg.TRAIN.num_training_steps * cfg.SCHEDULER.warmup_ratio) 140 | ) 141 | else: 142 | scheduler = None 143 | 144 | # Fitting model 145 | training( 146 | model = model, 147 | num_training_steps = cfg.TRAIN.num_training_steps, 148 | trainloader = trainloader, 149 | validloader = testloader, 150 | criterion = [l1_criterion, f_criterion], 151 | loss_weights = [cfg.TRAIN.l1_weight, cfg.TRAIN.focal_weight], 152 | optimizer = optimizer, 153 | scheduler = scheduler, 154 | log_interval = cfg.LOG.log_interval, 155 | eval_interval = cfg.LOG.eval_interval, 156 | savedir = savedir, 157 | device = device, 158 | use_wandb = cfg.TRAIN.use_wandb 159 | ) 160 | 161 | 162 | 163 | if __name__=='__main__': 164 | args = OmegaConf.from_cli() 165 | # load default config 166 | cfg = OmegaConf.load(args.configs) 167 | del args['configs'] 168 | 169 | # merge config with new keys 170 | cfg = OmegaConf.merge(cfg, args) 171 | 172 | # target cfg 173 | target_cfg = OmegaConf.load(cfg.DATASET.anomaly_mask_info) 174 | cfg.DATASET = OmegaConf.merge(cfg.DATASET, target_cfg[cfg.DATASET.target]) 175 | 176 | print(OmegaConf.to_yaml(cfg)) 177 | 178 | run(cfg) 179 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .pro_curve_util import compute_pro 2 | from .generic_util import trapezoid -------------------------------------------------------------------------------- /metrics/generic_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utility functions for: 3 | - parsing user arguments. 4 | - computing the area under a curve. 5 | - generating a toy dataset to test the evaluation script. 6 | """ 7 | from bisect import bisect 8 | import numpy as np 9 | 10 | 11 | def trapezoid(x, y, x_max=None): 12 | """ 13 | This function calculates the definit integral of a curve given by 14 | x- and corresponding y-values. In contrast to, e.g., 'numpy.trapz()', 15 | this function allows to define an upper bound to the integration range by 16 | setting a value x_max. 17 | 18 | Points that do not have a finite x or y value will be ignored with a 19 | warning. 20 | 21 | Args: 22 | x: Samples from the domain of the function to integrate 23 | Need to be sorted in ascending order. May contain the same value 24 | multiple times. In that case, the order of the corresponding 25 | y values will affect the integration with the trapezoidal rule. 26 | y: Values of the function corresponding to x values. 27 | x_max: Upper limit of the integration. The y value at max_x will be 28 | determined by interpolating between its neighbors. Must not lie 29 | outside of the range of x. 30 | 31 | Returns: 32 | Area under the curve. 33 | """ 34 | 35 | x = np.asarray(x) 36 | y = np.asarray(y) 37 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y)) 38 | if not finite_mask.all(): 39 | print("WARNING: Not all x and y values passed to trapezoid(...)" 40 | " are finite. Will continue with only the finite values.") 41 | x = x[finite_mask] 42 | y = y[finite_mask] 43 | 44 | # Introduce a correction term if max_x is not an element of x. 45 | correction = 0. 46 | if x_max is not None: 47 | if x_max not in x: 48 | # Get the insertion index that would keep x sorted after 49 | # np.insert(x, ins, x_max). 50 | ins = bisect(x, x_max) 51 | # x_max must be between the minimum and the maximum, so the 52 | # insertion_point cannot be zero or len(x). 53 | assert 0 < ins < len(x) 54 | 55 | # Calculate the correction term which is the integral between 56 | # the last x[ins-1] and x_max. Since we do not know the exact value 57 | # of y at x_max, we interpolate between y[ins] and y[ins-1]. 58 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * 59 | (x_max - x[ins - 1]) / 60 | (x[ins] - x[ins - 1])) 61 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1]) 62 | 63 | # Cut off at x_max. 64 | mask = x <= x_max 65 | x = x[mask] 66 | y = y[mask] 67 | 68 | # Return area under the curve using the trapezoidal rule. 69 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction 70 | 71 | def generate_toy_dataset(num_images, image_width, image_height, gt_size): 72 | """Generate a toy dataset to test the evaluation script. 73 | 74 | Args: 75 | num_images: Number of images that the toy dataset contains. 76 | image_width: Width of the dataset images in pixels. 77 | image_height: Height of the dataset images in pixels. 78 | gt_size: Size of rectangular ground truth regions that are 79 | artificially generated on the dataset images. 80 | 81 | Returns: 82 | anomaly_maps: List of numpy arrays that contain random anomaly maps. 83 | ground_truth_map: Corresponding list of numpy arrays that specify a 84 | rectangular ground truth region. 85 | """ 86 | # Fix a random seed for reproducibility. 87 | np.random.seed(1338) 88 | 89 | # Create synthetic evaluation data with random anomaly scores and 90 | # simple ground truth maps. 91 | anomaly_maps = [] 92 | ground_truth_maps = [] 93 | for _ in range(num_images): 94 | # Sample a random anomaly map. 95 | anomaly_map = np.random.random((image_height, image_width)) 96 | 97 | # Construct a fixed ground truth map. 98 | ground_truth_map = np.zeros((image_height, image_width)) 99 | ground_truth_map[0:gt_size, 0:gt_size] = 1 100 | 101 | anomaly_maps.append(anomaly_map) 102 | ground_truth_maps.append(ground_truth_map) 103 | 104 | return anomaly_maps, ground_truth_maps -------------------------------------------------------------------------------- /metrics/pro_curve_util.py: -------------------------------------------------------------------------------- 1 | """Utility function that computes a PRO curve, given pairs of anomaly and ground 2 | truth maps. 3 | 4 | The PRO curve can also be integrated up to a constant integration limit. 5 | """ 6 | import numpy as np 7 | from scipy.ndimage.measurements import label 8 | 9 | 10 | def compute_pro(anomaly_maps, ground_truth_maps): 11 | """Compute the PRO curve for a set of anomaly maps with corresponding ground 12 | truth maps. 13 | 14 | Args: 15 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a 16 | real-valued anomaly score at each pixel. 17 | 18 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that 19 | contain binary-valued ground truth labels for each pixel. 20 | 0 indicates that a pixel is anomaly-free. 21 | 1 indicates that a pixel contains an anomaly. 22 | 23 | Returns: 24 | fprs: numpy array of false positive rates. 25 | pros: numpy array of corresponding PRO values. 26 | """ 27 | 28 | print("Compute PRO curve...") 29 | 30 | # Structuring element for computing connected components. 31 | structure = np.ones((3, 3), dtype=int) 32 | 33 | num_ok_pixels = 0 34 | num_gt_regions = 0 35 | 36 | shape = (len(anomaly_maps), 37 | anomaly_maps[0].shape[0], 38 | anomaly_maps[0].shape[1]) 39 | fp_changes = np.zeros(shape, dtype=np.uint32) 40 | assert shape[0] * shape[1] * shape[2] < np.iinfo(fp_changes.dtype).max, \ 41 | 'Potential overflow when using np.cumsum(), consider using np.uint64.' 42 | 43 | pro_changes = np.zeros(shape, dtype=np.float64) 44 | 45 | for gt_ind, gt_map in enumerate(ground_truth_maps): 46 | 47 | # Compute the connected components in the ground truth map. 48 | labeled, n_components = label(gt_map, structure) 49 | num_gt_regions += n_components 50 | 51 | # Compute the mask that gives us all ok pixels. 52 | ok_mask = labeled == 0 53 | num_ok_pixels_in_map = np.sum(ok_mask) 54 | num_ok_pixels += num_ok_pixels_in_map 55 | 56 | # Compute by how much the FPR changes when each anomaly score is 57 | # added to the set of positives. 58 | # fp_change needs to be normalized later when we know the final value 59 | # of num_ok_pixels -> right now it is only the change in the number of 60 | # false positives 61 | fp_change = np.zeros_like(gt_map, dtype=fp_changes.dtype) 62 | fp_change[ok_mask] = 1 63 | 64 | # Compute by how much the PRO changes when each anomaly score is 65 | # added to the set of positives. 66 | # pro_change needs to be normalized later when we know the final value 67 | # of num_gt_regions. 68 | pro_change = np.zeros_like(gt_map, dtype=np.float64) 69 | for k in range(n_components): 70 | region_mask = labeled == (k + 1) 71 | region_size = np.sum(region_mask) 72 | pro_change[region_mask] = 1. / region_size 73 | 74 | fp_changes[gt_ind, :, :] = fp_change 75 | pro_changes[gt_ind, :, :] = pro_change 76 | 77 | # Flatten the numpy arrays before sorting. 78 | anomaly_scores_flat = np.array(anomaly_maps).ravel() 79 | fp_changes_flat = fp_changes.ravel() 80 | pro_changes_flat = pro_changes.ravel() 81 | 82 | # Sort all anomaly scores. 83 | print(f"Sort {len(anomaly_scores_flat)} anomaly scores...") 84 | sort_idxs = np.argsort(anomaly_scores_flat).astype(np.uint32)[::-1] 85 | 86 | # Info: np.take(a, ind, out=a) followed by b=a instead of 87 | # b=a[ind] showed to be more memory efficient. 88 | np.take(anomaly_scores_flat, sort_idxs, out=anomaly_scores_flat) 89 | anomaly_scores_sorted = anomaly_scores_flat 90 | np.take(fp_changes_flat, sort_idxs, out=fp_changes_flat) 91 | fp_changes_sorted = fp_changes_flat 92 | np.take(pro_changes_flat, sort_idxs, out=pro_changes_flat) 93 | pro_changes_sorted = pro_changes_flat 94 | 95 | del sort_idxs 96 | 97 | # Get the (FPR, PRO) curve values. 98 | np.cumsum(fp_changes_sorted, out=fp_changes_sorted) 99 | fp_changes_sorted = fp_changes_sorted.astype(np.float32, copy=False) 100 | np.divide(fp_changes_sorted, num_ok_pixels, out=fp_changes_sorted) 101 | fprs = fp_changes_sorted 102 | 103 | np.cumsum(pro_changes_sorted, out=pro_changes_sorted) 104 | np.divide(pro_changes_sorted, num_gt_regions, out=pro_changes_sorted) 105 | pros = pro_changes_sorted 106 | 107 | # Merge (FPR, PRO) points that occur together at the same threshold. 108 | # For those points, only the final (FPR, PRO) point should be kept. 109 | # That is because that point is the one that takes all changes 110 | # to the FPR and the PRO at the respective threshold into account. 111 | # -> keep_mask is True if the subsequent score is different from the 112 | # score at the respective position. 113 | # anomaly_scores_sorted = [7, 4, 4, 4, 3, 1, 1] 114 | # -> keep_mask = [T, F, F, T, T, F] 115 | keep_mask = np.append(np.diff(anomaly_scores_sorted) != 0, np.True_) 116 | del anomaly_scores_sorted 117 | 118 | fprs = fprs[keep_mask] 119 | pros = pros[keep_mask] 120 | del keep_mask 121 | 122 | # To mitigate the adding up of numerical errors during the np.cumsum calls, 123 | # make sure that the curve ends at (1, 1) and does not contain values > 1. 124 | np.clip(fprs, a_min=None, a_max=1., out=fprs) 125 | np.clip(pros, a_min=None, a_max=1., out=pros) 126 | 127 | # Make the fprs and pros start at 0 and end at 1. 128 | zero = np.array([0.]) 129 | one = np.array([1.]) 130 | 131 | return np.concatenate((zero, fprs, one)), np.concatenate((zero, pros, one)) 132 | 133 | 134 | def main(): 135 | """ 136 | Compute the area under the PRO curve for a toy dataset and an algorithm 137 | that randomly assigns anomaly scores to each pixel. The integration 138 | limit can be specified. 139 | """ 140 | 141 | from generic_util import trapezoid, generate_toy_dataset 142 | 143 | integration_limit = 0.3 144 | 145 | # Generate a toy dataset. 146 | anomaly_maps, ground_truth_maps = generate_toy_dataset( 147 | num_images=200, image_width=500, image_height=300, gt_size=10) 148 | 149 | # Compute the PRO curve for this dataset. 150 | all_fprs, all_pros = compute_pro( 151 | anomaly_maps=anomaly_maps, 152 | ground_truth_maps=ground_truth_maps) 153 | 154 | au_pro = trapezoid(all_fprs, all_pros, x_max=integration_limit) 155 | au_pro /= integration_limit 156 | print(f"AU-PRO (FPR limit: {integration_limit}): {au_pro}") 157 | 158 | 159 | if __name__ == "__main__": 160 | main() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .coordatt import * 2 | from .decoder import * 3 | from .memseg import * 4 | from .msff import * 5 | from .memory_module import * -------------------------------------------------------------------------------- /models/coordatt.py: -------------------------------------------------------------------------------- 1 | # https://github.com/houqb/CoordAttention/blob/main/coordatt.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import torch.nn.functional as F 7 | 8 | class h_sigmoid(nn.Module): 9 | def __init__(self, inplace=True): 10 | super(h_sigmoid, self).__init__() 11 | self.relu = nn.ReLU6(inplace=inplace) 12 | 13 | def forward(self, x): 14 | return self.relu(x + 3) / 6 15 | 16 | class h_swish(nn.Module): 17 | def __init__(self, inplace=True): 18 | super(h_swish, self).__init__() 19 | self.sigmoid = h_sigmoid(inplace=inplace) 20 | 21 | def forward(self, x): 22 | return x * self.sigmoid(x) 23 | 24 | 25 | class CoordAtt(nn.Module): 26 | def __init__(self, inp, oup, reduction=32): 27 | super(CoordAtt, self).__init__() 28 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) 29 | self.pool_w = nn.AdaptiveAvgPool2d((1, None)) 30 | 31 | mip = max(8, inp // reduction) 32 | 33 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) 34 | self.bn1 = nn.BatchNorm2d(mip) 35 | self.act = h_swish() 36 | 37 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 38 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 39 | 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | n,c,h,w = x.size() 45 | x_h = self.pool_h(x) 46 | x_w = self.pool_w(x).permute(0, 1, 3, 2) 47 | 48 | y = torch.cat([x_h, x_w], dim=2) 49 | y = self.conv1(y) 50 | y = self.bn1(y) 51 | y = self.act(y) 52 | 53 | x_h, x_w = torch.split(y, [h, w], dim=2) 54 | x_w = x_w.permute(0, 1, 3, 2) 55 | 56 | a_h = self.conv_h(x_h).sigmoid() 57 | a_w = self.conv_w(x_w).sigmoid() 58 | 59 | out = identity * a_w * a_h 60 | 61 | return out -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class UpConvBlock(nn.Module): 6 | def __init__(self, in_channel, out_channel): 7 | super(UpConvBlock, self).__init__() 8 | self.blk = nn.Sequential( 9 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 10 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), 11 | nn.BatchNorm2d(out_channel), 12 | nn.ReLU() 13 | ) 14 | 15 | def forward(self, x): 16 | return self.blk(x) 17 | 18 | 19 | class Decoder(nn.Module): 20 | def __init__(self): 21 | super(Decoder, self).__init__() 22 | 23 | self.conv = nn.Conv2d(64, 48, kernel_size=3, stride=1, padding=1) 24 | 25 | self.upconv3 = UpConvBlock(512, 256) 26 | self.upconv2 = UpConvBlock(512, 128) 27 | self.upconv1 = UpConvBlock(256, 64) 28 | self.upconv0 = UpConvBlock(128, 48) 29 | self.upconv2mask = UpConvBlock(96, 48) 30 | 31 | self.final_conv = nn.Conv2d(48, 2, kernel_size=3, stride=1, padding=1) 32 | 33 | def forward(self, encoder_output, concat_features): 34 | # concat_features = [level0, level1, level2, level3] 35 | f0, f1, f2, f3 = concat_features 36 | 37 | # 512 x 8 x 8 -> 512 x 16 x 16 38 | x_up3 = self.upconv3(encoder_output) 39 | x_up3 = torch.cat([x_up3, f3], dim=1) 40 | 41 | # 512 x 16 x 16 -> 256 x 32 x 32 42 | x_up2 = self.upconv2(x_up3) 43 | x_up2 = torch.cat([x_up2, f2], dim=1) 44 | 45 | # 256 x 32 x 32 -> 128 x 64 x 64 46 | x_up1 = self.upconv1(x_up2) 47 | x_up1 = torch.cat([x_up1, f1], dim=1) 48 | 49 | # 128 x 64 x 64 -> 96 x 128 x 128 50 | x_up0 = self.upconv0(x_up1) 51 | f0 = self.conv(f0) 52 | x_up2mask = torch.cat([x_up0, f0], dim=1) 53 | 54 | # 96 x 128 x 128 -> 48 x 256 x 256 55 | x_mask = self.upconv2mask(x_up2mask) 56 | 57 | # 48 x 256 x 256 -> 1 x 256 x 256 58 | x_mask = self.final_conv(x_mask) 59 | 60 | return x_mask -------------------------------------------------------------------------------- /models/memory_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | from typing import List 6 | 7 | 8 | class MemoryBank: 9 | def __init__(self, normal_dataset, nb_memory_sample: int = 30, device='cpu'): 10 | self.device = device 11 | 12 | # memory bank 13 | self.memory_information = {} 14 | 15 | # normal dataset 16 | self.normal_dataset = normal_dataset 17 | 18 | # the number of samples saved in memory bank 19 | self.nb_memory_sample = nb_memory_sample 20 | 21 | 22 | def update(self, feature_extractor): 23 | feature_extractor.eval() 24 | 25 | # define sample index 26 | samples_idx = np.arange(len(self.normal_dataset)) 27 | np.random.shuffle(samples_idx) 28 | 29 | # extract features and save features into memory bank 30 | with torch.no_grad(): 31 | for i in range(self.nb_memory_sample): 32 | # select image 33 | input_normal, _, _ = self.normal_dataset[samples_idx[i]] 34 | input_normal = input_normal.to(self.device) 35 | 36 | # extract features 37 | features = feature_extractor(input_normal.unsqueeze(0)) 38 | 39 | # save features into memoery bank 40 | for i, features_l in enumerate(features[1:-1]): 41 | if f'level{i}' not in self.memory_information.keys(): 42 | self.memory_information[f'level{i}'] = features_l 43 | else: 44 | self.memory_information[f'level{i}'] = torch.cat([self.memory_information[f'level{i}'], features_l], dim=0) 45 | 46 | 47 | def _calc_diff(self, features: List[torch.Tensor]) -> torch.Tensor: 48 | # batch size X the number of samples saved in memory 49 | diff_bank = torch.zeros(features[0].size(0), self.nb_memory_sample).to(self.device) 50 | 51 | # level 52 | for l, level in enumerate(self.memory_information.keys()): 53 | # batch 54 | for b_idx, features_b in enumerate(features[l]): 55 | # calculate l2 loss 56 | diff = F.mse_loss( 57 | input = torch.repeat_interleave(features_b.unsqueeze(0), repeats=self.nb_memory_sample, dim=0), 58 | target = self.memory_information[level], 59 | reduction ='none' 60 | ).mean(dim=[1,2,3]) 61 | 62 | # sum loss 63 | diff_bank[b_idx] += diff 64 | 65 | return diff_bank 66 | 67 | 68 | def select(self, features: List[torch.Tensor]) -> torch.Tensor: 69 | # calculate difference between features and normal features of memory bank 70 | diff_bank = self._calc_diff(features=features) 71 | 72 | # concatenate features with minimum difference features of memory bank 73 | for l, level in enumerate(self.memory_information.keys()): 74 | 75 | selected_features = torch.index_select(self.memory_information[level], dim=0, index=diff_bank.argmin(dim=1)) 76 | diff_features = F.mse_loss(selected_features, features[l], reduction='none') 77 | features[l] = torch.cat([features[l], diff_features], dim=1) 78 | 79 | return features 80 | -------------------------------------------------------------------------------- /models/memseg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .decoder import Decoder 3 | from .msff import MSFF 4 | 5 | class MemSeg(nn.Module): 6 | def __init__(self, memory_bank, feature_extractor): 7 | super(MemSeg, self).__init__() 8 | 9 | self.memory_bank = memory_bank 10 | self.feature_extractor = feature_extractor 11 | self.msff = MSFF() 12 | self.decoder = Decoder() 13 | 14 | def forward(self, inputs): 15 | # extract features 16 | features = self.feature_extractor(inputs) 17 | f_in = features[0] 18 | f_out = features[-1] 19 | f_ii = features[1:-1] 20 | 21 | # extract concatenated information(CI) 22 | concat_features = self.memory_bank.select(features = f_ii) 23 | 24 | # Multi-scale Feature Fusion(MSFF) Module 25 | msff_outputs = self.msff(features = concat_features) 26 | 27 | # decoder 28 | predicted_mask = self.decoder( 29 | encoder_output = f_out, 30 | concat_features = [f_in] + msff_outputs 31 | ) 32 | 33 | return predicted_mask 34 | -------------------------------------------------------------------------------- /models/msff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | from .coordatt import CoordAtt 6 | 7 | class MSFFBlock(nn.Module): 8 | def __init__(self, in_channel): 9 | super(MSFFBlock, self).__init__() 10 | self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1) 11 | self.attn = CoordAtt(in_channel, in_channel) 12 | self.conv2 = nn.Sequential( 13 | nn.Conv2d(in_channel, in_channel // 2, kernel_size=3, stride=1, padding=1), 14 | nn.Conv2d(in_channel // 2, in_channel // 2, kernel_size=3, stride=1, padding=1) 15 | ) 16 | 17 | def forward(self, x): 18 | x_conv = self.conv1(x) 19 | x_att = self.attn(x) 20 | 21 | x = x_conv * x_att 22 | x = self.conv2(x) 23 | return x 24 | 25 | 26 | class MSFF(nn.Module): 27 | def __init__(self): 28 | super(MSFF, self).__init__() 29 | self.blk1 = MSFFBlock(128) 30 | self.blk2 = MSFFBlock(256) 31 | self.blk3 = MSFFBlock(512) 32 | 33 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 34 | 35 | self.upconv32 = nn.Sequential( 36 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 37 | nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) 38 | ) 39 | self.upconv21 = nn.Sequential( 40 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 41 | nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1) 42 | ) 43 | 44 | def forward(self, features): 45 | # features = [level1, level2, level3] 46 | f1, f2, f3 = features 47 | 48 | # MSFF Module 49 | f1_k = self.blk1(f1) 50 | f2_k = self.blk2(f2) 51 | f3_k = self.blk3(f3) 52 | 53 | f2_f = f2_k + self.upconv32(f3_k) 54 | f1_f = f1_k + self.upconv21(f2_f) 55 | 56 | # spatial attention 57 | 58 | # mask 59 | m3 = f3[:,256:,...].mean(dim=1, keepdim=True) 60 | m2 = f2[:,128:,...].mean(dim=1, keepdim=True) * self.upsample(m3) 61 | m1 = f1[:,64:,...].mean(dim=1, keepdim=True) * self.upsample(m2) 62 | 63 | f1_out = f1_f * m1 64 | f2_out = f2_f * m2 65 | f3_out = f3_k * m3 66 | 67 | return [f1_out, f2_out, f3_out] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.5.0 2 | timm==0.5.4 3 | wandb==0.12.17 4 | omegaconf 5 | imgaug==0.4.0 6 | ipywidgets==8.1.1 7 | voila==0.5.5 -------------------------------------------------------------------------------- /saved_model/MemSeg-bottle/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 4599, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.9878949119786599, 5 | "eval_AUPRO-pixel": 0.9835623941012489 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-bottle/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.9810207179490311, 5 | "eval_AUPRO-pixel": 0.9756835825062008 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-cable/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 2999, 3 | "eval_AUROC-image": 0.9250374812593704, 4 | "eval_AUROC-pixel": 0.8230074651568189, 5 | "eval_AUPRO-pixel": 0.8731255778978297 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-cable/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.8967391304347826, 4 | "eval_AUROC-pixel": 0.7628859940117106, 5 | "eval_AUPRO-pixel": 0.8339052111761991 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-capsule/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 4499, 3 | "eval_AUROC-image": 0.9541284403669725, 4 | "eval_AUROC-pixel": 0.9842618218962903, 5 | "eval_AUPRO-pixel": 0.9772760890870018 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-capsule/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9509373753490227, 4 | "eval_AUROC-pixel": 0.9818320950550422, 5 | "eval_AUPRO-pixel": 0.9762463379799712 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-carpet/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 999, 3 | "eval_AUROC-image": 0.9911717495987158, 4 | "eval_AUROC-pixel": 0.9754415202769691, 5 | "eval_AUPRO-pixel": 0.9702277436742398 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-carpet/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9225521669341894, 4 | "eval_AUROC-pixel": 0.9521577575948166, 5 | "eval_AUPRO-pixel": 0.9304082803819657 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-grid/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 2299, 3 | "eval_AUROC-image": 0.9983291562238931, 4 | "eval_AUROC-pixel": 0.9836655532727973, 5 | "eval_AUPRO-pixel": 0.9853033944553282 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-grid/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9949874686716792, 4 | "eval_AUROC-pixel": 0.9839286697331651, 5 | "eval_AUPRO-pixel": 0.9811414483214054 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-hazelnut/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 2799, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.977846774892756, 5 | "eval_AUPRO-pixel": 0.9899949505024724 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-hazelnut/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.9530533678779027, 5 | "eval_AUPRO-pixel": 0.9864527018411502 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-leather/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 2499, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.9882677082218918, 5 | "eval_AUPRO-pixel": 0.9908964012921226 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-leather/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.9806788880434785, 5 | "eval_AUPRO-pixel": 0.9771980890520273 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-metal_nut/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 3599, 3 | "eval_AUROC-image": 0.9946236559139785, 4 | "eval_AUROC-pixel": 0.8848478801101993, 5 | "eval_AUPRO-pixel": 0.9500360784835721 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-metal_nut/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9912023460410557, 4 | "eval_AUROC-pixel": 0.8702097697878765, 5 | "eval_AUPRO-pixel": 0.9522523369617932 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-pill/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 3899, 3 | "eval_AUROC-image": 0.9705400981996726, 4 | "eval_AUROC-pixel": 0.9829100041961665, 5 | "eval_AUPRO-pixel": 0.9796020986582216 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-pill/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9735406437534097, 4 | "eval_AUROC-pixel": 0.9837363970095163, 5 | "eval_AUPRO-pixel": 0.9726783368058662 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-screw/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 2099, 3 | "eval_AUROC-image": 0.9485550317688051, 4 | "eval_AUROC-pixel": 0.9508182505179927, 5 | "eval_AUPRO-pixel": 0.9400450867663136 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-screw/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9321582291453167, 4 | "eval_AUROC-pixel": 0.9421926937096599, 5 | "eval_AUPRO-pixel": 0.9239317290959074 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-tile/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 2599, 3 | "eval_AUROC-image": 0.9985569985569985, 4 | "eval_AUROC-pixel": 0.9938479608847095, 5 | "eval_AUPRO-pixel": 0.9880707645224925 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-tile/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9967532467532467, 4 | "eval_AUROC-pixel": 0.9954440823852293, 5 | "eval_AUPRO-pixel": 0.9881914197355043 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-toothbrush/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 3499, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.9928249709185054, 5 | "eval_AUPRO-pixel": 0.985574884437306 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-toothbrush/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.9930413807016041, 5 | "eval_AUPRO-pixel": 0.9831682986954344 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-transistor/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 4199, 3 | "eval_AUROC-image": 0.965, 4 | "eval_AUROC-pixel": 0.7629001885386938, 5 | "eval_AUPRO-pixel": 0.8606269093140095 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-transistor/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9566666666666667, 4 | "eval_AUROC-pixel": 0.7424347738574396, 5 | "eval_AUPRO-pixel": 0.8319123925254669 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-wood/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 3799, 3 | "eval_AUROC-image": 1.0, 4 | "eval_AUROC-pixel": 0.9753957059774258, 5 | "eval_AUPRO-pixel": 0.9762236672467449 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-wood/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.999122807017544, 4 | "eval_AUROC-pixel": 0.972050242335017, 5 | "eval_AUPRO-pixel": 0.9735705883386424 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-zipper/best_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "best_step": 4899, 3 | "eval_AUROC-image": 0.9994747899159664, 4 | "eval_AUROC-pixel": 0.9794035590859894, 5 | "eval_AUPRO-pixel": 0.9725967686762359 6 | } -------------------------------------------------------------------------------- /saved_model/MemSeg-zipper/latest_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "latest_step": 5000, 3 | "eval_AUROC-image": 0.9994747899159664, 4 | "eval_AUROC-pixel": 0.9773868547463116, 5 | "eval_AUPRO-pixel": 0.9692512286487627 6 | } -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | # https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/blob/master/cosine_annealing_warmup/scheduler.py 2 | 3 | import math 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | class CosineAnnealingWarmupRestarts(_LRScheduler): 8 | """ 9 | optimizer (Optimizer): Wrapped optimizer. 10 | first_cycle_steps (int): First cycle step size. 11 | cycle_mult(float): Cycle steps magnification. Default: -1. 12 | max_lr(float): First cycle's max learning rate. Default: 0.1. 13 | min_lr(float): Min learning rate. Default: 0.001. 14 | warmup_steps(int): Linear warmup step size. Default: 0. 15 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 16 | last_epoch (int): The index of last epoch. Default: -1. 17 | """ 18 | 19 | def __init__(self, 20 | optimizer : torch.optim.Optimizer, 21 | first_cycle_steps : int, 22 | cycle_mult : float = 1., 23 | max_lr : float = 0.1, 24 | min_lr : float = 0.001, 25 | warmup_steps : int = 0, 26 | gamma : float = 1., 27 | last_epoch : int = -1 28 | ): 29 | assert warmup_steps < first_cycle_steps 30 | 31 | self.first_cycle_steps = first_cycle_steps # first cycle step size 32 | self.cycle_mult = cycle_mult # cycle steps magnification 33 | self.base_max_lr = max_lr # first max learning rate 34 | self.max_lr = max_lr # max learning rate in the current cycle 35 | self.min_lr = min_lr # min learning rate 36 | self.warmup_steps = warmup_steps # warmup step size 37 | self.gamma = gamma # decrease rate of max learning rate by cycle 38 | 39 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 40 | self.cycle = 0 # cycle count 41 | self.step_in_cycle = last_epoch # step size of the current cycle 42 | 43 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 44 | 45 | # set learning rate min_lr 46 | self.init_lr() 47 | 48 | def init_lr(self): 49 | self.base_lrs = [] 50 | for param_group in self.optimizer.param_groups: 51 | param_group['lr'] = self.min_lr 52 | self.base_lrs.append(self.min_lr) 53 | 54 | def get_lr(self): 55 | if self.step_in_cycle == -1: 56 | return self.base_lrs 57 | elif self.step_in_cycle < self.warmup_steps: 58 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 59 | else: 60 | return [base_lr + (self.max_lr - base_lr) \ 61 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 62 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 63 | for base_lr in self.base_lrs] 64 | 65 | def step(self, epoch=None): 66 | if epoch is None: 67 | epoch = self.last_epoch + 1 68 | self.step_in_cycle = self.step_in_cycle + 1 69 | if self.step_in_cycle >= self.cur_cycle_steps: 70 | self.cycle += 1 71 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 72 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 73 | else: 74 | if epoch >= self.first_cycle_steps: 75 | if self.cycle_mult == 1.: 76 | self.step_in_cycle = epoch % self.first_cycle_steps 77 | self.cycle = epoch // self.first_cycle_steps 78 | else: 79 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 80 | self.cycle = n 81 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 82 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 83 | else: 84 | self.cur_cycle_steps = self.first_cycle_steps 85 | self.step_in_cycle = epoch 86 | 87 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 88 | self.last_epoch = math.floor(epoch) 89 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 90 | param_group['lr'] = lr -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import os 4 | import wandb 5 | import logging 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from typing import List 11 | from sklearn.metrics import roc_auc_score 12 | from metrics import compute_pro, trapezoid 13 | 14 | _logger = logging.getLogger('train') 15 | 16 | class AverageMeter: 17 | """Computes and stores the average and current value""" 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | 34 | 35 | def training(model, trainloader, validloader, criterion, optimizer, scheduler, num_training_steps: int = 1000, loss_weights: List[float] = [0.6, 0.4], 36 | log_interval: int = 1, eval_interval: int = 1, savedir: str = None, use_wandb: bool = False, device: str ='cpu') -> dict: 37 | 38 | batch_time_m = AverageMeter() 39 | data_time_m = AverageMeter() 40 | losses_m = AverageMeter() 41 | l1_losses_m = AverageMeter() 42 | focal_losses_m = AverageMeter() 43 | 44 | # criterion 45 | l1_criterion, focal_criterion = criterion 46 | l1_weight, focal_weight = loss_weights 47 | 48 | # set train mode 49 | model.train() 50 | 51 | # set optimizer 52 | optimizer.zero_grad() 53 | 54 | # training 55 | best_score = 0 56 | step = 0 57 | train_mode = True 58 | while train_mode: 59 | 60 | end = time.time() 61 | for inputs, masks, targets in trainloader: 62 | # batch 63 | inputs, masks, targets = inputs.to(device), masks.to(device), targets.to(device) 64 | 65 | data_time_m.update(time.time() - end) 66 | 67 | # predict 68 | outputs = model(inputs) 69 | outputs = F.softmax(outputs, dim=1) 70 | l1_loss = l1_criterion(outputs[:,1,:], masks) 71 | focal_loss = focal_criterion(outputs, masks) 72 | loss = (l1_weight * l1_loss) + (focal_weight * focal_loss) 73 | 74 | loss.backward() 75 | 76 | # update weight 77 | optimizer.step() 78 | optimizer.zero_grad() 79 | 80 | # log loss 81 | l1_losses_m.update(l1_loss.item()) 82 | focal_losses_m.update(focal_loss.item()) 83 | losses_m.update(loss.item()) 84 | 85 | batch_time_m.update(time.time() - end) 86 | 87 | # wandb 88 | if use_wandb: 89 | wandb.log({ 90 | 'lr':optimizer.param_groups[0]['lr'], 91 | 'train_focal_loss':focal_losses_m.val, 92 | 'train_l1_loss':l1_losses_m.val, 93 | 'train_loss':losses_m.val 94 | }, 95 | step=step) 96 | 97 | if (step+1) % log_interval == 0 or step == 0: 98 | _logger.info('TRAIN [{:>4d}/{}] ' 99 | 'Loss: {loss.val:>6.4f} ({loss.avg:>6.4f}) ' 100 | 'L1 Loss: {l1_loss.val:>6.4f} ({l1_loss.avg:>6.4f}) ' 101 | 'Focal Loss: {focal_loss.val:>6.4f} ({focal_loss.avg:>6.4f}) ' 102 | 'LR: {lr:.3e} ' 103 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 104 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 105 | step+1, num_training_steps, 106 | loss = losses_m, 107 | l1_loss = l1_losses_m, 108 | focal_loss = focal_losses_m, 109 | lr = optimizer.param_groups[0]['lr'], 110 | batch_time = batch_time_m, 111 | rate = inputs.size(0) / batch_time_m.val, 112 | rate_avg = inputs.size(0) / batch_time_m.avg, 113 | data_time = data_time_m)) 114 | 115 | 116 | if ((step+1) % eval_interval == 0 and step != 0) or (step+1) == num_training_steps: 117 | eval_metrics = evaluate( 118 | model = model, 119 | dataloader = validloader, 120 | device = device 121 | ) 122 | model.train() 123 | 124 | eval_log = dict([(f'eval_{k}', v) for k, v in eval_metrics.items()]) 125 | 126 | # wandb 127 | if use_wandb: 128 | wandb.log(eval_log, step=step) 129 | 130 | # checkpoint 131 | if best_score < np.mean(list(eval_metrics.values())): 132 | # save best score 133 | state = {'best_step':step} 134 | state.update(eval_log) 135 | json.dump(state, open(os.path.join(savedir, 'best_score.json'),'w'), indent='\t') 136 | 137 | # save best model 138 | torch.save(model.state_dict(), os.path.join(savedir, f'best_model.pt')) 139 | 140 | _logger.info('Best Score {0:.3%} to {1:.3%}'.format(best_score, np.mean(list(eval_metrics.values())))) 141 | 142 | best_score = np.mean(list(eval_metrics.values())) 143 | 144 | # scheduler 145 | if scheduler: 146 | scheduler.step() 147 | 148 | end = time.time() 149 | 150 | step += 1 151 | 152 | if step == num_training_steps: 153 | train_mode = False 154 | break 155 | 156 | # print best score and step 157 | _logger.info('Best Metric: {0:.3%} (step {1:})'.format(best_score, state['best_step'])) 158 | 159 | # save latest model 160 | torch.save(model.state_dict(), os.path.join(savedir, f'latest_model.pt')) 161 | 162 | # save latest score 163 | state = {'latest_step':step} 164 | state.update(eval_log) 165 | json.dump(state, open(os.path.join(savedir, 'latest_score.json'),'w'), indent='\t') 166 | 167 | 168 | 169 | 170 | def evaluate(model, dataloader, device: str = 'cpu'): 171 | # targets and outputs 172 | image_targets = [] 173 | image_masks = [] 174 | anomaly_score = [] 175 | anomaly_map = [] 176 | 177 | model.eval() 178 | with torch.no_grad(): 179 | for idx, (inputs, masks, targets) in enumerate(dataloader): 180 | inputs, masks, targets = inputs.to(device), masks.to(device), targets.to(device) 181 | 182 | # predict 183 | outputs = model(inputs) 184 | outputs = F.softmax(outputs, dim=1) 185 | anomaly_score_i = torch.topk(torch.flatten(outputs[:,1,:], start_dim=1), 100)[0].mean(dim=1) 186 | 187 | # stack targets and outputs 188 | image_targets.extend(targets.cpu().tolist()) 189 | image_masks.extend(masks.cpu().numpy()) 190 | 191 | anomaly_score.extend(anomaly_score_i.cpu().tolist()) 192 | anomaly_map.extend(outputs[:,1,:].cpu().numpy()) 193 | 194 | # metrics 195 | image_masks = np.array(image_masks) 196 | anomaly_map = np.array(anomaly_map) 197 | 198 | auroc_image = roc_auc_score(image_targets, anomaly_score) 199 | auroc_pixel = roc_auc_score(image_masks.reshape(-1).astype(int), anomaly_map.reshape(-1)) 200 | all_fprs, all_pros = compute_pro( 201 | anomaly_maps = anomaly_map, 202 | ground_truth_maps = image_masks 203 | ) 204 | aupro = trapezoid(all_fprs, all_pros) 205 | 206 | metrics = { 207 | 'AUROC-image':auroc_image, 208 | 'AUROC-pixel':auroc_pixel, 209 | 'AUPRO-pixel':aupro 210 | 211 | } 212 | 213 | _logger.info('TEST: AUROC-image: %.3f%% | AUROC-pixel: %.3f%% | AUPRO-pixel: %.3f%%' % 214 | (metrics['AUROC-image'], metrics['AUROC-pixel'], metrics['AUPRO-pixel'])) 215 | 216 | 217 | return metrics 218 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import os 5 | 6 | def torch_seed(random_seed): 7 | torch.manual_seed(random_seed) 8 | torch.cuda.manual_seed(random_seed) 9 | torch.cuda.manual_seed_all(random_seed) # if use multi-GPU 10 | # CUDA randomness 11 | torch.backends.cudnn.deterministic = True 12 | torch.backends.cudnn.benchmark = False 13 | 14 | np.random.seed(random_seed) 15 | random.seed(random_seed) 16 | os.environ['PYTHONHASHSEED'] = str(random_seed) --------------------------------------------------------------------------------