├── vit_version ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── utils.py │ ├── prompts.py │ ├── simple_tokenizer.py │ ├── clip.py │ └── model.py ├── clip_rar │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── utils.py │ ├── prompts.py │ ├── simple_tokenizer.py │ └── clip.py ├── clip_bn.py ├── clip_decoder.py ├── train_vit.py └── test_vit.py ├── requirements.txt ├── utils ├── utils.py ├── perlin.py └── vis.py ├── README.md ├── datasets ├── MvTec3D.py ├── MvTec.py ├── VisA.py └── database.py ├── models ├── recons_net.py ├── loss.py └── de_resnet.py ├── scripts └── fore_extractor.py ├── train.py └── test.py /vit_version/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /vit_version/clip_rar/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /vit_version/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hui-design/AAND/HEAD/vit_version/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /vit_version/clip_rar/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hui-design/AAND/HEAD/vit_version/clip_rar/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | scipy 4 | scikit-learn 5 | pillow 6 | seaborn 7 | opencv-python 8 | scikit-image 9 | tqdm 10 | imgaug 11 | ftfy 12 | regex -------------------------------------------------------------------------------- /vit_version/clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /vit_version/clip_rar/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def setup_seed(seed): 6 | torch.manual_seed(seed) 7 | torch.cuda.manual_seed_all(seed) 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | 13 | def get_dataset_name(data_path): 14 | if 'mvtec' in data_path.lower() and 'mvtec3d' not in data_path.lower(): 15 | return 'mvtec' 16 | 17 | elif 'visa' in data_path.lower(): 18 | return 'VisA' 19 | 20 | elif 'mvtec3d' in data_path.lower(): 21 | return 'mvtec3d' 22 | 23 | else: 24 | raise ValueError('No such dataset') 25 | 26 | 27 | # Image handling classes. 28 | class PatchMaker: 29 | def __init__(self, patchsize, stride=None): 30 | self.patchsize = patchsize 31 | self.stride = stride 32 | 33 | def patchify(self, features, return_spatial_info=False): 34 | """Convert a tensor into a tensor of respective patches. 35 | Args: 36 | x: [torch.Tensor, bs x c x w x h] 37 | Returns: 38 | x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize, 39 | patchsize] 40 | """ 41 | padding = int((self.patchsize - 1) / 2) 42 | unfolder = torch.nn.Unfold( 43 | kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1 44 | ) 45 | unfolded_features = unfolder(features) 46 | number_of_total_patches = [] 47 | for s in features.shape[-2:]: 48 | n_patches = ( 49 | s + 2 * padding - 1 * (self.patchsize - 1) - 1 50 | ) / self.stride + 1 51 | number_of_total_patches.append(int(n_patches)) 52 | unfolded_features = unfolded_features.reshape( 53 | *features.shape[:2], self.patchsize, self.patchsize, -1 54 | ) 55 | unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) 56 | 57 | if return_spatial_info: 58 | return unfolded_features, number_of_total_patches 59 | return unfolded_features 60 | 61 | def unpatch_scores(self, x, batchsize): 62 | return x.reshape(batchsize, -1, *x.shape[1:]) 63 | 64 | def score(self, x): 65 | was_numpy = False 66 | if isinstance(x, np.ndarray): 67 | was_numpy = True 68 | x = torch.from_numpy(x) 69 | while x.ndim > 1: 70 | x = torch.max(x, dim=-1).values 71 | if was_numpy: 72 | return x.numpy() 73 | return x -------------------------------------------------------------------------------- /vit_version/clip/prompts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import clip 4 | import pdb 5 | 6 | # class Prompts(nn.Module): 7 | # def __init__(self, initials=None, clip_model=None, len_text=None, pretrained=None): 8 | # super(Prompts,self).__init__() 9 | # print("The initial prompts are:",initials) 10 | # # tokenized, embedding 11 | # self.tokenized_prompts = clip.tokenize(initials) 12 | # embedding = clip_model.token_embedding(self.tokenized_prompts.cuda()) 13 | # self.prefix, self.suffix = embedding[:, :1, :], embedding[:, 1+len_text:, :] 14 | # # learnable embedding 15 | # ctx_vectors = torch.empty(2, len_text, 768, dtype=embedding.dtype) 16 | # nn.init.normal_(ctx_vectors, std=0.02) 17 | # self.embedding_learn = nn.Parameter(ctx_vectors) 18 | 19 | # # state_dict = torch.load(pretrained) 20 | # # self.embedding_prompt = state_dict['embedding_prompt'].cuda() 21 | 22 | # def forward(self, clip_model): 23 | # # pdb.set_trace() 24 | # embedding_prompt = torch.cat((self.prefix, self.embedding_learn, self.suffix), dim=1) 25 | # text_features = clip_model.encode_text(embedding_prompt, self.tokenized_prompts) 26 | # text_features /= text_features.norm(dim=-1, keepdim=True) 27 | # return text_features 28 | 29 | 30 | class Prompts(nn.Module): 31 | def __init__(self, initials=None, clip_model=None, len_text=None, pretrained=None, device='cuda'): 32 | super(Prompts,self).__init__() 33 | # initials[0] += ' normal person' 34 | # initials[1] += ' abnormal person' 35 | print("The initial prompts are:",initials) 36 | # tokenized, embedding 37 | # pdb.set_trace() 38 | self.tokenized_prompts = clip.tokenize(initials).to(device) 39 | self.model = clip_model 40 | embedding = self.model.token_embedding(self.tokenized_prompts).to(device) 41 | self.prefix, self.suffix = embedding[:, :1, :], embedding[:, 1+len_text:, :] 42 | self.device = device 43 | # self.prefix.requires_grad = False 44 | # self.suffix.requires_grad = False 45 | if pretrained is None: 46 | ctx_vectors = torch.zeros_like(embedding[:, 1:1+len_text, :]) 47 | nn.init.normal_(ctx_vectors, std=0.02) 48 | self.embedding_prompt = nn.Parameter(ctx_vectors.requires_grad_()).to(device) 49 | # self.embedding_prompt = nn.Parameter(embedding[:, 1:1+len_text, :].requires_grad_()).to(device) 50 | # self.embedding_prompt = nn.ParameterList([nn.Parameter(torch.zeros(2,768).requires_grad_())]).to(device) 51 | # self.embedding_prompt = nn.Parameter(torch.randn(2,512).requires_grad_()) 52 | else: 53 | state_dict = torch.load(pretrained)['embedding_prompt'] 54 | self.embedding_prompt = state_dict.to(self.device) 55 | 56 | # ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 57 | # nn.init.normal_(ctx_vectors, std=0.02) 58 | # text_features = self.embedding_prompt 59 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True) 60 | 61 | def forward(self): 62 | pdb.set_trace() 63 | self.embedding_prompt_cat = torch.cat((self.prefix, self.embedding_prompt, self.suffix), dim=1) 64 | text_features = self.model.encode_text(self.embedding_prompt_cat, self.tokenized_prompts) 65 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 66 | return text_features 67 | 68 | 69 | 70 | class Linear(nn.Module): 71 | def __init__(self, in_dim=512, out_dim=512): 72 | super(Linear,self).__init__() 73 | self.mlp = nn.Linear(in_dim, out_dim) 74 | # self.mlp = nn.Sequential(nn.Linear(in_dim, out_dim), 75 | # nn.ReLU(), 76 | # nn.Linear(in_dim, out_dim)) 77 | 78 | 79 | def forward(self, x): 80 | return self.mlp(x) -------------------------------------------------------------------------------- /vit_version/clip_rar/prompts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import clip 4 | import pdb 5 | 6 | # class Prompts(nn.Module): 7 | # def __init__(self, initials=None, clip_model=None, len_text=None, pretrained=None): 8 | # super(Prompts,self).__init__() 9 | # print("The initial prompts are:",initials) 10 | # # tokenized, embedding 11 | # self.tokenized_prompts = clip.tokenize(initials) 12 | # embedding = clip_model.token_embedding(self.tokenized_prompts.cuda()) 13 | # self.prefix, self.suffix = embedding[:, :1, :], embedding[:, 1+len_text:, :] 14 | # # learnable embedding 15 | # ctx_vectors = torch.empty(2, len_text, 768, dtype=embedding.dtype) 16 | # nn.init.normal_(ctx_vectors, std=0.02) 17 | # self.embedding_learn = nn.Parameter(ctx_vectors) 18 | 19 | # # state_dict = torch.load(pretrained) 20 | # # self.embedding_prompt = state_dict['embedding_prompt'].cuda() 21 | 22 | # def forward(self, clip_model): 23 | # # pdb.set_trace() 24 | # embedding_prompt = torch.cat((self.prefix, self.embedding_learn, self.suffix), dim=1) 25 | # text_features = clip_model.encode_text(embedding_prompt, self.tokenized_prompts) 26 | # text_features /= text_features.norm(dim=-1, keepdim=True) 27 | # return text_features 28 | 29 | 30 | class Prompts(nn.Module): 31 | def __init__(self, initials=None, clip_model=None, len_text=None, pretrained=None, device='cuda'): 32 | super(Prompts,self).__init__() 33 | # initials[0] += ' normal person' 34 | # initials[1] += ' abnormal person' 35 | print("The initial prompts are:",initials) 36 | # tokenized, embedding 37 | # pdb.set_trace() 38 | self.tokenized_prompts = clip.tokenize(initials).to(device) 39 | self.model = clip_model 40 | embedding = self.model.token_embedding(self.tokenized_prompts).to(device) 41 | self.prefix, self.suffix = embedding[:, :1, :], embedding[:, 1+len_text:, :] 42 | self.device = device 43 | # self.prefix.requires_grad = False 44 | # self.suffix.requires_grad = False 45 | if pretrained is None: 46 | ctx_vectors = torch.zeros_like(embedding[:, 1:1+len_text, :]) 47 | nn.init.normal_(ctx_vectors, std=0.02) 48 | self.embedding_prompt = nn.Parameter(ctx_vectors.requires_grad_()).to(device) 49 | # self.embedding_prompt = nn.Parameter(embedding[:, 1:1+len_text, :].requires_grad_()).to(device) 50 | # self.embedding_prompt = nn.ParameterList([nn.Parameter(torch.zeros(2,768).requires_grad_())]).to(device) 51 | # self.embedding_prompt = nn.Parameter(torch.randn(2,512).requires_grad_()) 52 | else: 53 | state_dict = torch.load(pretrained)['embedding_prompt'] 54 | self.embedding_prompt = state_dict.to(self.device) 55 | 56 | # ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 57 | # nn.init.normal_(ctx_vectors, std=0.02) 58 | # text_features = self.embedding_prompt 59 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True) 60 | 61 | def forward(self): 62 | pdb.set_trace() 63 | self.embedding_prompt_cat = torch.cat((self.prefix, self.embedding_prompt, self.suffix), dim=1) 64 | text_features = self.model.encode_text(self.embedding_prompt_cat, self.tokenized_prompts) 65 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 66 | return text_features 67 | 68 | 69 | 70 | class Linear(nn.Module): 71 | def __init__(self, in_dim=512, out_dim=512): 72 | super(Linear,self).__init__() 73 | self.mlp = nn.Linear(in_dim, out_dim) 74 | # self.mlp = nn.Sequential(nn.Linear(in_dim, out_dim), 75 | # nn.ReLU(), 76 | # nn.Linear(in_dim, out_dim)) 77 | 78 | 79 | def forward(self, x): 80 | return self.mlp(x) -------------------------------------------------------------------------------- /vit_version/clip_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | import os, pdb, sys 9 | sys.path.append(os.getcwd()) 10 | from typing import Type, Any, Callable, Union, List, Optional 11 | from models.recons_net import RAR_single, MLP 12 | from models.resnet_rar import BasicBlock, Bottleneck, AttnBottleneck, conv1x1, conv3x3 13 | import pdb 14 | 15 | 16 | class BN_layer(nn.Module): 17 | def __init__(self, 18 | block: Type[Union[BasicBlock, Bottleneck]], 19 | layers: int, 20 | groups: int = 1, 21 | width_per_group: int = 64, 22 | norm_layer: Optional[Callable[..., nn.Module]] = None, 23 | ): 24 | super(BN_layer, self).__init__() 25 | if norm_layer is None: 26 | norm_layer = nn.BatchNorm2d 27 | self._norm_layer = norm_layer 28 | self.groups = groups 29 | self.base_width = width_per_group 30 | self.inplanes = 768 31 | self.dilation = 1 32 | # self.bn_layer = self._make_layer(block, 768, layers, stride=1) 33 | 34 | self.conv1 = conv1x1(768, 768, 1) 35 | self.bn1 = norm_layer(768) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv1x1(768, 768, 1) 38 | self.bn2 = norm_layer(768) 39 | self.conv3 = conv1x1(768, 768, 1) 40 | self.bn3 = norm_layer(768) 41 | # self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) 42 | # self.bn3 = norm_layer(256 * block.expansion) 43 | 44 | self.conv4 = conv1x1(768*3, 768, 1) 45 | self.bn4 = norm_layer(768) 46 | 47 | 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 51 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 52 | nn.init.constant_(m.weight, 1) 53 | nn.init.constant_(m.bias, 0) 54 | 55 | # def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 56 | # stride: int = 1, dilate: bool = False) -> nn.Sequential: 57 | # norm_layer = self._norm_layer 58 | # downsample = None 59 | # previous_dilation = self.dilation 60 | # if dilate: 61 | # self.dilation *= stride 62 | # stride = 1 63 | # if stride != 1 or self.inplanes != planes * block.expansion: 64 | # downsample = nn.Sequential( 65 | # conv1x1(self.inplanes, planes, stride), 66 | # norm_layer(planes * block.expansion), 67 | # ) 68 | 69 | # layers = [] 70 | # # layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 71 | # # self.base_width, previous_dilation, norm_layer)) 72 | # # self.inplanes = planes * block.expansion 73 | # self.inplances = planes 74 | # for _ in range(blocks): 75 | # layers.append(block(self.inplanes, planes, groups=self.groups, 76 | # base_width=self.base_width, dilation=self.dilation, 77 | # norm_layer=norm_layer)) 78 | 79 | # return nn.Sequential(*layers) 80 | 81 | def _forward_impl(self, x: Tensor) -> Tensor: 82 | # See note [TorchScript super()] 83 | #x = self.cbam(x) 84 | l1 = self.relu(self.bn1(self.conv1(x[0]))) 85 | l2 = self.relu(self.bn2(self.conv2(x[1]))) 86 | l3 = self.relu(self.bn3(self.conv3(x[2]))) 87 | feature = torch.cat([l1,l2,l3],1) 88 | output = self.conv4(feature) 89 | # output = self.bn_layer(feature) 90 | #x = self.avgpool(feature_d) 91 | #x = torch.flatten(x, 1) 92 | #x = self.fc(x) 93 | 94 | return output.contiguous() 95 | 96 | def forward(self, x: Tensor) -> Tensor: 97 | return self._forward_impl(x) 98 | 99 | if __name__ == '__main__': 100 | x = [torch.randn((1,768,16,16))]*3 101 | bn = BN_layer(AttnBottleneck, 3) 102 | pdb.set_trace() 103 | x_new = bn(x) -------------------------------------------------------------------------------- /utils/perlin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | def lerp_np(x,y,w): 6 | fin_out = (y-x)*w + x 7 | return fin_out 8 | 9 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5): 10 | noise = np.zeros(shape) 11 | frequency = 1 12 | amplitude = 1 13 | for _ in range(octaves): 14 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1])) 15 | frequency *= 2 16 | amplitude *= persistence 17 | return noise 18 | 19 | 20 | def generate_perlin_noise_2d(shape, res): 21 | def f(t): 22 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3 23 | 24 | delta = (res[0] / shape[0], res[1] / shape[1]) 25 | d = (shape[0] // res[0], shape[1] // res[1]) 26 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 27 | # Gradients 28 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) 29 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 30 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 31 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 32 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) 33 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) 34 | # Ramps 35 | n00 = np.sum(grid * g00, 2) 36 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) 37 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) 38 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) 39 | # Interpolation 40 | t = f(grid) 41 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 42 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 43 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) 44 | 45 | 46 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 47 | delta = (res[0] / shape[0], res[1] / shape[1]) 48 | d = (shape[0] // res[0], shape[1] // res[1]) 49 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 50 | 51 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) 52 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) 53 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1) 54 | 55 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1) 56 | dot = lambda grad, shift: ( 57 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 58 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) 59 | 60 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 61 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 62 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 63 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 64 | t = fade(grid[:shape[0], :shape[1]]) 65 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) 66 | 67 | 68 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 69 | delta = (res[0] / shape[0], res[1] / shape[1]) 70 | d = (shape[0] // res[0], shape[1] // res[1]) 71 | 72 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 73 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) 74 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) 75 | 76 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 77 | 0).repeat_interleave( 78 | d[1], 1) 79 | dot = lambda grad, shift: ( 80 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 81 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) 82 | 83 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 84 | 85 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 86 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 87 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 88 | t = fade(grid[:shape[0], :shape[1]]) 89 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) 90 | 91 | 92 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): 93 | noise = torch.zeros(shape) 94 | frequency = 1 95 | amplitude = 1 96 | for _ in range(octaves): 97 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) 98 | frequency *= 2 99 | amplitude *= persistence 100 | return noise -------------------------------------------------------------------------------- /vit_version/clip_decoder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | import os, pdb, sys 4 | sys.path.append(os.getcwd()) 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | import math 10 | import pdb 11 | from models.recons_net import patch_to_tensor, tensor_to_patch 12 | from clip.model import LayerNorm, QuickGELU 13 | 14 | 15 | class QuickGELU(nn.Module): 16 | def forward(self, x: torch.Tensor): 17 | return x * torch.sigmoid(1.702 * x) 18 | 19 | 20 | class ResidualAttentionBlock(nn.Module): 21 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 22 | super().__init__() 23 | 24 | self.attn = nn.MultiheadAttention(d_model, n_head) 25 | self.ln_1 = LayerNorm(d_model) 26 | self.mlp = nn.Sequential(OrderedDict([ 27 | ("c_fc", nn.Linear(d_model, d_model * 4)), 28 | ("gelu", QuickGELU()), 29 | ("c_proj", nn.Linear(d_model * 4, d_model)) 30 | ])) 31 | self.ln_2 = LayerNorm(d_model) 32 | self.attn_mask = attn_mask 33 | 34 | def attention(self, x: torch.Tensor): 35 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 36 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 37 | 38 | def forward(self, x: torch.Tensor): 39 | # pdb.set_trace() 40 | x = x + self.attention(self.ln_1(x)) 41 | x = x + self.mlp(self.ln_2(x)) 42 | return x 43 | 44 | 45 | class Transformer(nn.Module): 46 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 47 | super().__init__() 48 | self.width = width 49 | self.layers = layers 50 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 51 | 52 | # def forward(self, x: torch.Tensor): 53 | # return self.resblocks(x) 54 | 55 | def forward(self, x, f_list=[]): 56 | # pdb.set_trace() 57 | out_tokens = [] 58 | idx = 0 59 | for r in self.resblocks: 60 | # pdb.set_trace() 61 | idx+=1 62 | x = r(x) 63 | if idx in f_list: 64 | if len(x)==2: 65 | out_tokens.append(x[0]) 66 | out_tokens.append(x[1]) 67 | else: 68 | out_tokens.append(x.permute(1,0,2)) 69 | return x, out_tokens #self.resblocks(x) 70 | 71 | 72 | class de_VisionTransformer(nn.Module): 73 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 74 | super().__init__() 75 | self.input_resolution = input_resolution 76 | self.grid_size = (input_resolution // patch_size) # modify 77 | self.output_dim = output_dim 78 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 79 | 80 | scale = width ** -0.5 81 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 82 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 83 | self.ln_pre = LayerNorm(width) 84 | 85 | self.transformer = Transformer(width, layers, heads) 86 | 87 | self.ln_post = LayerNorm(width) 88 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 89 | 90 | def forward(self, x: torch.Tensor, f_list: list): 91 | # x = self.conv1(x) # shape = [*, width, grid, grid] 92 | # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 93 | # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 94 | # x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 95 | x = tensor_to_patch(x) + self.positional_embedding.to(x.dtype)[None,1:] 96 | # pdb.set_trace() 97 | x = self.ln_pre(x) 98 | 99 | x = x.permute(1, 0, 2) # NLD -> LND 100 | x, patch_tokens = self.transformer(x, f_list) 101 | x = x.permute(1, 0, 2) # LND -> NLD 102 | 103 | # x = self.ln_post(x[:, 0, :]) 104 | # # pdb.set_trace() 105 | # if self.proj is not None: 106 | # x = x @ self.proj 107 | 108 | patch_tokens = [patch_to_tensor(x) for x in patch_tokens] 109 | 110 | return x, patch_tokens 111 | 112 | 113 | if __name__ == '__main__': 114 | x = torch.randn((1,768,16,16)) 115 | decoder_SS = de_VisionTransformer(input_resolution=256, patch_size=16, width=768, layers=6, heads=12, output_dim=512) 116 | # pdb.set_trace() 117 | _, x_new = decoder_SS(x, f_list=[2,4,6]) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | Advancing Pre-trained Teacher: Towards Robust Feature Discrepancy for Anomaly Detection 4 |

5 | 6 | ## Overview 7 | 8 | 9 | 12 | 13 | 17 | There are two underlying assumptions in KD-based anomaly detection framework. **Assumption I**: The teacher model can represent two separable distributions for the normal and abnormal patterns; **Assumption II**: the student model can only reconstruct the normal distribution. In this paper, we propose a simple yet effective two-stage anomaly detection framework, termed AAND, which comprises an Anomaly Amplification Stage **Stage I** to address Assumption I and a Normality Distillation Stage **Stage II** to address Assumption II. 18 | 19 | ## Author 20 | [Canhui Tang](https://scholar.google.com/citations?hl=zh-CN&user=TKqkrnUAAAAJ), [Sanping Zhou](https://scholar.google.cz/citations?hl=zh-CN&user=2Drvv44AAAAJ), Yizhe Li, Yonghao Dong, [Le Wang](https://scholar.google.cz/citations?hl=zh-CN&user=RypRCUQAAAAJ) 21 | 22 | Xi'an Jiaotong University 23 | 24 | ## News 25 | 🔥 2025.09: Awaiting SAE Decision approval 26 | 27 | 🔥 2025.05: Accept with Mandatory Minor Revisions 28 | 29 | 🔥 2024.06: Our another KD-based Project [VAND-GNL](https://github.com/Hui-design/VAND-GNL) won the 2nd Place of CVPR 2024 [VAND2.0 Challenge](https://www.hackster.io/contests/openvino2024#challengeNav) 30 | 31 | ## 🔧 Installation 32 | 33 | Please use the following command for installation. 34 | 35 | ```bash 36 | # It is recommended to create a new environment 37 | conda create -n AAND python==3.8 38 | conda activate AAND 39 | 40 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 41 | 42 | # Install packages and other dependencies 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | 47 | ## 💾 Dataset 48 | 49 | - [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad) 50 | - [VisA](https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar), use [visa.py](https://github.com/ByChelsea/VAND-APRIL-GAN/blob/master/data/visa.py) to generate meta.json 51 | - [MVTec-3D](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad), we only use the rgb images, called **MVTec3D-RGB** in our paper. 52 | - [DRAEM_dtd](https://www.robots.ox.ac.uk/~vgg/data/dtd/), used as the auxillary texture datasets for synthesizing anomalies like [DRAEM](https://github.com/VitjanZ/DRAEM). 53 | ``` 54 | 55 | ├── mvtec 56 | ├── bottle 57 | ├── train 58 | ├── test 59 | ├── ground_truth 60 | ├── ... 61 | 62 | ├── VisA 63 | ├── meta.json 64 | ├── candle 65 | ├── Data 66 | ├── image_anno.csv 67 | ├── ... 68 | 69 | ├── mvtec3d 70 | ├── bagel 71 | ├── train 72 | ├── good 73 | ├── rgb (we only use rgb) 74 | ├── xyz 75 | ├── test 76 | ├── ... 77 | 78 | ├── DRAEM_dtd 79 | ├── dtd 80 | ├── images 81 | ├── ... 82 | ``` 83 | 84 | ## Preprocessing 85 | - Extract foreground mask for **training** images. 86 | 87 | ```bash 88 | python scripts/fore_extractor.py --data_path // --aux_path /dtd/images/ # the is mvtec, VisA, or mvtec3d 89 | ``` 90 | 91 | 102 | 103 | ## 🚅 Training 104 | You can train models on mvtec, VisA, or mvtec3d by the following commands: 105 | ```bash 106 | python train.py --data_root // # the is mvtec, VisA, or mvtec3d 107 | ``` 108 | 109 | 110 | ## ⛳ Testing 111 | You can test the trained models on mvtec, VisA, or mvtec3d by the following commands: 112 | ```bash 113 | python test.py --data_root // # the is mvtec, VisA, or mvtec3d 114 | ``` 115 | 116 | ## Citation 117 | 118 | ```bibtex 119 | @article{tang2024advancing, 120 | title={Advancing Pre-trained Teacher: Towards Robust Feature Discrepancy for Anomaly Detection}, 121 | author={Tang, Canhui and Zhou, Sanping and Li, Yizhe and Dong, Yonghao and Wang, Le}, 122 | journal={arXiv preprint arXiv:2405.02068}, 123 | year={2024} 124 | } 125 | ``` 126 | 127 | 128 | ## Acknowledgements 129 | - [RD4AD](https://github.com/hq-deng/RD4AD) 130 | - [DRAEM](https://github.com/VitjanZ/DRAEM) 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /vit_version/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /vit_version/clip_rar/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from torchvision import transforms 5 | from PIL import ImageFilter 6 | from sklearn.manifold import TSNE 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | import cv2, pdb, os 10 | from PIL import Image 11 | from torch.nn import functional as F 12 | import pandas as pd 13 | 14 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 15 | IMAGENET_STD = [0.229, 0.224, 0.225] 16 | 17 | resize_lib = {0: torch.nn.AdaptiveAvgPool2d((64, 64)), 18 | 1: torch.nn.AdaptiveAvgPool2d((32, 32)) , 19 | 2: torch.nn.AdaptiveAvgPool2d((16, 16))} # 注意:可以尝试不同的下采样方式 20 | 21 | def vis_anomaly_images(imgs, mask_ori, obj): 22 | image_mean = np.array(IMAGENET_MEAN).reshape(1,1,3) 23 | image_std = np.array(IMAGENET_STD).reshape(1,1,3) 24 | B, C, H, W = imgs.shape 25 | os.makedirs(f'imgs/{obj}', exist_ok=True) 26 | masks_list = [resize_lib[i](mask_ori) for i in range(3)] 27 | for i in range(B): 28 | img = imgs[i].permute(1,2,0).cpu().numpy() # [H,W,3] 29 | img = ((img * image_std + image_mean)*255).astype(np.uint8) 30 | img = Image.fromarray(img) 31 | img.save(f'imgs/{obj}/img{i}.jpg') 32 | 33 | # pdb.set_trace() 34 | for level in range(3): 35 | mask = (masks_list[level][i][0]>0.3).float().cpu().numpy() # [H,W] 36 | mean = mask.mean() 37 | mask = (mask * 255).astype(np.uint8) 38 | mask = Image.fromarray(mask) 39 | mask.save(f'imgs/{obj}/mask{i}_{level}.jpg') 40 | 41 | mask_ = (mask_ori[i].permute(1,2,0).cpu().numpy().squeeze() * 255).astype(np.uint8) 42 | mask_ = Image.fromarray(mask_) 43 | mask_.save(f'imgs/{obj}/mask{i}.jpg') 44 | 45 | def vis_gt(gt): 46 | gt = Image.fromarray(gt.astype(np.uint8) * 255, mode='L') 47 | gt.save('gt.png') 48 | 49 | 50 | def vis_hotmap(fs, ss, a_map): 51 | B, C, H, W = fs.shape 52 | os.makedirs('imgs/a_map', exist_ok=True) 53 | # pdb.set_trace() 54 | fig = plt.figure() 55 | a_map = a_map.squeeze().cpu().detach() 56 | sns.heatmap(data=a_map) 57 | plt.savefig(f'imgs/a_map/a_map.png') 58 | 59 | tsne = TSNE(n_components=1) 60 | fs = fs.squeeze().permute(1,2,0).reshape(H*W, C).detach().cpu() 61 | ss = ss.squeeze().permute(1,2,0).reshape(H*W, C).detach().cpu() 62 | 63 | # L2-norm 64 | fs = F.normalize(fs, p=1) 65 | ss = F.normalize(ss, p=1) 66 | cat_tsne = tsne.fit_transform(np.vstack((fs, ss))) 67 | fs_tsne = cat_tsne[:H*W].reshape(H,W) 68 | ss_tsne = cat_tsne[H*W:].reshape(H,W) 69 | 70 | fig = plt.figure() 71 | sns.heatmap(data=fs_tsne) 72 | plt.savefig(f'imgs/a_map/fs.png') 73 | 74 | fig = plt.figure() 75 | sns.heatmap(data=ss_tsne) 76 | plt.savefig(f'imgs/a_map/ss.png') 77 | pdb.set_trace() 78 | plt.close('all') 79 | 80 | def vis_hotmap_single(a_map): 81 | fig = plt.figure() 82 | sns.heatmap(data=a_map, xticklabels=[], yticklabels=[], cbar=False) 83 | plt.savefig(f'atten.png') 84 | # pdb.set_trace() 85 | plt.close('all') 86 | 87 | def tsne_vis_single(memory, name): 88 | tsne = TSNE(n_components=2) 89 | data = tsne.fit_transform(memory) 90 | plt.figure() 91 | plt.scatter(data[:,0], data[:,1]) 92 | # plt.legend((s1,s2),('memory','anomaly') ,loc = 'best') 93 | plt.savefig(f'{name}.png') 94 | plt.close('all') 95 | 96 | def normalize(pred, max_value=None, min_value=None): 97 | if max_value is None or min_value is None: 98 | return (pred - pred.min()) / (pred.max() - pred.min()) 99 | else: 100 | return (pred - min_value) / (max_value - min_value) 101 | 102 | def apply_ad_scoremap(rgb_path, scoremap, cls, alpha=0.5): 103 | img_size = scoremap.shape[0] 104 | # pdb.set_trace() 105 | image = cv2.cvtColor(cv2.resize(cv2.imread(rgb_path), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB 106 | mask = normalize(scoremap) 107 | # vis = apply_ad_scoremap(vis, mask) 108 | np_image = np.asarray(image, dtype=float) 109 | scoremap[scoremap>1] =1 110 | scoremap = (scoremap * 255).astype(np.uint8) 111 | scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET) 112 | scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB) 113 | scoremap = (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8) 114 | 115 | vis = cv2.cvtColor(scoremap, cv2.COLOR_RGB2BGR) # BGR 116 | save_vis = 'imgs' 117 | if not os.path.exists(save_vis): 118 | os.makedirs(save_vis) 119 | cv2.imwrite(f'{save_vis}/{cls}.png', vis) 120 | 121 | 122 | def plot_hist(normal, anomaly): 123 | width = 7 124 | fig = plt.figure(figsize=(width, width*0.7)) 125 | print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}') 126 | # plt.hist(normal,bins=50,label='normal sample',alpha=0.5) 127 | # plt.hist(anomaly,bins=50,label='anomalous sample',alpha=0.5) 128 | # normal = pd.Series(normal) # 将数据由数组转换成series形式 129 | # normal.plot(kind = 'kde',label = 'normal') 130 | # anomaly = pd.Series(anomaly) # 将数据由数组转换成series形式 131 | # anomaly.plot(kind = 'kde',label = 'anomaly') 132 | import seaborn as sn 133 | sn.kdeplot(normal, color="blue", shade="True", label='normal samples', bw_adjust=0.3) 134 | sn.kdeplot(anomaly, color="red", shade="True", label='anomalous samples', bw_adjust=0.3) 135 | plt.legend(fontsize=14) 136 | plt.tick_params(labelsize=13) 137 | plt.ylabel('Density',fontsize=12.5) 138 | plt.xlabel('Anomaly score',fontsize=12.5) 139 | plt.savefig(f'hist_SS.png') 140 | 141 | def plot_line(data, name): 142 | plt.figure() 143 | plt.plot(data) # 绘制 sin(x) 曲线 144 | print(f'save to visual/{name}.png') 145 | plt.savefig(f'visual/{name}.png', bbox_inches='tight', dpi=1000) -------------------------------------------------------------------------------- /datasets/MvTec3D.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | import os 4 | import torch 5 | import glob 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | import pdb 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms as T 11 | import imgaug.augmenters as iaa 12 | # import albumentations as A 13 | from utils.perlin import rand_perlin_2d_np 14 | import random 15 | import cv2 16 | from datasets.database import BaseAnomalyDetectionDataset, SynthesisDataset, resize_organized_pc 17 | 18 | 19 | def mvtec3d_classes(): 20 | return [ 21 | "bagel", 22 | "cable_gland", 23 | "carrot", 24 | "cookie", 25 | "dowel", 26 | "foam", 27 | "peach", 28 | "potato", 29 | "rope", 30 | "tire", 31 | ] 32 | 33 | # return [ 34 | # "bagel", 35 | # "cable_gland", 36 | # "carrot", 37 | # "cookie", 38 | # "dowel", 39 | # "foam", 40 | # "peach", 41 | # "potato", 42 | # "rope", 43 | # "tire", 44 | # ] 45 | 46 | 47 | 48 | class TrainDataset(BaseAnomalyDetectionDataset): 49 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): 50 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path) 51 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 52 | 53 | def load_dataset(self): 54 | img_tot_paths = [] 55 | tot_labels = [] 56 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png") 57 | rgb_paths.sort() 58 | img_tot_paths.extend(rgb_paths) 59 | tot_labels.extend([0] * len(rgb_paths)) 60 | return img_tot_paths, tot_labels 61 | 62 | def __len__(self): 63 | return len(self.img_paths) 64 | 65 | def __getitem__(self, idx): 66 | rgb_path, label = self.img_paths[idx], self.labels[idx] 67 | img = Image.open(rgb_path).convert('RGB') 68 | img = self.rgb_transform(img) 69 | 70 | return img, label 71 | 72 | 73 | class TestDataset(BaseAnomalyDetectionDataset): 74 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): 75 | super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path) 76 | self.size = img_size 77 | self.gt_transform = transforms.Compose([ 78 | transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST), 79 | transforms.ToTensor()]) 80 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 81 | 82 | def load_dataset(self): 83 | img_tot_paths = [] 84 | gt_tot_paths = [] 85 | tot_labels = [] 86 | defect_types = os.listdir(self.img_path) 87 | 88 | for defect_type in defect_types: 89 | if defect_type == 'good': 90 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png") 91 | rgb_paths.sort() 92 | img_tot_paths.extend(rgb_paths) 93 | gt_tot_paths.extend([0] * len(rgb_paths)) 94 | tot_labels.extend([0] * len(rgb_paths)) 95 | else: 96 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png") 97 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png") 98 | rgb_paths.sort() 99 | gt_paths.sort() 100 | 101 | img_tot_paths.extend(rgb_paths) 102 | gt_tot_paths.extend(gt_paths) 103 | tot_labels.extend([1] * len(rgb_paths)) 104 | 105 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 106 | 107 | return img_tot_paths, gt_tot_paths, tot_labels 108 | 109 | def __len__(self): 110 | return len(self.img_paths) 111 | 112 | def __getitem__(self, idx): 113 | rgb_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 114 | img_original = Image.open(rgb_path).convert('RGB') 115 | img = self.rgb_transform(img_original) 116 | 117 | if gt == 0: 118 | gt = torch.zeros( 119 | [1, self.size, self.size]) 120 | else: 121 | gt = Image.open(gt).convert('L') 122 | gt = self.gt_transform(gt) 123 | gt = torch.where(gt > 0.5, 1., .0) 124 | 125 | return img, gt[:1], label, rgb_path 126 | 127 | 128 | class AnomalyDataset(SynthesisDataset): 129 | def __init__(self, class_name, dataset_path, img_size, aux_path): 130 | super().__init__(class_name=class_name, img_size=img_size, dataset_path=dataset_path, aux_path=aux_path) 131 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 132 | 133 | def load_dataset(self): 134 | img_tot_paths = [] 135 | tot_labels = [] 136 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png") 137 | tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff") 138 | rgb_paths.sort() 139 | tiff_paths.sort() 140 | sample_paths = list(zip(rgb_paths, tiff_paths)) 141 | img_tot_paths.extend(sample_paths) 142 | tot_labels.extend([0] * len(sample_paths)) 143 | return img_tot_paths, tot_labels 144 | 145 | 146 | def __len__(self): 147 | return len(self.labels) 148 | 149 | def __getitem__(self, idx): 150 | img_path, label = self.img_paths[idx], self.labels[idx] 151 | rgb_path = img_path[0] 152 | tiff_path = img_path[1] 153 | class_name = tiff_path.split("/")[2] 154 | 155 | img = Image.open(rgb_path).convert('RGB') 156 | img = self.rgb_transform(img) 157 | 158 | img_file = rgb_path.split('/')[-1] 159 | fg_path = os.path.join(f'fg_mask/{self.dataset_name}/{self.cls}/train', img_file) 160 | fg_mask = Image.open(fg_path) 161 | fg_mask = np.asarray(fg_mask)[:, :, np.newaxis] # [H, W, 1] 162 | resized_depth_map = resize_organized_pc(fg_mask, img_size=self.size) 163 | fore_mask = resized_depth_map > 0 164 | 165 | # modify 166 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item() 167 | augmented_image, anomaly_mask, has_anomaly = self.transform_image(img, fore_mask, 168 | self.anomaly_source_paths[anomaly_source_idx]) 169 | # # 下采样mask 170 | # anomaly_mask = torch.from_numpy(anomaly_mask).unsqueeze(0) 171 | 172 | # denoising 173 | # return {"ori_img": img, "aug_img": augmented_image, "label": has_anomaly, "anomaly_mask": anomaly_mask} 174 | 175 | return {"img": augmented_image, "label": has_anomaly, "anomaly_mask": anomaly_mask, "fore_mask": fore_mask.float()} 176 | 177 | -------------------------------------------------------------------------------- /models/recons_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | # from loss import resize_lib, mask_thresh 6 | from utils.vis import * 7 | import pdb 8 | 9 | def pair_cosine(a, b): 10 | cos_sim = torch.einsum('bnc,bmc->bnm', a, b) # [B, N, M] 11 | B, N, M = cos_sim.shape 12 | a_norm = torch.sqrt((a**2).sum(dim=-1)) # [B, N] 13 | a_norm = a_norm.unsqueeze(-1).expand(-1,-1,M) # [B, N, M] 14 | b_norm = torch.sqrt((b**2).sum(dim=-1)) # [B, M] 15 | b_norm = b_norm.unsqueeze(1).expand(-1,N,-1) # [B, N, M] 16 | cos_sim = cos_sim / (a_norm*b_norm) 17 | return cos_sim 18 | 19 | resize_lib = {0: torch.nn.AdaptiveAvgPool2d((64, 64)), 20 | 1: torch.nn.AdaptiveAvgPool2d((32, 32)) , 21 | 2: torch.nn.AdaptiveAvgPool2d((16, 16))} # 注意:可以尝试不同的下采样方式 22 | mask_thresh = 0.3 23 | 24 | def patch_to_tensor(x): 25 | B, N, C = x.shape 26 | H, W = int(math.sqrt(N)), int(math.sqrt(N)) 27 | x = x.reshape(B, H, W, C).permute(0,3,1,2) # [B, C, H, W] 28 | return x 29 | 30 | def tensor_to_patch(x): 31 | B, C, H, W = x.shape 32 | x = x.reshape(B, C, H*W).permute(0,2,1) # [B, N, C] 33 | return x 34 | 35 | class RAR_single(nn.Module): 36 | def __init__(self, d_model=[], size_list=[], num_item=50): 37 | super(RAR_single, self).__init__() 38 | self.d_model = d_model 39 | self.num_item = num_item 40 | self.pos_enc = [] 41 | self.projs = [] 42 | self.projs2 = [] 43 | self.K_list = [] 44 | hidden_dim = min(2*d_model[0], 1024) 45 | for i, f_dim in enumerate(d_model): 46 | self.projs.append(nn.Sequential( 47 | nn.Conv2d(f_dim, hidden_dim, 1), 48 | nn.ReLU(), 49 | nn.Conv2d(hidden_dim, hidden_dim, 1), 50 | nn.ReLU(), 51 | nn.Conv2d(hidden_dim, f_dim, 1)) 52 | ) 53 | # self.K_list.append(nn.Parameter(torch.randn(1, num_item+1, f_dim))) 54 | # self.V_list.append(nn.Parameter(torch.randn(1, 1, f_dim))) 55 | self.K_list.append(nn.Parameter(torch.randn(1, num_item*2, f_dim))) 56 | self.projs2.append(nn.Sequential( 57 | nn.Conv2d(f_dim, hidden_dim, 1), 58 | nn.ReLU(), 59 | nn.Conv2d(hidden_dim, hidden_dim, 1), 60 | nn.ReLU(), 61 | nn.Conv2d(hidden_dim, f_dim, 1) 62 | ) 63 | ) 64 | self.pos_enc.append(positionalencoding2d(f_dim, size_list[i], size_list[i])) 65 | 66 | self.projs = nn.ModuleList(self.projs) 67 | self.projs2 = nn.ModuleList(self.projs2) 68 | self.K_list = nn.ParameterList(self.K_list) # 如果不加这句话则paramter不会更新 69 | self.sigmoid = nn.Sigmoid() 70 | self.tanh = nn.Tanh() 71 | 72 | def regular_score(self,score): 73 | score = torch.where(torch.isnan(score), torch.zeros_like(score), score) 74 | score = torch.where(torch.isinf(score), torch.zeros_like(score), score) 75 | return score 76 | 77 | def forward( 78 | self, 79 | inputs, 80 | flag, 81 | atten_mask=None, 82 | ): 83 | inputs_recons = [] 84 | attention_list = [] 85 | for i in range(len(self.projs)): 86 | # Query 87 | # pdb.set_trace() 88 | pos_enc = self.pos_enc[i].unsqueeze(0).expand(inputs[i].shape[0],-1,-1,-1).to(inputs[i].device) # [B, C, H, W] 89 | q = self.projs[i](inputs[i]+pos_enc) # [B, C, H, W] 90 | q = tensor_to_patch(q) # [B, N, C] 91 | # Key, Value 92 | # pdb.set_trace() 93 | k = self.K_list[i].expand(q.shape[0], -1, -1).to(q.device) # [B, M, C] 94 | # v = self.V_list[i].expand(q.shape[0], -1, -1).to(q.device) # [B, M, C] 95 | # pdb.set_trace() 96 | noise = self.projs2[i](inputs[i]) # [B, N, C] 97 | noise = torch.clamp(self.tanh(noise), min=-1, max=1) 98 | v = tensor_to_patch(noise * inputs[i]) # [B, N, C] 99 | 100 | cos_sim = pair_cosine(q, k) # [B, N, M] 101 | attention_scores = F.softmax(cos_sim / 0.2, dim=-1) # [B, N, M] 102 | attention_scores = attention_scores[:, :, self.num_item:].sum(dim=-1, keepdim=True) # [B,N,1] 103 | attention_list.append(attention_scores) 104 | # if flag: 105 | # mask = (attention_scores >= 0.2).float() 106 | # attention_scores = attention_scores * mask 107 | 108 | if atten_mask is not None: 109 | # pdb.set_trace() 110 | index = {256:0, 512:1, 1024:2} 111 | atten_mask = (resize_lib[index[self.d_model[0]]](atten_mask).reshape(attention_scores.shape[0], -1) > 0.3) # [B,N] 112 | attention_scores = atten_mask.unsqueeze(-1).float() # [B,N,1] 113 | 114 | # weighted Value 115 | # pdb.set_trace() 116 | # q_recons = torch.matmul(attention_scores, v) # [B, N, C] 117 | q_recons = attention_scores * v 118 | # q_recons = attention_scores * v * torch.exp(attention_scores-0.5) 119 | q_recons = patch_to_tensor(q_recons) # [B, C, H, W] 120 | inputs_recons.append(q_recons) 121 | 122 | 123 | return inputs_recons[0], attention_list[0] 124 | 125 | 126 | 127 | class MLP(nn.Module): 128 | def __init__(self, f_dim): 129 | super(MLP, self).__init__() 130 | self.projs = nn.Sequential( 131 | nn.Conv2d(f_dim, 2*f_dim, 1), 132 | nn.ReLU(), 133 | nn.Conv2d(2*f_dim, 2*f_dim, 1), 134 | nn.ReLU(), 135 | nn.Conv2d(2*f_dim, f_dim, 1) 136 | ) 137 | 138 | def forward( 139 | self, 140 | inputs, 141 | flag, 142 | ): 143 | inputs_recons = self.projs(inputs[0]) 144 | 145 | return inputs_recons, None 146 | 147 | 148 | 149 | def positionalencoding2d(D, H, W): 150 | """ 151 | :param D: dimension of the model 152 | :param H: H of the positions 153 | :param W: W of the positions 154 | :return: DxHxW position matrix 155 | """ 156 | if D % 4 != 0: 157 | raise ValueError("Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(D)) 158 | P = torch.zeros(D, H, W) 159 | # Each dimension use half of D 160 | D = D // 2 161 | div_term = torch.exp(torch.arange(0.0, D, 2) * -(math.log(1e4) / D)) 162 | pos_w = torch.arange(0.0, W).unsqueeze(1) 163 | pos_h = torch.arange(0.0, H).unsqueeze(1) 164 | P[0:D:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, H, 1) 165 | P[1:D:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, H, 1) 166 | P[D::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, W) 167 | P[D+1::2,:, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, W) 168 | return P -------------------------------------------------------------------------------- /scripts/fore_extractor.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import os, pdb, sys 6 | sys.path.append(os.getcwd()) 7 | from tqdm import tqdm 8 | import json, glob 9 | from scipy.ndimage import gaussian_filter 10 | import argparse 11 | 12 | 13 | def extract_mvtec(data_root, save_root, obj_names): 14 | obj_list1 = ['carpet','grid','leather','tile','wood','transistor'] 15 | obj_list2 = ['capsule', 'screw', 'zipper'] 16 | obj_list3 = ['hazelnut', 'metal_nut', 'pill', 'toothbrush'] 17 | obj_list4 = ['cabel', 'bottle'] 18 | 19 | for obj in tqdm(obj_names): 20 | rgb_paths = glob.glob(os.path.join(data_root, obj, 'train', 'good') + "/*.png") 21 | rgb_paths.sort() 22 | save_dir = f'{save_root}/{obj}/train' 23 | os.makedirs(save_dir, exist_ok=True) 24 | 25 | for rgb_path in rgb_paths: 26 | img = Image.open(rgb_path).convert('RGB') 27 | image_transforms = transforms.Compose([ 28 | transforms.Grayscale(1) 29 | ]) 30 | img = image_transforms(img) 31 | img = np.asarray(img) 32 | img = gaussian_filter(img, sigma=8) 33 | # pdb.set_trace() 34 | if obj in obj_list1: 35 | ret, new_img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY) 36 | new_img = Image.fromarray(new_img) 37 | elif obj in obj_list2: 38 | ret, new_img = cv2.threshold(img, 150, 255, cv2.THRESH_BINARY) 39 | new_img = Image.fromarray(255-new_img) 40 | elif obj in obj_list3: 41 | ret, new_img = cv2.threshold(img, 50, 255, cv2.THRESH_BINARY) 42 | new_img = Image.fromarray(new_img) 43 | elif obj in obj_list4: 44 | H, W = img.shape 45 | new_img = np.zeros((H, W), dtype=np.uint8) 46 | center_x, center_y = W // 2, H // 2 47 | radius = int(0.45 * H) 48 | for x in range(W): 49 | for y in range(H): 50 | dist = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2) 51 | if dist <= radius: 52 | new_img[y, x] = 255 53 | new_img = Image.fromarray(new_img) 54 | file_name = rgb_path.split('/')[-1].split('.')[0] 55 | new_img.save(f"{save_dir}/{file_name}.png") 56 | 57 | 58 | def extract_visa(data_root, save_root, obj_names): 59 | obj_list1 = ['candle', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 'pcb1'] 60 | obj_list2 = ['pcb2','pcb3','pipe_fryum'] 61 | obj_list3 = ['pcb4'] 62 | obj_list4 = ['capsules'] 63 | 64 | for obj in obj_names: 65 | meta_info = json.load(open(f'{data_root}/meta.json', 'r'))['train'] 66 | rgb_paths = meta_info[obj] 67 | rgb_paths = [x["img_path"] for x in rgb_paths] 68 | rgb_paths.sort() 69 | save_dir = f'{save_root}/{obj}/train' 70 | os.makedirs(save_dir, exist_ok=True) 71 | 72 | for rgb_path in tqdm(rgb_paths): 73 | rgb_path = f'{data_root}/{rgb_path}' 74 | img = Image.open(rgb_path).convert('RGB') 75 | image_transforms = transforms.Compose([ 76 | transforms.Grayscale(1) 77 | ]) 78 | img = image_transforms(img) 79 | img = np.asarray(img) 80 | img = gaussian_filter(img, sigma=8) 81 | if obj in obj_list1: 82 | ret, new_img = cv2.threshold(img, 150, 255, cv2.THRESH_BINARY) 83 | new_img = Image.fromarray(new_img) 84 | elif obj in obj_list2: 85 | ret, new_img = cv2.threshold(img, 100, 255, cv2.THRESH_BINARY) 86 | new_img = Image.fromarray(new_img) 87 | elif obj in obj_list3: 88 | ret, new_img = cv2.threshold(img, 50, 255, cv2.THRESH_BINARY) 89 | new_img = Image.fromarray(new_img) 90 | elif obj in obj_list4: 91 | ret, new_img = cv2.threshold(img, 100, 255, cv2.THRESH_BINARY) 92 | new_img = Image.fromarray(255 - new_img) 93 | file_name = rgb_path.split('/')[-1].split('.')[0] 94 | new_img.save(f"{save_dir}/{file_name}.png") 95 | 96 | 97 | def extract_mvtec3d(data_root, save_root, obj_names): 98 | obj_list1 = ["bagel","cable_gland", "carrot", "cookie", "dowel", "peach","potato", "rope",] 99 | obj_list2 = ["foam", "tire",] 100 | 101 | for obj in tqdm(obj_names): 102 | rgb_paths = glob.glob(os.path.join(data_root, obj, 'train', 'good', 'rgb') + "/*.png") 103 | rgb_paths.sort() 104 | save_dir = f'{save_root}/{obj}/train' 105 | os.makedirs(save_dir, exist_ok=True) 106 | 107 | for rgb_path in rgb_paths: 108 | img = Image.open(rgb_path).convert('RGB') 109 | image_transforms = transforms.Compose([ 110 | transforms.Grayscale(1) 111 | ]) 112 | img = image_transforms(img) 113 | img = np.asarray(img) 114 | img = gaussian_filter(img, sigma=8) 115 | # pdb.set_trace() 116 | if obj in obj_list1: 117 | ret, new_img = cv2.threshold(img, 50, 255, cv2.THRESH_BINARY) 118 | new_img = Image.fromarray(new_img) 119 | elif obj in obj_list2: 120 | H, W = img.shape 121 | new_img = np.zeros((H, W), dtype=np.uint8) 122 | center_x, center_y = W // 2, H // 2 123 | half_width = int(0.35 * W) 124 | half_height = int(0.35 * H) 125 | for x in range(W): 126 | for y in range(H): 127 | if (center_x - half_width) <= x <= (center_x + half_width) and (center_y - half_height) <= y <= (center_y + half_height): 128 | new_img[y, x] = 255 129 | new_img = Image.fromarray(new_img) 130 | 131 | file_name = rgb_path.split('/')[-1].split('.')[0] 132 | new_img.save(f"{save_dir}/{file_name}.png") 133 | # pdb.set_trace() 134 | 135 | 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('--data_path', type=str, default='/data4/tch/AD_data/mvtec') 140 | args = parser.parse_args() 141 | 142 | if 'mvtec' in args.data_path.lower() and 'mvtec3d' not in args.data_path.lower(): 143 | from datasets.MvTec import mvtec_classes 144 | obj_names = mvtec_classes() 145 | save_root = 'fg_mask/mvtec' 146 | extract_mvtec(args.data_path, save_root, obj_names) 147 | 148 | elif 'visa' in args.data_path.lower(): 149 | from datasets.VisA import visa_classes 150 | obj_names = visa_classes() 151 | save_root = 'fg_mask/VisA' 152 | extract_visa(args.data_path, save_root, obj_names) 153 | 154 | elif 'mvtec3d' in args.data_path.lower(): 155 | from datasets.MvTec3D import mvtec3d_classes 156 | obj_names = mvtec3d_classes() 157 | save_root = 'fg_mask/mvtec3d' 158 | extract_mvtec3d(args.data_path, save_root, obj_names) 159 | 160 | else: 161 | print('no such dataset') 162 | -------------------------------------------------------------------------------- /datasets/MvTec.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | import os 4 | import torch 5 | import glob 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | import pdb 9 | import os, sys 10 | sys.path.append(os.getcwd()) 11 | from torch.utils.data import Dataset 12 | from torchvision import transforms as T 13 | import imgaug.augmenters as iaa 14 | # import albumentations as A 15 | from utils.perlin import rand_perlin_2d_np 16 | from utils.utils import get_dataset_name 17 | import random 18 | from utils.vis import vis_anomaly_images 19 | import cv2 20 | from datasets.database import BaseAnomalyDetectionDataset, SynthesisDataset, resize_organized_pc 21 | 22 | 23 | def mvtec_classes(): 24 | return [ 25 | 'carpet', 26 | 'grid', 27 | 'leather', 28 | 'tile', 29 | 'wood', 30 | 'bottle', 31 | 'cable', 32 | 'capsule', 33 | 'hazelnut', 34 | 'metal_nut', 35 | 'pill', 36 | 'screw', 37 | 'toothbrush', 38 | 'transistor', 39 | 'zipper' 40 | ] 41 | 42 | 43 | 44 | 45 | class TrainDataset(BaseAnomalyDetectionDataset): 46 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): 47 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path) 48 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 49 | 50 | def load_dataset(self): 51 | img_tot_paths = [] 52 | tot_labels = [] 53 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good') + "/*.png") 54 | rgb_paths.sort() 55 | img_tot_paths.extend(rgb_paths) 56 | tot_labels.extend([0] * len(rgb_paths)) 57 | return img_tot_paths, tot_labels 58 | 59 | def __len__(self): 60 | return len(self.img_paths) 61 | 62 | def __getitem__(self, idx): 63 | rgb_path, label = self.img_paths[idx], self.labels[idx] 64 | img = Image.open(rgb_path).convert('RGB') 65 | img = self.rgb_transform(img) 66 | 67 | return img, label 68 | 69 | 70 | class TestDataset(BaseAnomalyDetectionDataset): 71 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): 72 | super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path) 73 | self.size = img_size 74 | self.gt_transform = transforms.Compose([ 75 | transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST), 76 | transforms.ToTensor()]) 77 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 78 | 79 | def load_dataset(self): 80 | img_tot_paths = [] 81 | gt_tot_paths = [] 82 | tot_labels = [] 83 | defect_types = os.listdir(self.img_path) 84 | 85 | for defect_type in defect_types: 86 | if defect_type == 'good': 87 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") 88 | rgb_paths.sort() 89 | img_tot_paths.extend(rgb_paths) 90 | gt_tot_paths.extend([0] * len(rgb_paths)) 91 | tot_labels.extend([0] * len(rgb_paths)) 92 | else: 93 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") 94 | gt_paths = glob.glob(os.path.join(self.gt_path, defect_type) + "/*.png") 95 | rgb_paths.sort() 96 | gt_paths.sort() 97 | 98 | img_tot_paths.extend(rgb_paths) 99 | gt_tot_paths.extend(gt_paths) 100 | tot_labels.extend([1] * len(rgb_paths)) 101 | 102 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 103 | 104 | return img_tot_paths, gt_tot_paths, tot_labels 105 | 106 | def __len__(self): 107 | return len(self.img_paths) 108 | 109 | def __getitem__(self, idx): 110 | rgb_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 111 | # pdb.set_trace() 112 | img_original = Image.open(rgb_path).convert('RGB') 113 | img = self.rgb_transform(img_original) 114 | 115 | if gt == 0: 116 | gt = torch.zeros( 117 | [1, self.size, self.size]) 118 | else: 119 | gt = Image.open(gt).convert('L') 120 | gt = self.gt_transform(gt) 121 | gt = torch.where(gt > 0.5, 1., .0) 122 | 123 | return img, gt[:1], label, rgb_path 124 | 125 | 126 | class AnomalyDataset(SynthesisDataset): 127 | def __init__(self, class_name, dataset_path, img_size, aux_path): 128 | super().__init__(class_name=class_name, img_size=img_size, dataset_path=dataset_path, aux_path=aux_path) 129 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 130 | 131 | def load_dataset(self): 132 | img_tot_paths = [] 133 | tot_labels = [] 134 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good') + "/*.png") 135 | # tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff") 136 | rgb_paths.sort() 137 | # tiff_paths.sort() 138 | # sample_paths = list(zip(rgb_paths, tiff_paths)) 139 | # img_tot_paths.extend(sample_paths) 140 | # tot_labels.extend([0] * len(sample_paths)) 141 | img_tot_paths.extend(rgb_paths) 142 | tot_labels.extend([0] * len(rgb_paths)) 143 | return img_tot_paths, tot_labels 144 | 145 | def __len__(self): 146 | return len(self.labels) 147 | 148 | def __getitem__(self, idx): 149 | img_path, label = self.img_paths[idx], self.labels[idx] 150 | rgb_path = img_path 151 | # tiff_path = img_path[1] 152 | class_name = img_path.split("/")[2] 153 | 154 | img = Image.open(rgb_path).convert('RGB') 155 | img = self.rgb_transform(img) 156 | 157 | img_file = rgb_path.split('/')[-1] 158 | fg_path = os.path.join(f'fg_mask/{self.dataset_name}/{self.cls}/train', img_file) 159 | fg_mask = Image.open(fg_path) 160 | fg_mask = np.asarray(fg_mask)[:, :, np.newaxis] # [H, W, 1] 161 | resized_depth_map = resize_organized_pc(fg_mask, img_size=self.size) 162 | fore_mask = resized_depth_map > 0 163 | 164 | # # pdb.set_trace() 165 | # vis_anomaly_images(img.unsqueeze(0), fore_mask.unsqueeze(0).float(), class_name) 166 | 167 | # modify 168 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item() 169 | augmented_image, anomaly_mask, has_anomaly = self.transform_image(img, fore_mask, 170 | self.anomaly_source_paths[anomaly_source_idx]) 171 | # # 下采样mask 172 | # anomaly_mask = torch.from_numpy(anomaly_mask).unsqueeze(0) 173 | 174 | # if has_anomaly == 1.0: 175 | # # pdb.set_trace() 176 | # vis_anomaly_images(torch.from_numpy(augmented_image)[None], torch.from_numpy(anomaly_mask)[None], class_name) 177 | 178 | return {"img": augmented_image, "label": has_anomaly, "anomaly_mask": anomaly_mask, "fore_mask": fore_mask.float()} 179 | 180 | 181 | if __name__ == '__main__': 182 | train_dat = AnomalyDataset(class_name='transistor', img_size=256, dataset_path='/data4/tch/AD_data/mvtec', aux_path='/data4/tch/AD_data/DRAEM_dtd/dtd/images') 183 | 184 | train_dat.__getitem__(0) -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import torch.nn as nn 4 | from models.recons_net import pair_cosine 5 | import pdb 6 | import numpy as np 7 | from utils.vis import vis_hotmap_single 8 | 9 | resize_lib = {0: torch.nn.AdaptiveAvgPool2d((64, 64)), 10 | 1: torch.nn.AdaptiveAvgPool2d((32, 32)) , 11 | 2: torch.nn.AdaptiveAvgPool2d((16, 16))} # 注意:可以尝试不同的下采样方式 12 | mask_thresh = 0.3 13 | 14 | def get_recons_loss(a, b): 15 | # mse_loss = torch.nn.MSELoss() 16 | # mse_loss = torch.nn.MSELoss(reduction='none') 17 | cos_loss = torch.nn.CosineSimilarity() 18 | loss = 0 19 | # pdb.set_trace() 20 | for item in range(len(a)): 21 | # loss += torch.sqrt(mse_loss(a[item], b[item]).sum(dim=1)).mean() # [B,C,H,W]->[B,H,W] 22 | # loss += torch.mean(1-cos_loss(a[item].reshape(a[item].shape[0],-1), 23 | # b[item].reshape(b[item].shape[0],-1))) 24 | loss += torch.mean(1-cos_loss(a[item],b[item])) 25 | return loss 26 | 27 | def get_orth_loss(a, b): 28 | #mse_loss = torch.nn.MSELoss() 29 | cos_loss = torch.nn.CosineSimilarity() 30 | loss = 0 31 | for item in range(len(a)): 32 | #print(a[item].shape) 33 | #print(b[item].shape) 34 | #loss += 0.1*mse_loss(a[item], b[item]) 35 | # loss += torch.mean(cos_loss(a[item], b[item])) # dim=-1 # 为什么用这个的话梯度就很大呢? 36 | # a: [1, 50, C] 37 | # loss += torch.mean( F.relu(cos_loss(a[item].view(a[item].shape[0],-1), 38 | # b[item].view(b[item].shape[0],-1)) ) ) # 为什么刚好收敛到了0呢?写错了,写成了1- 39 | 40 | loss += torch.mean( pair_cosine(a[item], b[item]) ) 41 | return loss 42 | 43 | 44 | def loss_fucntion(a, b): 45 | #mse_loss = torch.nn.MSELoss() 46 | cos_loss = torch.nn.CosineSimilarity() 47 | loss = 0 48 | for item in range(len(a)): 49 | #print(a[item].shape) 50 | #print(b[item].shape) 51 | #loss += 0.1*mse_loss(a[item], b[item]) 52 | # pdb.set_trace() 53 | B, C, H, W = a[item].shape 54 | f_t = a[item].reshape(B,C,H*W) 55 | f_s = b[item].reshape(B,C,H*W) 56 | # weight = 1 / cos_loss(f_t, f_s).detach() # [B,C,H*W] -> [B,H*W] 相似度越高,权重越小 57 | # pdb.set_trace() 58 | loss_i = 1 - cos_loss(f_t, f_s) # [B,C,H*W] -> [B,H*W], 相似度越高,loss越小 59 | # loss += torch.mean(loss_i.topk(k=int(0.1*H*W), dim=1)[0]) # 选择损失大的 60 | loss += torch.mean(loss_i.topk(k=10, dim=1)[0]) # 选择损失大的 61 | loss += torch.mean(1-cos_loss(a[item].reshape(a[item].shape[0],-1), 62 | b[item].reshape(b[item].shape[0],-1))) 63 | # loss += torch.mean(weight * loss_i) / 3 64 | # loss += torch.mean(loss_i) 65 | 66 | return loss 67 | 68 | 69 | def entropy(p): 70 | # p: (N*M) 71 | logits = -p * torch.log(p + 0.0001) 72 | entropy = torch.sum(logits, dim=-1) # (N) 73 | return torch.mean(entropy) 74 | 75 | def get_entropy_loss(attention_list, mask_ori): 76 | # attention_list[[B,N,M],[]], mask_ori[B,1,H,W] 77 | entropy_loss = 0.0 78 | num_item = attention_list[0].shape[-1] 79 | for i, atten in enumerate(attention_list): 80 | mask = resize_lib[i](mask_ori).reshape(-1) > 0.3 # [B*N] 81 | atten = atten.reshape(-1, num_item) # [B*N, M] 82 | # pdb.set_trace() 83 | if mask.sum() != 0: 84 | entropy_loss += entropy(atten[mask]) # mask, anomaly 增大 num_item: 85 | return entropy_loss 86 | 87 | 88 | def get_focal_loss(pred_list, mask_ori, vit=False): 89 | # bce_loss = nn.BCELoss() 90 | bce_loss = BCEFocalLoss() 91 | total_loss = 0.0 92 | num_item = pred_list[0].shape[-1] 93 | P_sum, R_sum = 0, 0 94 | for i, pred in enumerate(pred_list): 95 | mask = (resize_lib[i](mask_ori).reshape(-1) > 0.3).float() # [B*N] 96 | if vit: 97 | mask = (resize_lib[2](mask_ori).reshape(-1) > 0.3).float() # [B*N] 98 | pred = pred.reshape(-1) 99 | loss = bce_loss(pred, mask) 100 | pred = (pred>0.5).float() 101 | PT = (pred.bool() & mask.bool()).sum() 102 | P = PT / (pred.sum()+1e-5) 103 | R = PT / (mask.sum()+1e-5) 104 | P_sum += P 105 | R_sum += R 106 | total_loss += loss 107 | return total_loss, P_sum/3, R_sum/3 108 | 109 | 110 | def get_amplify_loss(inputs, inputs_AT, mask_ori, fore_mask_ori=None, vit=False): 111 | cos_loss = torch.nn.CosineSimilarity() 112 | normal_loss = 0 113 | anomaly_loss = 0 114 | # 64, 32, 16 115 | for i in range(len(inputs)): 116 | # pdb.set_trace() 117 | B, C, h, w = inputs[i].shape 118 | mask = resize_lib[i](mask_ori).reshape(-1) > mask_thresh # [B*h*w] 119 | if vit: 120 | mask = resize_lib[2](mask_ori).reshape(-1) > mask_thresh 121 | 122 | if fore_mask_ori is not None: 123 | fore_mask = resize_lib[i](fore_mask_ori).reshape(-1) > mask_thresh 124 | if vit: 125 | fore_mask = resize_lib[2](fore_mask_ori).reshape(-1) > mask_thresh 126 | 127 | input = inputs[i].permute(0,2,3,1).reshape(-1, C) # [B,C,h,w]->[B,h,w,C]->[B*h*w,C] 128 | input_AT = inputs_AT[i].permute(0,2,3,1).reshape(-1, C) # [B,C,h,w]->[B,h,w,C]->[B*h*w,C] 129 | # normal unchange 130 | normal_loss += torch.mean(1-cos_loss(input[~mask], input_AT[~mask])) 131 | 132 | if fore_mask_ori is not None: 133 | normal_mask = (~mask)&(fore_mask) 134 | else: 135 | normal_mask = ~mask 136 | n_idx = np.random.permutation(normal_mask.sum().item())[:5000] 137 | a_idx = np.random.permutation(mask.sum().item())[:1000] 138 | input_normal, input_AT_normal = input[normal_mask][n_idx], input_AT[normal_mask][n_idx] 139 | if mask.sum() > 0: 140 | input_anomaly, input_AT_anomaly = input[mask][a_idx], input_AT[mask][a_idx] 141 | s_anomaly = pair_cosine(input_normal.unsqueeze(0), input_anomaly.unsqueeze(0))[0] 142 | s_AT_anomaly = pair_cosine(input_normal.unsqueeze(0), input_AT_anomaly.unsqueeze(0))[0] 143 | weight = 1 144 | anomaly_loss += torch.mean(F.relu(s_AT_anomaly - (s_anomaly-0.3)) * weight) 145 | else: 146 | anomaly_loss += 0 147 | 148 | return normal_loss, anomaly_loss 149 | 150 | 151 | def get_FnormLoss(inputs, mask_ori, vit=False): 152 | fnorm_loss = 0.0 153 | afnorm_loss = 0.0 154 | count = 0 155 | for i in range(len(inputs)): 156 | B, C, h, w = inputs[i].shape 157 | mask = resize_lib[i](mask_ori).reshape(-1) > mask_thresh # [B*h*w] 158 | if vit: 159 | mask = resize_lib[2](mask_ori).reshape(-1) > mask_thresh 160 | input = inputs[i].permute(0,2,3,1).reshape(-1, C) # [B,C,h,w]->[B,h,w,C]->[B*h*w,C] 161 | fnorm_loss += (input[~mask] ** 2).mean() 162 | if mask.sum() != 0: 163 | afnorm_loss += (input[mask] ** 2).mean() 164 | # afnorm_loss += F.relu(math.sqrt(C) - temp) 165 | count += 1 166 | fnorm_loss = fnorm_loss / 3 167 | if count > 0: 168 | afnorm_loss = afnorm_loss / count 169 | return fnorm_loss, afnorm_loss 170 | 171 | 172 | class BCEFocalLoss(torch.nn.Module): 173 | """ 174 | 二分类的Focalloss alpha 固定 175 | """ 176 | def __init__(self, gamma=2, alpha=0.8, reduction='elementwise_mean'): 177 | super().__init__() 178 | self.gamma = gamma 179 | self.alpha = alpha 180 | self.reduction = reduction 181 | 182 | def forward(self, pt, target): 183 | # pt = torch.sigmoid(_input) 184 | # pdb.set_trace() 185 | alpha = self.alpha 186 | loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \ 187 | (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt) 188 | if self.reduction == 'elementwise_mean': 189 | loss = torch.mean(loss) 190 | elif self.reduction == 'sum': 191 | loss = torch.sum(loss) 192 | return loss 193 | -------------------------------------------------------------------------------- /datasets/VisA.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | import os 4 | import torch 5 | import glob 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | import pdb 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms as T 11 | import imgaug.augmenters as iaa 12 | # import albumentations as A 13 | from utils.perlin import rand_perlin_2d_np 14 | import random 15 | from utils.vis import vis_anomaly_images 16 | import cv2 17 | import json 18 | from datasets.database import BaseAnomalyDetectionDataset, SynthesisDataset, resize_organized_pc 19 | 20 | 21 | def visa_classes(): 22 | return[ 23 | 'candle', 24 | 'capsules', 25 | 'cashew', 26 | 'chewinggum', 27 | 'fryum', 28 | 'macaroni1', 29 | 'macaroni2', 30 | 'pcb1', 31 | 'pcb2', 32 | 'pcb3', 33 | 'pcb4', 34 | 'pipe_fryum' 35 | ] 36 | 37 | 38 | 39 | class TrainDataset(BaseAnomalyDetectionDataset): 40 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): 41 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path) 42 | self.split = 'train' 43 | self.meta_info = json.load(open(f'{dataset_path}/meta.json', 'r'))[self.split] 44 | self.cls_name = class_name 45 | self.dataset_path = dataset_path 46 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 47 | 48 | def load_dataset(self): 49 | img_tot_paths = [] 50 | tot_labels = [] 51 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'good') + "/*.png") 52 | # pdb.set_trace() 53 | rgb_paths = self.meta_info[self.cls_name] 54 | rgb_paths = [f'{self.dataset_path}/{x["img_path"]}' for x in rgb_paths] 55 | rgb_paths.sort() 56 | img_tot_paths.extend(rgb_paths) 57 | tot_labels.extend([0] * len(rgb_paths)) 58 | return img_tot_paths, tot_labels 59 | 60 | def __len__(self): 61 | return len(self.img_paths) 62 | 63 | def __getitem__(self, idx): 64 | rgb_path, label = self.img_paths[idx], self.labels[idx] 65 | img = Image.open(rgb_path).convert('RGB') 66 | img = self.rgb_transform(img) 67 | 68 | return img, label 69 | 70 | 71 | class TestDataset(BaseAnomalyDetectionDataset): 72 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): 73 | super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path) 74 | self.split = 'test' 75 | self.meta_info = json.load(open(f'{dataset_path}/meta.json', 'r'))[self.split] 76 | self.cls_name = class_name 77 | self.dataset_path = dataset_path 78 | self.size = img_size 79 | self.gt_transform = transforms.Compose([ 80 | transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST), 81 | transforms.ToTensor()]) 82 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 83 | 84 | def load_dataset(self): 85 | img_tot_paths = [] 86 | gt_tot_paths = [] 87 | tot_labels = [] 88 | # defect_types = os.listdir(self.img_path) 89 | # pdb.set_trace() 90 | rgb_paths = self.meta_info[self.cls_name] 91 | # normal samples 92 | normal_paths = [f'{self.dataset_path}/{x["img_path"]}' for x in rgb_paths if x['anomaly']==0] 93 | normal_paths.sort() 94 | img_tot_paths.extend(normal_paths) 95 | gt_tot_paths.extend([0] * len(normal_paths)) 96 | tot_labels.extend([0] * len(normal_paths)) 97 | # anomaly samples 98 | anomaly_paths = [f'{self.dataset_path}/{x["img_path"]}' for x in rgb_paths if x['anomaly']==1] 99 | mask_paths = [f'{self.dataset_path}/{x["mask_path"]}' for x in rgb_paths if x['anomaly']==1] 100 | anomaly_paths.sort() 101 | mask_paths.sort() 102 | img_tot_paths.extend(anomaly_paths) 103 | gt_tot_paths.extend(mask_paths) 104 | tot_labels.extend([1] * len(anomaly_paths)) 105 | 106 | # for defect_type in defect_types: 107 | # if defect_type == 'good': 108 | # rgb_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") 109 | # rgb_paths.sort() 110 | # img_tot_paths.extend(rgb_paths) 111 | # gt_tot_paths.extend([0] * len(rgb_paths)) 112 | # tot_labels.extend([0] * len(rgb_paths)) 113 | # else: 114 | # rgb_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") 115 | # gt_paths = glob.glob(os.path.join(self.gt_path, defect_type) + "/*.png") 116 | # rgb_paths.sort() 117 | # gt_paths.sort() 118 | 119 | # img_tot_paths.extend(rgb_paths) 120 | # gt_tot_paths.extend(gt_paths) 121 | # tot_labels.extend([1] * len(rgb_paths)) 122 | 123 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 124 | 125 | return img_tot_paths, gt_tot_paths, tot_labels 126 | 127 | def __len__(self): 128 | return len(self.img_paths) 129 | 130 | def __getitem__(self, idx): 131 | rgb_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 132 | img_original = Image.open(rgb_path).convert('RGB') 133 | img = self.rgb_transform(img_original) 134 | 135 | if gt == 0: 136 | gt = torch.zeros( 137 | [1, self.size, self.size]) 138 | else: 139 | # pdb.set_trace() 140 | # gt = Image.open(gt).convert('L') 141 | gt= np.array(Image.open(gt).convert('L')) > 0 142 | gt = Image.fromarray(gt.astype(np.uint8) * 255, mode='L') 143 | gt = self.gt_transform(gt) 144 | gt = torch.where(gt > 0.5, 1., .0) 145 | 146 | # img_file = rgb_path.split('/')[-1].split('.')[0] 147 | # # pdb.set_trace() 148 | # fg_path = os.path.join(f'fg_mask/{self.cls}/test', img_file + '.png') 149 | # if os.path.exists(fg_path): 150 | # fg_mask = Image.open(fg_path) 151 | # fg_mask = np.asarray(fg_mask)[:, :, np.newaxis] # [H, W, 1] 152 | # resized_depth_map = resize_organized_pc(fg_mask, img_size=self.size) 153 | # fore_mask = resized_depth_map > 0 154 | # else: 155 | # fore_mask = None 156 | 157 | return img, gt[:1], label, rgb_path 158 | 159 | 160 | class AnomalyDataset(SynthesisDataset): 161 | def __init__(self, class_name, dataset_path, img_size, aux_path): 162 | super().__init__(class_name=class_name, img_size=img_size, dataset_path=dataset_path, aux_path=aux_path) 163 | self.split = 'train' 164 | self.dataset_path = dataset_path 165 | self.meta_info = json.load(open(f'{dataset_path}/meta.json', 'r'))[self.split] 166 | self.cls_name = class_name 167 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 168 | 169 | def load_dataset(self): 170 | img_tot_paths = [] 171 | tot_labels = [] 172 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'good') + "/*.png") 173 | # pdb.set_trace() 174 | rgb_paths = self.meta_info[self.cls_name] 175 | rgb_paths = [f'{self.dataset_path}/{x["img_path"]}' for x in rgb_paths] 176 | # tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff") 177 | rgb_paths.sort() 178 | # tiff_paths.sort() 179 | # sample_paths = list(zip(rgb_paths, tiff_paths)) 180 | # img_tot_paths.extend(sample_paths) 181 | # tot_labels.extend([0] * len(sample_paths)) 182 | img_tot_paths.extend(rgb_paths) 183 | tot_labels.extend([0] * len(rgb_paths)) 184 | return img_tot_paths, tot_labels 185 | 186 | 187 | def __len__(self): 188 | return len(self.labels) 189 | 190 | def __getitem__(self, idx): 191 | img_path, label = self.img_paths[idx], self.labels[idx] 192 | rgb_path = img_path 193 | # tiff_path = img_path[1] 194 | class_name = img_path.split("/")[2] 195 | 196 | img = Image.open(rgb_path).convert('RGB') 197 | img = self.rgb_transform(img) 198 | 199 | img_file = rgb_path.split('/')[-1].split('.')[0] 200 | # pdb.set_trace() 201 | fg_path = os.path.join(f'fg_mask/{self.dataset_name}/{self.cls}/train', img_file + '.png') 202 | if os.path.exists(fg_path): 203 | fg_mask = Image.open(fg_path) 204 | fg_mask = np.asarray(fg_mask)[:, :, np.newaxis] # [H, W, 1] 205 | resized_depth_map = resize_organized_pc(fg_mask, img_size=self.size) 206 | fore_mask = resized_depth_map > 0 207 | else: 208 | fore_mask = None 209 | 210 | 211 | # pdb.set_trace() 212 | # vis_anomaly_images(img.unsqueeze(0), fore_mask.unsqueeze(0).float(), class_name) 213 | 214 | # modify 215 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item() 216 | augmented_image, anomaly_mask, has_anomaly = self.transform_image(img, fore_mask, 217 | self.anomaly_source_paths[anomaly_source_idx]) 218 | # # 下采样mask 219 | # anomaly_mask = torch.from_numpy(anomaly_mask).unsqueeze(0) 220 | 221 | # if has_anomaly == 1.0: 222 | # pdb.set_trace() 223 | # vis_anomaly_images(torch.from_numpy(augmented_image)[None], torch.from_numpy(anomaly_mask)[None], class_name) 224 | 225 | return {"img": augmented_image, "label": has_anomaly, "anomaly_mask": anomaly_mask, "fore_mask": fore_mask.float()} 226 | -------------------------------------------------------------------------------- /datasets/database.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | import os 4 | import torch 5 | import glob 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | import pdb 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms as T 11 | import imgaug.augmenters as iaa 12 | # import albumentations as A 13 | from utils.perlin import rand_perlin_2d_np 14 | from utils.utils import get_dataset_name 15 | import random 16 | from utils.vis import vis_anomaly_images 17 | import cv2 18 | 19 | def resize_organized_pc(organized_pc, img_size=256, tensor_out=True): 20 | torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous() 21 | torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=img_size, 22 | mode='nearest') 23 | if tensor_out: 24 | return torch_resized_organized_pc.squeeze(dim=0).contiguous() 25 | else: 26 | return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy() 27 | 28 | class BaseAnomalyDetectionDataset(Dataset): 29 | def __init__(self, split, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): 30 | self.IMAGENET_MEAN = [0.485, 0.456, 0.406] 31 | self.IMAGENET_STD = [0.229, 0.224, 0.225] 32 | self.cls = class_name 33 | self.size = img_size 34 | self.dataset_name = get_dataset_name(dataset_path) 35 | if self.dataset_name == 'mvtec': 36 | self.img_path = os.path.join(dataset_path, self.cls, split) 37 | self.gt_path = os.path.join(dataset_path, self.cls, 'ground_truth') 38 | elif self.dataset_name == 'mvtec3d': 39 | self.img_path = os.path.join(dataset_path, self.cls, split) 40 | self.rgb_transform = transforms.Compose( 41 | [transforms.Resize((self.size, self.size), interpolation=transforms.InterpolationMode.BICUBIC), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)]) 44 | 45 | 46 | class SynthesisDataset(BaseAnomalyDetectionDataset): 47 | def __init__(self, class_name, dataset_path, img_size, aux_path): 48 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path) 49 | self.anomaly_source_paths = sorted(glob.glob(aux_path + "/*/*.jpg")) 50 | 51 | self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True), 52 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)), 53 | iaa.pillike.EnhanceSharpness(), 54 | iaa.AddToHueAndSaturation((-50,50),per_channel=True), 55 | iaa.Solarize(0.5, threshold=(32,128)), 56 | iaa.Posterize(), 57 | iaa.Invert(), 58 | iaa.pillike.Autocontrast(), 59 | iaa.pillike.Equalize(), 60 | iaa.Affine(rotate=(-45, 45)) 61 | ] 62 | # self.noise_transform = transforms.Compose( 63 | # [transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)]) 64 | 65 | self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) 66 | 67 | self.resize64 = torch.nn.AdaptiveAvgPool2d((64, 64)) 68 | self.resize32 = torch.nn.AdaptiveAvgPool2d((32, 32)) 69 | self.resize16 = torch.nn.AdaptiveAvgPool2d((16, 16)) 70 | 71 | # 随机选择3种数据增强方式 72 | def randAugmenter(self): 73 | aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False) 74 | aug = iaa.Sequential([self.augmenters[aug_ind[0]], 75 | self.augmenters[aug_ind[1]], 76 | self.augmenters[aug_ind[2]]] 77 | ) 78 | return aug 79 | 80 | def augment_image(self, image, anomaly_source_path, fore_mask): 81 | aug = self.randAugmenter() 82 | perlin_scale = 6 83 | min_perlin_scale = 0 84 | 85 | anomaly_source_img = cv2.imread(anomaly_source_path) 86 | anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.size, self.size)) 87 | 88 | anomaly_img_augmented = aug(image=anomaly_source_img) 89 | perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 90 | perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 91 | 92 | if fore_mask is not None: 93 | count = 0 94 | while True: 95 | count += 1 96 | perlin_noise = rand_perlin_2d_np((self.size, self.size), (perlin_scalex, perlin_scaley)) 97 | perlin_noise = self.rot(image=perlin_noise) 98 | threshold = 0.5 99 | # modify 100 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) 101 | perlin_thr = np.expand_dims(perlin_thr, axis=2) 102 | perlin_thr = perlin_thr * fore_mask 103 | # pdb.set_trace() 104 | if perlin_thr.sum() > 4 or count > 10: 105 | break 106 | else: 107 | perlin_noise = rand_perlin_2d_np((self.size, self.size), (perlin_scalex, perlin_scaley)) 108 | perlin_noise = self.rot(image=perlin_noise) 109 | threshold = 0.5 110 | # modify 111 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) 112 | perlin_thr = np.expand_dims(perlin_thr, axis=2) 113 | # 114 | # modify, '/255' 改成 imagenet 的归一化 115 | # 测试一下img_thr的最大值和image的最小值 116 | # img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0 117 | img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0 118 | image_mean = np.array(self.IMAGENET_MEAN).reshape(1,1,3) 119 | image_std = np.array(self.IMAGENET_STD).reshape(1,1,3) 120 | img_thr = (img_thr - image_mean) / image_std 121 | # 122 | 123 | beta = torch.rand(1).numpy()[0] * 0.8 124 | augmented_image = image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * ( 125 | perlin_thr) 126 | 127 | augmented_image = augmented_image.astype(np.float32) 128 | msk = (perlin_thr).astype(np.float32) 129 | augmented_image = msk * augmented_image + (1-msk)*image # This line is unnecessary and can be deleted 130 | has_anomaly = 1.0 131 | 132 | no_anomaly = torch.rand(1).numpy()[0] 133 | if no_anomaly > 0.8: 134 | image = image.astype(np.float32) 135 | return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0],dtype=np.float32) 136 | else: # 0.8概率产生异常 137 | augmented_image = augmented_image.astype(np.float32) 138 | msk = (perlin_thr).astype(np.float32) 139 | augmented_image = msk * augmented_image + (1-msk)*image # This line is unnecessary and can be deleted 140 | has_anomaly = 1.0 141 | if np.sum(msk) == 0: 142 | has_anomaly=0.0 143 | return augmented_image, msk, np.array([has_anomaly],dtype=np.float32) 144 | 145 | 146 | def transform_image(self, image, fore_mask, anomaly_source_path): 147 | image = image.permute(1,2,0).numpy() 148 | if fore_mask is not None: 149 | fore_mask = fore_mask.permute(1,2,0).numpy() 150 | 151 | # normalize the image to 0.0~1.0 152 | augmented_image, anomaly_mask, has_anomaly = self.augment_image(image, anomaly_source_path, fore_mask) 153 | augmented_image = np.transpose(augmented_image, (2, 0, 1)) 154 | image = np.transpose(image, (2, 0, 1)) 155 | anomaly_mask = np.transpose(anomaly_mask, (2, 0, 1)) 156 | return augmented_image, anomaly_mask, has_anomaly 157 | 158 | 159 | 160 | class CutPasteDataSet(BaseAnomalyDetectionDataset): 161 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): 162 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path) 163 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 164 | self.cutpaste_transform = CutPaste_fg(type='3way') 165 | self.gt_transform = transforms.Compose([ 166 | transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST), 167 | transforms.ToTensor()]) 168 | 169 | def load_dataset(self): 170 | img_tot_paths = [] 171 | tot_labels = [] 172 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png") 173 | tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff") 174 | rgb_paths.sort() 175 | tiff_paths.sort() 176 | sample_paths = list(zip(rgb_paths, tiff_paths)) 177 | img_tot_paths.extend(sample_paths) 178 | tot_labels.extend([0] * len(sample_paths)) 179 | return img_tot_paths, tot_labels 180 | 181 | def __len__(self): 182 | return len(self.img_paths) 183 | 184 | def __getitem__(self, idx): 185 | img_path, label = self.img_paths[idx], self.labels[idx] 186 | rgb_path = img_path[0] 187 | tiff_path = img_path[1] 188 | img = Image.open(rgb_path).convert('RGB') 189 | # img.save("imgs/ori_img.jpg") 190 | # pdb.set_trace() 191 | organized_pc = read_tiff_organized_pc(tiff_path) 192 | depth_map = organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis] 193 | resized_depth_map = resize_organized_pc(depth_map, img_size=[img.size[1],img.size[0]]).squeeze(0).numpy() 194 | fore_mask = resized_depth_map > 0 195 | 196 | cutpaste, anomaly_mask = self.cutpaste_transform(img, fore_mask) 197 | # cutpaste = self.copy_paste(img, fore_mask) 198 | # cutpaste.save("imgs/aug_img1.jpg") 199 | # cutpaste_scar.save("imgs/aug_img2.jpg") 200 | cutpaste = Image.fromarray(cutpaste) 201 | anomaly_mask = Image.fromarray(anomaly_mask) 202 | aug_img = self.rgb_transform(cutpaste) 203 | anomaly_mask = self.gt_transform(anomaly_mask) 204 | 205 | return {"img": aug_img, "anomaly_mask": anomaly_mask, "fore_mask": fore_mask} 206 | 207 | -------------------------------------------------------------------------------- /vit_version/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | import pdb 13 | 14 | from .model import build_model 15 | # from .model_plus import build_model 16 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | 18 | try: 19 | from torchvision.transforms import InterpolationMode 20 | BICUBIC = InterpolationMode.BICUBIC 21 | except ImportError: 22 | BICUBIC = Image.BICUBIC 23 | 24 | 25 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 26 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 27 | 28 | 29 | __all__ = ["available_models", "load", "tokenize"] 30 | _tokenizer = _Tokenizer() 31 | 32 | _MODELS = { 33 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 34 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 35 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 36 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 37 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 42 | } 43 | 44 | 45 | def _download(url: str, root: str): 46 | os.makedirs(root, exist_ok=True) 47 | filename = os.path.basename(url) 48 | 49 | expected_sha256 = url.split("/")[-2] 50 | download_target = os.path.join(root, filename) 51 | 52 | if os.path.exists(download_target) and not os.path.isfile(download_target): 53 | raise RuntimeError(f"{download_target} exists and is not a regular file") 54 | 55 | if os.path.isfile(download_target): 56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 57 | return download_target 58 | else: 59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 60 | 61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 63 | while True: 64 | buffer = source.read(8192) 65 | if not buffer: 66 | break 67 | 68 | output.write(buffer) 69 | loop.update(len(buffer)) 70 | 71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 72 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 73 | 74 | return download_target 75 | 76 | 77 | def _convert_image_to_rgb(image): 78 | return image.convert("RGB") 79 | 80 | 81 | def _transform(n_px): 82 | return Compose([ 83 | Resize((n_px, n_px), interpolation=BICUBIC), # Modify, 之前是(n_px, n_px) 84 | CenterCrop((n_px, n_px)), 85 | _convert_image_to_rgb, 86 | ToTensor(), 87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 88 | ]) 89 | 90 | 91 | 92 | def available_models() -> List[str]: 93 | """Returns the names of available CLIP models""" 94 | return list(_MODELS.keys()) 95 | 96 | 97 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, image_size: int = None): 98 | """Load a CLIP model 99 | 100 | Parameters 101 | ---------- 102 | name : str 103 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 104 | 105 | device : Union[str, torch.device] 106 | The device to put the loaded model 107 | 108 | jit : bool 109 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 110 | 111 | download_root: str 112 | path to download the model files; by default, it uses "~/.cache/clip" 113 | 114 | Returns 115 | ------- 116 | model : torch.nn.Module 117 | The CLIP model 118 | 119 | preprocess : Callable[[PIL.Image], torch.Tensor] 120 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 121 | """ 122 | if name in _MODELS: 123 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 124 | elif os.path.isfile(name): 125 | model_path = name 126 | else: 127 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 128 | 129 | # pdb.set_trace() 130 | with open(model_path, 'rb') as opened_file: 131 | try: 132 | # loading JIT archive 133 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 134 | state_dict = None 135 | except RuntimeError: 136 | # loading saved state dict 137 | if jit: 138 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 139 | jit = False 140 | state_dict = torch.load(opened_file, map_location="cpu") 141 | 142 | if not jit: 143 | model = build_model(state_dict or model.state_dict(), image_size).to(device) 144 | if str(device) == "cpu": 145 | model.float() 146 | return model, _transform(image_size) 147 | 148 | # patch the device names 149 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 150 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 151 | 152 | def patch_device(module): 153 | try: 154 | graphs = [module.graph] if hasattr(module, "graph") else [] 155 | except RuntimeError: 156 | graphs = [] 157 | 158 | if hasattr(module, "forward1"): 159 | graphs.append(module.forward1.graph) 160 | 161 | for graph in graphs: 162 | for node in graph.findAllNodes("prim::Constant"): 163 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 164 | node.copyAttributes(device_node) 165 | 166 | model.apply(patch_device) 167 | patch_device(model.encode_image) 168 | patch_device(model.encode_text) 169 | 170 | # patch dtype to float32 on CPU 171 | if str(device) == "cpu": 172 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 173 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 174 | float_node = float_input.node() 175 | 176 | def patch_float(module): 177 | try: 178 | graphs = [module.graph] if hasattr(module, "graph") else [] 179 | except RuntimeError: 180 | graphs = [] 181 | 182 | if hasattr(module, "forward1"): 183 | graphs.append(module.forward1.graph) 184 | 185 | for graph in graphs: 186 | for node in graph.findAllNodes("aten::to"): 187 | inputs = list(node.inputs()) 188 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 189 | if inputs[i].node()["value"] == 5: 190 | inputs[i].node().copyAttributes(float_node) 191 | 192 | model.apply(patch_float) 193 | patch_float(model.encode_image) 194 | patch_float(model.encode_text) 195 | 196 | model.float() 197 | 198 | return model, _transform(image_size) 199 | 200 | 201 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 202 | """ 203 | Returns the tokenized representation of given input string(s) 204 | 205 | Parameters 206 | ---------- 207 | texts : Union[str, List[str]] 208 | An input string or a list of input strings to tokenize 209 | 210 | context_length : int 211 | The context length to use; all CLIP models use 77 as the context length 212 | 213 | truncate: bool 214 | Whether to truncate the text in case its encoding is longer than the context length 215 | 216 | Returns 217 | ------- 218 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 219 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 220 | """ 221 | if isinstance(texts, str): 222 | texts = [texts] 223 | 224 | sot_token = _tokenizer.encoder["<|startoftext|>"] 225 | eot_token = _tokenizer.encoder["<|endoftext|>"] 226 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 227 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 228 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 229 | else: 230 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 231 | 232 | for i, tokens in enumerate(all_tokens): 233 | if len(tokens) > context_length: 234 | if truncate: 235 | tokens = tokens[:context_length] 236 | tokens[-1] = eot_token 237 | else: 238 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 239 | result[i, :len(tokens)] = torch.tensor(tokens) 240 | 241 | return result 242 | -------------------------------------------------------------------------------- /vit_version/clip_rar/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | import pdb 13 | 14 | from .model import build_model 15 | # from .model_plus import build_model 16 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | 18 | try: 19 | from torchvision.transforms import InterpolationMode 20 | BICUBIC = InterpolationMode.BICUBIC 21 | except ImportError: 22 | BICUBIC = Image.BICUBIC 23 | 24 | 25 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 26 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 27 | 28 | 29 | __all__ = ["available_models", "load", "tokenize"] 30 | _tokenizer = _Tokenizer() 31 | 32 | _MODELS = { 33 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 34 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 35 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 36 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 37 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 42 | } 43 | 44 | 45 | def _download(url: str, root: str): 46 | os.makedirs(root, exist_ok=True) 47 | filename = os.path.basename(url) 48 | 49 | expected_sha256 = url.split("/")[-2] 50 | download_target = os.path.join(root, filename) 51 | 52 | if os.path.exists(download_target) and not os.path.isfile(download_target): 53 | raise RuntimeError(f"{download_target} exists and is not a regular file") 54 | 55 | if os.path.isfile(download_target): 56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 57 | return download_target 58 | else: 59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 60 | 61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 63 | while True: 64 | buffer = source.read(8192) 65 | if not buffer: 66 | break 67 | 68 | output.write(buffer) 69 | loop.update(len(buffer)) 70 | 71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 72 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 73 | 74 | return download_target 75 | 76 | 77 | def _convert_image_to_rgb(image): 78 | return image.convert("RGB") 79 | 80 | 81 | def _transform(n_px): 82 | return Compose([ 83 | Resize((n_px, n_px), interpolation=BICUBIC), # Modify, 之前是(n_px, n_px) 84 | CenterCrop((n_px, n_px)), 85 | _convert_image_to_rgb, 86 | ToTensor(), 87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 88 | ]) 89 | 90 | 91 | 92 | def available_models() -> List[str]: 93 | """Returns the names of available CLIP models""" 94 | return list(_MODELS.keys()) 95 | 96 | 97 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, image_size: int = None): 98 | """Load a CLIP model 99 | 100 | Parameters 101 | ---------- 102 | name : str 103 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 104 | 105 | device : Union[str, torch.device] 106 | The device to put the loaded model 107 | 108 | jit : bool 109 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 110 | 111 | download_root: str 112 | path to download the model files; by default, it uses "~/.cache/clip" 113 | 114 | Returns 115 | ------- 116 | model : torch.nn.Module 117 | The CLIP model 118 | 119 | preprocess : Callable[[PIL.Image], torch.Tensor] 120 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 121 | """ 122 | if name in _MODELS: 123 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 124 | elif os.path.isfile(name): 125 | model_path = name 126 | else: 127 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 128 | 129 | # pdb.set_trace() 130 | with open(model_path, 'rb') as opened_file: 131 | try: 132 | # loading JIT archive 133 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 134 | state_dict = None 135 | except RuntimeError: 136 | # loading saved state dict 137 | if jit: 138 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 139 | jit = False 140 | state_dict = torch.load(opened_file, map_location="cpu") 141 | 142 | if not jit: 143 | model = build_model(state_dict or model.state_dict(), image_size).to(device) 144 | if str(device) == "cpu": 145 | model.float() 146 | return model, _transform(image_size) 147 | 148 | # patch the device names 149 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 150 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 151 | 152 | def patch_device(module): 153 | try: 154 | graphs = [module.graph] if hasattr(module, "graph") else [] 155 | except RuntimeError: 156 | graphs = [] 157 | 158 | if hasattr(module, "forward1"): 159 | graphs.append(module.forward1.graph) 160 | 161 | for graph in graphs: 162 | for node in graph.findAllNodes("prim::Constant"): 163 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 164 | node.copyAttributes(device_node) 165 | 166 | model.apply(patch_device) 167 | patch_device(model.encode_image) 168 | patch_device(model.encode_text) 169 | 170 | # patch dtype to float32 on CPU 171 | if str(device) == "cpu": 172 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 173 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 174 | float_node = float_input.node() 175 | 176 | def patch_float(module): 177 | try: 178 | graphs = [module.graph] if hasattr(module, "graph") else [] 179 | except RuntimeError: 180 | graphs = [] 181 | 182 | if hasattr(module, "forward1"): 183 | graphs.append(module.forward1.graph) 184 | 185 | for graph in graphs: 186 | for node in graph.findAllNodes("aten::to"): 187 | inputs = list(node.inputs()) 188 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 189 | if inputs[i].node()["value"] == 5: 190 | inputs[i].node().copyAttributes(float_node) 191 | 192 | model.apply(patch_float) 193 | patch_float(model.encode_image) 194 | patch_float(model.encode_text) 195 | 196 | model.float() 197 | 198 | return model, _transform(image_size) 199 | 200 | 201 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 202 | """ 203 | Returns the tokenized representation of given input string(s) 204 | 205 | Parameters 206 | ---------- 207 | texts : Union[str, List[str]] 208 | An input string or a list of input strings to tokenize 209 | 210 | context_length : int 211 | The context length to use; all CLIP models use 77 as the context length 212 | 213 | truncate: bool 214 | Whether to truncate the text in case its encoding is longer than the context length 215 | 216 | Returns 217 | ------- 218 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 219 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 220 | """ 221 | if isinstance(texts, str): 222 | texts = [texts] 223 | 224 | sot_token = _tokenizer.encoder["<|startoftext|>"] 225 | eot_token = _tokenizer.encoder["<|endoftext|>"] 226 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 227 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 228 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 229 | else: 230 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 231 | 232 | for i, tokens in enumerate(all_tokens): 233 | if len(tokens) > context_length: 234 | if truncate: 235 | tokens = tokens[:context_length] 236 | tokens[-1] = eot_token 237 | else: 238 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 239 | result[i, :len(tokens)] = torch.tensor(tokens) 240 | 241 | return result 242 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import ImageFolder 3 | import numpy as np 4 | import random 5 | import time 6 | import os 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | from models.resnet import resnet18, resnet34, resnet50, wide_resnet50_2 10 | from models.resnet_rar import resnet18, resnet34, resnet50, wide_resnet50_2_rar 11 | from models.de_resnet import de_resnet18, de_resnet34, de_wide_resnet50_2, de_resnet50 12 | # from models.de_resnet_mem import de_wide_resnet50_2_mem 13 | import torch.backends.cudnn as cudnn 14 | import argparse 15 | from test import evaluation_Stage1, evaluation_Stage2 16 | from torch.nn import functional as F 17 | from models.loss import * 18 | from utils.vis import vis_anomaly_images 19 | from models.recons_net import * 20 | from utils.utils import setup_seed, get_dataset_name 21 | import pdb 22 | from utils.vis import * 23 | from tqdm import tqdm 24 | from utils.utils import PatchMaker 25 | 26 | def count_parameters(model): 27 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 28 | 29 | 30 | def train_Stage1(args, _class_, epochs=100, eval_interval=10, lr=0.0002): 31 | print("training Advanced Teacher...") 32 | print(_class_) 33 | 34 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 35 | print(device) 36 | 37 | Stage1_ckp_path = './checkpoints_Stage1/' + 'wres50_'+_class_+'.pth' 38 | os.makedirs('checkpoints_Stage1', exist_ok=True) 39 | train_data = AnomalyDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root, aux_path=args.aux_path) # 注意:既有正常又有合成的异常 40 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root) 41 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4) 42 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 43 | 44 | encoder, _ = wide_resnet50_2(pretrained=True) 45 | encoder = encoder.to(device) 46 | encoder.eval() 47 | # load INet-pretrain, 48 | encoder_AT, _ = wide_resnet50_2_rar(pretrained=True) 49 | encoder_AT = encoder_AT.to(device) 50 | 51 | for name, para in encoder_AT.named_parameters(): 52 | if 'feat_recons' in name: 53 | para.requires_grad = True 54 | else: 55 | para.requires_grad = False 56 | # print([name for name, para in encoder_AT.named_parameters() if para.requires_grad]) 57 | # pdb.set_trace() 58 | optimizer_Stage1 = torch.optim.Adam((filter(lambda p: p.requires_grad, encoder_AT.parameters())), lr=lr, betas=(0.5,0.999)) 59 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Stage1, [epochs*0.8, epochs*0.9], gamma=0.2, last_epoch=-1) 60 | 61 | best_auroc_sp = -1 62 | best_auroc_px = -1 63 | best_P = -1 64 | loss_per_epoch = {'fnorm_loss': [], 'afnorm_loss': [], 'test_fnorm_loss': [], 'test_afnorm_loss': []} 65 | 66 | # auroc_px, auroc_sp, aupro_px, test_fnorm_loss, test_afnorm_loss = evaluation_Stage1(encoder, encoder_AT, test_dataloader, device) 67 | # print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px)) 68 | # print(f'test_fnorm_loss: {np.mean(test_fnorm_loss):.4f} test_afnorm_loss: {np.mean(test_afnorm_loss):.4f}') 69 | # pdb.set_trace() 70 | for epoch in tqdm(range(epochs)): 71 | encoder_AT.train() # for every epoch, set train mode, while the evaluation phase eval mode 72 | # set BN false 73 | for name, module in encoder_AT.named_modules(): 74 | if isinstance(module, nn.BatchNorm2d): 75 | module.eval() 76 | if isinstance(module, nn.BatchNorm1d): 77 | module.eval() 78 | 79 | normal_loss_list, anomaly_loss_list, fnorm_loss_list, afnorm_loss_list, focal_loss_list = [], [], [], [], [] 80 | P_list, R_list = [], [] 81 | for batch_data in train_dataloader: 82 | # load data 83 | img = batch_data["img"].to(device) 84 | anomaly_mask = batch_data["anomaly_mask"].to(device) # 注意:是否要求有一半是正常图像? 85 | fore_mask = batch_data["fore_mask"].to(device) 86 | # vis_anomaly_images(img, anomaly_mask, _class_) 87 | with torch.no_grad(): 88 | inputs = encoder(img) 89 | _, inputs_AT, delta, atten = encoder_AT(img, flag=False, atten_mask=anomaly_mask) 90 | # pdb.set_trace() 91 | loss_focal, P, R = get_focal_loss(atten, anomaly_mask) 92 | loss_normal, loss_anomaly = get_amplify_loss(inputs, inputs_AT, anomaly_mask, fore_mask) 93 | loss = loss_focal*1.0 + loss_anomaly * 0.1 94 | # Res Loss: the residual of anomalies (afnorm) should increase 95 | fnorm_loss, afnorm_loss = get_FnormLoss(delta, anomaly_mask) 96 | fnorm_loss_list.append(fnorm_loss.item()) 97 | if afnorm_loss != 0.0: 98 | afnorm_loss_list.append(afnorm_loss.item()) 99 | 100 | optimizer_Stage1.zero_grad() 101 | loss.backward() 102 | optimizer_Stage1.step() 103 | 104 | normal_loss_list.append(loss_normal.item()) 105 | loss_anomaly = loss_anomaly.item() if torch.is_tensor(loss_anomaly) else loss_anomaly 106 | anomaly_loss_list.append(loss_anomaly) 107 | focal_loss_list.append(loss_focal.item()) 108 | P_list.append(P.item()) 109 | R_list.append(R.item()) 110 | 111 | # if np.isnan(np.mean(anomaly_loss_list)): 112 | # pdb.set_trace() 113 | scheduler.step() # modify 114 | print(f'epoch [{epoch + 1}/{epochs}], focal_loss:{np.mean(focal_loss_list):.4f}\t, anomaly_loss:{np.mean(anomaly_loss_list):.4f}\t' 115 | f'fnorm_loss:{np.mean(fnorm_loss_list):.4f}, a_fnorm_loss:{np.mean(afnorm_loss_list):.4f}' 116 | f'P:{np.mean(P_list):.2f}, R:{np.mean(R_list):.2f}') 117 | loss_per_epoch['fnorm_loss'].append(np.mean(fnorm_loss_list)) 118 | loss_per_epoch['afnorm_loss'].append(np.mean(afnorm_loss_list)) 119 | 120 | if (epoch==0) or ((epoch + 1) % eval_interval == 0): 121 | ## evaluating perfermance to choose the best, common practice in indutrial anomaly detection, such as simplenet, RD++, BGAD ... 122 | auroc_px, auroc_sp, aupro_px, test_fnorm_loss, test_afnorm_loss, eval_P = evaluation_Stage1(encoder, encoder_AT, test_dataloader, device) 123 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px)) 124 | print(f'test_fnorm_loss: {np.mean(test_fnorm_loss):.4f} test_afnorm_loss: {np.mean(test_afnorm_loss):.4f}') 125 | loss_per_epoch['test_fnorm_loss'].append(np.mean(test_fnorm_loss)) 126 | loss_per_epoch['test_afnorm_loss'].append(np.mean(test_afnorm_loss)) 127 | 128 | # if auroc_sp > best_auroc_sp and epoch > 5: 129 | if eval_P > best_P and epoch > 5: 130 | best_P = eval_P 131 | best_auroc_sp = auroc_sp 132 | best_auroc_px = auroc_px 133 | torch.save({'encoder_AT': encoder_AT.state_dict()}, Stage1_ckp_path) 134 | # 'reconsKV_fs': recons_rar.state_dict()}, Stage1_ckp_path) 135 | if auroc_sp > 0.999 and epoch > 30: 136 | break 137 | return best_auroc_px, best_auroc_sp, aupro_px 138 | 139 | 140 | def train_Stage2(args, _class_, epochs=100, eval_interval=10, lr=0.005): 141 | print("training Stubborn Student...") 142 | print(_class_) 143 | 144 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 145 | print(device) 146 | 147 | Stage1_ckp_path = './checkpoints_Stage1/' + 'wres50_'+_class_+'.pth' 148 | Stage2_ckp_path = './checkpoints_Stage2/' + 'wres50_'+ _class_+'.pth' 149 | os.makedirs('checkpoints_Stage2', exist_ok=True) 150 | # only normal data 151 | train_data = TrainDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root) 152 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root) 153 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4) 154 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 155 | 156 | encoder_AT, bn = wide_resnet50_2_rar(pretrained=True) 157 | encoder_AT = encoder_AT.to(device) 158 | encoder_AT.load_state_dict(torch.load(Stage1_ckp_path)['encoder_AT']) # 加载Stage1 159 | encoder_AT.eval() 160 | 161 | encoder_pre, _ = wide_resnet50_2(pretrained=True) 162 | encoder_pre = encoder_pre.to(device) 163 | encoder_pre.eval() 164 | 165 | bn = bn.to(device) 166 | # decoder_SS = de_wide_resnet50_2_mem(pretrained=False) 167 | decoder_SS = de_wide_resnet50_2(pretrained=False) 168 | decoder_SS = decoder_SS.to(device) 169 | 170 | optimizer_Stage2 = torch.optim.Adam(list(decoder_SS.parameters())+list(bn.parameters()), lr=lr, betas=(0.5,0.999)) 171 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Stage2, [epochs*0.8, epochs*0.9], gamma=0.2, last_epoch=-1) 172 | 173 | best_auroc_sp = -1 174 | best_auroc_px = -1 175 | for epoch in tqdm(range(epochs)): 176 | bn.train() 177 | decoder_SS.train() 178 | 179 | kd_loss_list = [] 180 | recons_loss_list = [] 181 | orth_loss_list = [] 182 | for batch_data in train_dataloader: 183 | # load data 184 | img = batch_data[0].to(device) 185 | with torch.no_grad(): 186 | _, inputs, _, _ = encoder_AT(img, flag=True) 187 | # inputs = encoder_pre(img) 188 | 189 | outputs = decoder_SS(bn(inputs)) 190 | 191 | kd_loss = loss_fucntion(inputs, outputs) # 这句话是不变的 192 | loss = kd_loss 193 | 194 | optimizer_Stage2.zero_grad() 195 | loss.backward() 196 | optimizer_Stage2.step() 197 | 198 | kd_loss_list.append(kd_loss.item()) 199 | 200 | scheduler.step() # modify 201 | print('epoch [{}/{}], kd_loss:{:.4f}'.format(epoch + 1, epochs, np.mean(kd_loss_list))) 202 | 203 | if (epoch + 1) % eval_interval == 0: 204 | auroc_px, auroc_sp, aupro_px, _, _, _ = evaluation_Stage2(encoder_AT, bn, decoder_SS, test_dataloader, device) 205 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px)) 206 | 207 | if (auroc_sp+auroc_px) > (best_auroc_sp+best_auroc_px) and epoch > 5: 208 | best_auroc_sp = auroc_sp 209 | best_auroc_px = auroc_px 210 | torch.save({'bn': bn.state_dict(), 211 | 'decoder_ss': decoder_SS.state_dict()}, Stage2_ckp_path) 212 | if auroc_sp > 0.999 and epoch > 5: 213 | break 214 | return best_auroc_px, best_auroc_sp, aupro_px 215 | 216 | 217 | if __name__ == '__main__': 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument('--data_root', type=str, default='/data4/tch/AD_data/mvtec') 220 | parser.add_argument('--aux_path', type=str, default='/data4/tch/AD_data/DRAEM_dtd/dtd/images') 221 | parser.add_argument('--batch_size', type=int, default=16) 222 | parser.add_argument('--image_size', type=int, default=256) 223 | args = parser.parse_args() 224 | 225 | dataset_name = get_dataset_name(args.data_root) 226 | if dataset_name == 'mvtec': 227 | from datasets.MvTec import TrainDataset, TestDataset, AnomalyDataset 228 | from datasets.MvTec import mvtec_classes 229 | item_list = mvtec_classes() 230 | elif dataset_name == 'VisA': 231 | from datasets.VisA import TrainDataset, TestDataset, AnomalyDataset 232 | from datasets.VisA import visa_classes 233 | item_list = visa_classes() 234 | elif dataset_name == 'mvtec3d': 235 | from datasets.MvTec3D import TrainDataset, TestDataset, AnomalyDataset 236 | from datasets.MvTec3D import mvtec3d_classes 237 | item_list = mvtec3d_classes() 238 | 239 | setup_seed(0) 240 | 241 | auroc_Stage1_list = [] 242 | auroc_Stage2_list = [] 243 | 244 | for i in item_list: 245 | _, auroc_sp, _ = train_Stage1(args, i, epochs=100, eval_interval=5, lr=0.005) 246 | auroc_Stage1_list.append(auroc_sp) 247 | _, auroc_sp, _ = train_Stage2(args, i, epochs=120, eval_interval=5, lr=0.005) 248 | auroc_Stage2_list.append(auroc_sp) 249 | auroc_Stage1_mean = np.mean(auroc_Stage1_list) 250 | auroc_Stage2_mean = np.mean(auroc_Stage2_list) 251 | 252 | with open('results.txt', 'a') as f: 253 | f.write(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))+'\n') 254 | f.write('Stage1: ') 255 | for item in auroc_Stage1_list: 256 | f.write(f"{item:.4f}") 257 | f.write(' ') 258 | f.write(f'avg: {auroc_Stage1_mean:.4f}\n') 259 | 260 | f.write('Stage2: ') 261 | for item in auroc_Stage2_list: 262 | f.write(f"{item:.4f}") 263 | f.write(' ') 264 | f.write(f'avg: {auroc_Stage2_mean:.4f}\n') 265 | -------------------------------------------------------------------------------- /vit_version/train_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import ImageFolder 3 | import numpy as np 4 | import random 5 | import time 6 | import os 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | import os, pdb, sys 10 | sys.path.append(os.getcwd()) 11 | from models.resnet import resnet18, resnet34, resnet50, wide_resnet50_2 12 | from models.resnet_rar import resnet18, resnet34, resnet50, wide_resnet50_2_rar 13 | from models.de_resnet import de_resnet18, de_resnet34, de_wide_resnet50_2, de_resnet50 14 | # from models.de_resnet_mem import de_wide_resnet50_2_mem 15 | import torch.backends.cudnn as cudnn 16 | import argparse 17 | from vit_version.test_vit import evaluation_Stage1, evaluation_Stage2 18 | from torch.nn import functional as F 19 | from models.loss import * 20 | from utils.vis import vis_anomaly_images 21 | from models.recons_net import * 22 | from utils.utils import setup_seed, get_dataset_name 23 | import pdb 24 | from utils.vis import * 25 | from tqdm import tqdm 26 | from utils.utils import PatchMaker 27 | # import clip, clip_rar 28 | # from clip_bn import BN_layer, AttnBottleneck 29 | # from clip_decoder import de_VisionTransformer 30 | import vit_version.clip as clip 31 | import vit_version.clip_rar as clip_rar 32 | from vit_version.clip_bn import BN_layer, AttnBottleneck 33 | from vit_version.clip_decoder import de_VisionTransformer 34 | 35 | 36 | def count_parameters(model): 37 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | 39 | 40 | def train_Stage1(args, _class_, epochs=100, eval_interval=10, lr=0.0002): 41 | print("training Advanced Teacher...") 42 | print(_class_) 43 | 44 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | print(device) 46 | 47 | Stage1_ckp_path = 'vit_version/checkpoints_Stage1/' + 'wres50_'+_class_+'.pth' 48 | os.makedirs('vit_version/checkpoints_Stage1', exist_ok=True) 49 | train_data = AnomalyDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root, aux_path=args.aux_path) # 注意:既有正常又有合成的异常 50 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root) 51 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4) 52 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 53 | 54 | # encoder, _ = clip.load("ViT-L/14@336px", download_root='./clip', device=torch.device("cpu"), image_size=336) # Modify1 55 | encoder, _ = clip.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size) 56 | encoder = encoder.to(device) 57 | encoder.eval() 58 | 59 | encoder_AT, _ = clip_rar.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size) 60 | encoder_AT = encoder_AT.to(device) 61 | 62 | for name, para in encoder_AT.named_parameters(): 63 | if 'feat_recons' in name: 64 | para.requires_grad = True 65 | else: 66 | para.requires_grad = False 67 | # print([name for name, para in encoder_AT.named_parameters() if para.requires_grad]) 68 | # pdb.set_trace() 69 | optimizer_Stage1 = torch.optim.Adam((filter(lambda p: p.requires_grad, encoder_AT.parameters())), lr=lr, betas=(0.5,0.999)) 70 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Stage1, [epochs*0.8, epochs*0.9], gamma=0.2, last_epoch=-1) 71 | 72 | best_auroc_sp = -1 73 | best_auroc_px = -1 74 | best_P = -1 75 | loss_per_epoch = {'fnorm_loss': [], 'afnorm_loss': [], 'test_fnorm_loss': [], 'test_afnorm_loss': []} 76 | 77 | # auroc_px, auroc_sp, aupro_px, test_fnorm_loss, test_afnorm_loss = evaluation_Stage1(encoder, encoder_AT, test_dataloader, device) 78 | # print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px)) 79 | # print(f'test_fnorm_loss: {np.mean(test_fnorm_loss):.4f} test_afnorm_loss: {np.mean(test_afnorm_loss):.4f}') 80 | # pdb.set_trace() 81 | for epoch in tqdm(range(epochs)): 82 | encoder_AT.train() 83 | for name, module in encoder_AT.named_modules(): 84 | if isinstance(module, nn.BatchNorm2d): 85 | module.eval() 86 | if isinstance(module, nn.BatchNorm1d): 87 | module.eval() 88 | 89 | normal_loss_list, anomaly_loss_list, fnorm_loss_list, afnorm_loss_list, focal_loss_list = [], [], [], [], [] 90 | P_list, R_list = [], [] 91 | for batch_data in train_dataloader: 92 | # load data 93 | img = batch_data["img"].to(device) 94 | anomaly_mask = batch_data["anomaly_mask"].to(device) # 注意:是否要求有一半是正常图像? 95 | fore_mask = batch_data["fore_mask"].to(device) 96 | # vis_anomaly_images(img, anomaly_mask, _class_) 97 | 98 | with torch.no_grad(): 99 | _, inputs = encoder.encode_image(img, f_list=[4,8,12]) 100 | 101 | _, inputs_AT, atten, delta = encoder_AT.encode_image(img, f_list=[4,8,12]) 102 | # pdb.set_trace() 103 | loss_focal, P, R = get_focal_loss(atten, anomaly_mask, vit=True) 104 | loss_normal, loss_anomaly = get_amplify_loss(inputs, inputs_AT, anomaly_mask, fore_mask, vit=True) 105 | loss = loss_focal*1.0 + loss_anomaly *0.1 + loss_normal*0 106 | # Res Loss: 107 | fnorm_loss, afnorm_loss = get_FnormLoss(delta, anomaly_mask, vit=True) 108 | fnorm_loss_list.append(fnorm_loss.item()) 109 | if afnorm_loss != 0.0: 110 | afnorm_loss_list.append(afnorm_loss.item()) 111 | 112 | optimizer_Stage1.zero_grad() 113 | loss.backward() 114 | optimizer_Stage1.step() 115 | 116 | normal_loss_list.append(loss_normal.item()) 117 | loss_anomaly = loss_anomaly.item() if torch.is_tensor(loss_anomaly) else loss_anomaly 118 | anomaly_loss_list.append(loss_anomaly) 119 | focal_loss_list.append(loss_focal.item()) 120 | P_list.append(P.item()) 121 | R_list.append(R.item()) 122 | 123 | # if np.isnan(np.mean(anomaly_loss_list)): 124 | # pdb.set_trace() 125 | scheduler.step() # modify 126 | print(f'epoch [{epoch + 1}/{epochs}], focal_loss:{np.mean(focal_loss_list):.4f}\t, anomaly_loss:{np.mean(anomaly_loss_list):.4f}\t' 127 | f'fnorm_loss:{np.mean(fnorm_loss_list):.4f}, a_fnorm_loss:{np.mean(afnorm_loss_list):.4f}' 128 | f'P:{np.mean(P_list):.2f}, R:{np.mean(R_list):.2f}') 129 | loss_per_epoch['fnorm_loss'].append(np.mean(fnorm_loss_list)) 130 | loss_per_epoch['afnorm_loss'].append(np.mean(afnorm_loss_list)) 131 | 132 | if (epoch==0) or ((epoch + 1) % eval_interval == 0): 133 | ## evaluating perfermance to choose the best, common practice in indutrial anomaly detection, such as simplenet, RD++, BGAD ... 134 | auroc_px, auroc_sp, aupro_px, test_fnorm_loss, test_afnorm_loss, eval_P = evaluation_Stage1(encoder, encoder_AT, test_dataloader, device) 135 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px)) 136 | print(f'test_fnorm_loss: {np.mean(test_fnorm_loss):.4f} test_afnorm_loss: {np.mean(test_afnorm_loss):.4f}') 137 | loss_per_epoch['test_fnorm_loss'].append(np.mean(test_fnorm_loss)) 138 | loss_per_epoch['test_afnorm_loss'].append(np.mean(test_afnorm_loss)) 139 | 140 | # if auroc_sp > best_auroc_sp and epoch > 5: 141 | if eval_P > best_P and epoch > 5: 142 | best_P = eval_P 143 | best_auroc_sp = auroc_sp 144 | best_auroc_px = auroc_px 145 | torch.save({'encoder_AT': encoder_AT.state_dict()}, Stage1_ckp_path) 146 | # 'reconsKV_fs': recons_rar.state_dict()}, Stage1_ckp_path) 147 | if auroc_sp > 0.999 and epoch > 5: 148 | break 149 | return best_auroc_px, best_auroc_sp, aupro_px 150 | 151 | 152 | def train_Stage2(args, _class_, epochs=100, eval_interval=10, lr=0.005): 153 | print("training Stubborn Student...") 154 | print(_class_) 155 | 156 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 157 | print(device) 158 | 159 | Stage1_ckp_path = 'vit_version/checkpoints_Stage1/' + 'wres50_'+_class_+'.pth' 160 | Stage2_ckp_path = 'vit_version/checkpoints_Stage2/' + 'wres50_'+ _class_+'.pth' 161 | os.makedirs('vit_version/checkpoints_Stage2', exist_ok=True) 162 | # 163 | train_data = TrainDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root) 164 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root) 165 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4) 166 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 167 | 168 | # 169 | # encoder, _ = clip.load("ViT-L/14@336px", download_root='./clip', device=torch.device("cpu"), image_size=336) # Modify1 170 | encoder_pre, _ = clip.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size) 171 | encoder_pre = encoder_pre.to(device) 172 | encoder_pre.eval() 173 | # 174 | encoder_AT, _ = clip_rar.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size) 175 | encoder_AT = encoder_AT.to(device) 176 | encoder_AT.load_state_dict(torch.load(Stage1_ckp_path)['encoder_AT']) # 加载Stage1 177 | encoder_AT.eval() 178 | 179 | # 180 | bn = BN_layer(AttnBottleneck, 3) 181 | bn = bn.to(device) 182 | decoder_SS = de_VisionTransformer(input_resolution=256, patch_size=16, width=768, layers=6, heads=12, output_dim=512) 183 | decoder_SS = decoder_SS.to(device) 184 | 185 | optimizer_Stage2 = torch.optim.Adam(list(decoder_SS.parameters())+list(bn.parameters()), lr=lr, betas=(0.5,0.999)) 186 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Stage2, [epochs*0.8, epochs*0.9], gamma=0.2, last_epoch=-1) 187 | 188 | best_auroc_sp = -1 189 | best_auroc_px = -1 190 | for epoch in tqdm(range(epochs)): 191 | bn.train() 192 | decoder_SS.train() 193 | 194 | kd_loss_list = [] 195 | recons_loss_list = [] 196 | orth_loss_list = [] 197 | for batch_data in train_dataloader: 198 | # load data 199 | img = batch_data[0].to(device) 200 | # 201 | with torch.no_grad(): 202 | _, inputs, atten, delta = encoder_AT.encode_image(img, f_list=[4,8,12]) 203 | # inputs = encoder_pre(img) 204 | 205 | # outputs = decoder_SS(bn(inputs)) 206 | _, outputs = decoder_SS(bn(inputs), f_list=[2,4,6]) 207 | 208 | kd_loss = loss_fucntion(inputs, outputs) # 这句话是不变的 209 | loss = kd_loss 210 | 211 | optimizer_Stage2.zero_grad() 212 | loss.backward() 213 | optimizer_Stage2.step() 214 | 215 | kd_loss_list.append(kd_loss.item()) 216 | 217 | scheduler.step() # modify 218 | print('epoch [{}/{}], kd_loss:{:.4f}'.format(epoch + 1, epochs, np.mean(kd_loss_list))) 219 | 220 | if (epoch + 1) % eval_interval == 0: 221 | auroc_px, auroc_sp, aupro_px, _, _, _ = evaluation_Stage2(encoder_AT, bn, decoder_SS, test_dataloader, device) 222 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px)) 223 | 224 | if auroc_sp > best_auroc_sp and epoch > 5: 225 | best_auroc_sp = auroc_sp 226 | best_auroc_px = auroc_px 227 | torch.save({'bn': bn.state_dict(), 228 | 'decoder_ss': decoder_SS.state_dict()}, Stage2_ckp_path) 229 | if auroc_sp > 0.999 and epoch > 5: 230 | break 231 | return best_auroc_px, best_auroc_sp, aupro_px 232 | 233 | 234 | if __name__ == '__main__': 235 | parser = argparse.ArgumentParser() 236 | parser.add_argument('--data_root', type=str, default='/data4/tch/AD_data/mvtec') 237 | parser.add_argument('--aux_path', type=str, default='/data4/tch/AD_data/DRAEM_dtd/dtd/images') 238 | parser.add_argument('--L_rar', type=int, default=50) 239 | parser.add_argument('--batch_size', type=int, default=16) 240 | parser.add_argument('--image_size', type=int, default=256) 241 | args = parser.parse_args() 242 | 243 | dataset_name = get_dataset_name(args.data_root) 244 | if dataset_name == 'mvtec': 245 | from datasets.MvTec import TrainDataset, TestDataset, AnomalyDataset 246 | from datasets.MvTec import mvtec_classes 247 | item_list = mvtec_classes() 248 | elif dataset_name == 'VisA': 249 | from datasets.VisA import TrainDataset, TestDataset, AnomalyDataset 250 | from datasets.VisA import visa_classes 251 | item_list = visa_classes() 252 | elif dataset_name == 'mvtec3d': 253 | from datasets.MvTec3D import TrainDataset, TestDataset, AnomalyDataset 254 | from datasets.MvTec3D import mvtec3d_classes 255 | item_list = mvtec3d_classes() 256 | 257 | setup_seed(111) 258 | 259 | auroc_Stage1_list = [] 260 | auroc_Stage2_list = [] 261 | 262 | for i in item_list: 263 | _, auroc_sp, _ = train_Stage1(args, i, epochs=100, eval_interval=5, lr=0.005) 264 | auroc_Stage1_list.append(auroc_sp) 265 | _, auroc_sp, _ = train_Stage2(args, i, epochs=120, eval_interval=10, lr=0.005) 266 | auroc_Stage2_list.append(auroc_sp) 267 | auroc_Stage1_mean = np.mean(auroc_Stage1_list) 268 | auroc_Stage2_mean = np.mean(auroc_Stage2_list) 269 | 270 | with open('results.txt', 'a') as f: 271 | f.write(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))+'\n') 272 | f.write('Stage1: ') 273 | for item in auroc_Stage1_list: 274 | f.write(f"{item:.4f}") 275 | f.write(' ') 276 | f.write(f'avg: {auroc_Stage1_mean:.4f}\n') 277 | 278 | f.write('Stage2: ') 279 | for item in auroc_Stage2_list: 280 | f.write(f"{item:.4f}") 281 | f.write(' ') 282 | f.write(f'avg: {auroc_Stage2_mean:.4f}\n') 283 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import ImageFolder 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from models.resnet import resnet18, resnet34, resnet50, wide_resnet50_2 6 | from models.de_resnet import de_resnet18, de_resnet50, de_wide_resnet50_2 7 | from models.resnet_rar import resnet18, resnet34, resnet50, wide_resnet50_2_rar 8 | # from models.de_resnet_mem import de_wide_resnet50_2_mem 9 | from torch.nn import functional as F 10 | from sklearn.metrics import roc_auc_score, precision_recall_curve 11 | import cv2 12 | import matplotlib.pyplot as plt 13 | from sklearn.metrics import auc 14 | from skimage import measure 15 | import pandas as pd 16 | from numpy import ndarray 17 | from statistics import mean 18 | from scipy.ndimage import gaussian_filter 19 | from sklearn import manifold 20 | from matplotlib.ticker import NullFormatter 21 | from scipy.spatial.distance import pdist 22 | import matplotlib 23 | import pickle 24 | import argparse 25 | # from MvTec3D import TestDataset 26 | # from MvTec import TestDataset 27 | import pdb 28 | from utils.vis import * 29 | from utils.utils import PatchMaker, get_dataset_name 30 | from models.loss import * 31 | import time 32 | from models.recons_net import * 33 | from tqdm import tqdm 34 | 35 | 36 | def cal_anomaly_map(fs_list, ft_list, out_size=224, amap_mode='mul', vis=False): 37 | if amap_mode == 'mul': 38 | anomaly_map = np.ones([out_size, out_size]) 39 | else: 40 | anomaly_map = np.zeros([out_size, out_size]) 41 | a_map_list = [] 42 | mse_loss = torch.nn.MSELoss(reduction='none') 43 | for i in range(len(ft_list)): 44 | # pdb.set_trace() 45 | fs = fs_list[i] 46 | ft = ft_list[i] 47 | #fs_norm = F.normalize(fs, p=2) 48 | #ft_norm = F.normalize(ft, p=2) 49 | # pdb.set_trace() 50 | a_map = 1 - F.cosine_similarity(fs, ft) # cos_loss [1,C,H,W]->[1,H,W] 51 | # a_map2 = torch.sqrt(mse_loss(fs, ft).sum(dim=1)) # mse_loss [1,C,H,W]->[1,H,W] 52 | # if i == 1 and vis: 53 | # vis_hotmap(fs, ft, a_map) 54 | 55 | a_map = torch.unsqueeze(a_map, dim=1) 56 | a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True) 57 | a_map = a_map[0, 0, :, :].to('cpu').detach().numpy() 58 | a_map_list.append(a_map) 59 | if amap_mode == 'mul': 60 | anomaly_map *= a_map 61 | else: 62 | anomaly_map += a_map 63 | 64 | return anomaly_map, None 65 | 66 | 67 | def evaluation_Stage1(encoder, encoder_AT, dataloader, device, _class_=None): 68 | encoder.eval() 69 | encoder_AT.eval() 70 | # reconsKV.eval() 71 | gt_list_px = [] 72 | pr_list_px = [] 73 | gt_list_sp = [] 74 | pr_list_sp = [] 75 | aupro_list = [] 76 | fnorm_loss_list, afnorm_loss_list = [], [] 77 | P_list, R_list = [], [] 78 | with torch.no_grad(): 79 | for img, gt, label, _ in dataloader: 80 | img = img.to(device) 81 | gt = gt.to(device) 82 | inputs = encoder(img) 83 | # FS 84 | _, inputs_AT, delta, atten = encoder_AT(img, flag=True) 85 | # pdb.set_trace() 86 | loss_atten, P, R = get_focal_loss(atten, gt) 87 | fnorm_loss, afnorm_loss = get_FnormLoss(delta, gt) 88 | fnorm_loss_list.append(fnorm_loss.item()) 89 | if afnorm_loss != 0.0: 90 | afnorm_loss_list.append(afnorm_loss.item()) 91 | P_list.append(P.item()) 92 | R_list.append(R.item()) 93 | 94 | anomaly_map, _ = cal_anomaly_map(inputs, inputs_AT, img.shape[-1], amap_mode='a') 95 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 96 | gt[gt > 0.5] = 1 97 | gt[gt <= 0.5] = 0 98 | if label.item()!=0: 99 | aupro_list.append(0.0) 100 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel()) 101 | pr_list_px.extend(anomaly_map.ravel()) 102 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int))) 103 | pr_list_sp.append(np.max(anomaly_map)) 104 | 105 | # print(afnorm_loss_list, 'len:', len(afnorm_loss_list)) 106 | 107 | pr_list_sp = np.array(pr_list_sp) 108 | gt_list_sp = np.array(gt_list_sp) 109 | normal = pr_list_sp[gt_list_sp==0] 110 | anomaly = pr_list_sp[gt_list_sp==1] 111 | print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}\t' 112 | f'P:{np.mean(P_list):.2f}, R:{np.mean(R_list):.2f}') 113 | 114 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3) 115 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3) 116 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3), fnorm_loss_list, afnorm_loss_list, np.mean(P_list) 117 | 118 | 119 | def evaluation_Stage2(encoder_AT, bn, decoder_SS, dataloader, device, _class_=None): 120 | encoder_AT.eval() 121 | bn.eval() 122 | decoder_SS.eval() 123 | gt_list_px = [] 124 | pr_list_px = [] 125 | gt_list_sp = [] 126 | pr_list_sp = [] 127 | aupro_list = [] 128 | time_list = [] 129 | normal_feats, anomaly_feats = [], [] 130 | with torch.no_grad(): 131 | # for img, gt, label, rgb_path, fore_mask in dataloader: 132 | for img, gt, label, rgb_path in dataloader: 133 | # break 134 | img = img.to(device) 135 | tic = time.time() 136 | torch.cuda.synchronize() 137 | # FS 138 | _, inputs, _, _ = encoder_AT(img, flag=True) 139 | # inputs = encoder_AT(img) 140 | # SS 141 | outputs = decoder_SS(bn(inputs)) 142 | torch.cuda.synchronize() 143 | time_list.append(time.time() - tic) 144 | 145 | anomaly_map, _ = cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a', vis=label.bool()) 146 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 147 | gt[gt > 0.5] = 1 148 | gt[gt <= 0.5] = 0 149 | # if label.item() == 0: 150 | # pdb.set_trace() 151 | # apply_ad_scoremap(rgb_path[0], anomaly_map, _class_) 152 | if label.item()!=0: 153 | # if rgb_path == '/data4/tch/AD_data/mvtec/zipper/test/fabric_interior/007.png': 154 | # pdb.set_trace() 155 | # pdb.set_trace() 156 | # aupro_list.append(compute_pro(gt.squeeze(0).cpu().numpy().astype(int), 157 | # anomaly_map[np.newaxis,:,:])) 158 | aupro_list.append(0.0) 159 | 160 | anomaly_map = anomaly_map 161 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel()) 162 | pr_list_px.extend(anomaly_map.ravel()) 163 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int))) 164 | pr_list_sp.append(np.max(anomaly_map)) 165 | 166 | pr_list_px = np.array(pr_list_px) 167 | gt_list_px = np.array(gt_list_px) 168 | pr_list_sp = np.array(pr_list_sp) 169 | gt_list_sp = np.array(gt_list_sp) 170 | 171 | normal_sp = pr_list_sp[gt_list_sp==0] 172 | anomaly_sp = pr_list_sp[gt_list_sp==1] 173 | 174 | print(f'normal: {normal_sp.min():.4f} {normal_sp.max():.4f} anomaly: {anomaly_sp.min():.4f} {anomaly_sp.max():.4f}') 175 | 176 | precision, recall, thresholds = precision_recall_curve(gt_list_sp.astype(int), pr_list_sp) 177 | F1_scores = np.divide( 178 | 2 * precision * recall, 179 | precision + recall, 180 | out=np.zeros_like(precision), 181 | where=(precision + recall) != 0, 182 | ) 183 | optimal_threshold = thresholds[np.argmax(F1_scores)] 184 | fpr_optim = np.mean(normal_sp > optimal_threshold) 185 | fnr_optim = np.mean(anomaly_sp < optimal_threshold) 186 | print('thresh:', optimal_threshold) 187 | print('fpr:', fpr_optim, 'fnr', fnr_optim) 188 | 189 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3) 190 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3) 191 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3), time_list, pr_list_sp, gt_list_sp 192 | 193 | 194 | def test(_class_, args): 195 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 196 | print(device) 197 | print(_class_) 198 | 199 | Stage1_ckp_path = './checkpoints_Stage1/' + 'wres50_'+_class_+'.pth' 200 | Stage2_ckp_path = './checkpoints_Stage2/' + 'wres50_'+ _class_+'.pth' 201 | image_size = 256 202 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root) 203 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 204 | 205 | # Advanced Teacher 206 | encoder_AT, bn = wide_resnet50_2_rar(pretrained=False) 207 | encoder_AT = encoder_AT.to(device) 208 | encoder_AT.load_state_dict(torch.load(Stage1_ckp_path)['encoder_AT']) 209 | encoder_AT.eval() 210 | 211 | # Vanilla Teacher 212 | encoder_pre, _ = wide_resnet50_2(pretrained=True) 213 | encoder_pre = encoder_pre.to(device) 214 | encoder_pre.eval() 215 | 216 | bn = bn.to(device) 217 | decoder_SS = de_wide_resnet50_2(pretrained=False) 218 | decoder_SS = decoder_SS.to(device) 219 | 220 | SS_ckp = torch.load(Stage2_ckp_path) 221 | for k, v in list(SS_ckp['bn'].items()): 222 | if 'memory' in k: 223 | SS_ckp['bn'].pop(k) 224 | decoder_SS.load_state_dict(SS_ckp['decoder_ss']) 225 | bn.load_state_dict(SS_ckp['bn']) 226 | 227 | # total_params = compute_params([encoder_AT, bn, decoder_SS]) 228 | # print(f'total params: {total_params/1e6} M') #{sum([x.nelement() for x in self.model.parameters()])/1000000.} M 229 | 230 | # baseline, RD 231 | # auroc_px, auroc_sp, aupro_px, time_list, preds, labels = evaluation_Stage2(encoder_pre, bn, decoder_SS, test_dataloader, device, _class_) 232 | # Ours 233 | auroc_px, auroc_sp, aupro_px, time_list, preds, labels = evaluation_Stage2(encoder_AT, bn, decoder_SS, test_dataloader, device, _class_) 234 | # pdb.set_trace() 235 | print(_class_,':',auroc_px,',',auroc_sp,',',aupro_px) 236 | return auroc_px, auroc_sp, aupro_px, time_list, preds, labels 237 | 238 | 239 | def compute_params(nets): 240 | total_params = 0 241 | for net in nets: 242 | print(sum(p.numel() for p in net.parameters())) 243 | total_params += sum(p.numel() for p in net.parameters()) 244 | return total_params 245 | 246 | def compute_pro(masks: ndarray, amaps: ndarray, num_th: int = 200) -> None: 247 | 248 | """Compute the area under the curve of per-region overlaping (PRO) and 0 to 0.3 FPR 249 | Args: 250 | category (str): Category of product 251 | masks (ndarray): All binary masks in test. masks.shape -> (num_test_data, h, w) 252 | amaps (ndarray): All anomaly maps in test. amaps.shape -> (num_test_data, h, w) 253 | num_th (int, optional): Number of thresholds 254 | """ 255 | 256 | assert isinstance(amaps, ndarray), "type(amaps) must be ndarray" 257 | assert isinstance(masks, ndarray), "type(masks) must be ndarray" 258 | assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)" 259 | assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)" 260 | assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same" 261 | assert set(masks.flatten()) == {0, 1}, "set(masks.flatten()) must be {0, 1}" 262 | assert isinstance(num_th, int), "type(num_th) must be int" 263 | 264 | df = pd.DataFrame([], columns=["pro", "fpr", "threshold"]) 265 | binary_amaps = np.zeros_like(amaps, dtype=np.bool_) 266 | 267 | min_th = amaps.min() 268 | max_th = amaps.max() 269 | delta = (max_th - min_th) / num_th 270 | 271 | for th in np.arange(min_th, max_th, delta): 272 | binary_amaps[amaps <= th] = 0 273 | binary_amaps[amaps > th] = 1 274 | 275 | pros = [] 276 | for binary_amap, mask in zip(binary_amaps, masks): 277 | for region in measure.regionprops(measure.label(mask)): 278 | axes0_ids = region.coords[:, 0] 279 | axes1_ids = region.coords[:, 1] 280 | tp_pixels = binary_amap[axes0_ids, axes1_ids].sum() 281 | pros.append(tp_pixels / region.area) 282 | 283 | inverse_masks = 1 - masks 284 | fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() 285 | fpr = fp_pixels / inverse_masks.sum() 286 | 287 | # df = df.append({"pro": mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True) 288 | df = pd.concat([df, pd.DataFrame([{"pro": mean(pros), "fpr": fpr, "threshold": th}])], ignore_index=True) 289 | 290 | 291 | # Normalize FPR from 0 ~ 1 to 0 ~ 0.3 292 | df = df[df["fpr"] < 0.3] 293 | df["fpr"] = df["fpr"] / df["fpr"].max() 294 | 295 | pro_auc = auc(df["fpr"], df["pro"]) 296 | return pro_auc 297 | 298 | 299 | if __name__ == '__main__': 300 | # from main import setup_seed 301 | # setup_seed(111) 302 | parser = argparse.ArgumentParser() 303 | parser.add_argument('--data_root', type=str, default='/data4/tch/AD_data/mvtec3d') 304 | parser.add_argument('--image_size', type=int, default=256) 305 | args = parser.parse_args() 306 | 307 | dataset_name = get_dataset_name(args.data_root) 308 | if dataset_name == 'mvtec': 309 | from datasets.MvTec import TrainDataset, TestDataset, AnomalyDataset 310 | from datasets.MvTec import mvtec_classes 311 | item_list = mvtec_classes() 312 | elif dataset_name == 'VisA': 313 | from datasets.VisA import TrainDataset, TestDataset, AnomalyDataset 314 | from datasets.VisA import visa_classes 315 | item_list = visa_classes() 316 | elif dataset_name == 'mvtec3d': 317 | from datasets.MvTec3D import TrainDataset, TestDataset, AnomalyDataset 318 | from datasets.MvTec3D import mvtec3d_classes 319 | item_list = mvtec3d_classes() 320 | 321 | # all_time_list = [] 322 | # for i in item_list: 323 | # time_list = test(i) 324 | # all_time_list += time_list 325 | # # print(len(all_time_list)) 326 | # print('FPS: ', 1 / np.mean(all_time_list)) 327 | 328 | p_auc_list, i_auc_list, p_pro_list = [], [], [] 329 | label_list, pred_list = [], [] 330 | for i in item_list: 331 | p_auc, i_auc, p_pro, _, preds, labels = test(i, args) 332 | p_auc_list.append(p_auc) 333 | i_auc_list.append(i_auc) 334 | p_pro_list.append(p_pro) 335 | label_list += list(labels) 336 | pred_list += list(preds) 337 | # all_time_list += time_list 338 | # print(len(all_time_list)) 339 | # print('FPS: ', 1 / np.mean(all_time_list)) 340 | i_auc_mean = np.mean(i_auc_list) 341 | p_auc_mean = np.mean(p_auc_list) 342 | p_pro_mean = np.mean(p_pro_list) 343 | 344 | normal = np.array(pred_list)[np.array(label_list)==0] 345 | anomaly = np.array(pred_list)[np.array(label_list)==1] 346 | os.makedirs(f'results/{dataset_name}',exist_ok=True) 347 | np.save(f'results/{dataset_name}/normal_scores', normal) 348 | np.save(f'results/{dataset_name}/anomaly_scores', anomaly) 349 | # normal = np.load('normal_scores.npy') 350 | # anomaly = np.load('anomaly_scores.npy') 351 | # pdb.set_trace() 352 | # print(anomaly) 353 | 354 | # width = 7 355 | # fig = plt.figure(figsize=(width, width*0.75)) 356 | # print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}') 357 | # plt.hist(normal,bins=20,label='normal sample',alpha=0.5) 358 | # plt.hist(anomaly,bins=20,label='anomalous sample',alpha=0.5) 359 | # plt.legend(fontsize=15) 360 | # plt.tick_params(labelsize=12.5) 361 | # plt.savefig(f'hist_SS.png') 362 | 363 | 364 | with open('results2.txt', 'a') as f: 365 | f.write(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))+'\n') 366 | f.write('i_auc: ') 367 | for item in i_auc_list: 368 | f.write(f"{item:.4f}") 369 | f.write(' ') 370 | f.write(f'avg: {i_auc_mean:.4f}\n') 371 | 372 | f.write('p_auc: ') 373 | for item in p_auc_list: 374 | f.write(f"{item:.4f}") 375 | f.write(' ') 376 | f.write(f'avg: {p_auc_mean:.4f}\n') 377 | 378 | f.write('p_pro: ') 379 | for item in p_pro_list: 380 | f.write(f"{item:.4f}") 381 | f.write(' ') 382 | f.write(f'avg: {p_pro_mean:.4f}\n') -------------------------------------------------------------------------------- /vit_version/test_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import ImageFolder 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | import os, pdb, sys 6 | sys.path.append(os.getcwd()) 7 | from models.resnet import resnet18, resnet34, resnet50, wide_resnet50_2 8 | from models.de_resnet import de_resnet18, de_resnet50, de_wide_resnet50_2 9 | from models.resnet_rar import resnet18, resnet34, resnet50, wide_resnet50_2_rar 10 | # from models.de_resnet_mem import de_wide_resnet50_2_mem 11 | from torch.nn import functional as F 12 | from sklearn.metrics import roc_auc_score, precision_recall_curve 13 | import cv2 14 | import matplotlib.pyplot as plt 15 | from sklearn.metrics import auc 16 | from skimage import measure 17 | import pandas as pd 18 | from numpy import ndarray 19 | from statistics import mean 20 | from scipy.ndimage import gaussian_filter 21 | from sklearn import manifold 22 | from matplotlib.ticker import NullFormatter 23 | from scipy.spatial.distance import pdist 24 | import matplotlib 25 | import pickle 26 | import argparse 27 | # from MvTec3D import TestDataset 28 | # from MvTec import TestDataset 29 | import pdb 30 | from utils.vis import * 31 | from utils.utils import PatchMaker, get_dataset_name 32 | from models.loss import * 33 | import time 34 | from models.recons_net import * 35 | from tqdm import tqdm 36 | # import clip 37 | # from clip_bn import BN_layer, AttnBottleneck 38 | # from clip_decoder import de_VisionTransformer 39 | import vit_version.clip as clip 40 | import vit_version.clip_rar as clip_rar 41 | from vit_version.clip_bn import BN_layer, AttnBottleneck 42 | from vit_version.clip_decoder import de_VisionTransformer 43 | 44 | 45 | 46 | def cal_anomaly_map(fs_list, ft_list, out_size=224, amap_mode='mul', vis=False): 47 | if amap_mode == 'mul': 48 | anomaly_map = np.ones([out_size, out_size]) 49 | else: 50 | anomaly_map = np.zeros([out_size, out_size]) 51 | a_map_list = [] 52 | mse_loss = torch.nn.MSELoss(reduction='none') 53 | for i in range(len(ft_list)): 54 | # pdb.set_trace() 55 | fs = fs_list[i] 56 | ft = ft_list[i] 57 | #fs_norm = F.normalize(fs, p=2) 58 | #ft_norm = F.normalize(ft, p=2) 59 | # pdb.set_trace() 60 | a_map = 1 - F.cosine_similarity(fs, ft) # cos_loss [1,C,H,W]->[1,H,W] 61 | # a_map2 = torch.sqrt(mse_loss(fs, ft).sum(dim=1)) # mse_loss [1,C,H,W]->[1,H,W] 62 | # if i == 1 and vis: 63 | # vis_hotmap(fs, ft, a_map) 64 | 65 | a_map = torch.unsqueeze(a_map, dim=1) 66 | a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True) 67 | a_map = a_map[0, 0, :, :].to('cpu').detach().numpy() 68 | a_map_list.append(a_map) 69 | if amap_mode == 'mul': 70 | anomaly_map *= a_map 71 | else: 72 | anomaly_map += a_map 73 | 74 | return anomaly_map, None 75 | 76 | 77 | def evaluation_Stage1(encoder, encoder_AT, dataloader, device, _class_=None): 78 | encoder.eval() 79 | encoder_AT.eval() 80 | # reconsKV.eval() 81 | gt_list_px = [] 82 | pr_list_px = [] 83 | gt_list_sp = [] 84 | pr_list_sp = [] 85 | aupro_list = [] 86 | fnorm_loss_list, afnorm_loss_list = [], [] 87 | P_list, R_list = [], [] 88 | with torch.no_grad(): 89 | for img, gt, label, _ in dataloader: 90 | img = img.to(device) 91 | gt = gt.to(device) 92 | 93 | _, inputs = encoder.encode_image(img, f_list=[4,8,12]) 94 | _, inputs_AT, atten, delta = encoder_AT.encode_image(img, f_list=[4,8,12]) 95 | # pdb.set_trace() 96 | loss_atten, P, R = get_focal_loss(atten, gt, vit=True) 97 | fnorm_loss, afnorm_loss = get_FnormLoss(delta, gt, vit=True) 98 | fnorm_loss_list.append(fnorm_loss.item()) 99 | if afnorm_loss != 0.0: 100 | afnorm_loss_list.append(afnorm_loss.item()) 101 | P_list.append(P.item()) 102 | R_list.append(R.item()) 103 | 104 | anomaly_map, _ = cal_anomaly_map(inputs, inputs_AT, img.shape[-1], amap_mode='a') 105 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 106 | gt[gt > 0.5] = 1 107 | gt[gt <= 0.5] = 0 108 | if label.item()!=0: 109 | aupro_list.append(0.0) 110 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel()) 111 | pr_list_px.extend(anomaly_map.ravel()) 112 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int))) 113 | pr_list_sp.append(np.max(anomaly_map)) 114 | 115 | # print(afnorm_loss_list, 'len:', len(afnorm_loss_list)) 116 | 117 | pr_list_sp = np.array(pr_list_sp) 118 | gt_list_sp = np.array(gt_list_sp) 119 | normal = pr_list_sp[gt_list_sp==0] 120 | anomaly = pr_list_sp[gt_list_sp==1] 121 | print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}\t' 122 | f'P:{np.mean(P_list):.2f}, R:{np.mean(R_list):.2f}') 123 | 124 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3) 125 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3) 126 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3), fnorm_loss_list, afnorm_loss_list, np.mean(P_list) 127 | 128 | 129 | def evaluation_Stage2(encoder_AT, bn, decoder_SS, dataloader, device, _class_=None): 130 | encoder_AT.eval() 131 | bn.eval() 132 | decoder_SS.eval() 133 | gt_list_px = [] 134 | pr_list_px = [] 135 | gt_list_sp = [] 136 | pr_list_sp = [] 137 | aupro_list = [] 138 | time_list = [] 139 | normal_feats, anomaly_feats = [], [] 140 | with torch.no_grad(): 141 | # for img, gt, label, rgb_path, fore_mask in dataloader: 142 | for img, gt, label, rgb_path in dataloader: 143 | # break 144 | img = img.to(device) 145 | tic = time.time() 146 | torch.cuda.synchronize() 147 | # FS 148 | _, inputs, atten, delta = encoder_AT.encode_image(img, f_list=[4,8,12]) 149 | # inputs = encoder_AT(img) 150 | # SS 151 | # outputs = decoder_SS(bn(inputs)) 152 | _, outputs = decoder_SS(bn(inputs), f_list=[2,4,6]) 153 | torch.cuda.synchronize() 154 | time_list.append(time.time() - tic) 155 | 156 | anomaly_map, _ = cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a', vis=label.bool()) 157 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 158 | gt[gt > 0.5] = 1 159 | gt[gt <= 0.5] = 0 160 | # if label.item() == 0: 161 | # pdb.set_trace() 162 | # apply_ad_scoremap(rgb_path[0], anomaly_map, _class_) 163 | if label.item()!=0: 164 | # if rgb_path == '/data4/tch/AD_data/mvtec/zipper/test/fabric_interior/007.png': 165 | # pdb.set_trace() 166 | # pdb.set_trace() 167 | # aupro_list.append(compute_pro(gt.squeeze(0).cpu().numpy().astype(int), 168 | # anomaly_map[np.newaxis,:,:])) 169 | aupro_list.append(0.0) 170 | 171 | anomaly_map = anomaly_map 172 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel()) 173 | pr_list_px.extend(anomaly_map.ravel()) 174 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int))) 175 | pr_list_sp.append(np.max(anomaly_map)) 176 | 177 | pr_list_px = np.array(pr_list_px) 178 | gt_list_px = np.array(gt_list_px) 179 | pr_list_sp = np.array(pr_list_sp) 180 | gt_list_sp = np.array(gt_list_sp) 181 | # np.save('results/saved_data/zipper/pr_list_sp_zipper', pr_list_sp) 182 | # np.save('results/saved_data/zipper/gt_list_sp_zipper', gt_list_sp) 183 | 184 | normal_sp = pr_list_sp[gt_list_sp==0] 185 | anomaly_sp = pr_list_sp[gt_list_sp==1] 186 | 187 | print(f'normal: {normal_sp.min():.4f} {normal_sp.max():.4f} anomaly: {anomaly_sp.min():.4f} {anomaly_sp.max():.4f}') 188 | 189 | precision, recall, thresholds = precision_recall_curve(gt_list_sp.astype(int), pr_list_sp) 190 | F1_scores = np.divide( 191 | 2 * precision * recall, 192 | precision + recall, 193 | out=np.zeros_like(precision), 194 | where=(precision + recall) != 0, 195 | ) 196 | optimal_threshold = thresholds[np.argmax(F1_scores)] 197 | fpr_optim = np.mean(normal_sp > optimal_threshold) 198 | fnr_optim = np.mean(anomaly_sp < optimal_threshold) 199 | print('thresh:', optimal_threshold) 200 | print('fpr:', fpr_optim, 'fnr', fnr_optim) 201 | 202 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3) 203 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3) 204 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3), time_list, pr_list_sp, gt_list_sp 205 | 206 | 207 | def test(_class_, args): 208 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 209 | print(device) 210 | print(_class_) 211 | 212 | Stage1_ckp_path = 'vit_version/checkpoints_Stage1/' + 'wres50_'+_class_+'.pth' 213 | Stage2_ckp_path = 'vit_version/checkpoints_Stage2/' + 'wres50_'+ _class_+'.pth' 214 | image_size = 256 215 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root) 216 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) 217 | 218 | encoder_AT, _ = clip_rar.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size) 219 | encoder_AT = encoder_AT.to(device) 220 | encoder_AT.load_state_dict(torch.load(Stage1_ckp_path)['encoder_AT']) # 加载Stage1 221 | encoder_AT.eval() 222 | 223 | # encoder_pre, _ = clip.load("ViT-L/14@336px", download_root='./clip', device=torch.device("cpu"), image_size=336) 224 | encoder_pre, _ = clip.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size) 225 | encoder_pre = encoder_pre.to(device) 226 | encoder_pre.eval() 227 | # pdb.set_trace() 228 | 229 | bn = BN_layer(AttnBottleneck, 3) 230 | bn = bn.to(device) 231 | decoder_SS = de_VisionTransformer(input_resolution=256, patch_size=16, width=768, layers=6, heads=12, output_dim=512) 232 | decoder_SS = decoder_SS.to(device) 233 | 234 | SS_ckp = torch.load(Stage2_ckp_path) 235 | for k, v in list(SS_ckp['bn'].items()): 236 | if 'memory' in k: 237 | SS_ckp['bn'].pop(k) 238 | decoder_SS.load_state_dict(SS_ckp['decoder_ss']) 239 | bn.load_state_dict(SS_ckp['bn']) 240 | 241 | # total_params = compute_params([encoder_AT, bn, decoder_SS]) 242 | # print(f'total params: {total_params/1e6} M') #{sum([x.nelement() for x in self.model.parameters()])/1000000.} M 243 | 244 | auroc_px, auroc_sp, aupro_px, time_list, preds, labels = evaluation_Stage2(encoder_AT, bn, decoder_SS, test_dataloader, device, _class_) 245 | print(_class_,':',auroc_px,',',auroc_sp,',',aupro_px) 246 | return auroc_px, auroc_sp, aupro_px, time_list, preds, labels 247 | 248 | 249 | def compute_params(nets): 250 | total_params = 0 251 | for net in nets: 252 | print(sum(p.numel() for p in net.parameters())) 253 | total_params += sum(p.numel() for p in net.parameters()) 254 | return total_params 255 | 256 | def compute_pro(masks: ndarray, amaps: ndarray, num_th: int = 200) -> None: 257 | 258 | """Compute the area under the curve of per-region overlaping (PRO) and 0 to 0.3 FPR 259 | Args: 260 | category (str): Category of product 261 | masks (ndarray): All binary masks in test. masks.shape -> (num_test_data, h, w) 262 | amaps (ndarray): All anomaly maps in test. amaps.shape -> (num_test_data, h, w) 263 | num_th (int, optional): Number of thresholds 264 | """ 265 | 266 | assert isinstance(amaps, ndarray), "type(amaps) must be ndarray" 267 | assert isinstance(masks, ndarray), "type(masks) must be ndarray" 268 | assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)" 269 | assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)" 270 | assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same" 271 | assert set(masks.flatten()) == {0, 1}, "set(masks.flatten()) must be {0, 1}" 272 | assert isinstance(num_th, int), "type(num_th) must be int" 273 | 274 | df = pd.DataFrame([], columns=["pro", "fpr", "threshold"]) 275 | binary_amaps = np.zeros_like(amaps, dtype=np.bool_) 276 | 277 | min_th = amaps.min() 278 | max_th = amaps.max() 279 | delta = (max_th - min_th) / num_th 280 | 281 | for th in np.arange(min_th, max_th, delta): 282 | binary_amaps[amaps <= th] = 0 283 | binary_amaps[amaps > th] = 1 284 | 285 | pros = [] 286 | for binary_amap, mask in zip(binary_amaps, masks): 287 | for region in measure.regionprops(measure.label(mask)): 288 | axes0_ids = region.coords[:, 0] 289 | axes1_ids = region.coords[:, 1] 290 | tp_pixels = binary_amap[axes0_ids, axes1_ids].sum() 291 | pros.append(tp_pixels / region.area) 292 | 293 | inverse_masks = 1 - masks 294 | fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() 295 | fpr = fp_pixels / inverse_masks.sum() 296 | 297 | # df = df.append({"pro": mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True) # old pandas version 298 | df = pd.concat([df, pd.DataFrame([{"pro": mean(pros), "fpr": fpr, "threshold": th}])], ignore_index=True) 299 | 300 | # Normalize FPR from 0 ~ 1 to 0 ~ 0.3 301 | df = df[df["fpr"] < 0.3] 302 | df["fpr"] = df["fpr"] / df["fpr"].max() 303 | 304 | pro_auc = auc(df["fpr"], df["pro"]) 305 | return pro_auc 306 | 307 | 308 | # if __name__ == '__main__': 309 | # from utils.utils import setup_seed 310 | # setup_seed(111) 311 | # item_list = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill', 312 | # 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood'] 313 | # # from MvTec3D import mvtec3d_classes 314 | # # item_list = mvtec3d_classes() 315 | 316 | # for i in item_list: 317 | # test(i) 318 | 319 | if __name__ == '__main__': 320 | # from main import setup_seed 321 | # setup_seed(111) 322 | parser = argparse.ArgumentParser() 323 | parser.add_argument('--data_root', type=str, default='/data4/tch/AD_data/mvtec3d') 324 | parser.add_argument('--image_size', type=int, default=256) 325 | args = parser.parse_args() 326 | 327 | dataset_name = get_dataset_name(args.data_root) 328 | if dataset_name == 'mvtec': 329 | from datasets.MvTec import TrainDataset, TestDataset, AnomalyDataset 330 | from datasets.MvTec import mvtec_classes 331 | item_list = mvtec_classes() 332 | elif dataset_name == 'VisA': 333 | from datasets.VisA import TrainDataset, TestDataset, AnomalyDataset 334 | from datasets.VisA import visa_classes 335 | item_list = visa_classes() 336 | elif dataset_name == 'mvtec3d': 337 | from datasets.MvTec3D import TrainDataset, TestDataset, AnomalyDataset 338 | from datasets.MvTec3D import mvtec3d_classes 339 | item_list = mvtec3d_classes() 340 | 341 | # all_time_list = [] 342 | # for i in item_list: 343 | # time_list = test(i) 344 | # all_time_list += time_list 345 | # # print(len(all_time_list)) 346 | # print('FPS: ', 1 / np.mean(all_time_list)) 347 | 348 | p_auc_list, i_auc_list, p_pro_list = [], [], [] 349 | label_list, pred_list = [], [] 350 | for i in item_list: 351 | p_auc, i_auc, p_pro, _, preds, labels = test(i, args) 352 | p_auc_list.append(p_auc) 353 | i_auc_list.append(i_auc) 354 | p_pro_list.append(p_pro) 355 | label_list += list(labels) 356 | pred_list += list(preds) 357 | # all_time_list += time_list 358 | # print(len(all_time_list)) 359 | # print('FPS: ', 1 / np.mean(all_time_list)) 360 | i_auc_mean = np.mean(i_auc_list) 361 | p_auc_mean = np.mean(p_auc_list) 362 | p_pro_mean = np.mean(p_pro_list) 363 | 364 | normal = np.array(pred_list)[np.array(label_list)==0] 365 | anomaly = np.array(pred_list)[np.array(label_list)==1] 366 | os.makedirs(f'results/{dataset_name}',exist_ok=True) 367 | np.save(f'results/{dataset_name}/normal_scores_vit', normal) 368 | np.save(f'results/{dataset_name}/anomaly_scores_vit', anomaly) 369 | # np.save('normal_scores', normal) 370 | # np.save('anomaly_scores', anomaly) 371 | # normal = np.load('normal_scores.npy') 372 | # anomaly = np.load('anomaly_scores.npy') 373 | # pdb.set_trace() 374 | # print(anomaly) 375 | 376 | # width = 7 377 | # fig = plt.figure(figsize=(width, width*0.75)) 378 | # print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}') 379 | # plt.hist(normal,bins=20,label='normal sample',alpha=0.5) 380 | # plt.hist(anomaly,bins=20,label='anomalous sample',alpha=0.5) 381 | # plt.legend(fontsize=15) 382 | # plt.tick_params(labelsize=12.5) 383 | # plt.savefig(f'hist_SS.png') 384 | 385 | 386 | with open('results2.txt', 'a') as f: 387 | f.write(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))+'\n') 388 | f.write('i_auc: ') 389 | for item in i_auc_list: 390 | f.write(f"{item:.4f}") 391 | f.write(' ') 392 | f.write(f'avg: {i_auc_mean:.4f}\n') 393 | 394 | f.write('p_auc: ') 395 | for item in p_auc_list: 396 | f.write(f"{item:.4f}") 397 | f.write(' ') 398 | f.write(f'avg: {p_auc_mean:.4f}\n') 399 | 400 | f.write('p_pro: ') 401 | for item in p_pro_list: 402 | f.write(f"{item:.4f}") 403 | f.write(' ') 404 | f.write(f'avg: {p_pro_mean:.4f}\n') -------------------------------------------------------------------------------- /models/de_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | from typing import Type, Any, Callable, Union, List, Optional 9 | import pdb 10 | 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 14 | 'wide_resnet50_2', 'wide_resnet101_2'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | def deconv2x2(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 41 | """1x1 convolution""" 42 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=stride, 43 | groups=groups, bias=False, dilation=dilation) 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | expansion: int = 1 48 | 49 | def __init__( 50 | self, 51 | inplanes: int, 52 | planes: int, 53 | stride: int = 1, 54 | upsample: Optional[nn.Module] = None, 55 | groups: int = 1, 56 | base_width: int = 64, 57 | dilation: int = 1, 58 | norm_layer: Optional[Callable[..., nn.Module]] = None 59 | ) -> None: 60 | super(BasicBlock, self).__init__() 61 | if norm_layer is None: 62 | norm_layer = nn.BatchNorm2d 63 | if groups != 1 or base_width != 64: 64 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 65 | if dilation > 1: 66 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 67 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 68 | if stride == 2: 69 | self.conv1 = deconv2x2(inplanes, planes, stride) 70 | else: 71 | self.conv1 = conv3x3(inplanes, planes, stride) 72 | self.bn1 = norm_layer(planes) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.conv2 = conv3x3(planes, planes) 75 | self.bn2 = norm_layer(planes) 76 | self.upsample = upsample 77 | self.stride = stride 78 | 79 | def forward(self, x: Tensor) -> Tensor: 80 | identity = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | 89 | if self.upsample is not None: 90 | identity = self.upsample(x) 91 | 92 | out += identity 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class Bottleneck(nn.Module): 99 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 100 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 101 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 102 | # This variant is also known as ResNet V1.5 and improves accuracy according to 103 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 104 | 105 | expansion: int = 4 106 | 107 | def __init__( 108 | self, 109 | inplanes: int, 110 | planes: int, 111 | stride: int = 1, 112 | upsample: Optional[nn.Module] = None, 113 | groups: int = 1, 114 | base_width: int = 64, 115 | dilation: int = 1, 116 | norm_layer: Optional[Callable[..., nn.Module]] = None 117 | ) -> None: 118 | super(Bottleneck, self).__init__() 119 | if norm_layer is None: 120 | norm_layer = nn.BatchNorm2d 121 | width = int(planes * (base_width / 64.)) * groups 122 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 123 | self.conv1 = conv1x1(inplanes, width) 124 | self.bn1 = norm_layer(width) 125 | if stride == 2: 126 | self.conv2 = deconv2x2(width, width, stride, groups, dilation) 127 | else: 128 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 129 | self.bn2 = norm_layer(width) 130 | self.conv3 = conv1x1(width, planes * self.expansion) 131 | self.bn3 = norm_layer(planes * self.expansion) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.upsample = upsample 134 | self.stride = stride 135 | 136 | def forward(self, x: Tensor) -> Tensor: 137 | identity = x 138 | 139 | out = self.conv1(x) 140 | out = self.bn1(out) 141 | out = self.relu(out) 142 | 143 | out = self.conv2(out) 144 | out = self.bn2(out) 145 | out = self.relu(out) 146 | 147 | out = self.conv3(out) 148 | out = self.bn3(out) 149 | 150 | if self.upsample is not None: 151 | identity = self.upsample(x) 152 | 153 | out += identity 154 | out = self.relu(out) 155 | 156 | return out 157 | 158 | 159 | class ResNet(nn.Module): 160 | 161 | def __init__( 162 | self, 163 | block: Type[Union[BasicBlock, Bottleneck]], 164 | layers: List[int], 165 | num_classes: int = 1000, 166 | zero_init_residual: bool = False, 167 | groups: int = 1, 168 | width_per_group: int = 64, 169 | replace_stride_with_dilation: Optional[List[bool]] = None, 170 | norm_layer: Optional[Callable[..., nn.Module]] = None 171 | ) -> None: 172 | super(ResNet, self).__init__() 173 | if norm_layer is None: 174 | norm_layer = nn.BatchNorm2d 175 | self._norm_layer = norm_layer 176 | 177 | self.inplanes = 512 * block.expansion 178 | self.dilation = 1 179 | if replace_stride_with_dilation is None: 180 | # each element in the tuple indicates if we should replace 181 | # the 2x2 stride with a dilated convolution instead 182 | replace_stride_with_dilation = [False, False, False] 183 | if len(replace_stride_with_dilation) != 3: 184 | raise ValueError("replace_stride_with_dilation should be None " 185 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 186 | self.groups = groups 187 | self.base_width = width_per_group 188 | #self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 189 | # bias=False) 190 | #self.bn1 = norm_layer(self.inplanes) 191 | #self.relu = nn.ReLU(inplace=True) 192 | #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 193 | self.layer1 = self._make_layer(block, 256, layers[0], stride=2) 194 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 195 | dilate=replace_stride_with_dilation[0]) 196 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, 197 | dilate=replace_stride_with_dilation[1]) 198 | #self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 199 | # dilate=replace_stride_with_dilation[2]) 200 | #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 201 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 202 | 203 | for m in self.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 206 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 207 | nn.init.constant_(m.weight, 1) 208 | nn.init.constant_(m.bias, 0) 209 | 210 | # Zero-initialize the last BN in each residual branch, 211 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 212 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 213 | if zero_init_residual: 214 | for m in self.modules(): 215 | if isinstance(m, Bottleneck): 216 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 217 | elif isinstance(m, BasicBlock): 218 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 219 | 220 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 221 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 222 | norm_layer = self._norm_layer 223 | upsample = None 224 | previous_dilation = self.dilation 225 | if dilate: 226 | self.dilation *= stride 227 | stride = 1 228 | if stride != 1 or self.inplanes != planes * block.expansion: 229 | upsample = nn.Sequential( 230 | deconv2x2(self.inplanes, planes * block.expansion, stride), 231 | norm_layer(planes * block.expansion), 232 | ) 233 | 234 | layers = [] 235 | layers.append(block(self.inplanes, planes, stride, upsample, self.groups, 236 | self.base_width, previous_dilation, norm_layer)) 237 | self.inplanes = planes * block.expansion 238 | for _ in range(1, blocks): 239 | layers.append(block(self.inplanes, planes, groups=self.groups, 240 | base_width=self.base_width, dilation=self.dilation, 241 | norm_layer=norm_layer)) 242 | 243 | return nn.Sequential(*layers) 244 | 245 | def _forward_impl(self, x: Tensor) -> Tensor: 246 | # See note [TorchScript super()] 247 | #x = self.conv1(x) 248 | #x = self.bn1(x) 249 | #x = self.relu(x) 250 | #x = self.maxpool(x) 251 | # pdb.set_trace() 252 | feature_a = self.layer1(x) # 512*8*8->256*16*16 253 | feature_b = self.layer2(feature_a) # 256*16*16->128*32*32 254 | feature_c = self.layer3(feature_b) # 128*32*32->64*64*64 255 | #feature_d = self.layer4(feature_c) # 64*64*64->128*32*32 256 | 257 | #x = self.avgpool(feature_d) 258 | #x = torch.flatten(x, 1) 259 | #x = self.fc(x) 260 | 261 | return [feature_c, feature_b, feature_a] 262 | 263 | def forward(self, x: Tensor) -> Tensor: 264 | return self._forward_impl(x) 265 | 266 | 267 | def _resnet( 268 | arch: str, 269 | block: Type[Union[BasicBlock, Bottleneck]], 270 | layers: List[int], 271 | pretrained: bool, 272 | progress: bool, 273 | **kwargs: Any 274 | ) -> ResNet: 275 | model = ResNet(block, layers, **kwargs) 276 | if pretrained: 277 | state_dict = load_state_dict_from_url(model_urls[arch], 278 | progress=progress) 279 | #for k,v in list(state_dict.items()): 280 | # if 'layer4' in k or 'fc' in k: 281 | # state_dict.pop(k) 282 | model.load_state_dict(state_dict) 283 | return model 284 | 285 | 286 | def de_resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 287 | r"""ResNet-18 model from 288 | `"Deep Residual Learning for Image Recognition" `_. 289 | Args: 290 | pretrained (bool): If True, returns a model pre-trained on ImageNet 291 | progress (bool): If True, displays a progress bar of the download to stderr 292 | """ 293 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 294 | **kwargs) 295 | 296 | 297 | def de_resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 298 | r"""ResNet-34 model from 299 | `"Deep Residual Learning for Image Recognition" `_. 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 305 | **kwargs) 306 | 307 | 308 | def de_resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 309 | r"""ResNet-50 model from 310 | `"Deep Residual Learning for Image Recognition" `_. 311 | Args: 312 | pretrained (bool): If True, returns a model pre-trained on ImageNet 313 | progress (bool): If True, displays a progress bar of the download to stderr 314 | """ 315 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 316 | **kwargs) 317 | 318 | 319 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 320 | r"""ResNet-101 model from 321 | `"Deep Residual Learning for Image Recognition" `_. 322 | Args: 323 | pretrained (bool): If True, returns a model pre-trained on ImageNet 324 | progress (bool): If True, displays a progress bar of the download to stderr 325 | """ 326 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 327 | **kwargs) 328 | 329 | 330 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 331 | r"""ResNet-152 model from 332 | `"Deep Residual Learning for Image Recognition" `_. 333 | Args: 334 | pretrained (bool): If True, returns a model pre-trained on ImageNet 335 | progress (bool): If True, displays a progress bar of the download to stderr 336 | """ 337 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 338 | **kwargs) 339 | 340 | 341 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 342 | r"""ResNeXt-50 32x4d model from 343 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | progress (bool): If True, displays a progress bar of the download to stderr 347 | """ 348 | kwargs['groups'] = 32 349 | kwargs['width_per_group'] = 4 350 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 351 | pretrained, progress, **kwargs) 352 | 353 | 354 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 355 | r"""ResNeXt-101 32x8d model from 356 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 357 | Args: 358 | pretrained (bool): If True, returns a model pre-trained on ImageNet 359 | progress (bool): If True, displays a progress bar of the download to stderr 360 | """ 361 | kwargs['groups'] = 32 362 | kwargs['width_per_group'] = 8 363 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 364 | pretrained, progress, **kwargs) 365 | 366 | 367 | def de_wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 368 | r"""Wide ResNet-50-2 model from 369 | `"Wide Residual Networks" `_. 370 | The model is the same as ResNet except for the bottleneck number of channels 371 | which is twice larger in every block. The number of channels in outer 1x1 372 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 373 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 374 | Args: 375 | pretrained (bool): If True, returns a model pre-trained on ImageNet 376 | progress (bool): If True, displays a progress bar of the download to stderr 377 | """ 378 | kwargs['width_per_group'] = 64 * 2 379 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 380 | pretrained, progress, **kwargs) 381 | 382 | 383 | def de_wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 384 | r"""Wide ResNet-101-2 model from 385 | `"Wide Residual Networks" `_. 386 | The model is the same as ResNet except for the bottleneck number of channels 387 | which is twice larger in every block. The number of channels in outer 1x1 388 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 389 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 390 | Args: 391 | pretrained (bool): If True, returns a model pre-trained on ImageNet 392 | progress (bool): If True, displays a progress bar of the download to stderr 393 | """ 394 | kwargs['width_per_group'] = 64 * 2 395 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 396 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /vit_version/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from .utils import to_2tuple 9 | import math 10 | import pdb 11 | from models.recons_net import patch_to_tensor 12 | 13 | 14 | class Bottleneck(nn.Module): 15 | expansion = 4 16 | 17 | def __init__(self, inplanes, planes, stride=1): 18 | super().__init__() 19 | 20 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 21 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu1 = nn.ReLU(inplace=True) 24 | 25 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.relu2 = nn.ReLU(inplace=True) 28 | 29 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 30 | 31 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 32 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 33 | self.relu3 = nn.ReLU(inplace=True) 34 | 35 | self.downsample = None 36 | self.stride = stride 37 | 38 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 39 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 40 | self.downsample = nn.Sequential(OrderedDict([ 41 | ("-1", nn.AvgPool2d(stride)), 42 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 43 | ("1", nn.BatchNorm2d(planes * self.expansion)) 44 | ])) 45 | 46 | def forward(self, x: torch.Tensor): 47 | identity = x 48 | 49 | out = self.relu1(self.bn1(self.conv1(x))) 50 | out = self.relu2(self.bn2(self.conv2(out))) 51 | out = self.avgpool(out) 52 | out = self.bn3(self.conv3(out)) 53 | 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | 57 | out += identity 58 | out = self.relu3(out) 59 | return out 60 | 61 | 62 | class AttentionPool2d(nn.Module): 63 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 64 | super().__init__() 65 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 66 | self.k_proj = nn.Linear(embed_dim, embed_dim) 67 | self.q_proj = nn.Linear(embed_dim, embed_dim) 68 | self.v_proj = nn.Linear(embed_dim, embed_dim) 69 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 70 | self.num_heads = num_heads 71 | 72 | def forward(self, x): 73 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 74 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 75 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 76 | x, _ = F.multi_head_attention_forward( 77 | query=x, key=x, value=x, 78 | embed_dim_to_check=x.shape[-1], 79 | num_heads=self.num_heads, 80 | q_proj_weight=self.q_proj.weight, 81 | k_proj_weight=self.k_proj.weight, 82 | v_proj_weight=self.v_proj.weight, 83 | in_proj_weight=None, 84 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 85 | bias_k=None, 86 | bias_v=None, 87 | add_zero_attn=False, 88 | dropout_p=0, 89 | out_proj_weight=self.c_proj.weight, 90 | out_proj_bias=self.c_proj.bias, 91 | use_separate_proj_weight=True, 92 | training=self.training, 93 | need_weights=False 94 | ) 95 | 96 | return x[0] 97 | 98 | 99 | class ModifiedResNet(nn.Module): 100 | """ 101 | A ResNet class that is similar to torchvision's but contains the following changes: 102 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 103 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 104 | - The final pooling layer is a QKV attention instead of an average pool 105 | """ 106 | 107 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 108 | super().__init__() 109 | self.output_dim = output_dim 110 | self.input_resolution = input_resolution 111 | 112 | # the 3-layer stem 113 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 114 | self.bn1 = nn.BatchNorm2d(width // 2) 115 | self.relu1 = nn.ReLU(inplace=True) 116 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 117 | self.bn2 = nn.BatchNorm2d(width // 2) 118 | self.relu2 = nn.ReLU(inplace=True) 119 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 120 | self.bn3 = nn.BatchNorm2d(width) 121 | self.relu3 = nn.ReLU(inplace=True) 122 | self.avgpool = nn.AvgPool2d(2) 123 | 124 | # residual layers 125 | self._inplanes = width # this is a *mutable* variable used during construction 126 | self.layer1 = self._make_layer(width, layers[0]) 127 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 128 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 129 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 130 | 131 | embed_dim = width * 32 # the ResNet feature dimension 132 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 133 | 134 | def _make_layer(self, planes, blocks, stride=1): 135 | layers = [Bottleneck(self._inplanes, planes, stride)] 136 | 137 | self._inplanes = planes * Bottleneck.expansion 138 | for _ in range(1, blocks): 139 | layers.append(Bottleneck(self._inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | def stem(x): 145 | x = self.relu1(self.bn1(self.conv1(x))) 146 | x = self.relu2(self.bn2(self.conv2(x))) 147 | x = self.relu3(self.bn3(self.conv3(x))) 148 | x = self.avgpool(x) 149 | return x 150 | 151 | x = x.type(self.conv1.weight.dtype) 152 | x = stem(x) 153 | x1 = self.layer1(x) 154 | x2 = self.layer2(x1) 155 | x3 = self.layer3(x2) 156 | x4 = self.layer4(x3) 157 | y = self.attnpool(x4) 158 | 159 | return y, [x, x1, x2, x3, x4] 160 | 161 | 162 | class LayerNorm(nn.LayerNorm): 163 | """Subclass torch's LayerNorm to handle fp16.""" 164 | 165 | def forward(self, x: torch.Tensor): 166 | orig_type = x.dtype 167 | ret = super().forward(x.type(torch.float32)) 168 | return ret.type(orig_type) 169 | 170 | 171 | class QuickGELU(nn.Module): 172 | def forward(self, x: torch.Tensor): 173 | return x * torch.sigmoid(1.702 * x) 174 | 175 | 176 | class ResidualAttentionBlock(nn.Module): 177 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 178 | super().__init__() 179 | 180 | self.attn = nn.MultiheadAttention(d_model, n_head) 181 | self.ln_1 = LayerNorm(d_model) 182 | self.mlp = nn.Sequential(OrderedDict([ 183 | ("c_fc", nn.Linear(d_model, d_model * 4)), 184 | ("gelu", QuickGELU()), 185 | ("c_proj", nn.Linear(d_model * 4, d_model)) 186 | ])) 187 | self.ln_2 = LayerNorm(d_model) 188 | self.attn_mask = attn_mask 189 | 190 | def attention(self, x: torch.Tensor): 191 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 192 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 193 | 194 | def forward(self, x: torch.Tensor): 195 | # pdb.set_trace() 196 | x = x + self.attention(self.ln_1(x)) 197 | x = x + self.mlp(self.ln_2(x)) 198 | return x 199 | 200 | 201 | class Transformer(nn.Module): 202 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 203 | super().__init__() 204 | self.width = width 205 | self.layers = layers 206 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 207 | 208 | # def forward(self, x: torch.Tensor): 209 | # return self.resblocks(x) 210 | 211 | def forward(self, x, f_list=[]): 212 | # pdb.set_trace() 213 | out_tokens = [] 214 | idx = 0 215 | for r in self.resblocks: 216 | # pdb.set_trace() 217 | idx+=1 218 | x = r(x) 219 | if idx in f_list: 220 | if len(x)==2: 221 | out_tokens.append(x[0]) 222 | out_tokens.append(x[1]) 223 | else: 224 | out_tokens.append(x.permute(1,0,2)) 225 | return x, out_tokens #self.resblocks(x) 226 | 227 | 228 | class VisionTransformer(nn.Module): 229 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 230 | super().__init__() 231 | self.input_resolution = input_resolution 232 | self.grid_size = (input_resolution // patch_size) # modify 233 | self.output_dim = output_dim 234 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 235 | 236 | scale = width ** -0.5 237 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 238 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 239 | self.ln_pre = LayerNorm(width) 240 | 241 | self.transformer = Transformer(width, layers, heads) 242 | 243 | self.ln_post = LayerNorm(width) 244 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 245 | 246 | def forward(self, x: torch.Tensor, f_list: list): 247 | # pdb.set_trace() 248 | x = self.conv1(x) # shape = [*, width, grid, grid] 249 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 250 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 251 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 252 | x = x + self.positional_embedding.to(x.dtype) 253 | # pdb.set_trace() 254 | x = self.ln_pre(x) 255 | 256 | x = x.permute(1, 0, 2) # NLD -> LND 257 | x, patch_tokens = self.transformer(x, f_list) 258 | x = x.permute(1, 0, 2) # LND -> NLD 259 | 260 | x = self.ln_post(x[:, 0, :]) 261 | # pdb.set_trace() 262 | if self.proj is not None: 263 | x = x @ self.proj 264 | 265 | assert len(patch_tokens) == 3 266 | patch_tokens = [patch_to_tensor(x[:,1:]) for x in patch_tokens] 267 | 268 | return x, patch_tokens 269 | 270 | 271 | class CLIP(nn.Module): 272 | def __init__(self, 273 | embed_dim: int, 274 | # vision 275 | image_resolution: int, 276 | vision_layers: Union[Tuple[int, int, int, int], int], 277 | vision_width: int, 278 | vision_patch_size: int, 279 | # text 280 | context_length: int, 281 | vocab_size: int, 282 | transformer_width: int, 283 | transformer_heads: int, 284 | transformer_layers: int 285 | ): 286 | super().__init__() 287 | 288 | self.context_length = context_length 289 | 290 | if isinstance(vision_layers, (tuple, list)): 291 | vision_heads = vision_width * 32 // 64 292 | self.visual = ModifiedResNet( 293 | layers=vision_layers, 294 | output_dim=embed_dim, 295 | heads=vision_heads, 296 | input_resolution=image_resolution, 297 | width=vision_width 298 | ) 299 | else: 300 | vision_heads = vision_width // 64 301 | self.visual = VisionTransformer( 302 | input_resolution=image_resolution, 303 | patch_size=vision_patch_size, 304 | width=vision_width, 305 | layers=vision_layers, 306 | heads=vision_heads, 307 | output_dim=embed_dim 308 | ) 309 | 310 | self.transformer = Transformer( 311 | width=transformer_width, 312 | layers=transformer_layers, 313 | heads=transformer_heads, 314 | attn_mask=self.build_attention_mask() 315 | ) 316 | 317 | self.vocab_size = vocab_size 318 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 319 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 320 | self.ln_final = LayerNorm(transformer_width) 321 | 322 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 323 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 324 | 325 | self.initialize_parameters() 326 | 327 | def initialize_parameters(self): 328 | nn.init.normal_(self.token_embedding.weight, std=0.02) 329 | nn.init.normal_(self.positional_embedding, std=0.01) 330 | 331 | if isinstance(self.visual, ModifiedResNet): 332 | if self.visual.attnpool is not None: 333 | std = self.visual.attnpool.c_proj.in_features ** -0.5 334 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 335 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 336 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 337 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 338 | 339 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 340 | for name, param in resnet_block.named_parameters(): 341 | if name.endswith("bn3.weight"): 342 | nn.init.zeros_(param) 343 | 344 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 345 | attn_std = self.transformer.width ** -0.5 346 | fc_std = (2 * self.transformer.width) ** -0.5 347 | for block in self.transformer.resblocks: 348 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 349 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 350 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 351 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 352 | 353 | if self.text_projection is not None: 354 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 355 | 356 | def build_attention_mask(self): 357 | # lazily create causal attention mask, with full attention between the vision tokens 358 | # pytorch uses additive attention mask; fill with -inf 359 | mask = torch.empty(self.context_length, self.context_length) 360 | mask.fill_(float("-inf")) 361 | mask.triu_(1) # zero out the lower diagonal 362 | return mask 363 | 364 | @property 365 | def dtype(self): 366 | return self.visual.conv1.weight.dtype 367 | 368 | def encode_image(self, image, f_list): 369 | # pdb.set_trace() 370 | return self.visual(image.type(self.dtype), f_list) 371 | 372 | # def encode_text(self, text): 373 | # # pdb.set_trace() 374 | # x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 375 | 376 | # x = x + self.positional_embedding.type(self.dtype) 377 | # x = x.permute(1, 0, 2) # NLD -> LND 378 | # x, _ = self.transformer(x) 379 | # x = x.permute(1, 0, 2) # LND -> NLD 380 | # x = self.ln_final(x).type(self.dtype) 381 | 382 | # # x.shape = [batch_size, n_ctx, transformer.width] 383 | # # take features from the eot embedding (eot_token is the highest number in each sequence) 384 | # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 385 | 386 | # return x 387 | 388 | def encode_text(self, prompts, tokenized_prompts): 389 | 390 | x = prompts + self.positional_embedding.type(self.dtype) 391 | 392 | x = x.permute(1, 0, 2) # NLD -> LND 393 | x, _ = self.transformer(x) 394 | x = x.permute(1, 0, 2) # LND -> NLD 395 | x = self.ln_final(x).type(self.dtype) 396 | 397 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 398 | 399 | return x 400 | 401 | def forward(self, image, text): 402 | image_features = self.encode_image(image) 403 | text_features = self.encode_text(text) 404 | 405 | # normalized features 406 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 407 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 408 | 409 | # cosine similarity as logits 410 | logit_scale = self.logit_scale.exp() 411 | logits_per_image = logit_scale * image_features @ text_features.t() 412 | logits_per_text = logits_per_image.t() 413 | 414 | # shape = [global_batch_size, global_batch_size] 415 | return logits_per_image, logits_per_text 416 | 417 | 418 | def convert_weights(model: nn.Module): 419 | """Convert applicable model parameters to fp16""" 420 | 421 | def _convert_weights_to_fp16(l): 422 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 423 | l.weight.data = l.weight.data.half() 424 | if l.bias is not None: 425 | l.bias.data = l.bias.data.half() 426 | 427 | if isinstance(l, nn.MultiheadAttention): 428 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 429 | tensor = getattr(l, attr) 430 | if tensor is not None: 431 | tensor.data = tensor.data.half() 432 | 433 | for name in ["text_projection", "proj"]: 434 | if hasattr(l, name): 435 | attr = getattr(l, name) 436 | if attr is not None: 437 | attr.data = attr.data.half() 438 | 439 | model.apply(_convert_weights_to_fp16) 440 | 441 | 442 | def build_model(state_dict: dict, image_size: int): 443 | vit = "visual.proj" in state_dict 444 | # pdb.set_trace() 445 | if vit: 446 | vision_width = state_dict["visual.conv1.weight"].shape[0] 447 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 448 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 449 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 450 | image_resolution = vision_patch_size * grid_size 451 | else: 452 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 453 | vision_layers = tuple(counts) 454 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 455 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 456 | vision_patch_size = None 457 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 458 | image_resolution = output_width * 32 459 | 460 | embed_dim = state_dict["text_projection"].shape[1] 461 | context_length = state_dict["positional_embedding"].shape[0] 462 | vocab_size = state_dict["token_embedding.weight"].shape[0] 463 | transformer_width = state_dict["ln_final.weight"].shape[0] 464 | transformer_heads = transformer_width // 64 465 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 466 | 467 | # model = CLIP( 468 | # embed_dim, 469 | # image_resolution, vision_layers, vision_width, vision_patch_size, 470 | # context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 471 | # ) 472 | 473 | model = CLIP( 474 | embed_dim, 475 | image_size, vision_layers, vision_width, vision_patch_size, 476 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 477 | ) 478 | 479 | for key in ["input_resolution", "context_length", "vocab_size"]: 480 | if key in state_dict: 481 | del state_dict[key] 482 | # pdb.set_trace() 483 | resize_pos_embed(state_dict, model) # modify 484 | convert_weights(model) 485 | model.load_state_dict(state_dict) 486 | return model.eval() 487 | 488 | 489 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): 490 | # Rescale the grid of position embeddings when loading from state_dict 491 | flag = 1 492 | old_pos_embed = state_dict.get('visual.positional_embedding', None) 493 | if old_pos_embed is None: 494 | flag = 0 495 | old_pos_embed = state_dict.get('visual.attnpool.positional_embedding', None) 496 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 497 | return 498 | grid_size = to_2tuple(model.visual.grid_size) 499 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 500 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 501 | if new_seq_len == old_pos_embed.shape[0]: 502 | return 503 | 504 | if extra_tokens: 505 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 506 | else: 507 | pos_emb_tok, pos_emb_img = None, old_pos_embed 508 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 509 | 510 | print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 511 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 512 | pos_emb_img = F.interpolate( 513 | pos_emb_img, 514 | size=grid_size, 515 | mode=interpolation, 516 | # antialias=antialias, 517 | align_corners=False, 518 | ) 519 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 520 | if pos_emb_tok is not None: 521 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 522 | else: 523 | new_pos_embed = pos_emb_img 524 | if flag: 525 | state_dict['visual.positional_embedding'] = new_pos_embed 526 | else: 527 | state_dict['visual.attnpool.positional_embedding'] = new_pos_embed 528 | --------------------------------------------------------------------------------