├── LICENSE ├── dataset.py ├── de_resnet.py ├── main.py ├── readme.md ├── resnet.py └── test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 hq-deng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | import os 4 | import torch 5 | import glob 6 | from torchvision.datasets import MNIST, CIFAR10, FashionMNIST, ImageFolder 7 | import numpy as np 8 | 9 | def get_data_transforms(size, isize): 10 | mean_train = [0.485, 0.456, 0.406] 11 | std_train = [0.229, 0.224, 0.225] 12 | data_transforms = transforms.Compose([ 13 | transforms.Resize((size, size)), 14 | transforms.ToTensor(), 15 | transforms.CenterCrop(isize), 16 | #transforms.CenterCrop(args.input_size), 17 | transforms.Normalize(mean=mean_train, 18 | std=std_train)]) 19 | gt_transforms = transforms.Compose([ 20 | transforms.Resize((size, size)), 21 | transforms.CenterCrop(isize), 22 | transforms.ToTensor()]) 23 | return data_transforms, gt_transforms 24 | 25 | 26 | 27 | class MVTecDataset(torch.utils.data.Dataset): 28 | def __init__(self, root, transform, gt_transform, phase): 29 | if phase == 'train': 30 | self.img_path = os.path.join(root, 'train') 31 | else: 32 | self.img_path = os.path.join(root, 'test') 33 | self.gt_path = os.path.join(root, 'ground_truth') 34 | self.transform = transform 35 | self.gt_transform = gt_transform 36 | # load dataset 37 | self.img_paths, self.gt_paths, self.labels, self.types = self.load_dataset() # self.labels => good : 0, anomaly : 1 38 | 39 | def load_dataset(self): 40 | 41 | img_tot_paths = [] 42 | gt_tot_paths = [] 43 | tot_labels = [] 44 | tot_types = [] 45 | 46 | defect_types = os.listdir(self.img_path) 47 | 48 | for defect_type in defect_types: 49 | if defect_type == 'good': 50 | img_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") 51 | img_tot_paths.extend(img_paths) 52 | gt_tot_paths.extend([0] * len(img_paths)) 53 | tot_labels.extend([0] * len(img_paths)) 54 | tot_types.extend(['good'] * len(img_paths)) 55 | else: 56 | img_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") 57 | gt_paths = glob.glob(os.path.join(self.gt_path, defect_type) + "/*.png") 58 | img_paths.sort() 59 | gt_paths.sort() 60 | img_tot_paths.extend(img_paths) 61 | gt_tot_paths.extend(gt_paths) 62 | tot_labels.extend([1] * len(img_paths)) 63 | tot_types.extend([defect_type] * len(img_paths)) 64 | 65 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 66 | 67 | return img_tot_paths, gt_tot_paths, tot_labels, tot_types 68 | 69 | def __len__(self): 70 | return len(self.img_paths) 71 | 72 | def __getitem__(self, idx): 73 | img_path, gt, label, img_type = self.img_paths[idx], self.gt_paths[idx], self.labels[idx], self.types[idx] 74 | img = Image.open(img_path).convert('RGB') 75 | img = self.transform(img) 76 | if gt == 0: 77 | gt = torch.zeros([1, img.size()[-2], img.size()[-2]]) 78 | else: 79 | gt = Image.open(gt) 80 | gt = self.gt_transform(gt) 81 | 82 | assert img.size()[1:] == gt.size()[1:], "image.size != gt.size !!!" 83 | 84 | return img, gt, label, img_type 85 | 86 | def load_data(dataset_name='mnist',normal_class=0,batch_size='16'): 87 | 88 | if dataset_name == 'cifar10': 89 | img_transform = transforms.Compose([ 90 | transforms.Resize((32, 32)), 91 | #transforms.CenterCrop(28), 92 | transforms.ToTensor(), 93 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 94 | ]) 95 | 96 | os.makedirs("./Dataset/CIFAR10/train", exist_ok=True) 97 | dataset = CIFAR10('./Dataset/CIFAR10/train', train=True, download=True, transform=img_transform) 98 | print("Cifar10 DataLoader Called...") 99 | print("All Train Data: ", dataset.data.shape) 100 | dataset.data = dataset.data[np.array(dataset.targets) == normal_class] 101 | dataset.targets = [normal_class] * dataset.data.shape[0] 102 | print("Normal Train Data: ", dataset.data.shape) 103 | 104 | os.makedirs("./Dataset/CIFAR10/test", exist_ok=True) 105 | test_set = CIFAR10("./Dataset/CIFAR10/test", train=False, download=True, transform=img_transform) 106 | print("Test Train Data:", test_set.data.shape) 107 | 108 | elif dataset_name == 'mnist': 109 | img_transform = transforms.Compose([ 110 | transforms.Resize((32, 32)), 111 | transforms.ToTensor() 112 | ]) 113 | 114 | os.makedirs("./Dataset/MNIST/train", exist_ok=True) 115 | dataset = MNIST('./Dataset/MNIST/train', train=True, download=True, transform=img_transform) 116 | print("MNIST DataLoader Called...") 117 | print("All Train Data: ", dataset.data.shape) 118 | dataset.data = dataset.data[np.array(dataset.targets) == normal_class] 119 | dataset.targets = [normal_class] * dataset.data.shape[0] 120 | print("Normal Train Data: ", dataset.data.shape) 121 | 122 | os.makedirs("./Dataset/MNIST/test", exist_ok=True) 123 | test_set = MNIST("./Dataset/MNIST/test", train=False, download=True, transform=img_transform) 124 | print("Test Train Data:", test_set.data.shape) 125 | 126 | elif dataset_name == 'fashionmnist': 127 | img_transform = transforms.Compose([ 128 | transforms.Resize((32, 32)), 129 | transforms.ToTensor() 130 | ]) 131 | 132 | os.makedirs("./Dataset/FashionMNIST/train", exist_ok=True) 133 | dataset = FashionMNIST('./Dataset/FashionMNIST/train', train=True, download=True, transform=img_transform) 134 | print("FashionMNIST DataLoader Called...") 135 | print("All Train Data: ", dataset.data.shape) 136 | dataset.data = dataset.data[np.array(dataset.targets) == normal_class] 137 | dataset.targets = [normal_class] * dataset.data.shape[0] 138 | print("Normal Train Data: ", dataset.data.shape) 139 | 140 | os.makedirs("./Dataset/FashionMNIST/test", exist_ok=True) 141 | test_set = FashionMNIST("./Dataset/FashionMNIST/test", train=False, download=True, transform=img_transform) 142 | print("Test Train Data:", test_set.data.shape) 143 | 144 | 145 | elif dataset_name == 'retina': 146 | data_path = 'Dataset/OCT2017/train' 147 | 148 | orig_transform = transforms.Compose([ 149 | transforms.Resize([128, 128]), 150 | transforms.ToTensor() 151 | ]) 152 | 153 | dataset = ImageFolder(root=data_path, transform=orig_transform) 154 | 155 | test_data_path = 'Dataset/OCT2017/test' 156 | test_set = ImageFolder(root=test_data_path, transform=orig_transform) 157 | 158 | else: 159 | raise Exception( 160 | "You enter {} as dataset, which is not a valid dataset for this repository!".format(dataset_name)) 161 | 162 | train_dataloader = torch.utils.data.DataLoader( 163 | dataset, 164 | batch_size=batch_size, 165 | shuffle=True, 166 | ) 167 | test_dataloader = torch.utils.data.DataLoader( 168 | test_set, 169 | batch_size=1, 170 | shuffle=False, 171 | ) 172 | 173 | return train_dataloader, test_dataloader 174 | -------------------------------------------------------------------------------- /de_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | from typing import Type, Any, Callable, Union, List, Optional 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 13 | 'wide_resnet50_2', 'wide_resnet101_2'] 14 | 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 22 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 23 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 24 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 25 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 26 | } 27 | 28 | 29 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 30 | """3x3 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=dilation, groups=groups, bias=False, dilation=dilation) 33 | 34 | 35 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 36 | """1x1 convolution""" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 38 | 39 | def deconv2x2(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 40 | """1x1 convolution""" 41 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=stride, 42 | groups=groups, bias=False, dilation=dilation) 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | expansion: int = 1 47 | 48 | def __init__( 49 | self, 50 | inplanes: int, 51 | planes: int, 52 | stride: int = 1, 53 | upsample: Optional[nn.Module] = None, 54 | groups: int = 1, 55 | base_width: int = 64, 56 | dilation: int = 1, 57 | norm_layer: Optional[Callable[..., nn.Module]] = None 58 | ) -> None: 59 | super(BasicBlock, self).__init__() 60 | if norm_layer is None: 61 | norm_layer = nn.BatchNorm2d 62 | if groups != 1 or base_width != 64: 63 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 64 | if dilation > 1: 65 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 66 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 67 | if stride == 2: 68 | self.conv1 = deconv2x2(inplanes, planes, stride) 69 | else: 70 | self.conv1 = conv3x3(inplanes, planes, stride) 71 | self.bn1 = norm_layer(planes) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.conv2 = conv3x3(planes, planes) 74 | self.bn2 = norm_layer(planes) 75 | self.upsample = upsample 76 | self.stride = stride 77 | 78 | def forward(self, x: Tensor) -> Tensor: 79 | identity = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | 88 | if self.upsample is not None: 89 | identity = self.upsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class Bottleneck(nn.Module): 98 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 99 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 100 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 101 | # This variant is also known as ResNet V1.5 and improves accuracy according to 102 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 103 | 104 | expansion: int = 4 105 | 106 | def __init__( 107 | self, 108 | inplanes: int, 109 | planes: int, 110 | stride: int = 1, 111 | upsample: Optional[nn.Module] = None, 112 | groups: int = 1, 113 | base_width: int = 64, 114 | dilation: int = 1, 115 | norm_layer: Optional[Callable[..., nn.Module]] = None 116 | ) -> None: 117 | super(Bottleneck, self).__init__() 118 | if norm_layer is None: 119 | norm_layer = nn.BatchNorm2d 120 | width = int(planes * (base_width / 64.)) * groups 121 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 122 | self.conv1 = conv1x1(inplanes, width) 123 | self.bn1 = norm_layer(width) 124 | if stride == 2: 125 | self.conv2 = deconv2x2(width, width, stride, groups, dilation) 126 | else: 127 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 128 | self.bn2 = norm_layer(width) 129 | self.conv3 = conv1x1(width, planes * self.expansion) 130 | self.bn3 = norm_layer(planes * self.expansion) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.upsample = upsample 133 | self.stride = stride 134 | 135 | def forward(self, x: Tensor) -> Tensor: 136 | identity = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | out = self.relu(out) 145 | 146 | out = self.conv3(out) 147 | out = self.bn3(out) 148 | 149 | if self.upsample is not None: 150 | identity = self.upsample(x) 151 | 152 | out += identity 153 | out = self.relu(out) 154 | 155 | return out 156 | 157 | 158 | class ResNet(nn.Module): 159 | 160 | def __init__( 161 | self, 162 | block: Type[Union[BasicBlock, Bottleneck]], 163 | layers: List[int], 164 | num_classes: int = 1000, 165 | zero_init_residual: bool = False, 166 | groups: int = 1, 167 | width_per_group: int = 64, 168 | replace_stride_with_dilation: Optional[List[bool]] = None, 169 | norm_layer: Optional[Callable[..., nn.Module]] = None 170 | ) -> None: 171 | super(ResNet, self).__init__() 172 | if norm_layer is None: 173 | norm_layer = nn.BatchNorm2d 174 | self._norm_layer = norm_layer 175 | 176 | self.inplanes = 512 * block.expansion 177 | self.dilation = 1 178 | if replace_stride_with_dilation is None: 179 | # each element in the tuple indicates if we should replace 180 | # the 2x2 stride with a dilated convolution instead 181 | replace_stride_with_dilation = [False, False, False] 182 | if len(replace_stride_with_dilation) != 3: 183 | raise ValueError("replace_stride_with_dilation should be None " 184 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 185 | self.groups = groups 186 | self.base_width = width_per_group 187 | #self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 188 | # bias=False) 189 | #self.bn1 = norm_layer(self.inplanes) 190 | #self.relu = nn.ReLU(inplace=True) 191 | #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 192 | self.layer1 = self._make_layer(block, 256, layers[0], stride=2) 193 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 194 | dilate=replace_stride_with_dilation[0]) 195 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, 196 | dilate=replace_stride_with_dilation[1]) 197 | #self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 198 | # dilate=replace_stride_with_dilation[2]) 199 | #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 200 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 201 | 202 | for m in self.modules(): 203 | if isinstance(m, nn.Conv2d): 204 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 205 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 206 | nn.init.constant_(m.weight, 1) 207 | nn.init.constant_(m.bias, 0) 208 | 209 | # Zero-initialize the last BN in each residual branch, 210 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 211 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 212 | if zero_init_residual: 213 | for m in self.modules(): 214 | if isinstance(m, Bottleneck): 215 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 216 | elif isinstance(m, BasicBlock): 217 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 218 | 219 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 220 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 221 | norm_layer = self._norm_layer 222 | upsample = None 223 | previous_dilation = self.dilation 224 | if dilate: 225 | self.dilation *= stride 226 | stride = 1 227 | if stride != 1 or self.inplanes != planes * block.expansion: 228 | upsample = nn.Sequential( 229 | deconv2x2(self.inplanes, planes * block.expansion, stride), 230 | norm_layer(planes * block.expansion), 231 | ) 232 | 233 | layers = [] 234 | layers.append(block(self.inplanes, planes, stride, upsample, self.groups, 235 | self.base_width, previous_dilation, norm_layer)) 236 | self.inplanes = planes * block.expansion 237 | for _ in range(1, blocks): 238 | layers.append(block(self.inplanes, planes, groups=self.groups, 239 | base_width=self.base_width, dilation=self.dilation, 240 | norm_layer=norm_layer)) 241 | 242 | return nn.Sequential(*layers) 243 | 244 | def _forward_impl(self, x: Tensor) -> Tensor: 245 | # See note [TorchScript super()] 246 | #x = self.conv1(x) 247 | #x = self.bn1(x) 248 | #x = self.relu(x) 249 | #x = self.maxpool(x) 250 | 251 | feature_a = self.layer1(x) # 512*8*8->256*16*16 252 | feature_b = self.layer2(feature_a) # 256*16*16->128*32*32 253 | feature_c = self.layer3(feature_b) # 128*32*32->64*64*64 254 | #feature_d = self.layer4(feature_c) # 64*64*64->128*32*32 255 | 256 | #x = self.avgpool(feature_d) 257 | #x = torch.flatten(x, 1) 258 | #x = self.fc(x) 259 | 260 | return [feature_c, feature_b, feature_a] 261 | 262 | def forward(self, x: Tensor) -> Tensor: 263 | return self._forward_impl(x) 264 | 265 | 266 | def _resnet( 267 | arch: str, 268 | block: Type[Union[BasicBlock, Bottleneck]], 269 | layers: List[int], 270 | pretrained: bool, 271 | progress: bool, 272 | **kwargs: Any 273 | ) -> ResNet: 274 | model = ResNet(block, layers, **kwargs) 275 | if pretrained: 276 | state_dict = load_state_dict_from_url(model_urls[arch], 277 | progress=progress) 278 | #for k,v in list(state_dict.items()): 279 | # if 'layer4' in k or 'fc' in k: 280 | # state_dict.pop(k) 281 | model.load_state_dict(state_dict) 282 | return model 283 | 284 | 285 | def de_resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 286 | r"""ResNet-18 model from 287 | `"Deep Residual Learning for Image Recognition" `_. 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 293 | **kwargs) 294 | 295 | 296 | def de_resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 297 | r"""ResNet-34 model from 298 | `"Deep Residual Learning for Image Recognition" `_. 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | """ 303 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 304 | **kwargs) 305 | 306 | 307 | def de_resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 308 | r"""ResNet-50 model from 309 | `"Deep Residual Learning for Image Recognition" `_. 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 315 | **kwargs) 316 | 317 | 318 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 319 | r"""ResNet-101 model from 320 | `"Deep Residual Learning for Image Recognition" `_. 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | progress (bool): If True, displays a progress bar of the download to stderr 324 | """ 325 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 326 | **kwargs) 327 | 328 | 329 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 330 | r"""ResNet-152 model from 331 | `"Deep Residual Learning for Image Recognition" `_. 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 337 | **kwargs) 338 | 339 | 340 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 341 | r"""ResNeXt-50 32x4d model from 342 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 343 | Args: 344 | pretrained (bool): If True, returns a model pre-trained on ImageNet 345 | progress (bool): If True, displays a progress bar of the download to stderr 346 | """ 347 | kwargs['groups'] = 32 348 | kwargs['width_per_group'] = 4 349 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 350 | pretrained, progress, **kwargs) 351 | 352 | 353 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 354 | r"""ResNeXt-101 32x8d model from 355 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 356 | Args: 357 | pretrained (bool): If True, returns a model pre-trained on ImageNet 358 | progress (bool): If True, displays a progress bar of the download to stderr 359 | """ 360 | kwargs['groups'] = 32 361 | kwargs['width_per_group'] = 8 362 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 363 | pretrained, progress, **kwargs) 364 | 365 | 366 | def de_wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 367 | r"""Wide ResNet-50-2 model from 368 | `"Wide Residual Networks" `_. 369 | The model is the same as ResNet except for the bottleneck number of channels 370 | which is twice larger in every block. The number of channels in outer 1x1 371 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 372 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 373 | Args: 374 | pretrained (bool): If True, returns a model pre-trained on ImageNet 375 | progress (bool): If True, displays a progress bar of the download to stderr 376 | """ 377 | kwargs['width_per_group'] = 64 * 2 378 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 379 | pretrained, progress, **kwargs) 380 | 381 | 382 | def de_wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 383 | r"""Wide ResNet-101-2 model from 384 | `"Wide Residual Networks" `_. 385 | The model is the same as ResNet except for the bottleneck number of channels 386 | which is twice larger in every block. The number of channels in outer 1x1 387 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 388 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 389 | Args: 390 | pretrained (bool): If True, returns a model pre-trained on ImageNet 391 | progress (bool): If True, displays a progress bar of the download to stderr 392 | """ 393 | kwargs['width_per_group'] = 64 * 2 394 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 395 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # This is a sample Python script. 2 | 3 | # Press ⌃R to execute it or replace it with your code. 4 | # Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings. 5 | 6 | import torch 7 | from dataset import get_data_transforms 8 | from torchvision.datasets import ImageFolder 9 | import numpy as np 10 | import random 11 | import os 12 | from torch.utils.data import DataLoader 13 | from resnet import resnet18, resnet34, resnet50, wide_resnet50_2 14 | from de_resnet import de_resnet18, de_resnet34, de_wide_resnet50_2, de_resnet50 15 | from dataset import MVTecDataset 16 | import torch.backends.cudnn as cudnn 17 | import argparse 18 | from test import evaluation, visualization, test 19 | from torch.nn import functional as F 20 | 21 | def count_parameters(model): 22 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 23 | 24 | 25 | def setup_seed(seed): 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | np.random.seed(seed) 29 | random.seed(seed) 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.benchmark = False 32 | 33 | def loss_fucntion(a, b): 34 | #mse_loss = torch.nn.MSELoss() 35 | cos_loss = torch.nn.CosineSimilarity() 36 | loss = 0 37 | for item in range(len(a)): 38 | #print(a[item].shape) 39 | #print(b[item].shape) 40 | #loss += 0.1*mse_loss(a[item], b[item]) 41 | loss += torch.mean(1-cos_loss(a[item].view(a[item].shape[0],-1), 42 | b[item].view(b[item].shape[0],-1))) 43 | return loss 44 | 45 | def loss_concat(a, b): 46 | mse_loss = torch.nn.MSELoss() 47 | cos_loss = torch.nn.CosineSimilarity() 48 | loss = 0 49 | a_map = [] 50 | b_map = [] 51 | size = a[0].shape[-1] 52 | for item in range(len(a)): 53 | #loss += mse_loss(a[item], b[item]) 54 | a_map.append(F.interpolate(a[item], size=size, mode='bilinear', align_corners=True)) 55 | b_map.append(F.interpolate(b[item], size=size, mode='bilinear', align_corners=True)) 56 | a_map = torch.cat(a_map,1) 57 | b_map = torch.cat(b_map,1) 58 | loss += torch.mean(1-cos_loss(a_map,b_map)) 59 | return loss 60 | 61 | def train(_class_): 62 | print(_class_) 63 | epochs = 200 64 | learning_rate = 0.005 65 | batch_size = 16 66 | image_size = 256 67 | 68 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 69 | print(device) 70 | 71 | data_transform, gt_transform = get_data_transforms(image_size, image_size) 72 | train_path = './mvtec/' + _class_ + '/train' 73 | test_path = './mvtec/' + _class_ 74 | ckp_path = './checkpoints/' + 'wres50_'+_class_+'.pth' 75 | train_data = ImageFolder(root=train_path, transform=data_transform) 76 | test_data = MVTecDataset(root=test_path, transform=data_transform, gt_transform=gt_transform, phase="test") 77 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) 78 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 79 | 80 | encoder, bn = wide_resnet50_2(pretrained=True) 81 | encoder = encoder.to(device) 82 | bn = bn.to(device) 83 | encoder.eval() 84 | decoder = de_wide_resnet50_2(pretrained=False) 85 | decoder = decoder.to(device) 86 | 87 | optimizer = torch.optim.Adam(list(decoder.parameters())+list(bn.parameters()), lr=learning_rate, betas=(0.5,0.999)) 88 | 89 | 90 | for epoch in range(epochs): 91 | bn.train() 92 | decoder.train() 93 | loss_list = [] 94 | for img, label in train_dataloader: 95 | img = img.to(device) 96 | inputs = encoder(img) 97 | outputs = decoder(bn(inputs))#bn(inputs)) 98 | loss = loss_fucntion(inputs, outputs) 99 | optimizer.zero_grad() 100 | loss.backward() 101 | optimizer.step() 102 | loss_list.append(loss.item()) 103 | print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, np.mean(loss_list))) 104 | if (epoch + 1) % 10 == 0: 105 | auroc_px, auroc_sp, aupro_px = evaluation(encoder, bn, decoder, test_dataloader, device) 106 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px)) 107 | torch.save({'bn': bn.state_dict(), 108 | 'decoder': decoder.state_dict()}, ckp_path) 109 | return auroc_px, auroc_sp, aupro_px 110 | 111 | 112 | 113 | 114 | if __name__ == '__main__': 115 | 116 | setup_seed(111) 117 | item_list = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill', 118 | 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood'] 119 | for i in item_list: 120 | train(i) 121 | 122 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## CVPR2022 - Anomaly Detection via Reverse Distillation from One-Class Embedding 2 | ## Implementation (Official Code ⭐️ ⭐️ ⭐️ ) 3 | 4 | 1. Environment 5 | > pytorch == 1.91 6 | 7 | > torchvision == 0.10.1 8 | 9 | > numpy == 1.20.3 10 | 11 | > scipy == 1.7.1 12 | 13 | > sklearn == 1.0 14 | 15 | > PIL == 8.3.2 16 | 2. Dataset 17 | > You should download MVTec from [MVTec AD: MVTec Software](https://www.mvtec.com/company/research/datasets/mvtec-ad/). The folder "mvtec" should be unpacked into the code folder. 18 | 3. Train and Test the Model 19 | We have write both training and evaluation function in the main.py, execute the following command to see the training and evaluation results. 20 | > python main.py 21 | 22 | ## Reference 23 | @InProceedings{Deng_2022_CVPR, 24 | author = {Deng, Hanqiu and Li, Xingyu}, 25 | title = {Anomaly Detection via Reverse Distillation From One-Class Embedding}, 26 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 27 | month = {June}, 28 | year = {2022}, 29 | pages = {9737-9746}} 30 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | from typing import Type, Any, Callable, Union, List, Optional 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 13 | 'wide_resnet50_2', 'wide_resnet101_2'] 14 | 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 22 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 23 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 24 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 25 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 26 | } 27 | 28 | 29 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 30 | """3x3 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=dilation, groups=groups, bias=False, dilation=dilation) 33 | 34 | 35 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 36 | """1x1 convolution""" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion: int = 1 42 | 43 | def __init__( 44 | self, 45 | inplanes: int, 46 | planes: int, 47 | stride: int = 1, 48 | downsample: Optional[nn.Module] = None, 49 | groups: int = 1, 50 | base_width: int = 64, 51 | dilation: int = 1, 52 | norm_layer: Optional[Callable[..., nn.Module]] = None 53 | ) -> None: 54 | super(BasicBlock, self).__init__() 55 | if norm_layer is None: 56 | norm_layer = nn.BatchNorm2d 57 | if groups != 1 or base_width != 64: 58 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 59 | if dilation > 1: 60 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 61 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 62 | self.conv1 = conv3x3(inplanes, planes, stride) 63 | self.bn1 = norm_layer(planes) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.conv2 = conv3x3(planes, planes) 66 | self.bn2 = norm_layer(planes) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x: Tensor) -> Tensor: 71 | identity = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | 80 | if self.downsample is not None: 81 | identity = self.downsample(x) 82 | 83 | out += identity 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class Bottleneck(nn.Module): 90 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 91 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 92 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 93 | # This variant is also known as ResNet V1.5 and improves accuracy according to 94 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 95 | 96 | expansion: int = 4 97 | 98 | def __init__( 99 | self, 100 | inplanes: int, 101 | planes: int, 102 | stride: int = 1, 103 | downsample: Optional[nn.Module] = None, 104 | groups: int = 1, 105 | base_width: int = 64, 106 | dilation: int = 1, 107 | norm_layer: Optional[Callable[..., nn.Module]] = None 108 | ) -> None: 109 | super(Bottleneck, self).__init__() 110 | if norm_layer is None: 111 | norm_layer = nn.BatchNorm2d 112 | width = int(planes * (base_width / 64.)) * groups 113 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 114 | self.conv1 = conv1x1(inplanes, width) 115 | self.bn1 = norm_layer(width) 116 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 117 | self.bn2 = norm_layer(width) 118 | self.conv3 = conv1x1(width, planes * self.expansion) 119 | self.bn3 = norm_layer(planes * self.expansion) 120 | self.relu = nn.ReLU(inplace=True) 121 | self.downsample = downsample 122 | self.stride = stride 123 | 124 | def forward(self, x: Tensor) -> Tensor: 125 | identity = x 126 | 127 | out = self.conv1(x) 128 | out = self.bn1(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv2(out) 132 | out = self.bn2(out) 133 | out = self.relu(out) 134 | 135 | out = self.conv3(out) 136 | out = self.bn3(out) 137 | 138 | if self.downsample is not None: 139 | identity = self.downsample(x) 140 | 141 | out += identity 142 | out = self.relu(out) 143 | 144 | return out 145 | 146 | 147 | class ResNet(nn.Module): 148 | 149 | def __init__( 150 | self, 151 | block: Type[Union[BasicBlock, Bottleneck]], 152 | layers: List[int], 153 | num_classes: int = 1000, 154 | zero_init_residual: bool = False, 155 | groups: int = 1, 156 | width_per_group: int = 64, 157 | replace_stride_with_dilation: Optional[List[bool]] = None, 158 | norm_layer: Optional[Callable[..., nn.Module]] = None 159 | ) -> None: 160 | super(ResNet, self).__init__() 161 | if norm_layer is None: 162 | norm_layer = nn.BatchNorm2d 163 | self._norm_layer = norm_layer 164 | 165 | self.inplanes = 64 166 | self.dilation = 1 167 | if replace_stride_with_dilation is None: 168 | # each element in the tuple indicates if we should replace 169 | # the 2x2 stride with a dilated convolution instead 170 | replace_stride_with_dilation = [False, False, False] 171 | if len(replace_stride_with_dilation) != 3: 172 | raise ValueError("replace_stride_with_dilation should be None " 173 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 174 | self.groups = groups 175 | self.base_width = width_per_group 176 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 177 | bias=False) 178 | self.bn1 = norm_layer(self.inplanes) 179 | self.relu = nn.ReLU(inplace=True) 180 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 181 | self.layer1 = self._make_layer(block, 64, layers[0]) 182 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 183 | dilate=replace_stride_with_dilation[0]) 184 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 185 | dilate=replace_stride_with_dilation[1]) 186 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 187 | dilate=replace_stride_with_dilation[2]) 188 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 189 | self.fc = nn.Linear(512 * block.expansion, num_classes) 190 | 191 | for m in self.modules(): 192 | if isinstance(m, nn.Conv2d): 193 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 194 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 195 | nn.init.constant_(m.weight, 1) 196 | nn.init.constant_(m.bias, 0) 197 | 198 | # Zero-initialize the last BN in each residual branch, 199 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 200 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 201 | if zero_init_residual: 202 | for m in self.modules(): 203 | if isinstance(m, Bottleneck): 204 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 205 | elif isinstance(m, BasicBlock): 206 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 207 | 208 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 209 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 210 | norm_layer = self._norm_layer 211 | downsample = None 212 | previous_dilation = self.dilation 213 | if dilate: 214 | self.dilation *= stride 215 | stride = 1 216 | if stride != 1 or self.inplanes != planes * block.expansion: 217 | downsample = nn.Sequential( 218 | conv1x1(self.inplanes, planes * block.expansion, stride), 219 | norm_layer(planes * block.expansion), 220 | ) 221 | 222 | layers = [] 223 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 224 | self.base_width, previous_dilation, norm_layer)) 225 | self.inplanes = planes * block.expansion 226 | for _ in range(1, blocks): 227 | layers.append(block(self.inplanes, planes, groups=self.groups, 228 | base_width=self.base_width, dilation=self.dilation, 229 | norm_layer=norm_layer)) 230 | 231 | return nn.Sequential(*layers) 232 | 233 | def _forward_impl(self, x: Tensor) -> Tensor: 234 | # See note [TorchScript super()] 235 | x = self.conv1(x) 236 | x = self.bn1(x) 237 | x = self.relu(x) 238 | x = self.maxpool(x) 239 | 240 | feature_a = self.layer1(x) 241 | feature_b = self.layer2(feature_a) 242 | feature_c = self.layer3(feature_b) 243 | feature_d = self.layer4(feature_c) 244 | 245 | 246 | return [feature_a, feature_b, feature_c] 247 | 248 | def forward(self, x: Tensor) -> Tensor: 249 | return self._forward_impl(x) 250 | 251 | 252 | def _resnet( 253 | arch: str, 254 | block: Type[Union[BasicBlock, Bottleneck]], 255 | layers: List[int], 256 | pretrained: bool, 257 | progress: bool, 258 | **kwargs: Any 259 | ) -> ResNet: 260 | model = ResNet(block, layers, **kwargs) 261 | if pretrained: 262 | state_dict = load_state_dict_from_url(model_urls[arch], 263 | progress=progress) 264 | #for k,v in list(state_dict.items()): 265 | # if 'layer4' in k or 'fc' in k: 266 | # state_dict.pop(k) 267 | model.load_state_dict(state_dict) 268 | return model 269 | 270 | class AttnBasicBlock(nn.Module): 271 | expansion: int = 1 272 | 273 | def __init__( 274 | self, 275 | inplanes: int, 276 | planes: int, 277 | stride: int = 1, 278 | downsample: Optional[nn.Module] = None, 279 | groups: int = 1, 280 | base_width: int = 64, 281 | dilation: int = 1, 282 | norm_layer: Optional[Callable[..., nn.Module]] = None, 283 | attention: bool = True, 284 | ) -> None: 285 | super(AttnBasicBlock, self).__init__() 286 | self.attention = attention 287 | #print("Attention:", self.attention) 288 | if norm_layer is None: 289 | norm_layer = nn.BatchNorm2d 290 | if groups != 1 or base_width != 64: 291 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 292 | if dilation > 1: 293 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 294 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 295 | self.conv1 = conv3x3(inplanes, planes, stride) 296 | self.bn1 = norm_layer(planes) 297 | self.relu = nn.ReLU(inplace=True) 298 | self.conv2 = conv3x3(planes, planes) 299 | self.bn2 = norm_layer(planes) 300 | #self.cbam = GLEAM(planes, 16) 301 | self.downsample = downsample 302 | self.stride = stride 303 | 304 | def forward(self, x: Tensor) -> Tensor: 305 | #if self.attention: 306 | # x = self.cbam(x) 307 | identity = x 308 | 309 | out = self.conv1(x) 310 | out = self.bn1(out) 311 | out = self.relu(out) 312 | 313 | out = self.conv2(out) 314 | out = self.bn2(out) 315 | 316 | 317 | if self.downsample is not None: 318 | identity = self.downsample(x) 319 | 320 | out += identity 321 | out = self.relu(out) 322 | 323 | return out 324 | 325 | class AttnBottleneck(nn.Module): 326 | 327 | expansion: int = 4 328 | 329 | def __init__( 330 | self, 331 | inplanes: int, 332 | planes: int, 333 | stride: int = 1, 334 | downsample: Optional[nn.Module] = None, 335 | groups: int = 1, 336 | base_width: int = 64, 337 | dilation: int = 1, 338 | norm_layer: Optional[Callable[..., nn.Module]] = None, 339 | attention: bool = True, 340 | ) -> None: 341 | super(AttnBottleneck, self).__init__() 342 | self.attention = attention 343 | #print("Attention:",self.attention) 344 | if norm_layer is None: 345 | norm_layer = nn.BatchNorm2d 346 | width = int(planes * (base_width / 64.)) * groups 347 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 348 | self.conv1 = conv1x1(inplanes, width) 349 | self.bn1 = norm_layer(width) 350 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 351 | self.bn2 = norm_layer(width) 352 | self.conv3 = conv1x1(width, planes * self.expansion) 353 | self.bn3 = norm_layer(planes * self.expansion) 354 | self.relu = nn.ReLU(inplace=True) 355 | #self.cbam = GLEAM([int(planes * self.expansion/4), 356 | # int(planes * self.expansion//2), 357 | # planes * self.expansion], 16) 358 | self.downsample = downsample 359 | self.stride = stride 360 | 361 | def forward(self, x: Tensor) -> Tensor: 362 | #if self.attention: 363 | # x = self.cbam(x) 364 | identity = x 365 | 366 | out = self.conv1(x) 367 | out = self.bn1(out) 368 | out = self.relu(out) 369 | 370 | out = self.conv2(out) 371 | out = self.bn2(out) 372 | out = self.relu(out) 373 | 374 | out = self.conv3(out) 375 | out = self.bn3(out) 376 | 377 | if self.downsample is not None: 378 | identity = self.downsample(x) 379 | 380 | 381 | out += identity 382 | out = self.relu(out) 383 | 384 | return out 385 | 386 | class BN_layer(nn.Module): 387 | def __init__(self, 388 | block: Type[Union[BasicBlock, Bottleneck]], 389 | layers: int, 390 | groups: int = 1, 391 | width_per_group: int = 64, 392 | norm_layer: Optional[Callable[..., nn.Module]] = None, 393 | ): 394 | super(BN_layer, self).__init__() 395 | if norm_layer is None: 396 | norm_layer = nn.BatchNorm2d 397 | self._norm_layer = norm_layer 398 | self.groups = groups 399 | self.base_width = width_per_group 400 | self.inplanes = 256 * block.expansion 401 | self.dilation = 1 402 | self.bn_layer = self._make_layer(block, 512, layers, stride=2) 403 | 404 | self.conv1 = conv3x3(64 * block.expansion, 128 * block.expansion, 2) 405 | self.bn1 = norm_layer(128 * block.expansion) 406 | self.relu = nn.ReLU(inplace=True) 407 | self.conv2 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) 408 | self.bn2 = norm_layer(256 * block.expansion) 409 | self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) 410 | self.bn3 = norm_layer(256 * block.expansion) 411 | 412 | self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1) 413 | self.bn4 = norm_layer(512 * block.expansion) 414 | 415 | 416 | for m in self.modules(): 417 | if isinstance(m, nn.Conv2d): 418 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 419 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 420 | nn.init.constant_(m.weight, 1) 421 | nn.init.constant_(m.bias, 0) 422 | 423 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 424 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 425 | norm_layer = self._norm_layer 426 | downsample = None 427 | previous_dilation = self.dilation 428 | if dilate: 429 | self.dilation *= stride 430 | stride = 1 431 | if stride != 1 or self.inplanes != planes * block.expansion: 432 | downsample = nn.Sequential( 433 | conv1x1(self.inplanes*3, planes * block.expansion, stride), 434 | norm_layer(planes * block.expansion), 435 | ) 436 | 437 | layers = [] 438 | layers.append(block(self.inplanes*3, planes, stride, downsample, self.groups, 439 | self.base_width, previous_dilation, norm_layer)) 440 | self.inplanes = planes * block.expansion 441 | for _ in range(1, blocks): 442 | layers.append(block(self.inplanes, planes, groups=self.groups, 443 | base_width=self.base_width, dilation=self.dilation, 444 | norm_layer=norm_layer)) 445 | 446 | return nn.Sequential(*layers) 447 | 448 | def _forward_impl(self, x: Tensor) -> Tensor: 449 | # See note [TorchScript super()] 450 | #x = self.cbam(x) 451 | l1 = self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x[0])))))) 452 | l2 = self.relu(self.bn3(self.conv3(x[1]))) 453 | feature = torch.cat([l1,l2,x[2]],1) 454 | output = self.bn_layer(feature) 455 | #x = self.avgpool(feature_d) 456 | #x = torch.flatten(x, 1) 457 | #x = self.fc(x) 458 | 459 | return output.contiguous() 460 | 461 | def forward(self, x: Tensor) -> Tensor: 462 | return self._forward_impl(x) 463 | 464 | 465 | def resnet18(pretrained: bool = False, progress: bool = True,**kwargs: Any) -> ResNet: 466 | r"""ResNet-18 model from 467 | `"Deep Residual Learning for Image Recognition" `_. 468 | Args: 469 | pretrained (bool): If True, returns a model pre-trained on ImageNet 470 | progress (bool): If True, displays a progress bar of the download to stderr 471 | """ 472 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 473 | **kwargs), BN_layer(AttnBasicBlock,2,**kwargs) 474 | 475 | 476 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 477 | r"""ResNet-34 model from 478 | `"Deep Residual Learning for Image Recognition" `_. 479 | Args: 480 | pretrained (bool): If True, returns a model pre-trained on ImageNet 481 | progress (bool): If True, displays a progress bar of the download to stderr 482 | """ 483 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 484 | **kwargs), BN_layer(AttnBasicBlock,3,**kwargs) 485 | 486 | 487 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 488 | r"""ResNet-50 model from 489 | `"Deep Residual Learning for Image Recognition" `_. 490 | Args: 491 | pretrained (bool): If True, returns a model pre-trained on ImageNet 492 | progress (bool): If True, displays a progress bar of the download to stderr 493 | """ 494 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 495 | **kwargs), BN_layer(AttnBottleneck,3,**kwargs) 496 | 497 | 498 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 499 | r"""ResNet-101 model from 500 | `"Deep Residual Learning for Image Recognition" `_. 501 | Args: 502 | pretrained (bool): If True, returns a model pre-trained on ImageNet 503 | progress (bool): If True, displays a progress bar of the download to stderr 504 | """ 505 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 506 | **kwargs), BN_layer(AttnBasicBlock,3,**kwargs) 507 | 508 | 509 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 510 | r"""ResNet-152 model from 511 | `"Deep Residual Learning for Image Recognition" `_. 512 | Args: 513 | pretrained (bool): If True, returns a model pre-trained on ImageNet 514 | progress (bool): If True, displays a progress bar of the download to stderr 515 | """ 516 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 517 | **kwargs), BN_layer(AttnBottleneck,3,**kwargs) 518 | 519 | 520 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 521 | r"""ResNeXt-50 32x4d model from 522 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 523 | Args: 524 | pretrained (bool): If True, returns a model pre-trained on ImageNet 525 | progress (bool): If True, displays a progress bar of the download to stderr 526 | """ 527 | kwargs['groups'] = 32 528 | kwargs['width_per_group'] = 4 529 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 530 | pretrained, progress, **kwargs) 531 | 532 | 533 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 534 | r"""ResNeXt-101 32x8d model from 535 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 536 | Args: 537 | pretrained (bool): If True, returns a model pre-trained on ImageNet 538 | progress (bool): If True, displays a progress bar of the download to stderr 539 | """ 540 | kwargs['groups'] = 32 541 | kwargs['width_per_group'] = 8 542 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 543 | pretrained, progress, **kwargs) 544 | 545 | 546 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 547 | r"""Wide ResNet-50-2 model from 548 | `"Wide Residual Networks" `_. 549 | The model is the same as ResNet except for the bottleneck number of channels 550 | which is twice larger in every block. The number of channels in outer 1x1 551 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 552 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 553 | Args: 554 | pretrained (bool): If True, returns a model pre-trained on ImageNet 555 | progress (bool): If True, displays a progress bar of the download to stderr 556 | """ 557 | kwargs['width_per_group'] = 64 * 2 558 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 559 | pretrained, progress, **kwargs), BN_layer(AttnBottleneck,3,**kwargs) 560 | 561 | 562 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 563 | r"""Wide ResNet-101-2 model from 564 | `"Wide Residual Networks" `_. 565 | The model is the same as ResNet except for the bottleneck number of channels 566 | which is twice larger in every block. The number of channels in outer 1x1 567 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 568 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 569 | Args: 570 | pretrained (bool): If True, returns a model pre-trained on ImageNet 571 | progress (bool): If True, displays a progress bar of the download to stderr 572 | """ 573 | kwargs['width_per_group'] = 64 * 2 574 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 575 | pretrained, progress, **kwargs), BN_layer(AttnBottleneck,3,**kwargs) 576 | 577 | 578 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset import get_data_transforms, load_data 3 | from torchvision.datasets import ImageFolder 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from resnet import resnet18, resnet34, resnet50, wide_resnet50_2 7 | from de_resnet import de_resnet18, de_resnet50, de_wide_resnet50_2 8 | from dataset import MVTecDataset 9 | from torch.nn import functional as F 10 | from sklearn.metrics import roc_auc_score 11 | import cv2 12 | import matplotlib.pyplot as plt 13 | from sklearn.metrics import auc 14 | from skimage import measure 15 | import pandas as pd 16 | from numpy import ndarray 17 | from statistics import mean 18 | from scipy.ndimage import gaussian_filter 19 | from sklearn import manifold 20 | from matplotlib.ticker import NullFormatter 21 | from scipy.spatial.distance import pdist 22 | import matplotlib 23 | import pickle 24 | 25 | def cal_anomaly_map(fs_list, ft_list, out_size=224, amap_mode='mul'): 26 | if amap_mode == 'mul': 27 | anomaly_map = np.ones([out_size, out_size]) 28 | else: 29 | anomaly_map = np.zeros([out_size, out_size]) 30 | a_map_list = [] 31 | for i in range(len(ft_list)): 32 | fs = fs_list[i] 33 | ft = ft_list[i] 34 | #fs_norm = F.normalize(fs, p=2) 35 | #ft_norm = F.normalize(ft, p=2) 36 | a_map = 1 - F.cosine_similarity(fs, ft) 37 | a_map = torch.unsqueeze(a_map, dim=1) 38 | a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True) 39 | a_map = a_map[0, 0, :, :].to('cpu').detach().numpy() 40 | a_map_list.append(a_map) 41 | if amap_mode == 'mul': 42 | anomaly_map *= a_map 43 | else: 44 | anomaly_map += a_map 45 | return anomaly_map, a_map_list 46 | 47 | def show_cam_on_image(img, anomaly_map): 48 | #if anomaly_map.shape != img.shape: 49 | # anomaly_map = cv2.applyColorMap(np.uint8(anomaly_map), cv2.COLORMAP_JET) 50 | cam = np.float32(anomaly_map)/255 + np.float32(img)/255 51 | cam = cam / np.max(cam) 52 | return np.uint8(255 * cam) 53 | 54 | def min_max_norm(image): 55 | a_min, a_max = image.min(), image.max() 56 | return (image-a_min)/(a_max - a_min) 57 | 58 | def cvt2heatmap(gray): 59 | heatmap = cv2.applyColorMap(np.uint8(gray), cv2.COLORMAP_JET) 60 | return heatmap 61 | 62 | 63 | 64 | def evaluation(encoder, bn, decoder, dataloader,device,_class_=None): 65 | #_, t_bn = resnet50(pretrained=True) 66 | #bn.load_state_dict(bn.state_dict()) 67 | bn.eval() 68 | #bn.training = False 69 | #t_bn.to(device) 70 | #t_bn.load_state_dict(bn.state_dict()) 71 | decoder.eval() 72 | gt_list_px = [] 73 | pr_list_px = [] 74 | gt_list_sp = [] 75 | pr_list_sp = [] 76 | aupro_list = [] 77 | with torch.no_grad(): 78 | for img, gt, label, _ in dataloader: 79 | 80 | img = img.to(device) 81 | inputs = encoder(img) 82 | outputs = decoder(bn(inputs)) 83 | anomaly_map, _ = cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a') 84 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 85 | gt[gt > 0.5] = 1 86 | gt[gt <= 0.5] = 0 87 | if label.item()!=0: 88 | aupro_list.append(compute_pro(gt.squeeze(0).cpu().numpy().astype(int), 89 | anomaly_map[np.newaxis,:,:])) 90 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel()) 91 | pr_list_px.extend(anomaly_map.ravel()) 92 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int))) 93 | pr_list_sp.append(np.max(anomaly_map)) 94 | 95 | #ano_score = (pr_list_sp - np.min(pr_list_sp)) / (np.max(pr_list_sp) - np.min(pr_list_sp)) 96 | #vis_data = {} 97 | #vis_data['Anomaly Score'] = ano_score 98 | #vis_data['Ground Truth'] = np.array(gt_list_sp) 99 | # print(type(vis_data)) 100 | # np.save('vis.npy',vis_data) 101 | #with open('{}_vis.pkl'.format(_class_), 'wb') as f: 102 | # pickle.dump(vis_data, f, pickle.HIGHEST_PROTOCOL) 103 | 104 | 105 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3) 106 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3) 107 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3) 108 | 109 | def test(_class_): 110 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 111 | print(device) 112 | print(_class_) 113 | 114 | data_transform, gt_transform = get_data_transforms(256, 256) 115 | test_path = '../mvtec/' + _class_ 116 | ckp_path = './checkpoints/' + 'rm_1105_wres50_ff_mm_' + _class_ + '.pth' 117 | test_data = MVTecDataset(root=test_path, transform=data_transform, gt_transform=gt_transform, phase="test") 118 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 119 | encoder, bn = wide_resnet50_2(pretrained=True) 120 | encoder = encoder.to(device) 121 | bn = bn.to(device) 122 | encoder.eval() 123 | decoder = de_wide_resnet50_2(pretrained=False) 124 | decoder = decoder.to(device) 125 | ckp = torch.load(ckp_path) 126 | for k, v in list(ckp['bn'].items()): 127 | if 'memory' in k: 128 | ckp['bn'].pop(k) 129 | decoder.load_state_dict(ckp['decoder']) 130 | bn.load_state_dict(ckp['bn']) 131 | auroc_px, auroc_sp, aupro_px = evaluation(encoder, bn, decoder, test_dataloader, device,_class_) 132 | print(_class_,':',auroc_px,',',auroc_sp,',',aupro_px) 133 | return auroc_px 134 | 135 | import os 136 | 137 | def visualization(_class_): 138 | print(_class_) 139 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 140 | print(device) 141 | 142 | data_transform, gt_transform = get_data_transforms(256, 256) 143 | test_path = '../mvtec/' + _class_ 144 | ckp_path = './checkpoints/' + 'rm_1105_wres50_ff_mm_'+_class_+'.pth' 145 | test_data = MVTecDataset(root=test_path, transform=data_transform, gt_transform=gt_transform, phase="test") 146 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 147 | 148 | encoder, bn = wide_resnet50_2(pretrained=True) 149 | encoder = encoder.to(device) 150 | bn = bn.to(device) 151 | 152 | encoder.eval() 153 | decoder = de_wide_resnet50_2(pretrained=False) 154 | decoder = decoder.to(device) 155 | ckp = torch.load(ckp_path) 156 | for k, v in list(ckp['bn'].items()): 157 | if 'memory' in k: 158 | ckp['bn'].pop(k) 159 | decoder.load_state_dict(ckp['decoder']) 160 | bn.load_state_dict(ckp['bn']) 161 | 162 | count = 0 163 | with torch.no_grad(): 164 | for img, gt, label, _ in test_dataloader: 165 | if (label.item() == 0): 166 | continue 167 | #if count <= 10: 168 | # count += 1 169 | # continue 170 | 171 | decoder.eval() 172 | bn.eval() 173 | 174 | img = img.to(device) 175 | inputs = encoder(img) 176 | outputs = decoder(bn(inputs)) 177 | 178 | #inputs.append(feature) 179 | #inputs.append(outputs) 180 | #t_sne(inputs) 181 | 182 | 183 | anomaly_map, amap_list = cal_anomaly_map([inputs[-1]], [outputs[-1]], img.shape[-1], amap_mode='a') 184 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 185 | ano_map = min_max_norm(anomaly_map) 186 | ano_map = cvt2heatmap(ano_map*255) 187 | img = cv2.cvtColor(img.permute(0, 2, 3, 1).cpu().numpy()[0] * 255, cv2.COLOR_BGR2RGB) 188 | img = np.uint8(min_max_norm(img)*255) 189 | #if not os.path.exists('./results_all/'+_class_): 190 | # os.makedirs('./results_all/'+_class_) 191 | #cv2.imwrite('./results_all/'+_class_+'/'+str(count)+'_'+'org.png',img) 192 | #plt.imshow(img) 193 | #plt.axis('off') 194 | #plt.savefig('org.png') 195 | #plt.show() 196 | ano_map = show_cam_on_image(img, ano_map) 197 | #cv2.imwrite('./results_all/'+_class_+'/'+str(count)+'_'+'ad.png', ano_map) 198 | plt.imshow(ano_map) 199 | plt.axis('off') 200 | #plt.savefig('ad.png') 201 | plt.show() 202 | 203 | gt = gt.cpu().numpy().astype(int)[0][0]*255 204 | #cv2.imwrite('./results/'+_class_+'_'+str(count)+'_'+'gt.png', gt) 205 | 206 | #b, c, h, w = inputs[2].shape 207 | #t_feat = F.normalize(inputs[2], p=2).view(c, -1).permute(1, 0).cpu().numpy() 208 | #s_feat = F.normalize(outputs[2], p=2).view(c, -1).permute(1, 0).cpu().numpy() 209 | #c = 1-min_max_norm(cv2.resize(anomaly_map,(h,w))).flatten() 210 | #print(c.shape) 211 | #t_sne([t_feat, s_feat], c) 212 | #assert 1 == 2 213 | 214 | #name = 0 215 | #for anomaly_map in amap_list: 216 | # anomaly_map = gaussian_filter(anomaly_map, sigma=4) 217 | # ano_map = min_max_norm(anomaly_map) 218 | # ano_map = cvt2heatmap(ano_map * 255) 219 | #ano_map = show_cam_on_image(img, ano_map) 220 | #cv2.imwrite(str(name) + '.png', ano_map) 221 | #plt.imshow(ano_map) 222 | #plt.axis('off') 223 | #plt.savefig(str(name) + '.png') 224 | #plt.show() 225 | # name+=1 226 | count += 1 227 | #if count>20: 228 | # return 0 229 | #assert 1==2 230 | 231 | 232 | def vis_nd(name, _class_): 233 | print(name,':',_class_) 234 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 235 | print(device) 236 | 237 | ckp_path = './checkpoints/' + name + '_' + str(_class_) + '.pth' 238 | train_dataloader, test_dataloader = load_data(name, _class_, batch_size=16) 239 | 240 | encoder, bn = resnet18(pretrained=True) 241 | encoder = encoder.to(device) 242 | bn = bn.to(device) 243 | encoder.eval() 244 | decoder = de_resnet18(pretrained=False) 245 | decoder = decoder.to(device) 246 | 247 | ckp = torch.load(ckp_path) 248 | 249 | decoder.load_state_dict(ckp['decoder']) 250 | bn.load_state_dict(ckp['bn']) 251 | decoder.eval() 252 | bn.eval() 253 | 254 | gt_list_sp = [] 255 | prmax_list_sp = [] 256 | prmean_list_sp = [] 257 | 258 | count = 0 259 | with torch.no_grad(): 260 | for img, label in test_dataloader: 261 | if img.shape[1] == 1: 262 | img = img.repeat(1, 3, 1, 1) 263 | #if count <= 10: 264 | # count += 1 265 | # continue 266 | img = img.to(device) 267 | inputs = encoder(img) 268 | #print(inputs[-1].shape) 269 | outputs = decoder(bn(inputs)) 270 | 271 | 272 | anomaly_map, amap_list = cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a') 273 | #anomaly_map = gaussian_filter(anomaly_map, sigma=4) 274 | ano_map = min_max_norm(anomaly_map) 275 | ano_map = cvt2heatmap(ano_map*255) 276 | img = cv2.cvtColor(img.permute(0, 2, 3, 1).cpu().numpy()[0] * 255, cv2.COLOR_BGR2RGB) 277 | img = np.uint8(min_max_norm(img)*255) 278 | cv2.imwrite('./nd_results/'+name+'_'+str(_class_)+'_'+str(count)+'_'+'org.png',img) 279 | #plt.imshow(img) 280 | #plt.axis('off') 281 | #plt.savefig('org.png') 282 | #plt.show() 283 | ano_map = show_cam_on_image(img, ano_map) 284 | cv2.imwrite('./nd_results/'+name+'_'+str(_class_)+'_'+str(count)+'_'+'ad.png', ano_map) 285 | #plt.imshow(ano_map) 286 | #plt.axis('off') 287 | #plt.savefig('ad.png') 288 | #plt.show() 289 | 290 | #gt = gt.cpu().numpy().astype(int)[0][0]*255 291 | #cv2.imwrite('./results/'+_class_+'_'+str(count)+'_'+'gt.png', gt) 292 | 293 | #b, c, h, w = inputs[2].shape 294 | #t_feat = F.normalize(inputs[2], p=2).view(c, -1).permute(1, 0).cpu().numpy() 295 | #s_feat = F.normalize(outputs[2], p=2).view(c, -1).permute(1, 0).cpu().numpy() 296 | #c = 1-min_max_norm(cv2.resize(anomaly_map,(h,w))).flatten() 297 | #print(c.shape) 298 | #t_sne([t_feat, s_feat], c) 299 | #assert 1 == 2 300 | 301 | #name = 0 302 | #for anomaly_map in amap_list: 303 | # anomaly_map = gaussian_filter(anomaly_map, sigma=4) 304 | # ano_map = min_max_norm(anomaly_map) 305 | # ano_map = cvt2heatmap(ano_map * 255) 306 | #ano_map = show_cam_on_image(img, ano_map) 307 | #cv2.imwrite(str(name) + '.png', ano_map) 308 | #plt.imshow(ano_map) 309 | #plt.axis('off') 310 | #plt.savefig(str(name) + '.png') 311 | #plt.show() 312 | # name+=1 313 | #count += 1 314 | #if count>40: 315 | # return 0 316 | #assert 1==2 317 | gt_list_sp.extend(label.cpu().data.numpy()) 318 | prmax_list_sp.append(np.max(anomaly_map)) 319 | prmean_list_sp.append(np.sum(anomaly_map)) # np.sum(anomaly_map.ravel().argsort()[-1:][::-1])) 320 | 321 | gt_list_sp = np.array(gt_list_sp) 322 | indx1 = gt_list_sp == _class_ 323 | indx2 = gt_list_sp != _class_ 324 | gt_list_sp[indx1] = 0 325 | gt_list_sp[indx2] = 1 326 | 327 | ano_score = (prmean_list_sp-np.min(prmean_list_sp))/(np.max(prmean_list_sp)-np.min(prmean_list_sp)) 328 | vis_data = {} 329 | vis_data['Anomaly Score'] = ano_score 330 | vis_data['Ground Truth'] = np.array(gt_list_sp) 331 | #print(type(vis_data)) 332 | #np.save('vis.npy',vis_data) 333 | with open('vis.pkl','wb') as f: 334 | pickle.dump(vis_data,f,pickle.HIGHEST_PROTOCOL) 335 | 336 | 337 | def compute_pro(masks: ndarray, amaps: ndarray, num_th: int = 200) -> None: 338 | 339 | """Compute the area under the curve of per-region overlaping (PRO) and 0 to 0.3 FPR 340 | Args: 341 | category (str): Category of product 342 | masks (ndarray): All binary masks in test. masks.shape -> (num_test_data, h, w) 343 | amaps (ndarray): All anomaly maps in test. amaps.shape -> (num_test_data, h, w) 344 | num_th (int, optional): Number of thresholds 345 | """ 346 | 347 | assert isinstance(amaps, ndarray), "type(amaps) must be ndarray" 348 | assert isinstance(masks, ndarray), "type(masks) must be ndarray" 349 | assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)" 350 | assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)" 351 | assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same" 352 | assert set(masks.flatten()) == {0, 1}, "set(masks.flatten()) must be {0, 1}" 353 | assert isinstance(num_th, int), "type(num_th) must be int" 354 | 355 | df = pd.DataFrame([], columns=["pro", "fpr", "threshold"]) 356 | binary_amaps = np.zeros_like(amaps, dtype=np.bool) 357 | 358 | min_th = amaps.min() 359 | max_th = amaps.max() 360 | delta = (max_th - min_th) / num_th 361 | 362 | for th in np.arange(min_th, max_th, delta): 363 | binary_amaps[amaps <= th] = 0 364 | binary_amaps[amaps > th] = 1 365 | 366 | pros = [] 367 | for binary_amap, mask in zip(binary_amaps, masks): 368 | for region in measure.regionprops(measure.label(mask)): 369 | axes0_ids = region.coords[:, 0] 370 | axes1_ids = region.coords[:, 1] 371 | tp_pixels = binary_amap[axes0_ids, axes1_ids].sum() 372 | pros.append(tp_pixels / region.area) 373 | 374 | inverse_masks = 1 - masks 375 | fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() 376 | fpr = fp_pixels / inverse_masks.sum() 377 | 378 | df = df.append({"pro": mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True) 379 | 380 | # Normalize FPR from 0 ~ 1 to 0 ~ 0.3 381 | df = df[df["fpr"] < 0.3] 382 | df["fpr"] = df["fpr"] / df["fpr"].max() 383 | 384 | pro_auc = auc(df["fpr"], df["pro"]) 385 | return pro_auc 386 | 387 | def detection(encoder, bn, decoder, dataloader,device,_class_): 388 | #_, t_bn = resnet50(pretrained=True) 389 | bn.load_state_dict(bn.state_dict()) 390 | bn.eval() 391 | #t_bn.to(device) 392 | #t_bn.load_state_dict(bn.state_dict()) 393 | decoder.eval() 394 | gt_list_sp = [] 395 | prmax_list_sp = [] 396 | prmean_list_sp = [] 397 | with torch.no_grad(): 398 | for img, label in dataloader: 399 | 400 | img = img.to(device) 401 | if img.shape[1] == 1: 402 | img = img.repeat(1, 3, 1, 1) 403 | label = label.to(device) 404 | inputs = encoder(img) 405 | outputs = decoder(bn(inputs)) 406 | anomaly_map, _ = cal_anomaly_map(inputs, outputs, img.shape[-1], 'acc') 407 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 408 | 409 | 410 | gt_list_sp.extend(label.cpu().data.numpy()) 411 | prmax_list_sp.append(np.max(anomaly_map)) 412 | prmean_list_sp.append(np.sum(anomaly_map))#np.sum(anomaly_map.ravel().argsort()[-1:][::-1])) 413 | 414 | gt_list_sp = np.array(gt_list_sp) 415 | indx1 = gt_list_sp == _class_ 416 | indx2 = gt_list_sp != _class_ 417 | gt_list_sp[indx1] = 0 418 | gt_list_sp[indx2] = 1 419 | 420 | 421 | auroc_sp_max = round(roc_auc_score(gt_list_sp, prmax_list_sp), 4) 422 | auroc_sp_mean = round(roc_auc_score(gt_list_sp, prmean_list_sp), 4) 423 | return auroc_sp_max, auroc_sp_mean --------------------------------------------------------------------------------