├── .gitignore ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── constant.py ├── data ├── __init__.py ├── data_utils.py └── mvtec_dataset.py ├── eval.py ├── model ├── __init__.py ├── destseg.py ├── losses.py ├── metrics.py └── model_utils.py ├── requirements.txt ├── scripts └── download_dataset.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # custom 132 | saved_model 133 | datasets 134 | .out 135 | logs 136 | 137 | .vscode 138 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | _____________________ 5 | 6 | Zavrtanik, Vitjan and Kristan, Matej and Skocaj, Danijel 7 | Copyright (c) 2021 VitjanZ 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeSTSeg 2 | 3 | Official PyTorch implementation of [DeSTSeg](https://openaccess.thecvf.com/content/CVPR2023/html/Zhang_DeSTSeg_Segmentation_Guided_Denoising_Student-Teacher_for_Anomaly_Detection_CVPR_2023_paper.html) - CVPR 2023 4 | ## Datasets 5 | 6 | We use the MVTec AD dataset for experiments. To simulate anomalous image, the Describable Textures Dataset (DTD) is also adopted in our work. Users can run the **download_dataset.sh** script to download them directly. 7 | 8 | ``` 9 | ./scripts/download_dataset.sh 10 | ``` 11 | 12 | ## Installation 13 | 14 | Please install the dependency packages using the following command by **pip**: 15 | 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## Training and Testing 21 | 22 | To get started, users can run the following command to train the model on all categories of MVTec AD dataset: 23 | 24 | ``` 25 | python train.py --gpu_id 0 --num_workers 16 26 | ``` 27 | 28 | Users can also customize some default training parameters by resetting arguments like `--bs`, `--lr_DeST`, `--lr_res`, `--lr_seghead`, `--steps`, `--DeST_steps`, `--eval_per_steps`, `--log_per_steps`, `--gamma` and `--T`. 29 | 30 | To specify the training categories and the corresponding data augmentation strategies, please add the argument `--custom_training_category` and then add the categories after the arguments `--no_rotation_category`, `--slight_rotation_category` and `--rotation_category`. For example, to train the `screw` category and the `tile` category with no data augmentation strategy, just run the following command: 31 | 32 | ``` 33 | python train.py --gpu_id 0 --num_workers 16 --custom_training_category --no_rotation_category screw tile 34 | ``` 35 | 36 | To test the performance of the model, users can run the following command: 37 | 38 | ``` 39 | python eval.py --gpu_id 0 --num_workers 16 40 | ``` 41 | 42 | ## Pretrained Checkpoints 43 | 44 | Download pretrained checkpoints [here](https://www.icloud.com.cn/iclouddrive/0a3OPg_3wcMs38yDpWNRnRW9Q#saved%5Fmodel) and put the checkpoints under `/saved_model/`. 45 | 46 | ## Citation 47 | 48 | ``` 49 | @inproceedings{zhang2023destseg, 50 | title={DeSTSeg: Segmentation Guided Denoising Student-Teacher for Anomaly Detection}, 51 | author={Zhang, Xuan and Li, Shiyu and Li, Xi and Huang, Ping and Shan, Jiulong and Chen, Ting}, 52 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 53 | pages={3914--3923}, 54 | year={2023} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /constant.py: -------------------------------------------------------------------------------- 1 | ALL_CATEGORY = [ 2 | "capsule", 3 | "metal_nut", 4 | "pill", 5 | "toothbrush", 6 | "transistor", 7 | "wood", 8 | "zipper", 9 | "cable", 10 | "bottle", 11 | "grid", 12 | "hazelnut", 13 | "leather", 14 | "tile", 15 | "carpet", 16 | "screw", 17 | ] 18 | RESIZE_SHAPE = [256, 256] # width * height 19 | NORMALIZE_MEAN = [0.485, 0.456, 0.406] 20 | NORMALIZE_STD = [0.229, 0.224, 0.225] 21 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_utils, mvtec_dataset 2 | -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import cv2 4 | import imgaug.augmenters as iaa 5 | import numpy as np 6 | import torch 7 | 8 | """The scripts here are copied from DRAEM: https://github.com/VitjanZ/DRAEM""" 9 | 10 | 11 | def lerp_np(x, y, w): 12 | fin_out = (y - x) * w + x 13 | return fin_out 14 | 15 | 16 | def rand_perlin_2d_np( 17 | shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3 18 | ): 19 | delta = (res[0] / shape[0], res[1] / shape[1]) 20 | d = (shape[0] // res[0], shape[1] // res[1]) 21 | grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 22 | 23 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) 24 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) 25 | tt = np.repeat(np.repeat(gradients, d[0], axis=0), d[1], axis=1) 26 | 27 | tile_grads = lambda slice1, slice2: cv2.resize( 28 | np.repeat( 29 | np.repeat( 30 | gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]], d[0], axis=0 31 | ), 32 | d[1], 33 | axis=1, 34 | ), 35 | dsize=(shape[1], shape[0]), 36 | ) 37 | dot = lambda grad, shift: ( 38 | np.stack( 39 | ( 40 | grid[: shape[0], : shape[1], 0] + shift[0], 41 | grid[: shape[0], : shape[1], 1] + shift[1], 42 | ), 43 | axis=-1, 44 | ) 45 | * grad[: shape[0], : shape[1]] 46 | ).sum(axis=-1) 47 | 48 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 49 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 50 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 51 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 52 | t = fade(grid[: shape[0], : shape[1]]) 53 | return math.sqrt(2) * lerp_np( 54 | lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1] 55 | ) 56 | 57 | 58 | rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) 59 | 60 | 61 | def perlin_noise(image, dtd_image, aug_prob=1.0): 62 | image = np.array(image, dtype=np.float32) 63 | dtd_image = np.array(dtd_image, dtype=np.float32) 64 | shape = image.shape[:2] 65 | min_perlin_scale, max_perlin_scale = 0, 6 66 | t_x = torch.randint(min_perlin_scale, max_perlin_scale, (1,)).numpy()[0] 67 | t_y = torch.randint(min_perlin_scale, max_perlin_scale, (1,)).numpy()[0] 68 | perlin_scalex, perlin_scaley = 2**t_x, 2**t_y 69 | 70 | perlin_noise = rand_perlin_2d_np(shape, (perlin_scalex, perlin_scaley)) 71 | 72 | perlin_noise = rot(images=perlin_noise) 73 | perlin_noise = np.expand_dims(perlin_noise, axis=2) 74 | threshold = 0.5 75 | perlin_thr = np.where( 76 | perlin_noise > threshold, 77 | np.ones_like(perlin_noise), 78 | np.zeros_like(perlin_noise), 79 | ) 80 | 81 | img_thr = dtd_image * perlin_thr / 255.0 82 | image = image / 255.0 83 | 84 | beta = torch.rand(1).numpy()[0] * 0.8 85 | image_aug = ( 86 | image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * (perlin_thr) 87 | ) 88 | image_aug = image_aug.astype(np.float32) 89 | 90 | no_anomaly = torch.rand(1).numpy()[0] 91 | 92 | if no_anomaly > aug_prob: 93 | return image, np.zeros_like(perlin_thr) 94 | else: 95 | msk = (perlin_thr).astype(np.float32) 96 | msk = msk.transpose(2, 0, 1) 97 | 98 | return image_aug, msk 99 | -------------------------------------------------------------------------------- /data/mvtec_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import perlin_noise 11 | 12 | 13 | class MVTecDataset(Dataset): 14 | def __init__( 15 | self, 16 | is_train, 17 | mvtec_dir, 18 | resize_shape=[256, 256], 19 | normalize_mean=[0.485, 0.456, 0.406], 20 | normalize_std=[0.229, 0.224, 0.225], 21 | dtd_dir=None, 22 | rotate_90=False, 23 | random_rotate=0, 24 | ): 25 | super().__init__() 26 | self.resize_shape = resize_shape 27 | self.is_train = is_train 28 | if is_train: 29 | self.mvtec_paths = sorted(glob.glob(mvtec_dir + "/*.png")) 30 | self.dtd_paths = sorted(glob.glob(dtd_dir + "/*/*.jpg")) 31 | self.rotate_90 = rotate_90 32 | self.random_rotate = random_rotate 33 | else: 34 | self.mvtec_paths = sorted(glob.glob(mvtec_dir + "/*/*.png")) 35 | self.mask_preprocessing = transforms.Compose( 36 | [ 37 | transforms.ToTensor(), 38 | transforms.Resize( 39 | size=(self.resize_shape[1], self.resize_shape[0]), 40 | interpolation=transforms.InterpolationMode.BILINEAR, 41 | antialias=True, 42 | ), 43 | ] 44 | ) 45 | self.final_preprocessing = transforms.Compose( 46 | [ 47 | transforms.ToTensor(), 48 | transforms.Normalize(normalize_mean, normalize_std), 49 | ] 50 | ) 51 | 52 | def __len__(self): 53 | return len(self.mvtec_paths) 54 | 55 | def __getitem__(self, index): 56 | image = Image.open(self.mvtec_paths[index]).convert("RGB") 57 | image = image.resize(self.resize_shape, Image.BILINEAR) 58 | 59 | if self.is_train: 60 | dtd_index = torch.randint(0, len(self.dtd_paths), (1,)).item() 61 | dtd_image = Image.open(self.dtd_paths[dtd_index]).convert("RGB") 62 | dtd_image = dtd_image.resize(self.resize_shape, Image.BILINEAR) 63 | 64 | fill_color = (114, 114, 114) 65 | # rotate_90 66 | if self.rotate_90: 67 | degree = np.random.choice(np.array([0, 90, 180, 270])) 68 | image = image.rotate( 69 | degree, fillcolor=fill_color, resample=Image.BILINEAR 70 | ) 71 | # random_rotate 72 | if self.random_rotate > 0: 73 | degree = np.random.uniform(-self.random_rotate, self.random_rotate) 74 | image = image.rotate( 75 | degree, fillcolor=fill_color, resample=Image.BILINEAR 76 | ) 77 | 78 | # perlin_noise implementation 79 | aug_image, aug_mask = perlin_noise(image, dtd_image, aug_prob=1.0) 80 | aug_image = self.final_preprocessing(aug_image) 81 | 82 | image = self.final_preprocessing(image) 83 | return {"img_aug": aug_image, "img_origin": image, "mask": aug_mask} 84 | else: 85 | image = self.final_preprocessing(image) 86 | dir_path, file_name = os.path.split(self.mvtec_paths[index]) 87 | base_dir = os.path.basename(dir_path) 88 | if base_dir == "good": 89 | mask = torch.zeros_like(image[:1]) 90 | else: 91 | mask_path = os.path.join(dir_path, "../../ground_truth/") 92 | mask_path = os.path.join(mask_path, base_dir) 93 | mask_file_name = file_name.split(".")[0] + "_mask.png" 94 | mask_path = os.path.join(mask_path, mask_file_name) 95 | mask = Image.open(mask_path) 96 | mask = self.mask_preprocessing(mask) 97 | mask = torch.where( 98 | mask < 0.5, torch.zeros_like(mask), torch.ones_like(mask) 99 | ) 100 | return {"img": image, "mask": mask} 101 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import warnings 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from tensorboardX import SummaryWriter 9 | from torch.utils.data import DataLoader 10 | from torchmetrics import AUROC, AveragePrecision 11 | 12 | from constant import RESIZE_SHAPE, NORMALIZE_MEAN, NORMALIZE_STD, ALL_CATEGORY 13 | from data.mvtec_dataset import MVTecDataset 14 | from model.destseg import DeSTSeg 15 | from model.metrics import AUPRO, IAPS 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | def evaluate(args, category, model, visualizer, global_step=0): 21 | model.eval() 22 | with torch.no_grad(): 23 | dataset = MVTecDataset( 24 | is_train=False, 25 | mvtec_dir=args.mvtec_path + category + "/test/", 26 | resize_shape=RESIZE_SHAPE, 27 | normalize_mean=NORMALIZE_MEAN, 28 | normalize_std=NORMALIZE_STD, 29 | ) 30 | dataloader = DataLoader( 31 | dataset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers 32 | ) 33 | de_st_IAPS = IAPS().cuda() 34 | de_st_AUPRO = AUPRO().cuda() 35 | de_st_AUROC = AUROC().cuda() 36 | de_st_AP = AveragePrecision().cuda() 37 | de_st_detect_AUROC = AUROC().cuda() 38 | seg_IAPS = IAPS().cuda() 39 | seg_AUPRO = AUPRO().cuda() 40 | seg_AUROC = AUROC().cuda() 41 | seg_AP = AveragePrecision().cuda() 42 | seg_detect_AUROC = AUROC().cuda() 43 | 44 | for _, sample_batched in enumerate(dataloader): 45 | img = sample_batched["img"].cuda() 46 | mask = sample_batched["mask"].to(torch.int64).cuda() 47 | 48 | output_segmentation, output_de_st, output_de_st_list = model(img) 49 | 50 | output_segmentation = F.interpolate( 51 | output_segmentation, 52 | size=mask.size()[2:], 53 | mode="bilinear", 54 | align_corners=False, 55 | ) 56 | output_de_st = F.interpolate( 57 | output_de_st, size=mask.size()[2:], mode="bilinear", align_corners=False 58 | ) 59 | 60 | mask_sample = torch.max(mask.view(mask.size(0), -1), dim=1)[0] 61 | output_segmentation_sample, _ = torch.sort( 62 | output_segmentation.view(output_segmentation.size(0), -1), 63 | dim=1, 64 | descending=True, 65 | ) 66 | output_segmentation_sample = torch.mean( 67 | output_segmentation_sample[:, : args.T], dim=1 68 | ) 69 | output_de_st_sample, _ = torch.sort( 70 | output_de_st.view(output_de_st.size(0), -1), dim=1, descending=True 71 | ) 72 | output_de_st_sample = torch.mean(output_de_st_sample[:, : args.T], dim=1) 73 | 74 | de_st_IAPS.update(output_de_st, mask) 75 | de_st_AUPRO.update(output_de_st, mask) 76 | de_st_AP.update(output_de_st.flatten(), mask.flatten()) 77 | de_st_AUROC.update(output_de_st.flatten(), mask.flatten()) 78 | de_st_detect_AUROC.update(output_de_st_sample, mask_sample) 79 | 80 | seg_IAPS.update(output_segmentation, mask) 81 | seg_AUPRO.update(output_segmentation, mask) 82 | seg_AP.update(output_segmentation.flatten(), mask.flatten()) 83 | seg_AUROC.update(output_segmentation.flatten(), mask.flatten()) 84 | seg_detect_AUROC.update(output_segmentation_sample, mask_sample) 85 | 86 | iap_de_st, iap90_de_st = de_st_IAPS.compute() 87 | aupro_de_st, ap_de_st, auc_de_st, auc_detect_de_st = ( 88 | de_st_AUPRO.compute(), 89 | de_st_AP.compute(), 90 | de_st_AUROC.compute(), 91 | de_st_detect_AUROC.compute(), 92 | ) 93 | iap_seg, iap90_seg = seg_IAPS.compute() 94 | aupro_seg, ap_seg, auc_seg, auc_detect_seg = ( 95 | seg_AUPRO.compute(), 96 | seg_AP.compute(), 97 | seg_AUROC.compute(), 98 | seg_detect_AUROC.compute(), 99 | ) 100 | 101 | visualizer.add_scalar("DeST_IAP", iap_de_st, global_step) 102 | visualizer.add_scalar("DeST_IAP90", iap90_de_st, global_step) 103 | visualizer.add_scalar("DeST_AUPRO", aupro_de_st, global_step) 104 | visualizer.add_scalar("DeST_AP", ap_de_st, global_step) 105 | visualizer.add_scalar("DeST_AUC", auc_de_st, global_step) 106 | visualizer.add_scalar("DeST_detect_AUC", auc_detect_de_st, global_step) 107 | 108 | visualizer.add_scalar("DeSTSeg_IAP", iap_seg, global_step) 109 | visualizer.add_scalar("DeSTSeg_IAP90", iap90_seg, global_step) 110 | visualizer.add_scalar("DeSTSeg_AUPRO", aupro_seg, global_step) 111 | visualizer.add_scalar("DeSTSeg_AP", ap_seg, global_step) 112 | visualizer.add_scalar("DeSTSeg_AUC", auc_seg, global_step) 113 | visualizer.add_scalar("DeSTSeg_detect_AUC", auc_detect_seg, global_step) 114 | 115 | print("Eval at step", global_step) 116 | print("================================") 117 | print("Denoising Student-Teacher (DeST)") 118 | print("pixel_AUC:", round(float(auc_de_st), 4)) 119 | print("pixel_AP:", round(float(ap_de_st), 4)) 120 | print("PRO:", round(float(aupro_de_st), 4)) 121 | print("image_AUC:", round(float(auc_detect_de_st), 4)) 122 | print("IAP:", round(float(iap_de_st), 4)) 123 | print("IAP90:", round(float(iap90_de_st), 4)) 124 | print() 125 | print("Segmentation Guided Denoising Student-Teacher (DeSTSeg)") 126 | print("pixel_AUC:", round(float(auc_seg), 4)) 127 | print("pixel_AP:", round(float(ap_seg), 4)) 128 | print("PRO:", round(float(aupro_seg), 4)) 129 | print("image_AUC:", round(float(auc_detect_seg), 4)) 130 | print("IAP:", round(float(iap_seg), 4)) 131 | print("IAP90:", round(float(iap90_seg), 4)) 132 | print() 133 | 134 | de_st_IAPS.reset() 135 | de_st_AUPRO.reset() 136 | de_st_AUROC.reset() 137 | de_st_AP.reset() 138 | de_st_detect_AUROC.reset() 139 | seg_IAPS.reset() 140 | seg_AUPRO.reset() 141 | seg_AUROC.reset() 142 | seg_AP.reset() 143 | seg_detect_AUROC.reset() 144 | 145 | 146 | def test(args, category): 147 | if not os.path.exists(args.log_path): 148 | os.makedirs(args.log_path) 149 | 150 | run_name = f"DeSTSeg_MVTec_test_{category}" 151 | if os.path.exists(os.path.join(args.log_path, run_name + "/")): 152 | shutil.rmtree(os.path.join(args.log_path, run_name + "/")) 153 | 154 | visualizer = SummaryWriter(log_dir=os.path.join(args.log_path, run_name + "/")) 155 | 156 | model = DeSTSeg(dest=True, ed=True).cuda() 157 | 158 | assert os.path.exists( 159 | os.path.join(args.checkpoint_path, args.base_model_name + category + ".pckl") 160 | ) 161 | model.load_state_dict( 162 | torch.load( 163 | os.path.join( 164 | args.checkpoint_path, args.base_model_name + category + ".pckl" 165 | ) 166 | ) 167 | ) 168 | 169 | evaluate(args, category, model, visualizer) 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = argparse.ArgumentParser() 174 | 175 | parser.add_argument("--gpu_id", type=int, default=0) 176 | parser.add_argument("--num_workers", type=int, default=16) 177 | 178 | parser.add_argument("--mvtec_path", type=str, default="./datasets/mvtec/") 179 | parser.add_argument("--dtd_path", type=str, default="./datasets/dtd/images/") 180 | parser.add_argument("--checkpoint_path", type=str, default="./saved_model/") 181 | parser.add_argument("--base_model_name", type=str, default="DeSTSeg_MVTec_5000_") 182 | parser.add_argument("--log_path", type=str, default="./logs/") 183 | 184 | parser.add_argument("--bs", type=int, default=32) 185 | parser.add_argument("--T", type=int, default=100) # for image-level inference 186 | 187 | parser.add_argument("--category", nargs="*", type=str, default=ALL_CATEGORY) 188 | args = parser.parse_args() 189 | 190 | obj_list = args.category 191 | for obj in obj_list: 192 | assert obj in ALL_CATEGORY 193 | 194 | with torch.cuda.device(args.gpu_id): 195 | for obj in obj_list: 196 | print(obj) 197 | test(args, obj) 198 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import destseg, metrics, losses, model_utils 2 | -------------------------------------------------------------------------------- /model/destseg.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from model.model_utils import ASPP, BasicBlock, l2_normalize, make_layer 7 | 8 | 9 | class TeacherNet(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.encoder = timm.create_model( 13 | "resnet18", 14 | pretrained=True, 15 | features_only=True, 16 | out_indices=[1, 2, 3], 17 | ) 18 | # freeze teacher model 19 | for param in self.parameters(): 20 | param.requires_grad = False 21 | 22 | def forward(self, x): 23 | self.eval() 24 | x1, x2, x3 = self.encoder(x) 25 | return (x1, x2, x3) 26 | 27 | 28 | class StudentNet(nn.Module): 29 | def __init__(self, ed=True): 30 | super().__init__() 31 | self.ed = ed 32 | if self.ed: 33 | self.decoder_layer4 = make_layer(BasicBlock, 512, 512, 2) 34 | self.decoder_layer3 = make_layer(BasicBlock, 512, 256, 2) 35 | self.decoder_layer2 = make_layer(BasicBlock, 256, 128, 2) 36 | self.decoder_layer1 = make_layer(BasicBlock, 128, 64, 2) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 41 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 42 | nn.init.constant_(m.weight, 1) 43 | nn.init.constant_(m.bias, 0) 44 | 45 | self.encoder = timm.create_model( 46 | "resnet18", 47 | pretrained=False, 48 | features_only=True, 49 | out_indices=[1, 2, 3, 4], 50 | ) 51 | 52 | def forward(self, x): 53 | x1, x2, x3, x4 = self.encoder(x) 54 | if not self.ed: 55 | return (x1, x2, x3) 56 | x = x4 57 | b4 = self.decoder_layer4(x) 58 | b3 = F.interpolate(b4, size=x3.size()[2:], mode="bilinear", align_corners=False) 59 | b3 = self.decoder_layer3(b3) 60 | b2 = F.interpolate(b3, size=x2.size()[2:], mode="bilinear", align_corners=False) 61 | b2 = self.decoder_layer2(b2) 62 | b1 = F.interpolate(b2, size=x1.size()[2:], mode="bilinear", align_corners=False) 63 | b1 = self.decoder_layer1(b1) 64 | return (b1, b2, b3) 65 | 66 | 67 | class SegmentationNet(nn.Module): 68 | def __init__(self, inplanes=448): 69 | super().__init__() 70 | self.res = make_layer(BasicBlock, inplanes, 256, 2) 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 75 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 76 | nn.init.constant_(m.weight, 1) 77 | nn.init.constant_(m.bias, 0) 78 | 79 | self.head = nn.Sequential( 80 | ASPP(256, 256, [6, 12, 18]), 81 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 82 | nn.BatchNorm2d(256), 83 | nn.ReLU(inplace=True), 84 | nn.Conv2d(256, 1, 1), 85 | ) 86 | 87 | def forward(self, x): 88 | x = self.res(x) 89 | x = self.head(x) 90 | x = torch.sigmoid(x) 91 | return x 92 | 93 | 94 | class DeSTSeg(nn.Module): 95 | def __init__(self, dest=True, ed=True): 96 | super().__init__() 97 | self.teacher_net = TeacherNet() 98 | self.student_net = StudentNet(ed) 99 | self.dest = dest 100 | self.segmentation_net = SegmentationNet(inplanes=448) 101 | 102 | def forward(self, img_aug, img_origin=None): 103 | self.teacher_net.eval() 104 | 105 | if img_origin is None: # for inference 106 | img_origin = img_aug.clone() 107 | 108 | outputs_teacher_aug = [ 109 | l2_normalize(output_t.detach()) for output_t in self.teacher_net(img_aug) 110 | ] 111 | outputs_student_aug = [ 112 | l2_normalize(output_s) for output_s in self.student_net(img_aug) 113 | ] 114 | output = torch.cat( 115 | [ 116 | F.interpolate( 117 | -output_t * output_s, 118 | size=outputs_student_aug[0].size()[2:], 119 | mode="bilinear", 120 | align_corners=False, 121 | ) 122 | for output_t, output_s in zip(outputs_teacher_aug, outputs_student_aug) 123 | ], 124 | dim=1, 125 | ) 126 | 127 | output_segmentation = self.segmentation_net(output) 128 | 129 | if self.dest: 130 | outputs_student = outputs_student_aug 131 | else: 132 | outputs_student = [ 133 | l2_normalize(output_s) for output_s in self.student_net(img_origin) 134 | ] 135 | outputs_teacher = [ 136 | l2_normalize(output_t.detach()) for output_t in self.teacher_net(img_origin) 137 | ] 138 | 139 | output_de_st_list = [] 140 | for output_t, output_s in zip(outputs_teacher, outputs_student): 141 | a_map = 1 - torch.sum(output_s * output_t, dim=1, keepdim=True) 142 | output_de_st_list.append(a_map) 143 | output_de_st = torch.cat( 144 | [ 145 | F.interpolate( 146 | output_de_st_instance, 147 | size=outputs_student[0].size()[2:], 148 | mode="bilinear", 149 | align_corners=False, 150 | ) 151 | for output_de_st_instance in output_de_st_list 152 | ], 153 | dim=1, 154 | ) # [N, 3, H, W] 155 | output_de_st = torch.prod(output_de_st, dim=1, keepdim=True) 156 | 157 | return output_segmentation, output_de_st, output_de_st_list 158 | -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def cosine_similarity_loss(output_de_st_list): 6 | loss = 0 7 | for instance in output_de_st_list: 8 | _, _, h, w = instance.shape 9 | loss += torch.sum(instance) / (h * w) 10 | return loss 11 | 12 | 13 | def focal_loss(inputs, targets, alpha=-1, gamma=4, reduction="mean"): 14 | inputs = inputs.float() 15 | targets = targets.float() 16 | ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none") 17 | p_t = inputs * targets + (1 - inputs) * (1 - targets) 18 | loss = ce_loss * ((1 - p_t) ** gamma) 19 | 20 | if alpha >= 0: 21 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 22 | loss = alpha_t * loss 23 | 24 | if reduction == "mean": 25 | loss = loss.mean() 26 | elif reduction == "sum": 27 | loss = loss.sum() 28 | 29 | return loss 30 | 31 | 32 | def l1_loss(inputs, targets, reduction="mean"): 33 | return F.l1_loss(inputs, targets, reduction=reduction) 34 | -------------------------------------------------------------------------------- /model/metrics.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_left 2 | from typing import Any, Callable, List, Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from anomalib.utils.metrics.plotting_utils import plot_figure 7 | from anomalib.utils.metrics.pro import ( 8 | connected_components_cpu, 9 | connected_components_gpu, 10 | ) 11 | from matplotlib.figure import Figure 12 | from torch import Tensor 13 | from torchmetrics import Metric 14 | from torchmetrics.functional import auc, roc 15 | from torchmetrics.utilities.data import dim_zero_cat 16 | 17 | 18 | class AUPRO(Metric): 19 | """Area under per region overlap (AUPRO) Metric. Copy from anomalib: https://github.com/openvinotoolkit/anomalib""" 20 | 21 | is_differentiable: bool = False 22 | higher_is_better: Optional[bool] = None 23 | full_state_update: bool = False 24 | preds: List[Tensor] 25 | target: List[Tensor] 26 | 27 | def __init__( 28 | self, 29 | compute_on_step: bool = True, 30 | dist_sync_on_step: bool = False, 31 | process_group: Optional[Any] = None, 32 | dist_sync_fn: Callable = None, 33 | fpr_limit: float = 0.3, 34 | ) -> None: 35 | super().__init__( 36 | compute_on_step=compute_on_step, 37 | dist_sync_on_step=dist_sync_on_step, 38 | process_group=process_group, 39 | dist_sync_fn=dist_sync_fn, 40 | ) 41 | 42 | self.add_state( 43 | "preds", default=[], dist_reduce_fx="cat" 44 | ) # pylint: disable=not-callable 45 | self.add_state( 46 | "target", default=[], dist_reduce_fx="cat" 47 | ) # pylint: disable=not-callable 48 | self.register_buffer("fpr_limit", torch.tensor(fpr_limit)) 49 | 50 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 51 | """Update state with new values. 52 | Args: 53 | preds (Tensor): predictions of the model 54 | target (Tensor): ground truth targets 55 | """ 56 | self.target.append(target) 57 | self.preds.append(preds) 58 | 59 | def _compute(self) -> Tuple[Tensor, Tensor]: 60 | """Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit. 61 | It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall 62 | PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction. 63 | Raises: 64 | ValueError: ValueError is raised if self.target doesn't conform with requirements imposed by kornia for 65 | connected component analysis. 66 | Returns: 67 | Tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. 68 | """ 69 | target = dim_zero_cat(self.target) 70 | preds = dim_zero_cat(self.preds) 71 | 72 | # check and prepare target for labeling via kornia 73 | if target.min() < 0 or target.max() > 1: 74 | raise ValueError( 75 | ( 76 | f"kornia.contrib.connected_components expects input to lie in the interval [0, 1], but found " 77 | f"interval was [{target.min()}, {target.max()}]." 78 | ) 79 | ) 80 | # target = target.unsqueeze(1) # kornia expects N1HW format 81 | target = target.type(torch.float) # kornia expects FloatTensor 82 | if target.is_cuda: 83 | cca = connected_components_gpu(target) 84 | else: 85 | cca = connected_components_cpu(target) 86 | 87 | preds = preds.flatten() 88 | cca = cca.flatten() 89 | target = target.flatten() 90 | 91 | # compute the global fpr-size 92 | fpr: Tensor = roc(preds, target)[0] # only need fpr 93 | output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) 94 | 95 | # compute the PRO curve by aggregating per-region tpr/fpr curves/values. 96 | tpr = torch.zeros(output_size, device=preds.device, dtype=torch.float) 97 | fpr = torch.zeros(output_size, device=preds.device, dtype=torch.float) 98 | new_idx = torch.arange(0, output_size, device=preds.device, dtype=torch.float) 99 | 100 | # Loop over the labels, computing per-region tpr/fpr curves, and aggregating them. 101 | # Note that, since the groundtruth is different for every all to `roc`, we also get 102 | # different/unique tpr/fpr curves (i.e. len(_fpr_idx) is different for every call). 103 | # We therefore need to resample per-region curves to a fixed sampling ratio (defined above). 104 | labels = cca.unique()[1:] # 0 is background 105 | background = cca == 0 106 | _fpr: Tensor 107 | _tpr: Tensor 108 | for label in labels: 109 | interp: bool = False 110 | new_idx[-1] = output_size - 1 111 | mask = cca == label 112 | # Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other 113 | # label in labels as FPs. We also don't need to return the thresholds 114 | _fpr, _tpr = roc(preds[background | mask], mask[background | mask])[:-1] 115 | 116 | # catch edge-case where ROC only has fpr vals > self.fpr_limit 117 | if _fpr[_fpr <= self.fpr_limit].max() == 0: 118 | _fpr_limit = _fpr[_fpr > self.fpr_limit].min() 119 | else: 120 | _fpr_limit = self.fpr_limit 121 | 122 | _fpr_idx = torch.where(_fpr <= _fpr_limit)[0] 123 | # if computed roc curve is not specified sufficiently close to self.fpr_limit, 124 | # we include the closest higher tpr/fpr pair and linearly interpolate the tpr/fpr point at self.fpr_limit 125 | if not torch.allclose(_fpr[_fpr_idx].max(), self.fpr_limit): 126 | _tmp_idx = torch.searchsorted(_fpr, self.fpr_limit) 127 | _fpr_idx = torch.cat([_fpr_idx, _tmp_idx.unsqueeze_(0)]) 128 | _slope = 1 - ( 129 | (_fpr[_tmp_idx] - self.fpr_limit) 130 | / (_fpr[_tmp_idx] - _fpr[_tmp_idx - 1]) 131 | ) 132 | interp = True 133 | 134 | _fpr = _fpr[_fpr_idx] 135 | _tpr = _tpr[_fpr_idx] 136 | 137 | _fpr_idx = _fpr_idx.float() 138 | _fpr_idx /= _fpr_idx.max() 139 | _fpr_idx *= new_idx.max() 140 | 141 | if interp: 142 | # last point will be sampled at self.fpr_limit 143 | new_idx[-1] = _fpr_idx[-2] + ((_fpr_idx[-1] - _fpr_idx[-2]) * _slope) 144 | 145 | _tpr = self.interp1d(_fpr_idx, _tpr, new_idx) 146 | _fpr = self.interp1d(_fpr_idx, _fpr, new_idx) 147 | tpr += _tpr 148 | fpr += _fpr 149 | 150 | # Actually perform the averaging 151 | tpr /= labels.size(0) 152 | fpr /= labels.size(0) 153 | return fpr, tpr 154 | 155 | def compute(self) -> Tensor: 156 | """Fist compute PRO curve, then compute and scale area under the curve. 157 | Returns: 158 | Tensor: Value of the AUPRO metric 159 | """ 160 | fpr, tpr = self._compute() 161 | 162 | aupro = auc(fpr, tpr, reorder=True) 163 | aupro = aupro / fpr[-1] # normalize the area 164 | 165 | return aupro 166 | 167 | def generate_figure(self) -> Tuple[Figure, str]: 168 | """Generate a figure containing the PRO curve and the AUPRO. 169 | Returns: 170 | Tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging 171 | """ 172 | fpr, tpr = self._compute() 173 | aupro = self.compute() 174 | 175 | xlim = (0.0, self.fpr_limit.detach_().cpu().numpy()) 176 | ylim = (0.0, 1.0) 177 | xlabel = "Global FPR" 178 | ylabel = "Averaged Per-Region TPR" 179 | loc = "lower right" 180 | title = "PRO" 181 | 182 | fig, _axis = plot_figure( 183 | fpr, tpr, aupro, xlim, ylim, xlabel, ylabel, loc, title 184 | ) 185 | 186 | return fig, "PRO" 187 | 188 | @staticmethod 189 | def interp1d(old_x: Tensor, old_y: Tensor, new_x: Tensor) -> Tensor: 190 | """Function to interpolate a 1D signal linearly to new sampling points. 191 | Args: 192 | old_x (Tensor): original 1-D x values (same size as y) 193 | old_y (Tensor): original 1-D y values (same size as x) 194 | new_x (Tensor): x-values where y should be interpolated at 195 | Returns: 196 | Tensor: y-values at corresponding new_x values. 197 | """ 198 | 199 | # Compute slope 200 | eps = torch.finfo(old_y.dtype).eps 201 | slope = (old_y[1:] - old_y[:-1]) / (eps + (old_x[1:] - old_x[:-1])) 202 | 203 | # Prepare idx for linear interpolation 204 | idx = torch.searchsorted(old_x, new_x) 205 | 206 | # searchsorted looks for the index where the values must be inserted 207 | # to preserve order, but we actually want the preceeding index. 208 | idx -= 1 209 | # we clamp the index, because the number of intervals = old_x.size(0) -1, 210 | # and the left neighbour should hence be at most number of intervals -1, i.e. old_x.size(0) - 2 211 | idx = torch.clamp(idx, 0, old_x.size(0) - 2) 212 | 213 | # perform actual linear interpolation 214 | y_new = old_y[idx] + slope[idx] * (new_x - old_x[idx]) 215 | 216 | return y_new 217 | 218 | 219 | class IAPS(Metric): 220 | """Implementation of the instance average precision (IAP) score in our paper""" 221 | 222 | is_differentiable: bool = False 223 | higher_is_better: Optional[bool] = None 224 | full_state_update: bool = False 225 | preds: List[Tensor] 226 | target: List[Tensor] 227 | 228 | def __init__( 229 | self, 230 | compute_on_step: bool = True, 231 | dist_sync_on_step: bool = False, 232 | process_group: Optional[Any] = None, 233 | dist_sync_fn: Callable = None, 234 | ioi_thresh: float = 0.5, 235 | recall_thresh: float = 0.9, # the k% of the metric IAP@k in our paper 236 | ) -> None: 237 | super().__init__( 238 | compute_on_step=compute_on_step, 239 | dist_sync_on_step=dist_sync_on_step, 240 | process_group=process_group, 241 | dist_sync_fn=dist_sync_fn, 242 | ) 243 | 244 | self.add_state( 245 | "preds", default=[], dist_reduce_fx="cat" 246 | ) # pylint: disable=not-callable 247 | self.add_state( 248 | "target", default=[], dist_reduce_fx="cat" 249 | ) # pylint: disable=not-callable 250 | self.ioi_thresh = ioi_thresh 251 | self.recall_thresh = recall_thresh 252 | 253 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 254 | """Update state with new values. 255 | Args: 256 | preds (Tensor): predictions of the model 257 | target (Tensor): ground truth targets 258 | """ 259 | self.target.append(target) 260 | self.preds.append(preds) 261 | 262 | def compute(self): 263 | target = dim_zero_cat(self.target) 264 | preds = dim_zero_cat(self.preds) 265 | 266 | # check and prepare target for labeling via kornia 267 | if target.min() < 0 or target.max() > 1: 268 | raise ValueError( 269 | ( 270 | f"kornia.contrib.connected_components expects input to lie in the interval [0, 1], but found " 271 | f"interval was [{target.min()}, {target.max()}]." 272 | ) 273 | ) 274 | target = target.type(torch.float) # kornia expects FloatTensor 275 | if target.is_cuda: 276 | cca = connected_components_gpu(target) 277 | else: 278 | cca = connected_components_cpu(target) 279 | 280 | preds = preds.flatten() 281 | cca = cca.flatten() 282 | target = target.flatten() 283 | 284 | labels = cca.unique()[1:] 285 | 286 | ins_scores = [] 287 | 288 | for label in labels: 289 | mask = cca == label 290 | heatmap_ins, _ = preds[mask].sort(descending=True) 291 | ind = np.int64(self.ioi_thresh * len(heatmap_ins)) 292 | ins_scores.append(float(heatmap_ins[ind])) 293 | 294 | if len(ins_scores) == 0: 295 | raise Exception("gt_masks all zeros") 296 | 297 | ins_scores.sort() 298 | 299 | recall = [] 300 | precision = [] 301 | 302 | for i, score in enumerate(ins_scores): 303 | recall.append(1 - i / len(ins_scores)) 304 | tp = torch.sum(preds * target >= score) 305 | tpfp = torch.sum(preds >= score) 306 | precision.append(float(tp / tpfp)) 307 | 308 | for i in range(0, len(precision) - 1): 309 | precision[i + 1] = max(precision[i + 1], precision[i]) 310 | ap_score = sum(precision) / len(ins_scores) 311 | recall = recall[::-1] 312 | precision = precision[::-1] 313 | k = bisect_left(recall, self.recall_thresh) 314 | return ap_score, precision[k] 315 | -------------------------------------------------------------------------------- /model/model_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | def conv3x3( 10 | in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 11 | ) -> nn.Conv2d: 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d( 14 | in_planes, 15 | out_planes, 16 | kernel_size=3, 17 | stride=stride, 18 | padding=dilation, 19 | groups=groups, 20 | bias=False, 21 | dilation=dilation, 22 | ) 23 | 24 | 25 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | def make_layer(block, inplanes, planes, blocks, stride=1, norm_layer=None): 31 | if norm_layer is None: 32 | norm_layer = nn.BatchNorm2d 33 | downsample = None 34 | if stride != 1 or inplanes != planes * block.expansion: 35 | downsample = nn.Sequential( 36 | conv1x1(inplanes, planes * block.expansion, stride), 37 | norm_layer(planes * block.expansion), 38 | ) 39 | 40 | layers = [] 41 | layers.append(block(inplanes, planes, stride, downsample, norm_layer=norm_layer)) 42 | inplanes = planes * block.expansion 43 | for _ in range(1, blocks): 44 | layers.append(block(inplanes, planes, norm_layer=norm_layer)) 45 | 46 | return nn.Sequential(*layers) 47 | 48 | 49 | def l2_normalize(input, dim=1, eps=1e-12): 50 | denom = torch.sqrt(torch.sum(input**2, dim=dim, keepdim=True)) 51 | return input / (denom + eps) 52 | 53 | 54 | class BasicBlock(nn.Module): 55 | expansion: int = 1 56 | 57 | def __init__( 58 | self, 59 | inplanes: int, 60 | planes: int, 61 | stride: int = 1, 62 | downsample: Optional[nn.Module] = None, 63 | groups: int = 1, 64 | base_width: int = 64, 65 | dilation: int = 1, 66 | norm_layer: Optional[Callable[..., nn.Module]] = None, 67 | ) -> None: 68 | super().__init__() 69 | if norm_layer is None: 70 | norm_layer = nn.BatchNorm2d 71 | if groups != 1 or base_width != 64: 72 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 73 | if dilation > 1: 74 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 75 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 76 | self.conv1 = conv3x3(inplanes, planes, stride) 77 | self.bn1 = norm_layer(planes) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.conv2 = conv3x3(planes, planes) 80 | self.bn2 = norm_layer(planes) 81 | self.downsample = downsample 82 | self.stride = stride 83 | 84 | def forward(self, x: Tensor) -> Tensor: 85 | identity = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | 94 | if self.downsample is not None: 95 | identity = self.downsample(x) 96 | 97 | out += identity 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | def get_norm_layer(norm: str): 104 | norm = { 105 | "BN": nn.BatchNorm2d, 106 | "LN": nn.LayerNorm, 107 | }[norm.upper()] 108 | return norm 109 | 110 | 111 | def get_act_layer(act: str): 112 | act = { 113 | "relu": nn.ReLU, 114 | "relu6": nn.ReLU6, 115 | "swish": nn.SiLU, 116 | "mish": nn.Mish, 117 | "leaky_relu": nn.LeakyReLU, 118 | "sigmoid": nn.Sigmoid, 119 | "gelu": nn.GELU, 120 | }[act.lower()] 121 | return act 122 | 123 | 124 | class ConvNormAct2d(nn.Module): 125 | def __init__( 126 | self, 127 | in_channels, 128 | out_channels, 129 | kernel_size, 130 | stride=1, 131 | padding="same", 132 | dilation=1, 133 | groups=1, 134 | conv_kwargs=None, 135 | norm_layer=None, 136 | norm_kwargs=None, 137 | act_layer=None, 138 | act_kwargs=None, 139 | ): 140 | super(ConvNormAct2d, self).__init__() 141 | 142 | conv_kwargs = {} 143 | if norm_layer: 144 | conv_kwargs["bias"] = False 145 | if padding == "same" and stride > 1: 146 | # if kernel_size is even, -1 is must 147 | padding = (kernel_size - 1) // 2 148 | 149 | self.conv = self._build_conv( 150 | in_channels, 151 | out_channels, 152 | kernel_size, 153 | stride, 154 | padding, 155 | dilation, 156 | groups, 157 | conv_kwargs, 158 | ) 159 | self.norm = None 160 | if norm_layer: 161 | norm_kwargs = {} 162 | self.norm = get_norm_layer(norm_layer)( 163 | num_features=out_channels, **norm_kwargs 164 | ) 165 | self.act = None 166 | if act_layer: 167 | act_kwargs = {} 168 | self.act = get_act_layer(act_layer)(**act_kwargs) 169 | 170 | def _build_conv( 171 | self, 172 | in_channels, 173 | out_channels, 174 | kernel_size, 175 | stride, 176 | padding, 177 | dilation, 178 | groups, 179 | conv_kwargs, 180 | ): 181 | return nn.Conv2d( 182 | in_channels=in_channels, 183 | out_channels=out_channels, 184 | kernel_size=kernel_size, 185 | stride=stride, 186 | padding=padding, 187 | dilation=dilation, 188 | groups=groups, 189 | **conv_kwargs, 190 | ) 191 | 192 | def forward(self, x): 193 | x = self.conv(x) 194 | if self.norm: 195 | x = self.norm(x) 196 | if self.act: 197 | x = self.act(x) 198 | return x 199 | 200 | 201 | class ASPP(nn.Module): 202 | def __init__(self, input_channels, output_channels, atrous_rates): 203 | super(ASPP, self).__init__() 204 | modules = [] 205 | modules.append( 206 | nn.Sequential( 207 | nn.AdaptiveAvgPool2d(1), 208 | ConvNormAct2d( 209 | input_channels, 210 | output_channels, 211 | kernel_size=1, 212 | norm_layer="BN", 213 | act_layer="RELU", 214 | ), 215 | ) 216 | ) 217 | for atrous_rate in atrous_rates: 218 | conv_norm_act = ConvNormAct2d 219 | modules.append( 220 | conv_norm_act( 221 | in_channels=input_channels, 222 | out_channels=output_channels, 223 | kernel_size=1 if atrous_rate == 1 else 3, 224 | padding=0 if atrous_rate == 1 else atrous_rate, 225 | dilation=atrous_rate, 226 | norm_layer="BN", 227 | act_layer="RELU", 228 | ) 229 | ) 230 | 231 | self.aspp_feature_extractors = nn.ModuleList(modules) 232 | self.aspp_fusion_layer = ConvNormAct2d( 233 | (1 + len(atrous_rates)) * output_channels, 234 | output_channels, 235 | kernel_size=3, 236 | norm_layer="BN", 237 | act_layer="RELU", 238 | ) 239 | 240 | def forward(self, x): 241 | res = [] 242 | for aspp_feature_extractor in self.aspp_feature_extractors: 243 | res.append(aspp_feature_extractor(x)) 244 | res[0] = F.interpolate( 245 | input=res[0], size=x.shape[2:], mode="bilinear", align_corners=False 246 | ) # resize back after global-avg-pooling layer 247 | res = torch.cat(res, dim=1) 248 | res = self.aspp_fusion_layer(res) 249 | return res 250 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anomalib==0.4.0 2 | imgaug==0.4.0 3 | matplotlib==3.7.1 4 | numpy==1.24.2 5 | opencv_python_headless==4.7.0.72 6 | Pillow==9.5.0 7 | tensorboardX==2.6 8 | timm==0.6.12 9 | torch==2.0.0 10 | torchmetrics==0.10.3 11 | torchvision==0.15.1 12 | -------------------------------------------------------------------------------- /scripts/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | mkdir datasets 4 | cd datasets 5 | # Download describable textures dataset 6 | wget https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz 7 | tar -xf dtd-r1.0.1.tar.gz 8 | rm dtd-r1.0.1.tar.gz 9 | 10 | mkdir mvtec 11 | cd mvtec 12 | # Download MVTec anomaly detection dataset 13 | wget https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz 14 | tar -xf mvtec_anomaly_detection.tar.xz 15 | rm mvtec_anomaly_detection.tar.xz 16 | 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import warnings 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from tensorboardX import SummaryWriter 9 | from torch.utils.data import DataLoader 10 | 11 | from constant import RESIZE_SHAPE, NORMALIZE_MEAN, NORMALIZE_STD, ALL_CATEGORY 12 | from data.mvtec_dataset import MVTecDataset 13 | from eval import evaluate 14 | from model.destseg import DeSTSeg 15 | from model.losses import cosine_similarity_loss, focal_loss, l1_loss 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | def train(args, category, rotate_90=False, random_rotate=0): 21 | if not os.path.exists(args.checkpoint_path): 22 | os.makedirs(args.checkpoint_path) 23 | if not os.path.exists(args.log_path): 24 | os.makedirs(args.log_path) 25 | 26 | run_name = f"{args.run_name_head}_{args.steps}_{category}" 27 | if os.path.exists(os.path.join(args.log_path, run_name + "/")): 28 | shutil.rmtree(os.path.join(args.log_path, run_name + "/")) 29 | 30 | visualizer = SummaryWriter(log_dir=os.path.join(args.log_path, run_name + "/")) 31 | 32 | model = DeSTSeg(dest=True, ed=True).cuda() 33 | 34 | seg_optimizer = torch.optim.SGD( 35 | [ 36 | {"params": model.segmentation_net.res.parameters(), "lr": args.lr_res}, 37 | {"params": model.segmentation_net.head.parameters(), "lr": args.lr_seghead}, 38 | ], 39 | lr=0.001, 40 | momentum=0.9, 41 | weight_decay=1e-4, 42 | nesterov=False, 43 | ) 44 | de_st_optimizer = torch.optim.SGD( 45 | [ 46 | {"params": model.student_net.parameters(), "lr": args.lr_de_st}, 47 | ], 48 | lr=0.4, 49 | momentum=0.9, 50 | weight_decay=1e-4, 51 | nesterov=False, 52 | ) 53 | 54 | dataset = MVTecDataset( 55 | is_train=True, 56 | mvtec_dir=args.mvtec_path + category + "/train/good/", 57 | resize_shape=RESIZE_SHAPE, 58 | normalize_mean=NORMALIZE_MEAN, 59 | normalize_std=NORMALIZE_STD, 60 | dtd_dir=args.dtd_path, 61 | rotate_90=rotate_90, 62 | random_rotate=random_rotate, 63 | ) 64 | 65 | dataloader = DataLoader( 66 | dataset, 67 | batch_size=args.bs, 68 | shuffle=True, 69 | num_workers=args.num_workers, 70 | drop_last=True, 71 | ) 72 | 73 | global_step = 0 74 | 75 | flag = True 76 | 77 | while flag: 78 | for _, sample_batched in enumerate(dataloader): 79 | seg_optimizer.zero_grad() 80 | de_st_optimizer.zero_grad() 81 | img_origin = sample_batched["img_origin"].cuda() 82 | img_aug = sample_batched["img_aug"].cuda() 83 | mask = sample_batched["mask"].cuda() 84 | 85 | if global_step < args.de_st_steps: 86 | model.student_net.train() 87 | model.segmentation_net.eval() 88 | else: 89 | model.student_net.eval() 90 | model.segmentation_net.train() 91 | 92 | output_segmentation, output_de_st, output_de_st_list = model( 93 | img_aug, img_origin 94 | ) 95 | 96 | mask = F.interpolate( 97 | mask, 98 | size=output_segmentation.size()[2:], 99 | mode="bilinear", 100 | align_corners=False, 101 | ) 102 | mask = torch.where( 103 | mask < 0.5, torch.zeros_like(mask), torch.ones_like(mask) 104 | ) 105 | 106 | cosine_loss_val = cosine_similarity_loss(output_de_st_list) 107 | focal_loss_val = focal_loss(output_segmentation, mask, gamma=args.gamma) 108 | l1_loss_val = l1_loss(output_segmentation, mask) 109 | 110 | if global_step < args.de_st_steps: 111 | total_loss_val = cosine_loss_val 112 | total_loss_val.backward() 113 | de_st_optimizer.step() 114 | else: 115 | total_loss_val = focal_loss_val + l1_loss_val 116 | total_loss_val.backward() 117 | seg_optimizer.step() 118 | 119 | global_step += 1 120 | 121 | visualizer.add_scalar("cosine_loss", cosine_loss_val, global_step) 122 | visualizer.add_scalar("focal_loss", focal_loss_val, global_step) 123 | visualizer.add_scalar("l1_loss", l1_loss_val, global_step) 124 | visualizer.add_scalar("total_loss", total_loss_val, global_step) 125 | 126 | if global_step % args.eval_per_steps == 0: 127 | evaluate(args, category, model, visualizer, global_step) 128 | 129 | if global_step % args.log_per_steps == 0: 130 | if global_step < args.de_st_steps: 131 | print( 132 | f"Training at global step {global_step}, cosine loss: {round(float(cosine_loss_val), 4)}" 133 | ) 134 | else: 135 | print( 136 | f"Training at global step {global_step}, focal loss: {round(float(focal_loss_val), 4)}, l1 loss: {round(float(l1_loss_val), 4)}" 137 | ) 138 | 139 | if global_step >= args.steps: 140 | flag = False 141 | break 142 | 143 | torch.save( 144 | model.state_dict(), os.path.join(args.checkpoint_path, run_name + ".pckl") 145 | ) 146 | 147 | 148 | if __name__ == "__main__": 149 | parser = argparse.ArgumentParser() 150 | 151 | parser.add_argument("--gpu_id", type=int, default=0) 152 | parser.add_argument("--num_workers", type=int, default=16) 153 | 154 | parser.add_argument("--mvtec_path", type=str, default="./datasets/mvtec/") 155 | parser.add_argument("--dtd_path", type=str, default="./datasets/dtd/images/") 156 | parser.add_argument("--checkpoint_path", type=str, default="./saved_model/") 157 | parser.add_argument("--run_name_head", type=str, default="DeSTSeg_MVTec") 158 | parser.add_argument("--log_path", type=str, default="./logs/") 159 | 160 | parser.add_argument("--bs", type=int, default=32) 161 | parser.add_argument("--lr_de_st", type=float, default=0.4) 162 | parser.add_argument("--lr_res", type=float, default=0.1) 163 | parser.add_argument("--lr_seghead", type=float, default=0.01) 164 | parser.add_argument("--steps", type=int, default=5000) 165 | parser.add_argument( 166 | "--de_st_steps", type=int, default=1000 167 | ) # steps of training the denoising student model 168 | parser.add_argument("--eval_per_steps", type=int, default=1000) 169 | parser.add_argument("--log_per_steps", type=int, default=50) 170 | parser.add_argument("--gamma", type=float, default=4) # for focal loss 171 | parser.add_argument("--T", type=int, default=100) # for image-level inference 172 | 173 | parser.add_argument( 174 | "--custom_training_category", action="store_true", default=False 175 | ) 176 | parser.add_argument("--no_rotation_category", nargs="*", type=str, default=list()) 177 | parser.add_argument( 178 | "--slight_rotation_category", nargs="*", type=str, default=list() 179 | ) 180 | parser.add_argument("--rotation_category", nargs="*", type=str, default=list()) 181 | 182 | args = parser.parse_args() 183 | 184 | if args.custom_training_category: 185 | no_rotation_category = args.no_rotation_category 186 | slight_rotation_category = args.slight_rotation_category 187 | rotation_category = args.rotation_category 188 | # check 189 | for category in ( 190 | no_rotation_category + slight_rotation_category + rotation_category 191 | ): 192 | assert category in ALL_CATEGORY 193 | else: 194 | no_rotation_category = [ 195 | "capsule", 196 | "metal_nut", 197 | "pill", 198 | "toothbrush", 199 | "transistor", 200 | ] 201 | slight_rotation_category = [ 202 | "wood", 203 | "zipper", 204 | "cable", 205 | ] 206 | rotation_category = [ 207 | "bottle", 208 | "grid", 209 | "hazelnut", 210 | "leather", 211 | "tile", 212 | "carpet", 213 | "screw", 214 | ] 215 | 216 | with torch.cuda.device(args.gpu_id): 217 | for obj in no_rotation_category: 218 | print(obj) 219 | train(args, obj) 220 | 221 | for obj in slight_rotation_category: 222 | print(obj) 223 | train(args, obj, rotate_90=False, random_rotate=5) 224 | 225 | for obj in rotation_category: 226 | print(obj) 227 | train(args, obj, rotate_90=True, random_rotate=5) 228 | --------------------------------------------------------------------------------