├── vit_version
├── clip
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── utils.py
│ ├── prompts.py
│ ├── simple_tokenizer.py
│ ├── clip.py
│ └── model.py
├── clip_rar
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── utils.py
│ ├── prompts.py
│ ├── simple_tokenizer.py
│ └── clip.py
├── clip_bn.py
├── clip_decoder.py
├── train_vit.py
└── test_vit.py
├── requirements.txt
├── utils
├── utils.py
├── perlin.py
└── vis.py
├── README.md
├── datasets
├── MvTec3D.py
├── MvTec.py
├── VisA.py
└── database.py
├── models
├── recons_net.py
├── loss.py
└── de_resnet.py
├── scripts
└── fore_extractor.py
├── train.py
└── test.py
/vit_version/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/vit_version/clip_rar/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/vit_version/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hui-design/AAND/HEAD/vit_version/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/vit_version/clip_rar/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hui-design/AAND/HEAD/vit_version/clip_rar/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pandas
3 | scipy
4 | scikit-learn
5 | pillow
6 | seaborn
7 | opencv-python
8 | scikit-image
9 | tqdm
10 | imgaug
11 | ftfy
12 | regex
--------------------------------------------------------------------------------
/vit_version/clip/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | from torch import nn as nn
5 | from torchvision.ops.misc import FrozenBatchNorm2d
6 |
7 |
8 | def freeze_batch_norm_2d(module, module_match={}, name=''):
9 | """
10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
12 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
13 |
14 | Args:
15 | module (torch.nn.Module): Any PyTorch module.
16 | module_match (dict): Dictionary of full module names to freeze (all if empty)
17 | name (str): Full module name (prefix)
18 |
19 | Returns:
20 | torch.nn.Module: Resulting module
21 |
22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
23 | """
24 | res = module
25 | is_match = True
26 | if module_match:
27 | is_match = name in module_match
28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
29 | res = FrozenBatchNorm2d(module.num_features)
30 | res.num_features = module.num_features
31 | res.affine = module.affine
32 | if module.affine:
33 | res.weight.data = module.weight.data.clone().detach()
34 | res.bias.data = module.bias.data.clone().detach()
35 | res.running_mean.data = module.running_mean.data
36 | res.running_var.data = module.running_var.data
37 | res.eps = module.eps
38 | else:
39 | for child_name, child in module.named_children():
40 | full_child_name = '.'.join([name, child_name]) if name else child_name
41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
42 | if new_child is not child:
43 | res.add_module(child_name, new_child)
44 | return res
45 |
46 |
47 | # From PyTorch internals
48 | def _ntuple(n):
49 | def parse(x):
50 | if isinstance(x, collections.abc.Iterable):
51 | return x
52 | return tuple(repeat(x, n))
53 | return parse
54 |
55 |
56 | to_1tuple = _ntuple(1)
57 | to_2tuple = _ntuple(2)
58 | to_3tuple = _ntuple(3)
59 | to_4tuple = _ntuple(4)
60 | to_ntuple = lambda n, x: _ntuple(n)(x)
61 |
--------------------------------------------------------------------------------
/vit_version/clip_rar/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | from torch import nn as nn
5 | from torchvision.ops.misc import FrozenBatchNorm2d
6 |
7 |
8 | def freeze_batch_norm_2d(module, module_match={}, name=''):
9 | """
10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
12 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
13 |
14 | Args:
15 | module (torch.nn.Module): Any PyTorch module.
16 | module_match (dict): Dictionary of full module names to freeze (all if empty)
17 | name (str): Full module name (prefix)
18 |
19 | Returns:
20 | torch.nn.Module: Resulting module
21 |
22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
23 | """
24 | res = module
25 | is_match = True
26 | if module_match:
27 | is_match = name in module_match
28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
29 | res = FrozenBatchNorm2d(module.num_features)
30 | res.num_features = module.num_features
31 | res.affine = module.affine
32 | if module.affine:
33 | res.weight.data = module.weight.data.clone().detach()
34 | res.bias.data = module.bias.data.clone().detach()
35 | res.running_mean.data = module.running_mean.data
36 | res.running_var.data = module.running_var.data
37 | res.eps = module.eps
38 | else:
39 | for child_name, child in module.named_children():
40 | full_child_name = '.'.join([name, child_name]) if name else child_name
41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
42 | if new_child is not child:
43 | res.add_module(child_name, new_child)
44 | return res
45 |
46 |
47 | # From PyTorch internals
48 | def _ntuple(n):
49 | def parse(x):
50 | if isinstance(x, collections.abc.Iterable):
51 | return x
52 | return tuple(repeat(x, n))
53 | return parse
54 |
55 |
56 | to_1tuple = _ntuple(1)
57 | to_2tuple = _ntuple(2)
58 | to_3tuple = _ntuple(3)
59 | to_4tuple = _ntuple(4)
60 | to_ntuple = lambda n, x: _ntuple(n)(x)
61 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 |
5 | def setup_seed(seed):
6 | torch.manual_seed(seed)
7 | torch.cuda.manual_seed_all(seed)
8 | np.random.seed(seed)
9 | random.seed(seed)
10 | torch.backends.cudnn.deterministic = True
11 | torch.backends.cudnn.benchmark = False
12 |
13 | def get_dataset_name(data_path):
14 | if 'mvtec' in data_path.lower() and 'mvtec3d' not in data_path.lower():
15 | return 'mvtec'
16 |
17 | elif 'visa' in data_path.lower():
18 | return 'VisA'
19 |
20 | elif 'mvtec3d' in data_path.lower():
21 | return 'mvtec3d'
22 |
23 | else:
24 | raise ValueError('No such dataset')
25 |
26 |
27 | # Image handling classes.
28 | class PatchMaker:
29 | def __init__(self, patchsize, stride=None):
30 | self.patchsize = patchsize
31 | self.stride = stride
32 |
33 | def patchify(self, features, return_spatial_info=False):
34 | """Convert a tensor into a tensor of respective patches.
35 | Args:
36 | x: [torch.Tensor, bs x c x w x h]
37 | Returns:
38 | x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
39 | patchsize]
40 | """
41 | padding = int((self.patchsize - 1) / 2)
42 | unfolder = torch.nn.Unfold(
43 | kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1
44 | )
45 | unfolded_features = unfolder(features)
46 | number_of_total_patches = []
47 | for s in features.shape[-2:]:
48 | n_patches = (
49 | s + 2 * padding - 1 * (self.patchsize - 1) - 1
50 | ) / self.stride + 1
51 | number_of_total_patches.append(int(n_patches))
52 | unfolded_features = unfolded_features.reshape(
53 | *features.shape[:2], self.patchsize, self.patchsize, -1
54 | )
55 | unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
56 |
57 | if return_spatial_info:
58 | return unfolded_features, number_of_total_patches
59 | return unfolded_features
60 |
61 | def unpatch_scores(self, x, batchsize):
62 | return x.reshape(batchsize, -1, *x.shape[1:])
63 |
64 | def score(self, x):
65 | was_numpy = False
66 | if isinstance(x, np.ndarray):
67 | was_numpy = True
68 | x = torch.from_numpy(x)
69 | while x.ndim > 1:
70 | x = torch.max(x, dim=-1).values
71 | if was_numpy:
72 | return x.numpy()
73 | return x
--------------------------------------------------------------------------------
/vit_version/clip/prompts.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import clip
4 | import pdb
5 |
6 | # class Prompts(nn.Module):
7 | # def __init__(self, initials=None, clip_model=None, len_text=None, pretrained=None):
8 | # super(Prompts,self).__init__()
9 | # print("The initial prompts are:",initials)
10 | # # tokenized, embedding
11 | # self.tokenized_prompts = clip.tokenize(initials)
12 | # embedding = clip_model.token_embedding(self.tokenized_prompts.cuda())
13 | # self.prefix, self.suffix = embedding[:, :1, :], embedding[:, 1+len_text:, :]
14 | # # learnable embedding
15 | # ctx_vectors = torch.empty(2, len_text, 768, dtype=embedding.dtype)
16 | # nn.init.normal_(ctx_vectors, std=0.02)
17 | # self.embedding_learn = nn.Parameter(ctx_vectors)
18 |
19 | # # state_dict = torch.load(pretrained)
20 | # # self.embedding_prompt = state_dict['embedding_prompt'].cuda()
21 |
22 | # def forward(self, clip_model):
23 | # # pdb.set_trace()
24 | # embedding_prompt = torch.cat((self.prefix, self.embedding_learn, self.suffix), dim=1)
25 | # text_features = clip_model.encode_text(embedding_prompt, self.tokenized_prompts)
26 | # text_features /= text_features.norm(dim=-1, keepdim=True)
27 | # return text_features
28 |
29 |
30 | class Prompts(nn.Module):
31 | def __init__(self, initials=None, clip_model=None, len_text=None, pretrained=None, device='cuda'):
32 | super(Prompts,self).__init__()
33 | # initials[0] += ' normal person'
34 | # initials[1] += ' abnormal person'
35 | print("The initial prompts are:",initials)
36 | # tokenized, embedding
37 | # pdb.set_trace()
38 | self.tokenized_prompts = clip.tokenize(initials).to(device)
39 | self.model = clip_model
40 | embedding = self.model.token_embedding(self.tokenized_prompts).to(device)
41 | self.prefix, self.suffix = embedding[:, :1, :], embedding[:, 1+len_text:, :]
42 | self.device = device
43 | # self.prefix.requires_grad = False
44 | # self.suffix.requires_grad = False
45 | if pretrained is None:
46 | ctx_vectors = torch.zeros_like(embedding[:, 1:1+len_text, :])
47 | nn.init.normal_(ctx_vectors, std=0.02)
48 | self.embedding_prompt = nn.Parameter(ctx_vectors.requires_grad_()).to(device)
49 | # self.embedding_prompt = nn.Parameter(embedding[:, 1:1+len_text, :].requires_grad_()).to(device)
50 | # self.embedding_prompt = nn.ParameterList([nn.Parameter(torch.zeros(2,768).requires_grad_())]).to(device)
51 | # self.embedding_prompt = nn.Parameter(torch.randn(2,512).requires_grad_())
52 | else:
53 | state_dict = torch.load(pretrained)['embedding_prompt']
54 | self.embedding_prompt = state_dict.to(self.device)
55 |
56 | # ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
57 | # nn.init.normal_(ctx_vectors, std=0.02)
58 | # text_features = self.embedding_prompt
59 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
60 |
61 | def forward(self):
62 | pdb.set_trace()
63 | self.embedding_prompt_cat = torch.cat((self.prefix, self.embedding_prompt, self.suffix), dim=1)
64 | text_features = self.model.encode_text(self.embedding_prompt_cat, self.tokenized_prompts)
65 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
66 | return text_features
67 |
68 |
69 |
70 | class Linear(nn.Module):
71 | def __init__(self, in_dim=512, out_dim=512):
72 | super(Linear,self).__init__()
73 | self.mlp = nn.Linear(in_dim, out_dim)
74 | # self.mlp = nn.Sequential(nn.Linear(in_dim, out_dim),
75 | # nn.ReLU(),
76 | # nn.Linear(in_dim, out_dim))
77 |
78 |
79 | def forward(self, x):
80 | return self.mlp(x)
--------------------------------------------------------------------------------
/vit_version/clip_rar/prompts.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import clip
4 | import pdb
5 |
6 | # class Prompts(nn.Module):
7 | # def __init__(self, initials=None, clip_model=None, len_text=None, pretrained=None):
8 | # super(Prompts,self).__init__()
9 | # print("The initial prompts are:",initials)
10 | # # tokenized, embedding
11 | # self.tokenized_prompts = clip.tokenize(initials)
12 | # embedding = clip_model.token_embedding(self.tokenized_prompts.cuda())
13 | # self.prefix, self.suffix = embedding[:, :1, :], embedding[:, 1+len_text:, :]
14 | # # learnable embedding
15 | # ctx_vectors = torch.empty(2, len_text, 768, dtype=embedding.dtype)
16 | # nn.init.normal_(ctx_vectors, std=0.02)
17 | # self.embedding_learn = nn.Parameter(ctx_vectors)
18 |
19 | # # state_dict = torch.load(pretrained)
20 | # # self.embedding_prompt = state_dict['embedding_prompt'].cuda()
21 |
22 | # def forward(self, clip_model):
23 | # # pdb.set_trace()
24 | # embedding_prompt = torch.cat((self.prefix, self.embedding_learn, self.suffix), dim=1)
25 | # text_features = clip_model.encode_text(embedding_prompt, self.tokenized_prompts)
26 | # text_features /= text_features.norm(dim=-1, keepdim=True)
27 | # return text_features
28 |
29 |
30 | class Prompts(nn.Module):
31 | def __init__(self, initials=None, clip_model=None, len_text=None, pretrained=None, device='cuda'):
32 | super(Prompts,self).__init__()
33 | # initials[0] += ' normal person'
34 | # initials[1] += ' abnormal person'
35 | print("The initial prompts are:",initials)
36 | # tokenized, embedding
37 | # pdb.set_trace()
38 | self.tokenized_prompts = clip.tokenize(initials).to(device)
39 | self.model = clip_model
40 | embedding = self.model.token_embedding(self.tokenized_prompts).to(device)
41 | self.prefix, self.suffix = embedding[:, :1, :], embedding[:, 1+len_text:, :]
42 | self.device = device
43 | # self.prefix.requires_grad = False
44 | # self.suffix.requires_grad = False
45 | if pretrained is None:
46 | ctx_vectors = torch.zeros_like(embedding[:, 1:1+len_text, :])
47 | nn.init.normal_(ctx_vectors, std=0.02)
48 | self.embedding_prompt = nn.Parameter(ctx_vectors.requires_grad_()).to(device)
49 | # self.embedding_prompt = nn.Parameter(embedding[:, 1:1+len_text, :].requires_grad_()).to(device)
50 | # self.embedding_prompt = nn.ParameterList([nn.Parameter(torch.zeros(2,768).requires_grad_())]).to(device)
51 | # self.embedding_prompt = nn.Parameter(torch.randn(2,512).requires_grad_())
52 | else:
53 | state_dict = torch.load(pretrained)['embedding_prompt']
54 | self.embedding_prompt = state_dict.to(self.device)
55 |
56 | # ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
57 | # nn.init.normal_(ctx_vectors, std=0.02)
58 | # text_features = self.embedding_prompt
59 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
60 |
61 | def forward(self):
62 | pdb.set_trace()
63 | self.embedding_prompt_cat = torch.cat((self.prefix, self.embedding_prompt, self.suffix), dim=1)
64 | text_features = self.model.encode_text(self.embedding_prompt_cat, self.tokenized_prompts)
65 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
66 | return text_features
67 |
68 |
69 |
70 | class Linear(nn.Module):
71 | def __init__(self, in_dim=512, out_dim=512):
72 | super(Linear,self).__init__()
73 | self.mlp = nn.Linear(in_dim, out_dim)
74 | # self.mlp = nn.Sequential(nn.Linear(in_dim, out_dim),
75 | # nn.ReLU(),
76 | # nn.Linear(in_dim, out_dim))
77 |
78 |
79 | def forward(self, x):
80 | return self.mlp(x)
--------------------------------------------------------------------------------
/vit_version/clip_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | try:
5 | from torch.hub import load_state_dict_from_url
6 | except ImportError:
7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
8 | import os, pdb, sys
9 | sys.path.append(os.getcwd())
10 | from typing import Type, Any, Callable, Union, List, Optional
11 | from models.recons_net import RAR_single, MLP
12 | from models.resnet_rar import BasicBlock, Bottleneck, AttnBottleneck, conv1x1, conv3x3
13 | import pdb
14 |
15 |
16 | class BN_layer(nn.Module):
17 | def __init__(self,
18 | block: Type[Union[BasicBlock, Bottleneck]],
19 | layers: int,
20 | groups: int = 1,
21 | width_per_group: int = 64,
22 | norm_layer: Optional[Callable[..., nn.Module]] = None,
23 | ):
24 | super(BN_layer, self).__init__()
25 | if norm_layer is None:
26 | norm_layer = nn.BatchNorm2d
27 | self._norm_layer = norm_layer
28 | self.groups = groups
29 | self.base_width = width_per_group
30 | self.inplanes = 768
31 | self.dilation = 1
32 | # self.bn_layer = self._make_layer(block, 768, layers, stride=1)
33 |
34 | self.conv1 = conv1x1(768, 768, 1)
35 | self.bn1 = norm_layer(768)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv1x1(768, 768, 1)
38 | self.bn2 = norm_layer(768)
39 | self.conv3 = conv1x1(768, 768, 1)
40 | self.bn3 = norm_layer(768)
41 | # self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2)
42 | # self.bn3 = norm_layer(256 * block.expansion)
43 |
44 | self.conv4 = conv1x1(768*3, 768, 1)
45 | self.bn4 = norm_layer(768)
46 |
47 |
48 | for m in self.modules():
49 | if isinstance(m, nn.Conv2d):
50 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
51 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
52 | nn.init.constant_(m.weight, 1)
53 | nn.init.constant_(m.bias, 0)
54 |
55 | # def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
56 | # stride: int = 1, dilate: bool = False) -> nn.Sequential:
57 | # norm_layer = self._norm_layer
58 | # downsample = None
59 | # previous_dilation = self.dilation
60 | # if dilate:
61 | # self.dilation *= stride
62 | # stride = 1
63 | # if stride != 1 or self.inplanes != planes * block.expansion:
64 | # downsample = nn.Sequential(
65 | # conv1x1(self.inplanes, planes, stride),
66 | # norm_layer(planes * block.expansion),
67 | # )
68 |
69 | # layers = []
70 | # # layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
71 | # # self.base_width, previous_dilation, norm_layer))
72 | # # self.inplanes = planes * block.expansion
73 | # self.inplances = planes
74 | # for _ in range(blocks):
75 | # layers.append(block(self.inplanes, planes, groups=self.groups,
76 | # base_width=self.base_width, dilation=self.dilation,
77 | # norm_layer=norm_layer))
78 |
79 | # return nn.Sequential(*layers)
80 |
81 | def _forward_impl(self, x: Tensor) -> Tensor:
82 | # See note [TorchScript super()]
83 | #x = self.cbam(x)
84 | l1 = self.relu(self.bn1(self.conv1(x[0])))
85 | l2 = self.relu(self.bn2(self.conv2(x[1])))
86 | l3 = self.relu(self.bn3(self.conv3(x[2])))
87 | feature = torch.cat([l1,l2,l3],1)
88 | output = self.conv4(feature)
89 | # output = self.bn_layer(feature)
90 | #x = self.avgpool(feature_d)
91 | #x = torch.flatten(x, 1)
92 | #x = self.fc(x)
93 |
94 | return output.contiguous()
95 |
96 | def forward(self, x: Tensor) -> Tensor:
97 | return self._forward_impl(x)
98 |
99 | if __name__ == '__main__':
100 | x = [torch.randn((1,768,16,16))]*3
101 | bn = BN_layer(AttnBottleneck, 3)
102 | pdb.set_trace()
103 | x_new = bn(x)
--------------------------------------------------------------------------------
/utils/perlin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 |
5 | def lerp_np(x,y,w):
6 | fin_out = (y-x)*w + x
7 | return fin_out
8 |
9 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5):
10 | noise = np.zeros(shape)
11 | frequency = 1
12 | amplitude = 1
13 | for _ in range(octaves):
14 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1]))
15 | frequency *= 2
16 | amplitude *= persistence
17 | return noise
18 |
19 |
20 | def generate_perlin_noise_2d(shape, res):
21 | def f(t):
22 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
23 |
24 | delta = (res[0] / shape[0], res[1] / shape[1])
25 | d = (shape[0] // res[0], shape[1] // res[1])
26 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
27 | # Gradients
28 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1)
29 | gradients = np.dstack((np.cos(angles), np.sin(angles)))
30 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
31 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
32 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1)
33 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1)
34 | # Ramps
35 | n00 = np.sum(grid * g00, 2)
36 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
37 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
38 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
39 | # Interpolation
40 | t = f(grid)
41 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
42 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
43 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)
44 |
45 |
46 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
47 | delta = (res[0] / shape[0], res[1] / shape[1])
48 | d = (shape[0] // res[0], shape[1] // res[1])
49 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
50 |
51 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1)
52 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
53 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1)
54 |
55 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1)
56 | dot = lambda grad, shift: (
57 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
58 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
59 |
60 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
61 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
62 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
63 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
64 | t = fade(grid[:shape[0], :shape[1]])
65 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])
66 |
67 |
68 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
69 | delta = (res[0] / shape[0], res[1] / shape[1])
70 | d = (shape[0] // res[0], shape[1] // res[1])
71 |
72 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
73 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
74 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
75 |
76 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
77 | 0).repeat_interleave(
78 | d[1], 1)
79 | dot = lambda grad, shift: (
80 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
81 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
82 |
83 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
84 |
85 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
86 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
87 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
88 | t = fade(grid[:shape[0], :shape[1]])
89 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
90 |
91 |
92 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
93 | noise = torch.zeros(shape)
94 | frequency = 1
95 | amplitude = 1
96 | for _ in range(octaves):
97 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
98 | frequency *= 2
99 | amplitude *= persistence
100 | return noise
--------------------------------------------------------------------------------
/vit_version/clip_decoder.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 | import os, pdb, sys
4 | sys.path.append(os.getcwd())
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 | import math
10 | import pdb
11 | from models.recons_net import patch_to_tensor, tensor_to_patch
12 | from clip.model import LayerNorm, QuickGELU
13 |
14 |
15 | class QuickGELU(nn.Module):
16 | def forward(self, x: torch.Tensor):
17 | return x * torch.sigmoid(1.702 * x)
18 |
19 |
20 | class ResidualAttentionBlock(nn.Module):
21 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
22 | super().__init__()
23 |
24 | self.attn = nn.MultiheadAttention(d_model, n_head)
25 | self.ln_1 = LayerNorm(d_model)
26 | self.mlp = nn.Sequential(OrderedDict([
27 | ("c_fc", nn.Linear(d_model, d_model * 4)),
28 | ("gelu", QuickGELU()),
29 | ("c_proj", nn.Linear(d_model * 4, d_model))
30 | ]))
31 | self.ln_2 = LayerNorm(d_model)
32 | self.attn_mask = attn_mask
33 |
34 | def attention(self, x: torch.Tensor):
35 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
36 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
37 |
38 | def forward(self, x: torch.Tensor):
39 | # pdb.set_trace()
40 | x = x + self.attention(self.ln_1(x))
41 | x = x + self.mlp(self.ln_2(x))
42 | return x
43 |
44 |
45 | class Transformer(nn.Module):
46 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
47 | super().__init__()
48 | self.width = width
49 | self.layers = layers
50 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
51 |
52 | # def forward(self, x: torch.Tensor):
53 | # return self.resblocks(x)
54 |
55 | def forward(self, x, f_list=[]):
56 | # pdb.set_trace()
57 | out_tokens = []
58 | idx = 0
59 | for r in self.resblocks:
60 | # pdb.set_trace()
61 | idx+=1
62 | x = r(x)
63 | if idx in f_list:
64 | if len(x)==2:
65 | out_tokens.append(x[0])
66 | out_tokens.append(x[1])
67 | else:
68 | out_tokens.append(x.permute(1,0,2))
69 | return x, out_tokens #self.resblocks(x)
70 |
71 |
72 | class de_VisionTransformer(nn.Module):
73 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
74 | super().__init__()
75 | self.input_resolution = input_resolution
76 | self.grid_size = (input_resolution // patch_size) # modify
77 | self.output_dim = output_dim
78 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
79 |
80 | scale = width ** -0.5
81 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
82 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
83 | self.ln_pre = LayerNorm(width)
84 |
85 | self.transformer = Transformer(width, layers, heads)
86 |
87 | self.ln_post = LayerNorm(width)
88 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
89 |
90 | def forward(self, x: torch.Tensor, f_list: list):
91 | # x = self.conv1(x) # shape = [*, width, grid, grid]
92 | # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
93 | # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
94 | # x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
95 | x = tensor_to_patch(x) + self.positional_embedding.to(x.dtype)[None,1:]
96 | # pdb.set_trace()
97 | x = self.ln_pre(x)
98 |
99 | x = x.permute(1, 0, 2) # NLD -> LND
100 | x, patch_tokens = self.transformer(x, f_list)
101 | x = x.permute(1, 0, 2) # LND -> NLD
102 |
103 | # x = self.ln_post(x[:, 0, :])
104 | # # pdb.set_trace()
105 | # if self.proj is not None:
106 | # x = x @ self.proj
107 |
108 | patch_tokens = [patch_to_tensor(x) for x in patch_tokens]
109 |
110 | return x, patch_tokens
111 |
112 |
113 | if __name__ == '__main__':
114 | x = torch.randn((1,768,16,16))
115 | decoder_SS = de_VisionTransformer(input_resolution=256, patch_size=16, width=768, layers=6, heads=12, output_dim=512)
116 | # pdb.set_trace()
117 | _, x_new = decoder_SS(x, f_list=[2,4,6])
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | ## Overview
7 |
8 |
9 |
12 |
13 |
17 | There are two underlying assumptions in KD-based anomaly detection framework. **Assumption I**: The teacher model can represent two separable distributions for the normal and abnormal patterns; **Assumption II**: the student model can only reconstruct the normal distribution. In this paper, we propose a simple yet effective two-stage anomaly detection framework, termed AAND, which comprises an Anomaly Amplification Stage **Stage I** to address Assumption I and a Normality Distillation Stage **Stage II** to address Assumption II.
18 |
19 | ## Author
20 | [Canhui Tang](https://scholar.google.com/citations?hl=zh-CN&user=TKqkrnUAAAAJ), [Sanping Zhou](https://scholar.google.cz/citations?hl=zh-CN&user=2Drvv44AAAAJ), Yizhe Li, Yonghao Dong, [Le Wang](https://scholar.google.cz/citations?hl=zh-CN&user=RypRCUQAAAAJ)
21 |
22 | Xi'an Jiaotong University
23 |
24 | ## News
25 | 🔥 2025.09: Awaiting SAE Decision approval
26 |
27 | 🔥 2025.05: Accept with Mandatory Minor Revisions
28 |
29 | 🔥 2024.06: Our another KD-based Project [VAND-GNL](https://github.com/Hui-design/VAND-GNL) won the 2nd Place of CVPR 2024 [VAND2.0 Challenge](https://www.hackster.io/contests/openvino2024#challengeNav)
30 |
31 | ## 🔧 Installation
32 |
33 | Please use the following command for installation.
34 |
35 | ```bash
36 | # It is recommended to create a new environment
37 | conda create -n AAND python==3.8
38 | conda activate AAND
39 |
40 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
41 |
42 | # Install packages and other dependencies
43 | pip install -r requirements.txt
44 | ```
45 |
46 |
47 | ## 💾 Dataset
48 |
49 | - [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad)
50 | - [VisA](https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar), use [visa.py](https://github.com/ByChelsea/VAND-APRIL-GAN/blob/master/data/visa.py) to generate meta.json
51 | - [MVTec-3D](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad), we only use the rgb images, called **MVTec3D-RGB** in our paper.
52 | - [DRAEM_dtd](https://www.robots.ox.ac.uk/~vgg/data/dtd/), used as the auxillary texture datasets for synthesizing anomalies like [DRAEM](https://github.com/VitjanZ/DRAEM).
53 | ```
54 |
55 | ├── mvtec
56 | ├── bottle
57 | ├── train
58 | ├── test
59 | ├── ground_truth
60 | ├── ...
61 |
62 | ├── VisA
63 | ├── meta.json
64 | ├── candle
65 | ├── Data
66 | ├── image_anno.csv
67 | ├── ...
68 |
69 | ├── mvtec3d
70 | ├── bagel
71 | ├── train
72 | ├── good
73 | ├── rgb (we only use rgb)
74 | ├── xyz
75 | ├── test
76 | ├── ...
77 |
78 | ├── DRAEM_dtd
79 | ├── dtd
80 | ├── images
81 | ├── ...
82 | ```
83 |
84 | ## Preprocessing
85 | - Extract foreground mask for **training** images.
86 |
87 | ```bash
88 | python scripts/fore_extractor.py --data_path // --aux_path /dtd/images/ # the is mvtec, VisA, or mvtec3d
89 | ```
90 |
91 |
102 |
103 | ## 🚅 Training
104 | You can train models on mvtec, VisA, or mvtec3d by the following commands:
105 | ```bash
106 | python train.py --data_root // # the is mvtec, VisA, or mvtec3d
107 | ```
108 |
109 |
110 | ## ⛳ Testing
111 | You can test the trained models on mvtec, VisA, or mvtec3d by the following commands:
112 | ```bash
113 | python test.py --data_root // # the is mvtec, VisA, or mvtec3d
114 | ```
115 |
116 | ## Citation
117 |
118 | ```bibtex
119 | @article{tang2024advancing,
120 | title={Advancing Pre-trained Teacher: Towards Robust Feature Discrepancy for Anomaly Detection},
121 | author={Tang, Canhui and Zhou, Sanping and Li, Yizhe and Dong, Yonghao and Wang, Le},
122 | journal={arXiv preprint arXiv:2405.02068},
123 | year={2024}
124 | }
125 | ```
126 |
127 |
128 | ## Acknowledgements
129 | - [RD4AD](https://github.com/hq-deng/RD4AD)
130 | - [DRAEM](https://github.com/VitjanZ/DRAEM)
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/vit_version/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/vit_version/clip_rar/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/utils/vis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | from torchvision import transforms
5 | from PIL import ImageFilter
6 | from sklearn.manifold import TSNE
7 | import seaborn as sns
8 | import matplotlib.pyplot as plt
9 | import cv2, pdb, os
10 | from PIL import Image
11 | from torch.nn import functional as F
12 | import pandas as pd
13 |
14 | IMAGENET_MEAN = [0.485, 0.456, 0.406]
15 | IMAGENET_STD = [0.229, 0.224, 0.225]
16 |
17 | resize_lib = {0: torch.nn.AdaptiveAvgPool2d((64, 64)),
18 | 1: torch.nn.AdaptiveAvgPool2d((32, 32)) ,
19 | 2: torch.nn.AdaptiveAvgPool2d((16, 16))} # 注意:可以尝试不同的下采样方式
20 |
21 | def vis_anomaly_images(imgs, mask_ori, obj):
22 | image_mean = np.array(IMAGENET_MEAN).reshape(1,1,3)
23 | image_std = np.array(IMAGENET_STD).reshape(1,1,3)
24 | B, C, H, W = imgs.shape
25 | os.makedirs(f'imgs/{obj}', exist_ok=True)
26 | masks_list = [resize_lib[i](mask_ori) for i in range(3)]
27 | for i in range(B):
28 | img = imgs[i].permute(1,2,0).cpu().numpy() # [H,W,3]
29 | img = ((img * image_std + image_mean)*255).astype(np.uint8)
30 | img = Image.fromarray(img)
31 | img.save(f'imgs/{obj}/img{i}.jpg')
32 |
33 | # pdb.set_trace()
34 | for level in range(3):
35 | mask = (masks_list[level][i][0]>0.3).float().cpu().numpy() # [H,W]
36 | mean = mask.mean()
37 | mask = (mask * 255).astype(np.uint8)
38 | mask = Image.fromarray(mask)
39 | mask.save(f'imgs/{obj}/mask{i}_{level}.jpg')
40 |
41 | mask_ = (mask_ori[i].permute(1,2,0).cpu().numpy().squeeze() * 255).astype(np.uint8)
42 | mask_ = Image.fromarray(mask_)
43 | mask_.save(f'imgs/{obj}/mask{i}.jpg')
44 |
45 | def vis_gt(gt):
46 | gt = Image.fromarray(gt.astype(np.uint8) * 255, mode='L')
47 | gt.save('gt.png')
48 |
49 |
50 | def vis_hotmap(fs, ss, a_map):
51 | B, C, H, W = fs.shape
52 | os.makedirs('imgs/a_map', exist_ok=True)
53 | # pdb.set_trace()
54 | fig = plt.figure()
55 | a_map = a_map.squeeze().cpu().detach()
56 | sns.heatmap(data=a_map)
57 | plt.savefig(f'imgs/a_map/a_map.png')
58 |
59 | tsne = TSNE(n_components=1)
60 | fs = fs.squeeze().permute(1,2,0).reshape(H*W, C).detach().cpu()
61 | ss = ss.squeeze().permute(1,2,0).reshape(H*W, C).detach().cpu()
62 |
63 | # L2-norm
64 | fs = F.normalize(fs, p=1)
65 | ss = F.normalize(ss, p=1)
66 | cat_tsne = tsne.fit_transform(np.vstack((fs, ss)))
67 | fs_tsne = cat_tsne[:H*W].reshape(H,W)
68 | ss_tsne = cat_tsne[H*W:].reshape(H,W)
69 |
70 | fig = plt.figure()
71 | sns.heatmap(data=fs_tsne)
72 | plt.savefig(f'imgs/a_map/fs.png')
73 |
74 | fig = plt.figure()
75 | sns.heatmap(data=ss_tsne)
76 | plt.savefig(f'imgs/a_map/ss.png')
77 | pdb.set_trace()
78 | plt.close('all')
79 |
80 | def vis_hotmap_single(a_map):
81 | fig = plt.figure()
82 | sns.heatmap(data=a_map, xticklabels=[], yticklabels=[], cbar=False)
83 | plt.savefig(f'atten.png')
84 | # pdb.set_trace()
85 | plt.close('all')
86 |
87 | def tsne_vis_single(memory, name):
88 | tsne = TSNE(n_components=2)
89 | data = tsne.fit_transform(memory)
90 | plt.figure()
91 | plt.scatter(data[:,0], data[:,1])
92 | # plt.legend((s1,s2),('memory','anomaly') ,loc = 'best')
93 | plt.savefig(f'{name}.png')
94 | plt.close('all')
95 |
96 | def normalize(pred, max_value=None, min_value=None):
97 | if max_value is None or min_value is None:
98 | return (pred - pred.min()) / (pred.max() - pred.min())
99 | else:
100 | return (pred - min_value) / (max_value - min_value)
101 |
102 | def apply_ad_scoremap(rgb_path, scoremap, cls, alpha=0.5):
103 | img_size = scoremap.shape[0]
104 | # pdb.set_trace()
105 | image = cv2.cvtColor(cv2.resize(cv2.imread(rgb_path), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB
106 | mask = normalize(scoremap)
107 | # vis = apply_ad_scoremap(vis, mask)
108 | np_image = np.asarray(image, dtype=float)
109 | scoremap[scoremap>1] =1
110 | scoremap = (scoremap * 255).astype(np.uint8)
111 | scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
112 | scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
113 | scoremap = (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)
114 |
115 | vis = cv2.cvtColor(scoremap, cv2.COLOR_RGB2BGR) # BGR
116 | save_vis = 'imgs'
117 | if not os.path.exists(save_vis):
118 | os.makedirs(save_vis)
119 | cv2.imwrite(f'{save_vis}/{cls}.png', vis)
120 |
121 |
122 | def plot_hist(normal, anomaly):
123 | width = 7
124 | fig = plt.figure(figsize=(width, width*0.7))
125 | print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}')
126 | # plt.hist(normal,bins=50,label='normal sample',alpha=0.5)
127 | # plt.hist(anomaly,bins=50,label='anomalous sample',alpha=0.5)
128 | # normal = pd.Series(normal) # 将数据由数组转换成series形式
129 | # normal.plot(kind = 'kde',label = 'normal')
130 | # anomaly = pd.Series(anomaly) # 将数据由数组转换成series形式
131 | # anomaly.plot(kind = 'kde',label = 'anomaly')
132 | import seaborn as sn
133 | sn.kdeplot(normal, color="blue", shade="True", label='normal samples', bw_adjust=0.3)
134 | sn.kdeplot(anomaly, color="red", shade="True", label='anomalous samples', bw_adjust=0.3)
135 | plt.legend(fontsize=14)
136 | plt.tick_params(labelsize=13)
137 | plt.ylabel('Density',fontsize=12.5)
138 | plt.xlabel('Anomaly score',fontsize=12.5)
139 | plt.savefig(f'hist_SS.png')
140 |
141 | def plot_line(data, name):
142 | plt.figure()
143 | plt.plot(data) # 绘制 sin(x) 曲线
144 | print(f'save to visual/{name}.png')
145 | plt.savefig(f'visual/{name}.png', bbox_inches='tight', dpi=1000)
--------------------------------------------------------------------------------
/datasets/MvTec3D.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image
3 | import os
4 | import torch
5 | import glob
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | import pdb
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms as T
11 | import imgaug.augmenters as iaa
12 | # import albumentations as A
13 | from utils.perlin import rand_perlin_2d_np
14 | import random
15 | import cv2
16 | from datasets.database import BaseAnomalyDetectionDataset, SynthesisDataset, resize_organized_pc
17 |
18 |
19 | def mvtec3d_classes():
20 | return [
21 | "bagel",
22 | "cable_gland",
23 | "carrot",
24 | "cookie",
25 | "dowel",
26 | "foam",
27 | "peach",
28 | "potato",
29 | "rope",
30 | "tire",
31 | ]
32 |
33 | # return [
34 | # "bagel",
35 | # "cable_gland",
36 | # "carrot",
37 | # "cookie",
38 | # "dowel",
39 | # "foam",
40 | # "peach",
41 | # "potato",
42 | # "rope",
43 | # "tire",
44 | # ]
45 |
46 |
47 |
48 | class TrainDataset(BaseAnomalyDetectionDataset):
49 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
50 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
51 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
52 |
53 | def load_dataset(self):
54 | img_tot_paths = []
55 | tot_labels = []
56 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png")
57 | rgb_paths.sort()
58 | img_tot_paths.extend(rgb_paths)
59 | tot_labels.extend([0] * len(rgb_paths))
60 | return img_tot_paths, tot_labels
61 |
62 | def __len__(self):
63 | return len(self.img_paths)
64 |
65 | def __getitem__(self, idx):
66 | rgb_path, label = self.img_paths[idx], self.labels[idx]
67 | img = Image.open(rgb_path).convert('RGB')
68 | img = self.rgb_transform(img)
69 |
70 | return img, label
71 |
72 |
73 | class TestDataset(BaseAnomalyDetectionDataset):
74 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
75 | super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
76 | self.size = img_size
77 | self.gt_transform = transforms.Compose([
78 | transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST),
79 | transforms.ToTensor()])
80 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
81 |
82 | def load_dataset(self):
83 | img_tot_paths = []
84 | gt_tot_paths = []
85 | tot_labels = []
86 | defect_types = os.listdir(self.img_path)
87 |
88 | for defect_type in defect_types:
89 | if defect_type == 'good':
90 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
91 | rgb_paths.sort()
92 | img_tot_paths.extend(rgb_paths)
93 | gt_tot_paths.extend([0] * len(rgb_paths))
94 | tot_labels.extend([0] * len(rgb_paths))
95 | else:
96 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
97 | gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png")
98 | rgb_paths.sort()
99 | gt_paths.sort()
100 |
101 | img_tot_paths.extend(rgb_paths)
102 | gt_tot_paths.extend(gt_paths)
103 | tot_labels.extend([1] * len(rgb_paths))
104 |
105 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
106 |
107 | return img_tot_paths, gt_tot_paths, tot_labels
108 |
109 | def __len__(self):
110 | return len(self.img_paths)
111 |
112 | def __getitem__(self, idx):
113 | rgb_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
114 | img_original = Image.open(rgb_path).convert('RGB')
115 | img = self.rgb_transform(img_original)
116 |
117 | if gt == 0:
118 | gt = torch.zeros(
119 | [1, self.size, self.size])
120 | else:
121 | gt = Image.open(gt).convert('L')
122 | gt = self.gt_transform(gt)
123 | gt = torch.where(gt > 0.5, 1., .0)
124 |
125 | return img, gt[:1], label, rgb_path
126 |
127 |
128 | class AnomalyDataset(SynthesisDataset):
129 | def __init__(self, class_name, dataset_path, img_size, aux_path):
130 | super().__init__(class_name=class_name, img_size=img_size, dataset_path=dataset_path, aux_path=aux_path)
131 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
132 |
133 | def load_dataset(self):
134 | img_tot_paths = []
135 | tot_labels = []
136 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png")
137 | tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff")
138 | rgb_paths.sort()
139 | tiff_paths.sort()
140 | sample_paths = list(zip(rgb_paths, tiff_paths))
141 | img_tot_paths.extend(sample_paths)
142 | tot_labels.extend([0] * len(sample_paths))
143 | return img_tot_paths, tot_labels
144 |
145 |
146 | def __len__(self):
147 | return len(self.labels)
148 |
149 | def __getitem__(self, idx):
150 | img_path, label = self.img_paths[idx], self.labels[idx]
151 | rgb_path = img_path[0]
152 | tiff_path = img_path[1]
153 | class_name = tiff_path.split("/")[2]
154 |
155 | img = Image.open(rgb_path).convert('RGB')
156 | img = self.rgb_transform(img)
157 |
158 | img_file = rgb_path.split('/')[-1]
159 | fg_path = os.path.join(f'fg_mask/{self.dataset_name}/{self.cls}/train', img_file)
160 | fg_mask = Image.open(fg_path)
161 | fg_mask = np.asarray(fg_mask)[:, :, np.newaxis] # [H, W, 1]
162 | resized_depth_map = resize_organized_pc(fg_mask, img_size=self.size)
163 | fore_mask = resized_depth_map > 0
164 |
165 | # modify
166 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
167 | augmented_image, anomaly_mask, has_anomaly = self.transform_image(img, fore_mask,
168 | self.anomaly_source_paths[anomaly_source_idx])
169 | # # 下采样mask
170 | # anomaly_mask = torch.from_numpy(anomaly_mask).unsqueeze(0)
171 |
172 | # denoising
173 | # return {"ori_img": img, "aug_img": augmented_image, "label": has_anomaly, "anomaly_mask": anomaly_mask}
174 |
175 | return {"img": augmented_image, "label": has_anomaly, "anomaly_mask": anomaly_mask, "fore_mask": fore_mask.float()}
176 |
177 |
--------------------------------------------------------------------------------
/models/recons_net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import math
5 | # from loss import resize_lib, mask_thresh
6 | from utils.vis import *
7 | import pdb
8 |
9 | def pair_cosine(a, b):
10 | cos_sim = torch.einsum('bnc,bmc->bnm', a, b) # [B, N, M]
11 | B, N, M = cos_sim.shape
12 | a_norm = torch.sqrt((a**2).sum(dim=-1)) # [B, N]
13 | a_norm = a_norm.unsqueeze(-1).expand(-1,-1,M) # [B, N, M]
14 | b_norm = torch.sqrt((b**2).sum(dim=-1)) # [B, M]
15 | b_norm = b_norm.unsqueeze(1).expand(-1,N,-1) # [B, N, M]
16 | cos_sim = cos_sim / (a_norm*b_norm)
17 | return cos_sim
18 |
19 | resize_lib = {0: torch.nn.AdaptiveAvgPool2d((64, 64)),
20 | 1: torch.nn.AdaptiveAvgPool2d((32, 32)) ,
21 | 2: torch.nn.AdaptiveAvgPool2d((16, 16))} # 注意:可以尝试不同的下采样方式
22 | mask_thresh = 0.3
23 |
24 | def patch_to_tensor(x):
25 | B, N, C = x.shape
26 | H, W = int(math.sqrt(N)), int(math.sqrt(N))
27 | x = x.reshape(B, H, W, C).permute(0,3,1,2) # [B, C, H, W]
28 | return x
29 |
30 | def tensor_to_patch(x):
31 | B, C, H, W = x.shape
32 | x = x.reshape(B, C, H*W).permute(0,2,1) # [B, N, C]
33 | return x
34 |
35 | class RAR_single(nn.Module):
36 | def __init__(self, d_model=[], size_list=[], num_item=50):
37 | super(RAR_single, self).__init__()
38 | self.d_model = d_model
39 | self.num_item = num_item
40 | self.pos_enc = []
41 | self.projs = []
42 | self.projs2 = []
43 | self.K_list = []
44 | hidden_dim = min(2*d_model[0], 1024)
45 | for i, f_dim in enumerate(d_model):
46 | self.projs.append(nn.Sequential(
47 | nn.Conv2d(f_dim, hidden_dim, 1),
48 | nn.ReLU(),
49 | nn.Conv2d(hidden_dim, hidden_dim, 1),
50 | nn.ReLU(),
51 | nn.Conv2d(hidden_dim, f_dim, 1))
52 | )
53 | # self.K_list.append(nn.Parameter(torch.randn(1, num_item+1, f_dim)))
54 | # self.V_list.append(nn.Parameter(torch.randn(1, 1, f_dim)))
55 | self.K_list.append(nn.Parameter(torch.randn(1, num_item*2, f_dim)))
56 | self.projs2.append(nn.Sequential(
57 | nn.Conv2d(f_dim, hidden_dim, 1),
58 | nn.ReLU(),
59 | nn.Conv2d(hidden_dim, hidden_dim, 1),
60 | nn.ReLU(),
61 | nn.Conv2d(hidden_dim, f_dim, 1)
62 | )
63 | )
64 | self.pos_enc.append(positionalencoding2d(f_dim, size_list[i], size_list[i]))
65 |
66 | self.projs = nn.ModuleList(self.projs)
67 | self.projs2 = nn.ModuleList(self.projs2)
68 | self.K_list = nn.ParameterList(self.K_list) # 如果不加这句话则paramter不会更新
69 | self.sigmoid = nn.Sigmoid()
70 | self.tanh = nn.Tanh()
71 |
72 | def regular_score(self,score):
73 | score = torch.where(torch.isnan(score), torch.zeros_like(score), score)
74 | score = torch.where(torch.isinf(score), torch.zeros_like(score), score)
75 | return score
76 |
77 | def forward(
78 | self,
79 | inputs,
80 | flag,
81 | atten_mask=None,
82 | ):
83 | inputs_recons = []
84 | attention_list = []
85 | for i in range(len(self.projs)):
86 | # Query
87 | # pdb.set_trace()
88 | pos_enc = self.pos_enc[i].unsqueeze(0).expand(inputs[i].shape[0],-1,-1,-1).to(inputs[i].device) # [B, C, H, W]
89 | q = self.projs[i](inputs[i]+pos_enc) # [B, C, H, W]
90 | q = tensor_to_patch(q) # [B, N, C]
91 | # Key, Value
92 | # pdb.set_trace()
93 | k = self.K_list[i].expand(q.shape[0], -1, -1).to(q.device) # [B, M, C]
94 | # v = self.V_list[i].expand(q.shape[0], -1, -1).to(q.device) # [B, M, C]
95 | # pdb.set_trace()
96 | noise = self.projs2[i](inputs[i]) # [B, N, C]
97 | noise = torch.clamp(self.tanh(noise), min=-1, max=1)
98 | v = tensor_to_patch(noise * inputs[i]) # [B, N, C]
99 |
100 | cos_sim = pair_cosine(q, k) # [B, N, M]
101 | attention_scores = F.softmax(cos_sim / 0.2, dim=-1) # [B, N, M]
102 | attention_scores = attention_scores[:, :, self.num_item:].sum(dim=-1, keepdim=True) # [B,N,1]
103 | attention_list.append(attention_scores)
104 | # if flag:
105 | # mask = (attention_scores >= 0.2).float()
106 | # attention_scores = attention_scores * mask
107 |
108 | if atten_mask is not None:
109 | # pdb.set_trace()
110 | index = {256:0, 512:1, 1024:2}
111 | atten_mask = (resize_lib[index[self.d_model[0]]](atten_mask).reshape(attention_scores.shape[0], -1) > 0.3) # [B,N]
112 | attention_scores = atten_mask.unsqueeze(-1).float() # [B,N,1]
113 |
114 | # weighted Value
115 | # pdb.set_trace()
116 | # q_recons = torch.matmul(attention_scores, v) # [B, N, C]
117 | q_recons = attention_scores * v
118 | # q_recons = attention_scores * v * torch.exp(attention_scores-0.5)
119 | q_recons = patch_to_tensor(q_recons) # [B, C, H, W]
120 | inputs_recons.append(q_recons)
121 |
122 |
123 | return inputs_recons[0], attention_list[0]
124 |
125 |
126 |
127 | class MLP(nn.Module):
128 | def __init__(self, f_dim):
129 | super(MLP, self).__init__()
130 | self.projs = nn.Sequential(
131 | nn.Conv2d(f_dim, 2*f_dim, 1),
132 | nn.ReLU(),
133 | nn.Conv2d(2*f_dim, 2*f_dim, 1),
134 | nn.ReLU(),
135 | nn.Conv2d(2*f_dim, f_dim, 1)
136 | )
137 |
138 | def forward(
139 | self,
140 | inputs,
141 | flag,
142 | ):
143 | inputs_recons = self.projs(inputs[0])
144 |
145 | return inputs_recons, None
146 |
147 |
148 |
149 | def positionalencoding2d(D, H, W):
150 | """
151 | :param D: dimension of the model
152 | :param H: H of the positions
153 | :param W: W of the positions
154 | :return: DxHxW position matrix
155 | """
156 | if D % 4 != 0:
157 | raise ValueError("Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(D))
158 | P = torch.zeros(D, H, W)
159 | # Each dimension use half of D
160 | D = D // 2
161 | div_term = torch.exp(torch.arange(0.0, D, 2) * -(math.log(1e4) / D))
162 | pos_w = torch.arange(0.0, W).unsqueeze(1)
163 | pos_h = torch.arange(0.0, H).unsqueeze(1)
164 | P[0:D:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, H, 1)
165 | P[1:D:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, H, 1)
166 | P[D::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, W)
167 | P[D+1::2,:, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, W)
168 | return P
--------------------------------------------------------------------------------
/scripts/fore_extractor.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import os, pdb, sys
6 | sys.path.append(os.getcwd())
7 | from tqdm import tqdm
8 | import json, glob
9 | from scipy.ndimage import gaussian_filter
10 | import argparse
11 |
12 |
13 | def extract_mvtec(data_root, save_root, obj_names):
14 | obj_list1 = ['carpet','grid','leather','tile','wood','transistor']
15 | obj_list2 = ['capsule', 'screw', 'zipper']
16 | obj_list3 = ['hazelnut', 'metal_nut', 'pill', 'toothbrush']
17 | obj_list4 = ['cabel', 'bottle']
18 |
19 | for obj in tqdm(obj_names):
20 | rgb_paths = glob.glob(os.path.join(data_root, obj, 'train', 'good') + "/*.png")
21 | rgb_paths.sort()
22 | save_dir = f'{save_root}/{obj}/train'
23 | os.makedirs(save_dir, exist_ok=True)
24 |
25 | for rgb_path in rgb_paths:
26 | img = Image.open(rgb_path).convert('RGB')
27 | image_transforms = transforms.Compose([
28 | transforms.Grayscale(1)
29 | ])
30 | img = image_transforms(img)
31 | img = np.asarray(img)
32 | img = gaussian_filter(img, sigma=8)
33 | # pdb.set_trace()
34 | if obj in obj_list1:
35 | ret, new_img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY)
36 | new_img = Image.fromarray(new_img)
37 | elif obj in obj_list2:
38 | ret, new_img = cv2.threshold(img, 150, 255, cv2.THRESH_BINARY)
39 | new_img = Image.fromarray(255-new_img)
40 | elif obj in obj_list3:
41 | ret, new_img = cv2.threshold(img, 50, 255, cv2.THRESH_BINARY)
42 | new_img = Image.fromarray(new_img)
43 | elif obj in obj_list4:
44 | H, W = img.shape
45 | new_img = np.zeros((H, W), dtype=np.uint8)
46 | center_x, center_y = W // 2, H // 2
47 | radius = int(0.45 * H)
48 | for x in range(W):
49 | for y in range(H):
50 | dist = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
51 | if dist <= radius:
52 | new_img[y, x] = 255
53 | new_img = Image.fromarray(new_img)
54 | file_name = rgb_path.split('/')[-1].split('.')[0]
55 | new_img.save(f"{save_dir}/{file_name}.png")
56 |
57 |
58 | def extract_visa(data_root, save_root, obj_names):
59 | obj_list1 = ['candle', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 'pcb1']
60 | obj_list2 = ['pcb2','pcb3','pipe_fryum']
61 | obj_list3 = ['pcb4']
62 | obj_list4 = ['capsules']
63 |
64 | for obj in obj_names:
65 | meta_info = json.load(open(f'{data_root}/meta.json', 'r'))['train']
66 | rgb_paths = meta_info[obj]
67 | rgb_paths = [x["img_path"] for x in rgb_paths]
68 | rgb_paths.sort()
69 | save_dir = f'{save_root}/{obj}/train'
70 | os.makedirs(save_dir, exist_ok=True)
71 |
72 | for rgb_path in tqdm(rgb_paths):
73 | rgb_path = f'{data_root}/{rgb_path}'
74 | img = Image.open(rgb_path).convert('RGB')
75 | image_transforms = transforms.Compose([
76 | transforms.Grayscale(1)
77 | ])
78 | img = image_transforms(img)
79 | img = np.asarray(img)
80 | img = gaussian_filter(img, sigma=8)
81 | if obj in obj_list1:
82 | ret, new_img = cv2.threshold(img, 150, 255, cv2.THRESH_BINARY)
83 | new_img = Image.fromarray(new_img)
84 | elif obj in obj_list2:
85 | ret, new_img = cv2.threshold(img, 100, 255, cv2.THRESH_BINARY)
86 | new_img = Image.fromarray(new_img)
87 | elif obj in obj_list3:
88 | ret, new_img = cv2.threshold(img, 50, 255, cv2.THRESH_BINARY)
89 | new_img = Image.fromarray(new_img)
90 | elif obj in obj_list4:
91 | ret, new_img = cv2.threshold(img, 100, 255, cv2.THRESH_BINARY)
92 | new_img = Image.fromarray(255 - new_img)
93 | file_name = rgb_path.split('/')[-1].split('.')[0]
94 | new_img.save(f"{save_dir}/{file_name}.png")
95 |
96 |
97 | def extract_mvtec3d(data_root, save_root, obj_names):
98 | obj_list1 = ["bagel","cable_gland", "carrot", "cookie", "dowel", "peach","potato", "rope",]
99 | obj_list2 = ["foam", "tire",]
100 |
101 | for obj in tqdm(obj_names):
102 | rgb_paths = glob.glob(os.path.join(data_root, obj, 'train', 'good', 'rgb') + "/*.png")
103 | rgb_paths.sort()
104 | save_dir = f'{save_root}/{obj}/train'
105 | os.makedirs(save_dir, exist_ok=True)
106 |
107 | for rgb_path in rgb_paths:
108 | img = Image.open(rgb_path).convert('RGB')
109 | image_transforms = transforms.Compose([
110 | transforms.Grayscale(1)
111 | ])
112 | img = image_transforms(img)
113 | img = np.asarray(img)
114 | img = gaussian_filter(img, sigma=8)
115 | # pdb.set_trace()
116 | if obj in obj_list1:
117 | ret, new_img = cv2.threshold(img, 50, 255, cv2.THRESH_BINARY)
118 | new_img = Image.fromarray(new_img)
119 | elif obj in obj_list2:
120 | H, W = img.shape
121 | new_img = np.zeros((H, W), dtype=np.uint8)
122 | center_x, center_y = W // 2, H // 2
123 | half_width = int(0.35 * W)
124 | half_height = int(0.35 * H)
125 | for x in range(W):
126 | for y in range(H):
127 | if (center_x - half_width) <= x <= (center_x + half_width) and (center_y - half_height) <= y <= (center_y + half_height):
128 | new_img[y, x] = 255
129 | new_img = Image.fromarray(new_img)
130 |
131 | file_name = rgb_path.split('/')[-1].split('.')[0]
132 | new_img.save(f"{save_dir}/{file_name}.png")
133 | # pdb.set_trace()
134 |
135 |
136 |
137 | if __name__ == '__main__':
138 | parser = argparse.ArgumentParser()
139 | parser.add_argument('--data_path', type=str, default='/data4/tch/AD_data/mvtec')
140 | args = parser.parse_args()
141 |
142 | if 'mvtec' in args.data_path.lower() and 'mvtec3d' not in args.data_path.lower():
143 | from datasets.MvTec import mvtec_classes
144 | obj_names = mvtec_classes()
145 | save_root = 'fg_mask/mvtec'
146 | extract_mvtec(args.data_path, save_root, obj_names)
147 |
148 | elif 'visa' in args.data_path.lower():
149 | from datasets.VisA import visa_classes
150 | obj_names = visa_classes()
151 | save_root = 'fg_mask/VisA'
152 | extract_visa(args.data_path, save_root, obj_names)
153 |
154 | elif 'mvtec3d' in args.data_path.lower():
155 | from datasets.MvTec3D import mvtec3d_classes
156 | obj_names = mvtec3d_classes()
157 | save_root = 'fg_mask/mvtec3d'
158 | extract_mvtec3d(args.data_path, save_root, obj_names)
159 |
160 | else:
161 | print('no such dataset')
162 |
--------------------------------------------------------------------------------
/datasets/MvTec.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image
3 | import os
4 | import torch
5 | import glob
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | import pdb
9 | import os, sys
10 | sys.path.append(os.getcwd())
11 | from torch.utils.data import Dataset
12 | from torchvision import transforms as T
13 | import imgaug.augmenters as iaa
14 | # import albumentations as A
15 | from utils.perlin import rand_perlin_2d_np
16 | from utils.utils import get_dataset_name
17 | import random
18 | from utils.vis import vis_anomaly_images
19 | import cv2
20 | from datasets.database import BaseAnomalyDetectionDataset, SynthesisDataset, resize_organized_pc
21 |
22 |
23 | def mvtec_classes():
24 | return [
25 | 'carpet',
26 | 'grid',
27 | 'leather',
28 | 'tile',
29 | 'wood',
30 | 'bottle',
31 | 'cable',
32 | 'capsule',
33 | 'hazelnut',
34 | 'metal_nut',
35 | 'pill',
36 | 'screw',
37 | 'toothbrush',
38 | 'transistor',
39 | 'zipper'
40 | ]
41 |
42 |
43 |
44 |
45 | class TrainDataset(BaseAnomalyDetectionDataset):
46 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
47 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
48 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
49 |
50 | def load_dataset(self):
51 | img_tot_paths = []
52 | tot_labels = []
53 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good') + "/*.png")
54 | rgb_paths.sort()
55 | img_tot_paths.extend(rgb_paths)
56 | tot_labels.extend([0] * len(rgb_paths))
57 | return img_tot_paths, tot_labels
58 |
59 | def __len__(self):
60 | return len(self.img_paths)
61 |
62 | def __getitem__(self, idx):
63 | rgb_path, label = self.img_paths[idx], self.labels[idx]
64 | img = Image.open(rgb_path).convert('RGB')
65 | img = self.rgb_transform(img)
66 |
67 | return img, label
68 |
69 |
70 | class TestDataset(BaseAnomalyDetectionDataset):
71 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
72 | super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
73 | self.size = img_size
74 | self.gt_transform = transforms.Compose([
75 | transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST),
76 | transforms.ToTensor()])
77 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
78 |
79 | def load_dataset(self):
80 | img_tot_paths = []
81 | gt_tot_paths = []
82 | tot_labels = []
83 | defect_types = os.listdir(self.img_path)
84 |
85 | for defect_type in defect_types:
86 | if defect_type == 'good':
87 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png")
88 | rgb_paths.sort()
89 | img_tot_paths.extend(rgb_paths)
90 | gt_tot_paths.extend([0] * len(rgb_paths))
91 | tot_labels.extend([0] * len(rgb_paths))
92 | else:
93 | rgb_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png")
94 | gt_paths = glob.glob(os.path.join(self.gt_path, defect_type) + "/*.png")
95 | rgb_paths.sort()
96 | gt_paths.sort()
97 |
98 | img_tot_paths.extend(rgb_paths)
99 | gt_tot_paths.extend(gt_paths)
100 | tot_labels.extend([1] * len(rgb_paths))
101 |
102 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
103 |
104 | return img_tot_paths, gt_tot_paths, tot_labels
105 |
106 | def __len__(self):
107 | return len(self.img_paths)
108 |
109 | def __getitem__(self, idx):
110 | rgb_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
111 | # pdb.set_trace()
112 | img_original = Image.open(rgb_path).convert('RGB')
113 | img = self.rgb_transform(img_original)
114 |
115 | if gt == 0:
116 | gt = torch.zeros(
117 | [1, self.size, self.size])
118 | else:
119 | gt = Image.open(gt).convert('L')
120 | gt = self.gt_transform(gt)
121 | gt = torch.where(gt > 0.5, 1., .0)
122 |
123 | return img, gt[:1], label, rgb_path
124 |
125 |
126 | class AnomalyDataset(SynthesisDataset):
127 | def __init__(self, class_name, dataset_path, img_size, aux_path):
128 | super().__init__(class_name=class_name, img_size=img_size, dataset_path=dataset_path, aux_path=aux_path)
129 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
130 |
131 | def load_dataset(self):
132 | img_tot_paths = []
133 | tot_labels = []
134 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good') + "/*.png")
135 | # tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff")
136 | rgb_paths.sort()
137 | # tiff_paths.sort()
138 | # sample_paths = list(zip(rgb_paths, tiff_paths))
139 | # img_tot_paths.extend(sample_paths)
140 | # tot_labels.extend([0] * len(sample_paths))
141 | img_tot_paths.extend(rgb_paths)
142 | tot_labels.extend([0] * len(rgb_paths))
143 | return img_tot_paths, tot_labels
144 |
145 | def __len__(self):
146 | return len(self.labels)
147 |
148 | def __getitem__(self, idx):
149 | img_path, label = self.img_paths[idx], self.labels[idx]
150 | rgb_path = img_path
151 | # tiff_path = img_path[1]
152 | class_name = img_path.split("/")[2]
153 |
154 | img = Image.open(rgb_path).convert('RGB')
155 | img = self.rgb_transform(img)
156 |
157 | img_file = rgb_path.split('/')[-1]
158 | fg_path = os.path.join(f'fg_mask/{self.dataset_name}/{self.cls}/train', img_file)
159 | fg_mask = Image.open(fg_path)
160 | fg_mask = np.asarray(fg_mask)[:, :, np.newaxis] # [H, W, 1]
161 | resized_depth_map = resize_organized_pc(fg_mask, img_size=self.size)
162 | fore_mask = resized_depth_map > 0
163 |
164 | # # pdb.set_trace()
165 | # vis_anomaly_images(img.unsqueeze(0), fore_mask.unsqueeze(0).float(), class_name)
166 |
167 | # modify
168 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
169 | augmented_image, anomaly_mask, has_anomaly = self.transform_image(img, fore_mask,
170 | self.anomaly_source_paths[anomaly_source_idx])
171 | # # 下采样mask
172 | # anomaly_mask = torch.from_numpy(anomaly_mask).unsqueeze(0)
173 |
174 | # if has_anomaly == 1.0:
175 | # # pdb.set_trace()
176 | # vis_anomaly_images(torch.from_numpy(augmented_image)[None], torch.from_numpy(anomaly_mask)[None], class_name)
177 |
178 | return {"img": augmented_image, "label": has_anomaly, "anomaly_mask": anomaly_mask, "fore_mask": fore_mask.float()}
179 |
180 |
181 | if __name__ == '__main__':
182 | train_dat = AnomalyDataset(class_name='transistor', img_size=256, dataset_path='/data4/tch/AD_data/mvtec', aux_path='/data4/tch/AD_data/DRAEM_dtd/dtd/images')
183 |
184 | train_dat.__getitem__(0)
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | import torch.nn as nn
4 | from models.recons_net import pair_cosine
5 | import pdb
6 | import numpy as np
7 | from utils.vis import vis_hotmap_single
8 |
9 | resize_lib = {0: torch.nn.AdaptiveAvgPool2d((64, 64)),
10 | 1: torch.nn.AdaptiveAvgPool2d((32, 32)) ,
11 | 2: torch.nn.AdaptiveAvgPool2d((16, 16))} # 注意:可以尝试不同的下采样方式
12 | mask_thresh = 0.3
13 |
14 | def get_recons_loss(a, b):
15 | # mse_loss = torch.nn.MSELoss()
16 | # mse_loss = torch.nn.MSELoss(reduction='none')
17 | cos_loss = torch.nn.CosineSimilarity()
18 | loss = 0
19 | # pdb.set_trace()
20 | for item in range(len(a)):
21 | # loss += torch.sqrt(mse_loss(a[item], b[item]).sum(dim=1)).mean() # [B,C,H,W]->[B,H,W]
22 | # loss += torch.mean(1-cos_loss(a[item].reshape(a[item].shape[0],-1),
23 | # b[item].reshape(b[item].shape[0],-1)))
24 | loss += torch.mean(1-cos_loss(a[item],b[item]))
25 | return loss
26 |
27 | def get_orth_loss(a, b):
28 | #mse_loss = torch.nn.MSELoss()
29 | cos_loss = torch.nn.CosineSimilarity()
30 | loss = 0
31 | for item in range(len(a)):
32 | #print(a[item].shape)
33 | #print(b[item].shape)
34 | #loss += 0.1*mse_loss(a[item], b[item])
35 | # loss += torch.mean(cos_loss(a[item], b[item])) # dim=-1 # 为什么用这个的话梯度就很大呢?
36 | # a: [1, 50, C]
37 | # loss += torch.mean( F.relu(cos_loss(a[item].view(a[item].shape[0],-1),
38 | # b[item].view(b[item].shape[0],-1)) ) ) # 为什么刚好收敛到了0呢?写错了,写成了1-
39 |
40 | loss += torch.mean( pair_cosine(a[item], b[item]) )
41 | return loss
42 |
43 |
44 | def loss_fucntion(a, b):
45 | #mse_loss = torch.nn.MSELoss()
46 | cos_loss = torch.nn.CosineSimilarity()
47 | loss = 0
48 | for item in range(len(a)):
49 | #print(a[item].shape)
50 | #print(b[item].shape)
51 | #loss += 0.1*mse_loss(a[item], b[item])
52 | # pdb.set_trace()
53 | B, C, H, W = a[item].shape
54 | f_t = a[item].reshape(B,C,H*W)
55 | f_s = b[item].reshape(B,C,H*W)
56 | # weight = 1 / cos_loss(f_t, f_s).detach() # [B,C,H*W] -> [B,H*W] 相似度越高,权重越小
57 | # pdb.set_trace()
58 | loss_i = 1 - cos_loss(f_t, f_s) # [B,C,H*W] -> [B,H*W], 相似度越高,loss越小
59 | # loss += torch.mean(loss_i.topk(k=int(0.1*H*W), dim=1)[0]) # 选择损失大的
60 | loss += torch.mean(loss_i.topk(k=10, dim=1)[0]) # 选择损失大的
61 | loss += torch.mean(1-cos_loss(a[item].reshape(a[item].shape[0],-1),
62 | b[item].reshape(b[item].shape[0],-1)))
63 | # loss += torch.mean(weight * loss_i) / 3
64 | # loss += torch.mean(loss_i)
65 |
66 | return loss
67 |
68 |
69 | def entropy(p):
70 | # p: (N*M)
71 | logits = -p * torch.log(p + 0.0001)
72 | entropy = torch.sum(logits, dim=-1) # (N)
73 | return torch.mean(entropy)
74 |
75 | def get_entropy_loss(attention_list, mask_ori):
76 | # attention_list[[B,N,M],[]], mask_ori[B,1,H,W]
77 | entropy_loss = 0.0
78 | num_item = attention_list[0].shape[-1]
79 | for i, atten in enumerate(attention_list):
80 | mask = resize_lib[i](mask_ori).reshape(-1) > 0.3 # [B*N]
81 | atten = atten.reshape(-1, num_item) # [B*N, M]
82 | # pdb.set_trace()
83 | if mask.sum() != 0:
84 | entropy_loss += entropy(atten[mask]) # mask, anomaly 增大 num_item:
85 | return entropy_loss
86 |
87 |
88 | def get_focal_loss(pred_list, mask_ori, vit=False):
89 | # bce_loss = nn.BCELoss()
90 | bce_loss = BCEFocalLoss()
91 | total_loss = 0.0
92 | num_item = pred_list[0].shape[-1]
93 | P_sum, R_sum = 0, 0
94 | for i, pred in enumerate(pred_list):
95 | mask = (resize_lib[i](mask_ori).reshape(-1) > 0.3).float() # [B*N]
96 | if vit:
97 | mask = (resize_lib[2](mask_ori).reshape(-1) > 0.3).float() # [B*N]
98 | pred = pred.reshape(-1)
99 | loss = bce_loss(pred, mask)
100 | pred = (pred>0.5).float()
101 | PT = (pred.bool() & mask.bool()).sum()
102 | P = PT / (pred.sum()+1e-5)
103 | R = PT / (mask.sum()+1e-5)
104 | P_sum += P
105 | R_sum += R
106 | total_loss += loss
107 | return total_loss, P_sum/3, R_sum/3
108 |
109 |
110 | def get_amplify_loss(inputs, inputs_AT, mask_ori, fore_mask_ori=None, vit=False):
111 | cos_loss = torch.nn.CosineSimilarity()
112 | normal_loss = 0
113 | anomaly_loss = 0
114 | # 64, 32, 16
115 | for i in range(len(inputs)):
116 | # pdb.set_trace()
117 | B, C, h, w = inputs[i].shape
118 | mask = resize_lib[i](mask_ori).reshape(-1) > mask_thresh # [B*h*w]
119 | if vit:
120 | mask = resize_lib[2](mask_ori).reshape(-1) > mask_thresh
121 |
122 | if fore_mask_ori is not None:
123 | fore_mask = resize_lib[i](fore_mask_ori).reshape(-1) > mask_thresh
124 | if vit:
125 | fore_mask = resize_lib[2](fore_mask_ori).reshape(-1) > mask_thresh
126 |
127 | input = inputs[i].permute(0,2,3,1).reshape(-1, C) # [B,C,h,w]->[B,h,w,C]->[B*h*w,C]
128 | input_AT = inputs_AT[i].permute(0,2,3,1).reshape(-1, C) # [B,C,h,w]->[B,h,w,C]->[B*h*w,C]
129 | # normal unchange
130 | normal_loss += torch.mean(1-cos_loss(input[~mask], input_AT[~mask]))
131 |
132 | if fore_mask_ori is not None:
133 | normal_mask = (~mask)&(fore_mask)
134 | else:
135 | normal_mask = ~mask
136 | n_idx = np.random.permutation(normal_mask.sum().item())[:5000]
137 | a_idx = np.random.permutation(mask.sum().item())[:1000]
138 | input_normal, input_AT_normal = input[normal_mask][n_idx], input_AT[normal_mask][n_idx]
139 | if mask.sum() > 0:
140 | input_anomaly, input_AT_anomaly = input[mask][a_idx], input_AT[mask][a_idx]
141 | s_anomaly = pair_cosine(input_normal.unsqueeze(0), input_anomaly.unsqueeze(0))[0]
142 | s_AT_anomaly = pair_cosine(input_normal.unsqueeze(0), input_AT_anomaly.unsqueeze(0))[0]
143 | weight = 1
144 | anomaly_loss += torch.mean(F.relu(s_AT_anomaly - (s_anomaly-0.3)) * weight)
145 | else:
146 | anomaly_loss += 0
147 |
148 | return normal_loss, anomaly_loss
149 |
150 |
151 | def get_FnormLoss(inputs, mask_ori, vit=False):
152 | fnorm_loss = 0.0
153 | afnorm_loss = 0.0
154 | count = 0
155 | for i in range(len(inputs)):
156 | B, C, h, w = inputs[i].shape
157 | mask = resize_lib[i](mask_ori).reshape(-1) > mask_thresh # [B*h*w]
158 | if vit:
159 | mask = resize_lib[2](mask_ori).reshape(-1) > mask_thresh
160 | input = inputs[i].permute(0,2,3,1).reshape(-1, C) # [B,C,h,w]->[B,h,w,C]->[B*h*w,C]
161 | fnorm_loss += (input[~mask] ** 2).mean()
162 | if mask.sum() != 0:
163 | afnorm_loss += (input[mask] ** 2).mean()
164 | # afnorm_loss += F.relu(math.sqrt(C) - temp)
165 | count += 1
166 | fnorm_loss = fnorm_loss / 3
167 | if count > 0:
168 | afnorm_loss = afnorm_loss / count
169 | return fnorm_loss, afnorm_loss
170 |
171 |
172 | class BCEFocalLoss(torch.nn.Module):
173 | """
174 | 二分类的Focalloss alpha 固定
175 | """
176 | def __init__(self, gamma=2, alpha=0.8, reduction='elementwise_mean'):
177 | super().__init__()
178 | self.gamma = gamma
179 | self.alpha = alpha
180 | self.reduction = reduction
181 |
182 | def forward(self, pt, target):
183 | # pt = torch.sigmoid(_input)
184 | # pdb.set_trace()
185 | alpha = self.alpha
186 | loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
187 | (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
188 | if self.reduction == 'elementwise_mean':
189 | loss = torch.mean(loss)
190 | elif self.reduction == 'sum':
191 | loss = torch.sum(loss)
192 | return loss
193 |
--------------------------------------------------------------------------------
/datasets/VisA.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image
3 | import os
4 | import torch
5 | import glob
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | import pdb
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms as T
11 | import imgaug.augmenters as iaa
12 | # import albumentations as A
13 | from utils.perlin import rand_perlin_2d_np
14 | import random
15 | from utils.vis import vis_anomaly_images
16 | import cv2
17 | import json
18 | from datasets.database import BaseAnomalyDetectionDataset, SynthesisDataset, resize_organized_pc
19 |
20 |
21 | def visa_classes():
22 | return[
23 | 'candle',
24 | 'capsules',
25 | 'cashew',
26 | 'chewinggum',
27 | 'fryum',
28 | 'macaroni1',
29 | 'macaroni2',
30 | 'pcb1',
31 | 'pcb2',
32 | 'pcb3',
33 | 'pcb4',
34 | 'pipe_fryum'
35 | ]
36 |
37 |
38 |
39 | class TrainDataset(BaseAnomalyDetectionDataset):
40 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
41 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
42 | self.split = 'train'
43 | self.meta_info = json.load(open(f'{dataset_path}/meta.json', 'r'))[self.split]
44 | self.cls_name = class_name
45 | self.dataset_path = dataset_path
46 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
47 |
48 | def load_dataset(self):
49 | img_tot_paths = []
50 | tot_labels = []
51 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'good') + "/*.png")
52 | # pdb.set_trace()
53 | rgb_paths = self.meta_info[self.cls_name]
54 | rgb_paths = [f'{self.dataset_path}/{x["img_path"]}' for x in rgb_paths]
55 | rgb_paths.sort()
56 | img_tot_paths.extend(rgb_paths)
57 | tot_labels.extend([0] * len(rgb_paths))
58 | return img_tot_paths, tot_labels
59 |
60 | def __len__(self):
61 | return len(self.img_paths)
62 |
63 | def __getitem__(self, idx):
64 | rgb_path, label = self.img_paths[idx], self.labels[idx]
65 | img = Image.open(rgb_path).convert('RGB')
66 | img = self.rgb_transform(img)
67 |
68 | return img, label
69 |
70 |
71 | class TestDataset(BaseAnomalyDetectionDataset):
72 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
73 | super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
74 | self.split = 'test'
75 | self.meta_info = json.load(open(f'{dataset_path}/meta.json', 'r'))[self.split]
76 | self.cls_name = class_name
77 | self.dataset_path = dataset_path
78 | self.size = img_size
79 | self.gt_transform = transforms.Compose([
80 | transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST),
81 | transforms.ToTensor()])
82 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
83 |
84 | def load_dataset(self):
85 | img_tot_paths = []
86 | gt_tot_paths = []
87 | tot_labels = []
88 | # defect_types = os.listdir(self.img_path)
89 | # pdb.set_trace()
90 | rgb_paths = self.meta_info[self.cls_name]
91 | # normal samples
92 | normal_paths = [f'{self.dataset_path}/{x["img_path"]}' for x in rgb_paths if x['anomaly']==0]
93 | normal_paths.sort()
94 | img_tot_paths.extend(normal_paths)
95 | gt_tot_paths.extend([0] * len(normal_paths))
96 | tot_labels.extend([0] * len(normal_paths))
97 | # anomaly samples
98 | anomaly_paths = [f'{self.dataset_path}/{x["img_path"]}' for x in rgb_paths if x['anomaly']==1]
99 | mask_paths = [f'{self.dataset_path}/{x["mask_path"]}' for x in rgb_paths if x['anomaly']==1]
100 | anomaly_paths.sort()
101 | mask_paths.sort()
102 | img_tot_paths.extend(anomaly_paths)
103 | gt_tot_paths.extend(mask_paths)
104 | tot_labels.extend([1] * len(anomaly_paths))
105 |
106 | # for defect_type in defect_types:
107 | # if defect_type == 'good':
108 | # rgb_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png")
109 | # rgb_paths.sort()
110 | # img_tot_paths.extend(rgb_paths)
111 | # gt_tot_paths.extend([0] * len(rgb_paths))
112 | # tot_labels.extend([0] * len(rgb_paths))
113 | # else:
114 | # rgb_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png")
115 | # gt_paths = glob.glob(os.path.join(self.gt_path, defect_type) + "/*.png")
116 | # rgb_paths.sort()
117 | # gt_paths.sort()
118 |
119 | # img_tot_paths.extend(rgb_paths)
120 | # gt_tot_paths.extend(gt_paths)
121 | # tot_labels.extend([1] * len(rgb_paths))
122 |
123 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
124 |
125 | return img_tot_paths, gt_tot_paths, tot_labels
126 |
127 | def __len__(self):
128 | return len(self.img_paths)
129 |
130 | def __getitem__(self, idx):
131 | rgb_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
132 | img_original = Image.open(rgb_path).convert('RGB')
133 | img = self.rgb_transform(img_original)
134 |
135 | if gt == 0:
136 | gt = torch.zeros(
137 | [1, self.size, self.size])
138 | else:
139 | # pdb.set_trace()
140 | # gt = Image.open(gt).convert('L')
141 | gt= np.array(Image.open(gt).convert('L')) > 0
142 | gt = Image.fromarray(gt.astype(np.uint8) * 255, mode='L')
143 | gt = self.gt_transform(gt)
144 | gt = torch.where(gt > 0.5, 1., .0)
145 |
146 | # img_file = rgb_path.split('/')[-1].split('.')[0]
147 | # # pdb.set_trace()
148 | # fg_path = os.path.join(f'fg_mask/{self.cls}/test', img_file + '.png')
149 | # if os.path.exists(fg_path):
150 | # fg_mask = Image.open(fg_path)
151 | # fg_mask = np.asarray(fg_mask)[:, :, np.newaxis] # [H, W, 1]
152 | # resized_depth_map = resize_organized_pc(fg_mask, img_size=self.size)
153 | # fore_mask = resized_depth_map > 0
154 | # else:
155 | # fore_mask = None
156 |
157 | return img, gt[:1], label, rgb_path
158 |
159 |
160 | class AnomalyDataset(SynthesisDataset):
161 | def __init__(self, class_name, dataset_path, img_size, aux_path):
162 | super().__init__(class_name=class_name, img_size=img_size, dataset_path=dataset_path, aux_path=aux_path)
163 | self.split = 'train'
164 | self.dataset_path = dataset_path
165 | self.meta_info = json.load(open(f'{dataset_path}/meta.json', 'r'))[self.split]
166 | self.cls_name = class_name
167 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
168 |
169 | def load_dataset(self):
170 | img_tot_paths = []
171 | tot_labels = []
172 | # rgb_paths = glob.glob(os.path.join(self.img_path, 'good') + "/*.png")
173 | # pdb.set_trace()
174 | rgb_paths = self.meta_info[self.cls_name]
175 | rgb_paths = [f'{self.dataset_path}/{x["img_path"]}' for x in rgb_paths]
176 | # tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff")
177 | rgb_paths.sort()
178 | # tiff_paths.sort()
179 | # sample_paths = list(zip(rgb_paths, tiff_paths))
180 | # img_tot_paths.extend(sample_paths)
181 | # tot_labels.extend([0] * len(sample_paths))
182 | img_tot_paths.extend(rgb_paths)
183 | tot_labels.extend([0] * len(rgb_paths))
184 | return img_tot_paths, tot_labels
185 |
186 |
187 | def __len__(self):
188 | return len(self.labels)
189 |
190 | def __getitem__(self, idx):
191 | img_path, label = self.img_paths[idx], self.labels[idx]
192 | rgb_path = img_path
193 | # tiff_path = img_path[1]
194 | class_name = img_path.split("/")[2]
195 |
196 | img = Image.open(rgb_path).convert('RGB')
197 | img = self.rgb_transform(img)
198 |
199 | img_file = rgb_path.split('/')[-1].split('.')[0]
200 | # pdb.set_trace()
201 | fg_path = os.path.join(f'fg_mask/{self.dataset_name}/{self.cls}/train', img_file + '.png')
202 | if os.path.exists(fg_path):
203 | fg_mask = Image.open(fg_path)
204 | fg_mask = np.asarray(fg_mask)[:, :, np.newaxis] # [H, W, 1]
205 | resized_depth_map = resize_organized_pc(fg_mask, img_size=self.size)
206 | fore_mask = resized_depth_map > 0
207 | else:
208 | fore_mask = None
209 |
210 |
211 | # pdb.set_trace()
212 | # vis_anomaly_images(img.unsqueeze(0), fore_mask.unsqueeze(0).float(), class_name)
213 |
214 | # modify
215 | anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
216 | augmented_image, anomaly_mask, has_anomaly = self.transform_image(img, fore_mask,
217 | self.anomaly_source_paths[anomaly_source_idx])
218 | # # 下采样mask
219 | # anomaly_mask = torch.from_numpy(anomaly_mask).unsqueeze(0)
220 |
221 | # if has_anomaly == 1.0:
222 | # pdb.set_trace()
223 | # vis_anomaly_images(torch.from_numpy(augmented_image)[None], torch.from_numpy(anomaly_mask)[None], class_name)
224 |
225 | return {"img": augmented_image, "label": has_anomaly, "anomaly_mask": anomaly_mask, "fore_mask": fore_mask.float()}
226 |
--------------------------------------------------------------------------------
/datasets/database.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image
3 | import os
4 | import torch
5 | import glob
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | import pdb
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms as T
11 | import imgaug.augmenters as iaa
12 | # import albumentations as A
13 | from utils.perlin import rand_perlin_2d_np
14 | from utils.utils import get_dataset_name
15 | import random
16 | from utils.vis import vis_anomaly_images
17 | import cv2
18 |
19 | def resize_organized_pc(organized_pc, img_size=256, tensor_out=True):
20 | torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous()
21 | torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=img_size,
22 | mode='nearest')
23 | if tensor_out:
24 | return torch_resized_organized_pc.squeeze(dim=0).contiguous()
25 | else:
26 | return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy()
27 |
28 | class BaseAnomalyDetectionDataset(Dataset):
29 | def __init__(self, split, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
30 | self.IMAGENET_MEAN = [0.485, 0.456, 0.406]
31 | self.IMAGENET_STD = [0.229, 0.224, 0.225]
32 | self.cls = class_name
33 | self.size = img_size
34 | self.dataset_name = get_dataset_name(dataset_path)
35 | if self.dataset_name == 'mvtec':
36 | self.img_path = os.path.join(dataset_path, self.cls, split)
37 | self.gt_path = os.path.join(dataset_path, self.cls, 'ground_truth')
38 | elif self.dataset_name == 'mvtec3d':
39 | self.img_path = os.path.join(dataset_path, self.cls, split)
40 | self.rgb_transform = transforms.Compose(
41 | [transforms.Resize((self.size, self.size), interpolation=transforms.InterpolationMode.BICUBIC),
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)])
44 |
45 |
46 | class SynthesisDataset(BaseAnomalyDetectionDataset):
47 | def __init__(self, class_name, dataset_path, img_size, aux_path):
48 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
49 | self.anomaly_source_paths = sorted(glob.glob(aux_path + "/*/*.jpg"))
50 |
51 | self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True),
52 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)),
53 | iaa.pillike.EnhanceSharpness(),
54 | iaa.AddToHueAndSaturation((-50,50),per_channel=True),
55 | iaa.Solarize(0.5, threshold=(32,128)),
56 | iaa.Posterize(),
57 | iaa.Invert(),
58 | iaa.pillike.Autocontrast(),
59 | iaa.pillike.Equalize(),
60 | iaa.Affine(rotate=(-45, 45))
61 | ]
62 | # self.noise_transform = transforms.Compose(
63 | # [transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)])
64 |
65 | self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
66 |
67 | self.resize64 = torch.nn.AdaptiveAvgPool2d((64, 64))
68 | self.resize32 = torch.nn.AdaptiveAvgPool2d((32, 32))
69 | self.resize16 = torch.nn.AdaptiveAvgPool2d((16, 16))
70 |
71 | # 随机选择3种数据增强方式
72 | def randAugmenter(self):
73 | aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
74 | aug = iaa.Sequential([self.augmenters[aug_ind[0]],
75 | self.augmenters[aug_ind[1]],
76 | self.augmenters[aug_ind[2]]]
77 | )
78 | return aug
79 |
80 | def augment_image(self, image, anomaly_source_path, fore_mask):
81 | aug = self.randAugmenter()
82 | perlin_scale = 6
83 | min_perlin_scale = 0
84 |
85 | anomaly_source_img = cv2.imread(anomaly_source_path)
86 | anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.size, self.size))
87 |
88 | anomaly_img_augmented = aug(image=anomaly_source_img)
89 | perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
90 | perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
91 |
92 | if fore_mask is not None:
93 | count = 0
94 | while True:
95 | count += 1
96 | perlin_noise = rand_perlin_2d_np((self.size, self.size), (perlin_scalex, perlin_scaley))
97 | perlin_noise = self.rot(image=perlin_noise)
98 | threshold = 0.5
99 | # modify
100 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
101 | perlin_thr = np.expand_dims(perlin_thr, axis=2)
102 | perlin_thr = perlin_thr * fore_mask
103 | # pdb.set_trace()
104 | if perlin_thr.sum() > 4 or count > 10:
105 | break
106 | else:
107 | perlin_noise = rand_perlin_2d_np((self.size, self.size), (perlin_scalex, perlin_scaley))
108 | perlin_noise = self.rot(image=perlin_noise)
109 | threshold = 0.5
110 | # modify
111 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
112 | perlin_thr = np.expand_dims(perlin_thr, axis=2)
113 | #
114 | # modify, '/255' 改成 imagenet 的归一化
115 | # 测试一下img_thr的最大值和image的最小值
116 | # img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0
117 | img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0
118 | image_mean = np.array(self.IMAGENET_MEAN).reshape(1,1,3)
119 | image_std = np.array(self.IMAGENET_STD).reshape(1,1,3)
120 | img_thr = (img_thr - image_mean) / image_std
121 | #
122 |
123 | beta = torch.rand(1).numpy()[0] * 0.8
124 | augmented_image = image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * (
125 | perlin_thr)
126 |
127 | augmented_image = augmented_image.astype(np.float32)
128 | msk = (perlin_thr).astype(np.float32)
129 | augmented_image = msk * augmented_image + (1-msk)*image # This line is unnecessary and can be deleted
130 | has_anomaly = 1.0
131 |
132 | no_anomaly = torch.rand(1).numpy()[0]
133 | if no_anomaly > 0.8:
134 | image = image.astype(np.float32)
135 | return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0],dtype=np.float32)
136 | else: # 0.8概率产生异常
137 | augmented_image = augmented_image.astype(np.float32)
138 | msk = (perlin_thr).astype(np.float32)
139 | augmented_image = msk * augmented_image + (1-msk)*image # This line is unnecessary and can be deleted
140 | has_anomaly = 1.0
141 | if np.sum(msk) == 0:
142 | has_anomaly=0.0
143 | return augmented_image, msk, np.array([has_anomaly],dtype=np.float32)
144 |
145 |
146 | def transform_image(self, image, fore_mask, anomaly_source_path):
147 | image = image.permute(1,2,0).numpy()
148 | if fore_mask is not None:
149 | fore_mask = fore_mask.permute(1,2,0).numpy()
150 |
151 | # normalize the image to 0.0~1.0
152 | augmented_image, anomaly_mask, has_anomaly = self.augment_image(image, anomaly_source_path, fore_mask)
153 | augmented_image = np.transpose(augmented_image, (2, 0, 1))
154 | image = np.transpose(image, (2, 0, 1))
155 | anomaly_mask = np.transpose(anomaly_mask, (2, 0, 1))
156 | return augmented_image, anomaly_mask, has_anomaly
157 |
158 |
159 |
160 | class CutPasteDataSet(BaseAnomalyDetectionDataset):
161 | def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
162 | super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
163 | self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
164 | self.cutpaste_transform = CutPaste_fg(type='3way')
165 | self.gt_transform = transforms.Compose([
166 | transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST),
167 | transforms.ToTensor()])
168 |
169 | def load_dataset(self):
170 | img_tot_paths = []
171 | tot_labels = []
172 | rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png")
173 | tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff")
174 | rgb_paths.sort()
175 | tiff_paths.sort()
176 | sample_paths = list(zip(rgb_paths, tiff_paths))
177 | img_tot_paths.extend(sample_paths)
178 | tot_labels.extend([0] * len(sample_paths))
179 | return img_tot_paths, tot_labels
180 |
181 | def __len__(self):
182 | return len(self.img_paths)
183 |
184 | def __getitem__(self, idx):
185 | img_path, label = self.img_paths[idx], self.labels[idx]
186 | rgb_path = img_path[0]
187 | tiff_path = img_path[1]
188 | img = Image.open(rgb_path).convert('RGB')
189 | # img.save("imgs/ori_img.jpg")
190 | # pdb.set_trace()
191 | organized_pc = read_tiff_organized_pc(tiff_path)
192 | depth_map = organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis]
193 | resized_depth_map = resize_organized_pc(depth_map, img_size=[img.size[1],img.size[0]]).squeeze(0).numpy()
194 | fore_mask = resized_depth_map > 0
195 |
196 | cutpaste, anomaly_mask = self.cutpaste_transform(img, fore_mask)
197 | # cutpaste = self.copy_paste(img, fore_mask)
198 | # cutpaste.save("imgs/aug_img1.jpg")
199 | # cutpaste_scar.save("imgs/aug_img2.jpg")
200 | cutpaste = Image.fromarray(cutpaste)
201 | anomaly_mask = Image.fromarray(anomaly_mask)
202 | aug_img = self.rgb_transform(cutpaste)
203 | anomaly_mask = self.gt_transform(anomaly_mask)
204 |
205 | return {"img": aug_img, "anomaly_mask": anomaly_mask, "fore_mask": fore_mask}
206 |
207 |
--------------------------------------------------------------------------------
/vit_version/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 | import pdb
13 |
14 | from .model import build_model
15 | # from .model_plus import build_model
16 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
17 |
18 | try:
19 | from torchvision.transforms import InterpolationMode
20 | BICUBIC = InterpolationMode.BICUBIC
21 | except ImportError:
22 | BICUBIC = Image.BICUBIC
23 |
24 |
25 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
26 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
27 |
28 |
29 | __all__ = ["available_models", "load", "tokenize"]
30 | _tokenizer = _Tokenizer()
31 |
32 | _MODELS = {
33 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
34 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
35 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
36 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
37 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
42 | }
43 |
44 |
45 | def _download(url: str, root: str):
46 | os.makedirs(root, exist_ok=True)
47 | filename = os.path.basename(url)
48 |
49 | expected_sha256 = url.split("/")[-2]
50 | download_target = os.path.join(root, filename)
51 |
52 | if os.path.exists(download_target) and not os.path.isfile(download_target):
53 | raise RuntimeError(f"{download_target} exists and is not a regular file")
54 |
55 | if os.path.isfile(download_target):
56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
57 | return download_target
58 | else:
59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
60 |
61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
63 | while True:
64 | buffer = source.read(8192)
65 | if not buffer:
66 | break
67 |
68 | output.write(buffer)
69 | loop.update(len(buffer))
70 |
71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
72 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
73 |
74 | return download_target
75 |
76 |
77 | def _convert_image_to_rgb(image):
78 | return image.convert("RGB")
79 |
80 |
81 | def _transform(n_px):
82 | return Compose([
83 | Resize((n_px, n_px), interpolation=BICUBIC), # Modify, 之前是(n_px, n_px)
84 | CenterCrop((n_px, n_px)),
85 | _convert_image_to_rgb,
86 | ToTensor(),
87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
88 | ])
89 |
90 |
91 |
92 | def available_models() -> List[str]:
93 | """Returns the names of available CLIP models"""
94 | return list(_MODELS.keys())
95 |
96 |
97 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, image_size: int = None):
98 | """Load a CLIP model
99 |
100 | Parameters
101 | ----------
102 | name : str
103 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
104 |
105 | device : Union[str, torch.device]
106 | The device to put the loaded model
107 |
108 | jit : bool
109 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
110 |
111 | download_root: str
112 | path to download the model files; by default, it uses "~/.cache/clip"
113 |
114 | Returns
115 | -------
116 | model : torch.nn.Module
117 | The CLIP model
118 |
119 | preprocess : Callable[[PIL.Image], torch.Tensor]
120 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
121 | """
122 | if name in _MODELS:
123 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
124 | elif os.path.isfile(name):
125 | model_path = name
126 | else:
127 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
128 |
129 | # pdb.set_trace()
130 | with open(model_path, 'rb') as opened_file:
131 | try:
132 | # loading JIT archive
133 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
134 | state_dict = None
135 | except RuntimeError:
136 | # loading saved state dict
137 | if jit:
138 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
139 | jit = False
140 | state_dict = torch.load(opened_file, map_location="cpu")
141 |
142 | if not jit:
143 | model = build_model(state_dict or model.state_dict(), image_size).to(device)
144 | if str(device) == "cpu":
145 | model.float()
146 | return model, _transform(image_size)
147 |
148 | # patch the device names
149 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
150 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
151 |
152 | def patch_device(module):
153 | try:
154 | graphs = [module.graph] if hasattr(module, "graph") else []
155 | except RuntimeError:
156 | graphs = []
157 |
158 | if hasattr(module, "forward1"):
159 | graphs.append(module.forward1.graph)
160 |
161 | for graph in graphs:
162 | for node in graph.findAllNodes("prim::Constant"):
163 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
164 | node.copyAttributes(device_node)
165 |
166 | model.apply(patch_device)
167 | patch_device(model.encode_image)
168 | patch_device(model.encode_text)
169 |
170 | # patch dtype to float32 on CPU
171 | if str(device) == "cpu":
172 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
173 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
174 | float_node = float_input.node()
175 |
176 | def patch_float(module):
177 | try:
178 | graphs = [module.graph] if hasattr(module, "graph") else []
179 | except RuntimeError:
180 | graphs = []
181 |
182 | if hasattr(module, "forward1"):
183 | graphs.append(module.forward1.graph)
184 |
185 | for graph in graphs:
186 | for node in graph.findAllNodes("aten::to"):
187 | inputs = list(node.inputs())
188 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
189 | if inputs[i].node()["value"] == 5:
190 | inputs[i].node().copyAttributes(float_node)
191 |
192 | model.apply(patch_float)
193 | patch_float(model.encode_image)
194 | patch_float(model.encode_text)
195 |
196 | model.float()
197 |
198 | return model, _transform(image_size)
199 |
200 |
201 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
202 | """
203 | Returns the tokenized representation of given input string(s)
204 |
205 | Parameters
206 | ----------
207 | texts : Union[str, List[str]]
208 | An input string or a list of input strings to tokenize
209 |
210 | context_length : int
211 | The context length to use; all CLIP models use 77 as the context length
212 |
213 | truncate: bool
214 | Whether to truncate the text in case its encoding is longer than the context length
215 |
216 | Returns
217 | -------
218 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
219 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
220 | """
221 | if isinstance(texts, str):
222 | texts = [texts]
223 |
224 | sot_token = _tokenizer.encoder["<|startoftext|>"]
225 | eot_token = _tokenizer.encoder["<|endoftext|>"]
226 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
227 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
228 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
229 | else:
230 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
231 |
232 | for i, tokens in enumerate(all_tokens):
233 | if len(tokens) > context_length:
234 | if truncate:
235 | tokens = tokens[:context_length]
236 | tokens[-1] = eot_token
237 | else:
238 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
239 | result[i, :len(tokens)] = torch.tensor(tokens)
240 |
241 | return result
242 |
--------------------------------------------------------------------------------
/vit_version/clip_rar/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 | import pdb
13 |
14 | from .model import build_model
15 | # from .model_plus import build_model
16 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
17 |
18 | try:
19 | from torchvision.transforms import InterpolationMode
20 | BICUBIC = InterpolationMode.BICUBIC
21 | except ImportError:
22 | BICUBIC = Image.BICUBIC
23 |
24 |
25 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
26 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
27 |
28 |
29 | __all__ = ["available_models", "load", "tokenize"]
30 | _tokenizer = _Tokenizer()
31 |
32 | _MODELS = {
33 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
34 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
35 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
36 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
37 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
42 | }
43 |
44 |
45 | def _download(url: str, root: str):
46 | os.makedirs(root, exist_ok=True)
47 | filename = os.path.basename(url)
48 |
49 | expected_sha256 = url.split("/")[-2]
50 | download_target = os.path.join(root, filename)
51 |
52 | if os.path.exists(download_target) and not os.path.isfile(download_target):
53 | raise RuntimeError(f"{download_target} exists and is not a regular file")
54 |
55 | if os.path.isfile(download_target):
56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
57 | return download_target
58 | else:
59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
60 |
61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
63 | while True:
64 | buffer = source.read(8192)
65 | if not buffer:
66 | break
67 |
68 | output.write(buffer)
69 | loop.update(len(buffer))
70 |
71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
72 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
73 |
74 | return download_target
75 |
76 |
77 | def _convert_image_to_rgb(image):
78 | return image.convert("RGB")
79 |
80 |
81 | def _transform(n_px):
82 | return Compose([
83 | Resize((n_px, n_px), interpolation=BICUBIC), # Modify, 之前是(n_px, n_px)
84 | CenterCrop((n_px, n_px)),
85 | _convert_image_to_rgb,
86 | ToTensor(),
87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
88 | ])
89 |
90 |
91 |
92 | def available_models() -> List[str]:
93 | """Returns the names of available CLIP models"""
94 | return list(_MODELS.keys())
95 |
96 |
97 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, image_size: int = None):
98 | """Load a CLIP model
99 |
100 | Parameters
101 | ----------
102 | name : str
103 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
104 |
105 | device : Union[str, torch.device]
106 | The device to put the loaded model
107 |
108 | jit : bool
109 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
110 |
111 | download_root: str
112 | path to download the model files; by default, it uses "~/.cache/clip"
113 |
114 | Returns
115 | -------
116 | model : torch.nn.Module
117 | The CLIP model
118 |
119 | preprocess : Callable[[PIL.Image], torch.Tensor]
120 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
121 | """
122 | if name in _MODELS:
123 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
124 | elif os.path.isfile(name):
125 | model_path = name
126 | else:
127 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
128 |
129 | # pdb.set_trace()
130 | with open(model_path, 'rb') as opened_file:
131 | try:
132 | # loading JIT archive
133 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
134 | state_dict = None
135 | except RuntimeError:
136 | # loading saved state dict
137 | if jit:
138 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
139 | jit = False
140 | state_dict = torch.load(opened_file, map_location="cpu")
141 |
142 | if not jit:
143 | model = build_model(state_dict or model.state_dict(), image_size).to(device)
144 | if str(device) == "cpu":
145 | model.float()
146 | return model, _transform(image_size)
147 |
148 | # patch the device names
149 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
150 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
151 |
152 | def patch_device(module):
153 | try:
154 | graphs = [module.graph] if hasattr(module, "graph") else []
155 | except RuntimeError:
156 | graphs = []
157 |
158 | if hasattr(module, "forward1"):
159 | graphs.append(module.forward1.graph)
160 |
161 | for graph in graphs:
162 | for node in graph.findAllNodes("prim::Constant"):
163 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
164 | node.copyAttributes(device_node)
165 |
166 | model.apply(patch_device)
167 | patch_device(model.encode_image)
168 | patch_device(model.encode_text)
169 |
170 | # patch dtype to float32 on CPU
171 | if str(device) == "cpu":
172 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
173 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
174 | float_node = float_input.node()
175 |
176 | def patch_float(module):
177 | try:
178 | graphs = [module.graph] if hasattr(module, "graph") else []
179 | except RuntimeError:
180 | graphs = []
181 |
182 | if hasattr(module, "forward1"):
183 | graphs.append(module.forward1.graph)
184 |
185 | for graph in graphs:
186 | for node in graph.findAllNodes("aten::to"):
187 | inputs = list(node.inputs())
188 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
189 | if inputs[i].node()["value"] == 5:
190 | inputs[i].node().copyAttributes(float_node)
191 |
192 | model.apply(patch_float)
193 | patch_float(model.encode_image)
194 | patch_float(model.encode_text)
195 |
196 | model.float()
197 |
198 | return model, _transform(image_size)
199 |
200 |
201 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
202 | """
203 | Returns the tokenized representation of given input string(s)
204 |
205 | Parameters
206 | ----------
207 | texts : Union[str, List[str]]
208 | An input string or a list of input strings to tokenize
209 |
210 | context_length : int
211 | The context length to use; all CLIP models use 77 as the context length
212 |
213 | truncate: bool
214 | Whether to truncate the text in case its encoding is longer than the context length
215 |
216 | Returns
217 | -------
218 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
219 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
220 | """
221 | if isinstance(texts, str):
222 | texts = [texts]
223 |
224 | sot_token = _tokenizer.encoder["<|startoftext|>"]
225 | eot_token = _tokenizer.encoder["<|endoftext|>"]
226 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
227 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
228 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
229 | else:
230 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
231 |
232 | for i, tokens in enumerate(all_tokens):
233 | if len(tokens) > context_length:
234 | if truncate:
235 | tokens = tokens[:context_length]
236 | tokens[-1] = eot_token
237 | else:
238 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
239 | result[i, :len(tokens)] = torch.tensor(tokens)
240 |
241 | return result
242 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.datasets import ImageFolder
3 | import numpy as np
4 | import random
5 | import time
6 | import os
7 | import torch.nn as nn
8 | from torch.utils.data import DataLoader
9 | from models.resnet import resnet18, resnet34, resnet50, wide_resnet50_2
10 | from models.resnet_rar import resnet18, resnet34, resnet50, wide_resnet50_2_rar
11 | from models.de_resnet import de_resnet18, de_resnet34, de_wide_resnet50_2, de_resnet50
12 | # from models.de_resnet_mem import de_wide_resnet50_2_mem
13 | import torch.backends.cudnn as cudnn
14 | import argparse
15 | from test import evaluation_Stage1, evaluation_Stage2
16 | from torch.nn import functional as F
17 | from models.loss import *
18 | from utils.vis import vis_anomaly_images
19 | from models.recons_net import *
20 | from utils.utils import setup_seed, get_dataset_name
21 | import pdb
22 | from utils.vis import *
23 | from tqdm import tqdm
24 | from utils.utils import PatchMaker
25 |
26 | def count_parameters(model):
27 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
28 |
29 |
30 | def train_Stage1(args, _class_, epochs=100, eval_interval=10, lr=0.0002):
31 | print("training Advanced Teacher...")
32 | print(_class_)
33 |
34 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
35 | print(device)
36 |
37 | Stage1_ckp_path = './checkpoints_Stage1/' + 'wres50_'+_class_+'.pth'
38 | os.makedirs('checkpoints_Stage1', exist_ok=True)
39 | train_data = AnomalyDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root, aux_path=args.aux_path) # 注意:既有正常又有合成的异常
40 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root)
41 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
42 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
43 |
44 | encoder, _ = wide_resnet50_2(pretrained=True)
45 | encoder = encoder.to(device)
46 | encoder.eval()
47 | # load INet-pretrain,
48 | encoder_AT, _ = wide_resnet50_2_rar(pretrained=True)
49 | encoder_AT = encoder_AT.to(device)
50 |
51 | for name, para in encoder_AT.named_parameters():
52 | if 'feat_recons' in name:
53 | para.requires_grad = True
54 | else:
55 | para.requires_grad = False
56 | # print([name for name, para in encoder_AT.named_parameters() if para.requires_grad])
57 | # pdb.set_trace()
58 | optimizer_Stage1 = torch.optim.Adam((filter(lambda p: p.requires_grad, encoder_AT.parameters())), lr=lr, betas=(0.5,0.999))
59 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Stage1, [epochs*0.8, epochs*0.9], gamma=0.2, last_epoch=-1)
60 |
61 | best_auroc_sp = -1
62 | best_auroc_px = -1
63 | best_P = -1
64 | loss_per_epoch = {'fnorm_loss': [], 'afnorm_loss': [], 'test_fnorm_loss': [], 'test_afnorm_loss': []}
65 |
66 | # auroc_px, auroc_sp, aupro_px, test_fnorm_loss, test_afnorm_loss = evaluation_Stage1(encoder, encoder_AT, test_dataloader, device)
67 | # print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px))
68 | # print(f'test_fnorm_loss: {np.mean(test_fnorm_loss):.4f} test_afnorm_loss: {np.mean(test_afnorm_loss):.4f}')
69 | # pdb.set_trace()
70 | for epoch in tqdm(range(epochs)):
71 | encoder_AT.train() # for every epoch, set train mode, while the evaluation phase eval mode
72 | # set BN false
73 | for name, module in encoder_AT.named_modules():
74 | if isinstance(module, nn.BatchNorm2d):
75 | module.eval()
76 | if isinstance(module, nn.BatchNorm1d):
77 | module.eval()
78 |
79 | normal_loss_list, anomaly_loss_list, fnorm_loss_list, afnorm_loss_list, focal_loss_list = [], [], [], [], []
80 | P_list, R_list = [], []
81 | for batch_data in train_dataloader:
82 | # load data
83 | img = batch_data["img"].to(device)
84 | anomaly_mask = batch_data["anomaly_mask"].to(device) # 注意:是否要求有一半是正常图像?
85 | fore_mask = batch_data["fore_mask"].to(device)
86 | # vis_anomaly_images(img, anomaly_mask, _class_)
87 | with torch.no_grad():
88 | inputs = encoder(img)
89 | _, inputs_AT, delta, atten = encoder_AT(img, flag=False, atten_mask=anomaly_mask)
90 | # pdb.set_trace()
91 | loss_focal, P, R = get_focal_loss(atten, anomaly_mask)
92 | loss_normal, loss_anomaly = get_amplify_loss(inputs, inputs_AT, anomaly_mask, fore_mask)
93 | loss = loss_focal*1.0 + loss_anomaly * 0.1
94 | # Res Loss: the residual of anomalies (afnorm) should increase
95 | fnorm_loss, afnorm_loss = get_FnormLoss(delta, anomaly_mask)
96 | fnorm_loss_list.append(fnorm_loss.item())
97 | if afnorm_loss != 0.0:
98 | afnorm_loss_list.append(afnorm_loss.item())
99 |
100 | optimizer_Stage1.zero_grad()
101 | loss.backward()
102 | optimizer_Stage1.step()
103 |
104 | normal_loss_list.append(loss_normal.item())
105 | loss_anomaly = loss_anomaly.item() if torch.is_tensor(loss_anomaly) else loss_anomaly
106 | anomaly_loss_list.append(loss_anomaly)
107 | focal_loss_list.append(loss_focal.item())
108 | P_list.append(P.item())
109 | R_list.append(R.item())
110 |
111 | # if np.isnan(np.mean(anomaly_loss_list)):
112 | # pdb.set_trace()
113 | scheduler.step() # modify
114 | print(f'epoch [{epoch + 1}/{epochs}], focal_loss:{np.mean(focal_loss_list):.4f}\t, anomaly_loss:{np.mean(anomaly_loss_list):.4f}\t'
115 | f'fnorm_loss:{np.mean(fnorm_loss_list):.4f}, a_fnorm_loss:{np.mean(afnorm_loss_list):.4f}'
116 | f'P:{np.mean(P_list):.2f}, R:{np.mean(R_list):.2f}')
117 | loss_per_epoch['fnorm_loss'].append(np.mean(fnorm_loss_list))
118 | loss_per_epoch['afnorm_loss'].append(np.mean(afnorm_loss_list))
119 |
120 | if (epoch==0) or ((epoch + 1) % eval_interval == 0):
121 | ## evaluating perfermance to choose the best, common practice in indutrial anomaly detection, such as simplenet, RD++, BGAD ...
122 | auroc_px, auroc_sp, aupro_px, test_fnorm_loss, test_afnorm_loss, eval_P = evaluation_Stage1(encoder, encoder_AT, test_dataloader, device)
123 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px))
124 | print(f'test_fnorm_loss: {np.mean(test_fnorm_loss):.4f} test_afnorm_loss: {np.mean(test_afnorm_loss):.4f}')
125 | loss_per_epoch['test_fnorm_loss'].append(np.mean(test_fnorm_loss))
126 | loss_per_epoch['test_afnorm_loss'].append(np.mean(test_afnorm_loss))
127 |
128 | # if auroc_sp > best_auroc_sp and epoch > 5:
129 | if eval_P > best_P and epoch > 5:
130 | best_P = eval_P
131 | best_auroc_sp = auroc_sp
132 | best_auroc_px = auroc_px
133 | torch.save({'encoder_AT': encoder_AT.state_dict()}, Stage1_ckp_path)
134 | # 'reconsKV_fs': recons_rar.state_dict()}, Stage1_ckp_path)
135 | if auroc_sp > 0.999 and epoch > 30:
136 | break
137 | return best_auroc_px, best_auroc_sp, aupro_px
138 |
139 |
140 | def train_Stage2(args, _class_, epochs=100, eval_interval=10, lr=0.005):
141 | print("training Stubborn Student...")
142 | print(_class_)
143 |
144 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
145 | print(device)
146 |
147 | Stage1_ckp_path = './checkpoints_Stage1/' + 'wres50_'+_class_+'.pth'
148 | Stage2_ckp_path = './checkpoints_Stage2/' + 'wres50_'+ _class_+'.pth'
149 | os.makedirs('checkpoints_Stage2', exist_ok=True)
150 | # only normal data
151 | train_data = TrainDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root)
152 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root)
153 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
154 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
155 |
156 | encoder_AT, bn = wide_resnet50_2_rar(pretrained=True)
157 | encoder_AT = encoder_AT.to(device)
158 | encoder_AT.load_state_dict(torch.load(Stage1_ckp_path)['encoder_AT']) # 加载Stage1
159 | encoder_AT.eval()
160 |
161 | encoder_pre, _ = wide_resnet50_2(pretrained=True)
162 | encoder_pre = encoder_pre.to(device)
163 | encoder_pre.eval()
164 |
165 | bn = bn.to(device)
166 | # decoder_SS = de_wide_resnet50_2_mem(pretrained=False)
167 | decoder_SS = de_wide_resnet50_2(pretrained=False)
168 | decoder_SS = decoder_SS.to(device)
169 |
170 | optimizer_Stage2 = torch.optim.Adam(list(decoder_SS.parameters())+list(bn.parameters()), lr=lr, betas=(0.5,0.999))
171 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Stage2, [epochs*0.8, epochs*0.9], gamma=0.2, last_epoch=-1)
172 |
173 | best_auroc_sp = -1
174 | best_auroc_px = -1
175 | for epoch in tqdm(range(epochs)):
176 | bn.train()
177 | decoder_SS.train()
178 |
179 | kd_loss_list = []
180 | recons_loss_list = []
181 | orth_loss_list = []
182 | for batch_data in train_dataloader:
183 | # load data
184 | img = batch_data[0].to(device)
185 | with torch.no_grad():
186 | _, inputs, _, _ = encoder_AT(img, flag=True)
187 | # inputs = encoder_pre(img)
188 |
189 | outputs = decoder_SS(bn(inputs))
190 |
191 | kd_loss = loss_fucntion(inputs, outputs) # 这句话是不变的
192 | loss = kd_loss
193 |
194 | optimizer_Stage2.zero_grad()
195 | loss.backward()
196 | optimizer_Stage2.step()
197 |
198 | kd_loss_list.append(kd_loss.item())
199 |
200 | scheduler.step() # modify
201 | print('epoch [{}/{}], kd_loss:{:.4f}'.format(epoch + 1, epochs, np.mean(kd_loss_list)))
202 |
203 | if (epoch + 1) % eval_interval == 0:
204 | auroc_px, auroc_sp, aupro_px, _, _, _ = evaluation_Stage2(encoder_AT, bn, decoder_SS, test_dataloader, device)
205 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px))
206 |
207 | if (auroc_sp+auroc_px) > (best_auroc_sp+best_auroc_px) and epoch > 5:
208 | best_auroc_sp = auroc_sp
209 | best_auroc_px = auroc_px
210 | torch.save({'bn': bn.state_dict(),
211 | 'decoder_ss': decoder_SS.state_dict()}, Stage2_ckp_path)
212 | if auroc_sp > 0.999 and epoch > 5:
213 | break
214 | return best_auroc_px, best_auroc_sp, aupro_px
215 |
216 |
217 | if __name__ == '__main__':
218 | parser = argparse.ArgumentParser()
219 | parser.add_argument('--data_root', type=str, default='/data4/tch/AD_data/mvtec')
220 | parser.add_argument('--aux_path', type=str, default='/data4/tch/AD_data/DRAEM_dtd/dtd/images')
221 | parser.add_argument('--batch_size', type=int, default=16)
222 | parser.add_argument('--image_size', type=int, default=256)
223 | args = parser.parse_args()
224 |
225 | dataset_name = get_dataset_name(args.data_root)
226 | if dataset_name == 'mvtec':
227 | from datasets.MvTec import TrainDataset, TestDataset, AnomalyDataset
228 | from datasets.MvTec import mvtec_classes
229 | item_list = mvtec_classes()
230 | elif dataset_name == 'VisA':
231 | from datasets.VisA import TrainDataset, TestDataset, AnomalyDataset
232 | from datasets.VisA import visa_classes
233 | item_list = visa_classes()
234 | elif dataset_name == 'mvtec3d':
235 | from datasets.MvTec3D import TrainDataset, TestDataset, AnomalyDataset
236 | from datasets.MvTec3D import mvtec3d_classes
237 | item_list = mvtec3d_classes()
238 |
239 | setup_seed(0)
240 |
241 | auroc_Stage1_list = []
242 | auroc_Stage2_list = []
243 |
244 | for i in item_list:
245 | _, auroc_sp, _ = train_Stage1(args, i, epochs=100, eval_interval=5, lr=0.005)
246 | auroc_Stage1_list.append(auroc_sp)
247 | _, auroc_sp, _ = train_Stage2(args, i, epochs=120, eval_interval=5, lr=0.005)
248 | auroc_Stage2_list.append(auroc_sp)
249 | auroc_Stage1_mean = np.mean(auroc_Stage1_list)
250 | auroc_Stage2_mean = np.mean(auroc_Stage2_list)
251 |
252 | with open('results.txt', 'a') as f:
253 | f.write(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))+'\n')
254 | f.write('Stage1: ')
255 | for item in auroc_Stage1_list:
256 | f.write(f"{item:.4f}")
257 | f.write(' ')
258 | f.write(f'avg: {auroc_Stage1_mean:.4f}\n')
259 |
260 | f.write('Stage2: ')
261 | for item in auroc_Stage2_list:
262 | f.write(f"{item:.4f}")
263 | f.write(' ')
264 | f.write(f'avg: {auroc_Stage2_mean:.4f}\n')
265 |
--------------------------------------------------------------------------------
/vit_version/train_vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.datasets import ImageFolder
3 | import numpy as np
4 | import random
5 | import time
6 | import os
7 | import torch.nn as nn
8 | from torch.utils.data import DataLoader
9 | import os, pdb, sys
10 | sys.path.append(os.getcwd())
11 | from models.resnet import resnet18, resnet34, resnet50, wide_resnet50_2
12 | from models.resnet_rar import resnet18, resnet34, resnet50, wide_resnet50_2_rar
13 | from models.de_resnet import de_resnet18, de_resnet34, de_wide_resnet50_2, de_resnet50
14 | # from models.de_resnet_mem import de_wide_resnet50_2_mem
15 | import torch.backends.cudnn as cudnn
16 | import argparse
17 | from vit_version.test_vit import evaluation_Stage1, evaluation_Stage2
18 | from torch.nn import functional as F
19 | from models.loss import *
20 | from utils.vis import vis_anomaly_images
21 | from models.recons_net import *
22 | from utils.utils import setup_seed, get_dataset_name
23 | import pdb
24 | from utils.vis import *
25 | from tqdm import tqdm
26 | from utils.utils import PatchMaker
27 | # import clip, clip_rar
28 | # from clip_bn import BN_layer, AttnBottleneck
29 | # from clip_decoder import de_VisionTransformer
30 | import vit_version.clip as clip
31 | import vit_version.clip_rar as clip_rar
32 | from vit_version.clip_bn import BN_layer, AttnBottleneck
33 | from vit_version.clip_decoder import de_VisionTransformer
34 |
35 |
36 | def count_parameters(model):
37 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
38 |
39 |
40 | def train_Stage1(args, _class_, epochs=100, eval_interval=10, lr=0.0002):
41 | print("training Advanced Teacher...")
42 | print(_class_)
43 |
44 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
45 | print(device)
46 |
47 | Stage1_ckp_path = 'vit_version/checkpoints_Stage1/' + 'wres50_'+_class_+'.pth'
48 | os.makedirs('vit_version/checkpoints_Stage1', exist_ok=True)
49 | train_data = AnomalyDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root, aux_path=args.aux_path) # 注意:既有正常又有合成的异常
50 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root)
51 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
52 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
53 |
54 | # encoder, _ = clip.load("ViT-L/14@336px", download_root='./clip', device=torch.device("cpu"), image_size=336) # Modify1
55 | encoder, _ = clip.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size)
56 | encoder = encoder.to(device)
57 | encoder.eval()
58 |
59 | encoder_AT, _ = clip_rar.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size)
60 | encoder_AT = encoder_AT.to(device)
61 |
62 | for name, para in encoder_AT.named_parameters():
63 | if 'feat_recons' in name:
64 | para.requires_grad = True
65 | else:
66 | para.requires_grad = False
67 | # print([name for name, para in encoder_AT.named_parameters() if para.requires_grad])
68 | # pdb.set_trace()
69 | optimizer_Stage1 = torch.optim.Adam((filter(lambda p: p.requires_grad, encoder_AT.parameters())), lr=lr, betas=(0.5,0.999))
70 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Stage1, [epochs*0.8, epochs*0.9], gamma=0.2, last_epoch=-1)
71 |
72 | best_auroc_sp = -1
73 | best_auroc_px = -1
74 | best_P = -1
75 | loss_per_epoch = {'fnorm_loss': [], 'afnorm_loss': [], 'test_fnorm_loss': [], 'test_afnorm_loss': []}
76 |
77 | # auroc_px, auroc_sp, aupro_px, test_fnorm_loss, test_afnorm_loss = evaluation_Stage1(encoder, encoder_AT, test_dataloader, device)
78 | # print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px))
79 | # print(f'test_fnorm_loss: {np.mean(test_fnorm_loss):.4f} test_afnorm_loss: {np.mean(test_afnorm_loss):.4f}')
80 | # pdb.set_trace()
81 | for epoch in tqdm(range(epochs)):
82 | encoder_AT.train()
83 | for name, module in encoder_AT.named_modules():
84 | if isinstance(module, nn.BatchNorm2d):
85 | module.eval()
86 | if isinstance(module, nn.BatchNorm1d):
87 | module.eval()
88 |
89 | normal_loss_list, anomaly_loss_list, fnorm_loss_list, afnorm_loss_list, focal_loss_list = [], [], [], [], []
90 | P_list, R_list = [], []
91 | for batch_data in train_dataloader:
92 | # load data
93 | img = batch_data["img"].to(device)
94 | anomaly_mask = batch_data["anomaly_mask"].to(device) # 注意:是否要求有一半是正常图像?
95 | fore_mask = batch_data["fore_mask"].to(device)
96 | # vis_anomaly_images(img, anomaly_mask, _class_)
97 |
98 | with torch.no_grad():
99 | _, inputs = encoder.encode_image(img, f_list=[4,8,12])
100 |
101 | _, inputs_AT, atten, delta = encoder_AT.encode_image(img, f_list=[4,8,12])
102 | # pdb.set_trace()
103 | loss_focal, P, R = get_focal_loss(atten, anomaly_mask, vit=True)
104 | loss_normal, loss_anomaly = get_amplify_loss(inputs, inputs_AT, anomaly_mask, fore_mask, vit=True)
105 | loss = loss_focal*1.0 + loss_anomaly *0.1 + loss_normal*0
106 | # Res Loss:
107 | fnorm_loss, afnorm_loss = get_FnormLoss(delta, anomaly_mask, vit=True)
108 | fnorm_loss_list.append(fnorm_loss.item())
109 | if afnorm_loss != 0.0:
110 | afnorm_loss_list.append(afnorm_loss.item())
111 |
112 | optimizer_Stage1.zero_grad()
113 | loss.backward()
114 | optimizer_Stage1.step()
115 |
116 | normal_loss_list.append(loss_normal.item())
117 | loss_anomaly = loss_anomaly.item() if torch.is_tensor(loss_anomaly) else loss_anomaly
118 | anomaly_loss_list.append(loss_anomaly)
119 | focal_loss_list.append(loss_focal.item())
120 | P_list.append(P.item())
121 | R_list.append(R.item())
122 |
123 | # if np.isnan(np.mean(anomaly_loss_list)):
124 | # pdb.set_trace()
125 | scheduler.step() # modify
126 | print(f'epoch [{epoch + 1}/{epochs}], focal_loss:{np.mean(focal_loss_list):.4f}\t, anomaly_loss:{np.mean(anomaly_loss_list):.4f}\t'
127 | f'fnorm_loss:{np.mean(fnorm_loss_list):.4f}, a_fnorm_loss:{np.mean(afnorm_loss_list):.4f}'
128 | f'P:{np.mean(P_list):.2f}, R:{np.mean(R_list):.2f}')
129 | loss_per_epoch['fnorm_loss'].append(np.mean(fnorm_loss_list))
130 | loss_per_epoch['afnorm_loss'].append(np.mean(afnorm_loss_list))
131 |
132 | if (epoch==0) or ((epoch + 1) % eval_interval == 0):
133 | ## evaluating perfermance to choose the best, common practice in indutrial anomaly detection, such as simplenet, RD++, BGAD ...
134 | auroc_px, auroc_sp, aupro_px, test_fnorm_loss, test_afnorm_loss, eval_P = evaluation_Stage1(encoder, encoder_AT, test_dataloader, device)
135 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px))
136 | print(f'test_fnorm_loss: {np.mean(test_fnorm_loss):.4f} test_afnorm_loss: {np.mean(test_afnorm_loss):.4f}')
137 | loss_per_epoch['test_fnorm_loss'].append(np.mean(test_fnorm_loss))
138 | loss_per_epoch['test_afnorm_loss'].append(np.mean(test_afnorm_loss))
139 |
140 | # if auroc_sp > best_auroc_sp and epoch > 5:
141 | if eval_P > best_P and epoch > 5:
142 | best_P = eval_P
143 | best_auroc_sp = auroc_sp
144 | best_auroc_px = auroc_px
145 | torch.save({'encoder_AT': encoder_AT.state_dict()}, Stage1_ckp_path)
146 | # 'reconsKV_fs': recons_rar.state_dict()}, Stage1_ckp_path)
147 | if auroc_sp > 0.999 and epoch > 5:
148 | break
149 | return best_auroc_px, best_auroc_sp, aupro_px
150 |
151 |
152 | def train_Stage2(args, _class_, epochs=100, eval_interval=10, lr=0.005):
153 | print("training Stubborn Student...")
154 | print(_class_)
155 |
156 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
157 | print(device)
158 |
159 | Stage1_ckp_path = 'vit_version/checkpoints_Stage1/' + 'wres50_'+_class_+'.pth'
160 | Stage2_ckp_path = 'vit_version/checkpoints_Stage2/' + 'wres50_'+ _class_+'.pth'
161 | os.makedirs('vit_version/checkpoints_Stage2', exist_ok=True)
162 | #
163 | train_data = TrainDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root)
164 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root)
165 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
166 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
167 |
168 | #
169 | # encoder, _ = clip.load("ViT-L/14@336px", download_root='./clip', device=torch.device("cpu"), image_size=336) # Modify1
170 | encoder_pre, _ = clip.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size)
171 | encoder_pre = encoder_pre.to(device)
172 | encoder_pre.eval()
173 | #
174 | encoder_AT, _ = clip_rar.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size)
175 | encoder_AT = encoder_AT.to(device)
176 | encoder_AT.load_state_dict(torch.load(Stage1_ckp_path)['encoder_AT']) # 加载Stage1
177 | encoder_AT.eval()
178 |
179 | #
180 | bn = BN_layer(AttnBottleneck, 3)
181 | bn = bn.to(device)
182 | decoder_SS = de_VisionTransformer(input_resolution=256, patch_size=16, width=768, layers=6, heads=12, output_dim=512)
183 | decoder_SS = decoder_SS.to(device)
184 |
185 | optimizer_Stage2 = torch.optim.Adam(list(decoder_SS.parameters())+list(bn.parameters()), lr=lr, betas=(0.5,0.999))
186 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Stage2, [epochs*0.8, epochs*0.9], gamma=0.2, last_epoch=-1)
187 |
188 | best_auroc_sp = -1
189 | best_auroc_px = -1
190 | for epoch in tqdm(range(epochs)):
191 | bn.train()
192 | decoder_SS.train()
193 |
194 | kd_loss_list = []
195 | recons_loss_list = []
196 | orth_loss_list = []
197 | for batch_data in train_dataloader:
198 | # load data
199 | img = batch_data[0].to(device)
200 | #
201 | with torch.no_grad():
202 | _, inputs, atten, delta = encoder_AT.encode_image(img, f_list=[4,8,12])
203 | # inputs = encoder_pre(img)
204 |
205 | # outputs = decoder_SS(bn(inputs))
206 | _, outputs = decoder_SS(bn(inputs), f_list=[2,4,6])
207 |
208 | kd_loss = loss_fucntion(inputs, outputs) # 这句话是不变的
209 | loss = kd_loss
210 |
211 | optimizer_Stage2.zero_grad()
212 | loss.backward()
213 | optimizer_Stage2.step()
214 |
215 | kd_loss_list.append(kd_loss.item())
216 |
217 | scheduler.step() # modify
218 | print('epoch [{}/{}], kd_loss:{:.4f}'.format(epoch + 1, epochs, np.mean(kd_loss_list)))
219 |
220 | if (epoch + 1) % eval_interval == 0:
221 | auroc_px, auroc_sp, aupro_px, _, _, _ = evaluation_Stage2(encoder_AT, bn, decoder_SS, test_dataloader, device)
222 | print('Pixel Auroc:{:.3f}, Sample Auroc{:.3f}, Pixel Aupro{:.3}'.format(auroc_px, auroc_sp, aupro_px))
223 |
224 | if auroc_sp > best_auroc_sp and epoch > 5:
225 | best_auroc_sp = auroc_sp
226 | best_auroc_px = auroc_px
227 | torch.save({'bn': bn.state_dict(),
228 | 'decoder_ss': decoder_SS.state_dict()}, Stage2_ckp_path)
229 | if auroc_sp > 0.999 and epoch > 5:
230 | break
231 | return best_auroc_px, best_auroc_sp, aupro_px
232 |
233 |
234 | if __name__ == '__main__':
235 | parser = argparse.ArgumentParser()
236 | parser.add_argument('--data_root', type=str, default='/data4/tch/AD_data/mvtec')
237 | parser.add_argument('--aux_path', type=str, default='/data4/tch/AD_data/DRAEM_dtd/dtd/images')
238 | parser.add_argument('--L_rar', type=int, default=50)
239 | parser.add_argument('--batch_size', type=int, default=16)
240 | parser.add_argument('--image_size', type=int, default=256)
241 | args = parser.parse_args()
242 |
243 | dataset_name = get_dataset_name(args.data_root)
244 | if dataset_name == 'mvtec':
245 | from datasets.MvTec import TrainDataset, TestDataset, AnomalyDataset
246 | from datasets.MvTec import mvtec_classes
247 | item_list = mvtec_classes()
248 | elif dataset_name == 'VisA':
249 | from datasets.VisA import TrainDataset, TestDataset, AnomalyDataset
250 | from datasets.VisA import visa_classes
251 | item_list = visa_classes()
252 | elif dataset_name == 'mvtec3d':
253 | from datasets.MvTec3D import TrainDataset, TestDataset, AnomalyDataset
254 | from datasets.MvTec3D import mvtec3d_classes
255 | item_list = mvtec3d_classes()
256 |
257 | setup_seed(111)
258 |
259 | auroc_Stage1_list = []
260 | auroc_Stage2_list = []
261 |
262 | for i in item_list:
263 | _, auroc_sp, _ = train_Stage1(args, i, epochs=100, eval_interval=5, lr=0.005)
264 | auroc_Stage1_list.append(auroc_sp)
265 | _, auroc_sp, _ = train_Stage2(args, i, epochs=120, eval_interval=10, lr=0.005)
266 | auroc_Stage2_list.append(auroc_sp)
267 | auroc_Stage1_mean = np.mean(auroc_Stage1_list)
268 | auroc_Stage2_mean = np.mean(auroc_Stage2_list)
269 |
270 | with open('results.txt', 'a') as f:
271 | f.write(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))+'\n')
272 | f.write('Stage1: ')
273 | for item in auroc_Stage1_list:
274 | f.write(f"{item:.4f}")
275 | f.write(' ')
276 | f.write(f'avg: {auroc_Stage1_mean:.4f}\n')
277 |
278 | f.write('Stage2: ')
279 | for item in auroc_Stage2_list:
280 | f.write(f"{item:.4f}")
281 | f.write(' ')
282 | f.write(f'avg: {auroc_Stage2_mean:.4f}\n')
283 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.datasets import ImageFolder
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 | from models.resnet import resnet18, resnet34, resnet50, wide_resnet50_2
6 | from models.de_resnet import de_resnet18, de_resnet50, de_wide_resnet50_2
7 | from models.resnet_rar import resnet18, resnet34, resnet50, wide_resnet50_2_rar
8 | # from models.de_resnet_mem import de_wide_resnet50_2_mem
9 | from torch.nn import functional as F
10 | from sklearn.metrics import roc_auc_score, precision_recall_curve
11 | import cv2
12 | import matplotlib.pyplot as plt
13 | from sklearn.metrics import auc
14 | from skimage import measure
15 | import pandas as pd
16 | from numpy import ndarray
17 | from statistics import mean
18 | from scipy.ndimage import gaussian_filter
19 | from sklearn import manifold
20 | from matplotlib.ticker import NullFormatter
21 | from scipy.spatial.distance import pdist
22 | import matplotlib
23 | import pickle
24 | import argparse
25 | # from MvTec3D import TestDataset
26 | # from MvTec import TestDataset
27 | import pdb
28 | from utils.vis import *
29 | from utils.utils import PatchMaker, get_dataset_name
30 | from models.loss import *
31 | import time
32 | from models.recons_net import *
33 | from tqdm import tqdm
34 |
35 |
36 | def cal_anomaly_map(fs_list, ft_list, out_size=224, amap_mode='mul', vis=False):
37 | if amap_mode == 'mul':
38 | anomaly_map = np.ones([out_size, out_size])
39 | else:
40 | anomaly_map = np.zeros([out_size, out_size])
41 | a_map_list = []
42 | mse_loss = torch.nn.MSELoss(reduction='none')
43 | for i in range(len(ft_list)):
44 | # pdb.set_trace()
45 | fs = fs_list[i]
46 | ft = ft_list[i]
47 | #fs_norm = F.normalize(fs, p=2)
48 | #ft_norm = F.normalize(ft, p=2)
49 | # pdb.set_trace()
50 | a_map = 1 - F.cosine_similarity(fs, ft) # cos_loss [1,C,H,W]->[1,H,W]
51 | # a_map2 = torch.sqrt(mse_loss(fs, ft).sum(dim=1)) # mse_loss [1,C,H,W]->[1,H,W]
52 | # if i == 1 and vis:
53 | # vis_hotmap(fs, ft, a_map)
54 |
55 | a_map = torch.unsqueeze(a_map, dim=1)
56 | a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True)
57 | a_map = a_map[0, 0, :, :].to('cpu').detach().numpy()
58 | a_map_list.append(a_map)
59 | if amap_mode == 'mul':
60 | anomaly_map *= a_map
61 | else:
62 | anomaly_map += a_map
63 |
64 | return anomaly_map, None
65 |
66 |
67 | def evaluation_Stage1(encoder, encoder_AT, dataloader, device, _class_=None):
68 | encoder.eval()
69 | encoder_AT.eval()
70 | # reconsKV.eval()
71 | gt_list_px = []
72 | pr_list_px = []
73 | gt_list_sp = []
74 | pr_list_sp = []
75 | aupro_list = []
76 | fnorm_loss_list, afnorm_loss_list = [], []
77 | P_list, R_list = [], []
78 | with torch.no_grad():
79 | for img, gt, label, _ in dataloader:
80 | img = img.to(device)
81 | gt = gt.to(device)
82 | inputs = encoder(img)
83 | # FS
84 | _, inputs_AT, delta, atten = encoder_AT(img, flag=True)
85 | # pdb.set_trace()
86 | loss_atten, P, R = get_focal_loss(atten, gt)
87 | fnorm_loss, afnorm_loss = get_FnormLoss(delta, gt)
88 | fnorm_loss_list.append(fnorm_loss.item())
89 | if afnorm_loss != 0.0:
90 | afnorm_loss_list.append(afnorm_loss.item())
91 | P_list.append(P.item())
92 | R_list.append(R.item())
93 |
94 | anomaly_map, _ = cal_anomaly_map(inputs, inputs_AT, img.shape[-1], amap_mode='a')
95 | anomaly_map = gaussian_filter(anomaly_map, sigma=4)
96 | gt[gt > 0.5] = 1
97 | gt[gt <= 0.5] = 0
98 | if label.item()!=0:
99 | aupro_list.append(0.0)
100 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel())
101 | pr_list_px.extend(anomaly_map.ravel())
102 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int)))
103 | pr_list_sp.append(np.max(anomaly_map))
104 |
105 | # print(afnorm_loss_list, 'len:', len(afnorm_loss_list))
106 |
107 | pr_list_sp = np.array(pr_list_sp)
108 | gt_list_sp = np.array(gt_list_sp)
109 | normal = pr_list_sp[gt_list_sp==0]
110 | anomaly = pr_list_sp[gt_list_sp==1]
111 | print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}\t'
112 | f'P:{np.mean(P_list):.2f}, R:{np.mean(R_list):.2f}')
113 |
114 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3)
115 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3)
116 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3), fnorm_loss_list, afnorm_loss_list, np.mean(P_list)
117 |
118 |
119 | def evaluation_Stage2(encoder_AT, bn, decoder_SS, dataloader, device, _class_=None):
120 | encoder_AT.eval()
121 | bn.eval()
122 | decoder_SS.eval()
123 | gt_list_px = []
124 | pr_list_px = []
125 | gt_list_sp = []
126 | pr_list_sp = []
127 | aupro_list = []
128 | time_list = []
129 | normal_feats, anomaly_feats = [], []
130 | with torch.no_grad():
131 | # for img, gt, label, rgb_path, fore_mask in dataloader:
132 | for img, gt, label, rgb_path in dataloader:
133 | # break
134 | img = img.to(device)
135 | tic = time.time()
136 | torch.cuda.synchronize()
137 | # FS
138 | _, inputs, _, _ = encoder_AT(img, flag=True)
139 | # inputs = encoder_AT(img)
140 | # SS
141 | outputs = decoder_SS(bn(inputs))
142 | torch.cuda.synchronize()
143 | time_list.append(time.time() - tic)
144 |
145 | anomaly_map, _ = cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a', vis=label.bool())
146 | anomaly_map = gaussian_filter(anomaly_map, sigma=4)
147 | gt[gt > 0.5] = 1
148 | gt[gt <= 0.5] = 0
149 | # if label.item() == 0:
150 | # pdb.set_trace()
151 | # apply_ad_scoremap(rgb_path[0], anomaly_map, _class_)
152 | if label.item()!=0:
153 | # if rgb_path == '/data4/tch/AD_data/mvtec/zipper/test/fabric_interior/007.png':
154 | # pdb.set_trace()
155 | # pdb.set_trace()
156 | # aupro_list.append(compute_pro(gt.squeeze(0).cpu().numpy().astype(int),
157 | # anomaly_map[np.newaxis,:,:]))
158 | aupro_list.append(0.0)
159 |
160 | anomaly_map = anomaly_map
161 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel())
162 | pr_list_px.extend(anomaly_map.ravel())
163 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int)))
164 | pr_list_sp.append(np.max(anomaly_map))
165 |
166 | pr_list_px = np.array(pr_list_px)
167 | gt_list_px = np.array(gt_list_px)
168 | pr_list_sp = np.array(pr_list_sp)
169 | gt_list_sp = np.array(gt_list_sp)
170 |
171 | normal_sp = pr_list_sp[gt_list_sp==0]
172 | anomaly_sp = pr_list_sp[gt_list_sp==1]
173 |
174 | print(f'normal: {normal_sp.min():.4f} {normal_sp.max():.4f} anomaly: {anomaly_sp.min():.4f} {anomaly_sp.max():.4f}')
175 |
176 | precision, recall, thresholds = precision_recall_curve(gt_list_sp.astype(int), pr_list_sp)
177 | F1_scores = np.divide(
178 | 2 * precision * recall,
179 | precision + recall,
180 | out=np.zeros_like(precision),
181 | where=(precision + recall) != 0,
182 | )
183 | optimal_threshold = thresholds[np.argmax(F1_scores)]
184 | fpr_optim = np.mean(normal_sp > optimal_threshold)
185 | fnr_optim = np.mean(anomaly_sp < optimal_threshold)
186 | print('thresh:', optimal_threshold)
187 | print('fpr:', fpr_optim, 'fnr', fnr_optim)
188 |
189 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3)
190 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3)
191 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3), time_list, pr_list_sp, gt_list_sp
192 |
193 |
194 | def test(_class_, args):
195 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
196 | print(device)
197 | print(_class_)
198 |
199 | Stage1_ckp_path = './checkpoints_Stage1/' + 'wres50_'+_class_+'.pth'
200 | Stage2_ckp_path = './checkpoints_Stage2/' + 'wres50_'+ _class_+'.pth'
201 | image_size = 256
202 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root)
203 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
204 |
205 | # Advanced Teacher
206 | encoder_AT, bn = wide_resnet50_2_rar(pretrained=False)
207 | encoder_AT = encoder_AT.to(device)
208 | encoder_AT.load_state_dict(torch.load(Stage1_ckp_path)['encoder_AT'])
209 | encoder_AT.eval()
210 |
211 | # Vanilla Teacher
212 | encoder_pre, _ = wide_resnet50_2(pretrained=True)
213 | encoder_pre = encoder_pre.to(device)
214 | encoder_pre.eval()
215 |
216 | bn = bn.to(device)
217 | decoder_SS = de_wide_resnet50_2(pretrained=False)
218 | decoder_SS = decoder_SS.to(device)
219 |
220 | SS_ckp = torch.load(Stage2_ckp_path)
221 | for k, v in list(SS_ckp['bn'].items()):
222 | if 'memory' in k:
223 | SS_ckp['bn'].pop(k)
224 | decoder_SS.load_state_dict(SS_ckp['decoder_ss'])
225 | bn.load_state_dict(SS_ckp['bn'])
226 |
227 | # total_params = compute_params([encoder_AT, bn, decoder_SS])
228 | # print(f'total params: {total_params/1e6} M') #{sum([x.nelement() for x in self.model.parameters()])/1000000.} M
229 |
230 | # baseline, RD
231 | # auroc_px, auroc_sp, aupro_px, time_list, preds, labels = evaluation_Stage2(encoder_pre, bn, decoder_SS, test_dataloader, device, _class_)
232 | # Ours
233 | auroc_px, auroc_sp, aupro_px, time_list, preds, labels = evaluation_Stage2(encoder_AT, bn, decoder_SS, test_dataloader, device, _class_)
234 | # pdb.set_trace()
235 | print(_class_,':',auroc_px,',',auroc_sp,',',aupro_px)
236 | return auroc_px, auroc_sp, aupro_px, time_list, preds, labels
237 |
238 |
239 | def compute_params(nets):
240 | total_params = 0
241 | for net in nets:
242 | print(sum(p.numel() for p in net.parameters()))
243 | total_params += sum(p.numel() for p in net.parameters())
244 | return total_params
245 |
246 | def compute_pro(masks: ndarray, amaps: ndarray, num_th: int = 200) -> None:
247 |
248 | """Compute the area under the curve of per-region overlaping (PRO) and 0 to 0.3 FPR
249 | Args:
250 | category (str): Category of product
251 | masks (ndarray): All binary masks in test. masks.shape -> (num_test_data, h, w)
252 | amaps (ndarray): All anomaly maps in test. amaps.shape -> (num_test_data, h, w)
253 | num_th (int, optional): Number of thresholds
254 | """
255 |
256 | assert isinstance(amaps, ndarray), "type(amaps) must be ndarray"
257 | assert isinstance(masks, ndarray), "type(masks) must be ndarray"
258 | assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)"
259 | assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)"
260 | assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same"
261 | assert set(masks.flatten()) == {0, 1}, "set(masks.flatten()) must be {0, 1}"
262 | assert isinstance(num_th, int), "type(num_th) must be int"
263 |
264 | df = pd.DataFrame([], columns=["pro", "fpr", "threshold"])
265 | binary_amaps = np.zeros_like(amaps, dtype=np.bool_)
266 |
267 | min_th = amaps.min()
268 | max_th = amaps.max()
269 | delta = (max_th - min_th) / num_th
270 |
271 | for th in np.arange(min_th, max_th, delta):
272 | binary_amaps[amaps <= th] = 0
273 | binary_amaps[amaps > th] = 1
274 |
275 | pros = []
276 | for binary_amap, mask in zip(binary_amaps, masks):
277 | for region in measure.regionprops(measure.label(mask)):
278 | axes0_ids = region.coords[:, 0]
279 | axes1_ids = region.coords[:, 1]
280 | tp_pixels = binary_amap[axes0_ids, axes1_ids].sum()
281 | pros.append(tp_pixels / region.area)
282 |
283 | inverse_masks = 1 - masks
284 | fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
285 | fpr = fp_pixels / inverse_masks.sum()
286 |
287 | # df = df.append({"pro": mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True)
288 | df = pd.concat([df, pd.DataFrame([{"pro": mean(pros), "fpr": fpr, "threshold": th}])], ignore_index=True)
289 |
290 |
291 | # Normalize FPR from 0 ~ 1 to 0 ~ 0.3
292 | df = df[df["fpr"] < 0.3]
293 | df["fpr"] = df["fpr"] / df["fpr"].max()
294 |
295 | pro_auc = auc(df["fpr"], df["pro"])
296 | return pro_auc
297 |
298 |
299 | if __name__ == '__main__':
300 | # from main import setup_seed
301 | # setup_seed(111)
302 | parser = argparse.ArgumentParser()
303 | parser.add_argument('--data_root', type=str, default='/data4/tch/AD_data/mvtec3d')
304 | parser.add_argument('--image_size', type=int, default=256)
305 | args = parser.parse_args()
306 |
307 | dataset_name = get_dataset_name(args.data_root)
308 | if dataset_name == 'mvtec':
309 | from datasets.MvTec import TrainDataset, TestDataset, AnomalyDataset
310 | from datasets.MvTec import mvtec_classes
311 | item_list = mvtec_classes()
312 | elif dataset_name == 'VisA':
313 | from datasets.VisA import TrainDataset, TestDataset, AnomalyDataset
314 | from datasets.VisA import visa_classes
315 | item_list = visa_classes()
316 | elif dataset_name == 'mvtec3d':
317 | from datasets.MvTec3D import TrainDataset, TestDataset, AnomalyDataset
318 | from datasets.MvTec3D import mvtec3d_classes
319 | item_list = mvtec3d_classes()
320 |
321 | # all_time_list = []
322 | # for i in item_list:
323 | # time_list = test(i)
324 | # all_time_list += time_list
325 | # # print(len(all_time_list))
326 | # print('FPS: ', 1 / np.mean(all_time_list))
327 |
328 | p_auc_list, i_auc_list, p_pro_list = [], [], []
329 | label_list, pred_list = [], []
330 | for i in item_list:
331 | p_auc, i_auc, p_pro, _, preds, labels = test(i, args)
332 | p_auc_list.append(p_auc)
333 | i_auc_list.append(i_auc)
334 | p_pro_list.append(p_pro)
335 | label_list += list(labels)
336 | pred_list += list(preds)
337 | # all_time_list += time_list
338 | # print(len(all_time_list))
339 | # print('FPS: ', 1 / np.mean(all_time_list))
340 | i_auc_mean = np.mean(i_auc_list)
341 | p_auc_mean = np.mean(p_auc_list)
342 | p_pro_mean = np.mean(p_pro_list)
343 |
344 | normal = np.array(pred_list)[np.array(label_list)==0]
345 | anomaly = np.array(pred_list)[np.array(label_list)==1]
346 | os.makedirs(f'results/{dataset_name}',exist_ok=True)
347 | np.save(f'results/{dataset_name}/normal_scores', normal)
348 | np.save(f'results/{dataset_name}/anomaly_scores', anomaly)
349 | # normal = np.load('normal_scores.npy')
350 | # anomaly = np.load('anomaly_scores.npy')
351 | # pdb.set_trace()
352 | # print(anomaly)
353 |
354 | # width = 7
355 | # fig = plt.figure(figsize=(width, width*0.75))
356 | # print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}')
357 | # plt.hist(normal,bins=20,label='normal sample',alpha=0.5)
358 | # plt.hist(anomaly,bins=20,label='anomalous sample',alpha=0.5)
359 | # plt.legend(fontsize=15)
360 | # plt.tick_params(labelsize=12.5)
361 | # plt.savefig(f'hist_SS.png')
362 |
363 |
364 | with open('results2.txt', 'a') as f:
365 | f.write(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))+'\n')
366 | f.write('i_auc: ')
367 | for item in i_auc_list:
368 | f.write(f"{item:.4f}")
369 | f.write(' ')
370 | f.write(f'avg: {i_auc_mean:.4f}\n')
371 |
372 | f.write('p_auc: ')
373 | for item in p_auc_list:
374 | f.write(f"{item:.4f}")
375 | f.write(' ')
376 | f.write(f'avg: {p_auc_mean:.4f}\n')
377 |
378 | f.write('p_pro: ')
379 | for item in p_pro_list:
380 | f.write(f"{item:.4f}")
381 | f.write(' ')
382 | f.write(f'avg: {p_pro_mean:.4f}\n')
--------------------------------------------------------------------------------
/vit_version/test_vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.datasets import ImageFolder
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 | import os, pdb, sys
6 | sys.path.append(os.getcwd())
7 | from models.resnet import resnet18, resnet34, resnet50, wide_resnet50_2
8 | from models.de_resnet import de_resnet18, de_resnet50, de_wide_resnet50_2
9 | from models.resnet_rar import resnet18, resnet34, resnet50, wide_resnet50_2_rar
10 | # from models.de_resnet_mem import de_wide_resnet50_2_mem
11 | from torch.nn import functional as F
12 | from sklearn.metrics import roc_auc_score, precision_recall_curve
13 | import cv2
14 | import matplotlib.pyplot as plt
15 | from sklearn.metrics import auc
16 | from skimage import measure
17 | import pandas as pd
18 | from numpy import ndarray
19 | from statistics import mean
20 | from scipy.ndimage import gaussian_filter
21 | from sklearn import manifold
22 | from matplotlib.ticker import NullFormatter
23 | from scipy.spatial.distance import pdist
24 | import matplotlib
25 | import pickle
26 | import argparse
27 | # from MvTec3D import TestDataset
28 | # from MvTec import TestDataset
29 | import pdb
30 | from utils.vis import *
31 | from utils.utils import PatchMaker, get_dataset_name
32 | from models.loss import *
33 | import time
34 | from models.recons_net import *
35 | from tqdm import tqdm
36 | # import clip
37 | # from clip_bn import BN_layer, AttnBottleneck
38 | # from clip_decoder import de_VisionTransformer
39 | import vit_version.clip as clip
40 | import vit_version.clip_rar as clip_rar
41 | from vit_version.clip_bn import BN_layer, AttnBottleneck
42 | from vit_version.clip_decoder import de_VisionTransformer
43 |
44 |
45 |
46 | def cal_anomaly_map(fs_list, ft_list, out_size=224, amap_mode='mul', vis=False):
47 | if amap_mode == 'mul':
48 | anomaly_map = np.ones([out_size, out_size])
49 | else:
50 | anomaly_map = np.zeros([out_size, out_size])
51 | a_map_list = []
52 | mse_loss = torch.nn.MSELoss(reduction='none')
53 | for i in range(len(ft_list)):
54 | # pdb.set_trace()
55 | fs = fs_list[i]
56 | ft = ft_list[i]
57 | #fs_norm = F.normalize(fs, p=2)
58 | #ft_norm = F.normalize(ft, p=2)
59 | # pdb.set_trace()
60 | a_map = 1 - F.cosine_similarity(fs, ft) # cos_loss [1,C,H,W]->[1,H,W]
61 | # a_map2 = torch.sqrt(mse_loss(fs, ft).sum(dim=1)) # mse_loss [1,C,H,W]->[1,H,W]
62 | # if i == 1 and vis:
63 | # vis_hotmap(fs, ft, a_map)
64 |
65 | a_map = torch.unsqueeze(a_map, dim=1)
66 | a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True)
67 | a_map = a_map[0, 0, :, :].to('cpu').detach().numpy()
68 | a_map_list.append(a_map)
69 | if amap_mode == 'mul':
70 | anomaly_map *= a_map
71 | else:
72 | anomaly_map += a_map
73 |
74 | return anomaly_map, None
75 |
76 |
77 | def evaluation_Stage1(encoder, encoder_AT, dataloader, device, _class_=None):
78 | encoder.eval()
79 | encoder_AT.eval()
80 | # reconsKV.eval()
81 | gt_list_px = []
82 | pr_list_px = []
83 | gt_list_sp = []
84 | pr_list_sp = []
85 | aupro_list = []
86 | fnorm_loss_list, afnorm_loss_list = [], []
87 | P_list, R_list = [], []
88 | with torch.no_grad():
89 | for img, gt, label, _ in dataloader:
90 | img = img.to(device)
91 | gt = gt.to(device)
92 |
93 | _, inputs = encoder.encode_image(img, f_list=[4,8,12])
94 | _, inputs_AT, atten, delta = encoder_AT.encode_image(img, f_list=[4,8,12])
95 | # pdb.set_trace()
96 | loss_atten, P, R = get_focal_loss(atten, gt, vit=True)
97 | fnorm_loss, afnorm_loss = get_FnormLoss(delta, gt, vit=True)
98 | fnorm_loss_list.append(fnorm_loss.item())
99 | if afnorm_loss != 0.0:
100 | afnorm_loss_list.append(afnorm_loss.item())
101 | P_list.append(P.item())
102 | R_list.append(R.item())
103 |
104 | anomaly_map, _ = cal_anomaly_map(inputs, inputs_AT, img.shape[-1], amap_mode='a')
105 | anomaly_map = gaussian_filter(anomaly_map, sigma=4)
106 | gt[gt > 0.5] = 1
107 | gt[gt <= 0.5] = 0
108 | if label.item()!=0:
109 | aupro_list.append(0.0)
110 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel())
111 | pr_list_px.extend(anomaly_map.ravel())
112 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int)))
113 | pr_list_sp.append(np.max(anomaly_map))
114 |
115 | # print(afnorm_loss_list, 'len:', len(afnorm_loss_list))
116 |
117 | pr_list_sp = np.array(pr_list_sp)
118 | gt_list_sp = np.array(gt_list_sp)
119 | normal = pr_list_sp[gt_list_sp==0]
120 | anomaly = pr_list_sp[gt_list_sp==1]
121 | print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}\t'
122 | f'P:{np.mean(P_list):.2f}, R:{np.mean(R_list):.2f}')
123 |
124 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3)
125 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3)
126 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3), fnorm_loss_list, afnorm_loss_list, np.mean(P_list)
127 |
128 |
129 | def evaluation_Stage2(encoder_AT, bn, decoder_SS, dataloader, device, _class_=None):
130 | encoder_AT.eval()
131 | bn.eval()
132 | decoder_SS.eval()
133 | gt_list_px = []
134 | pr_list_px = []
135 | gt_list_sp = []
136 | pr_list_sp = []
137 | aupro_list = []
138 | time_list = []
139 | normal_feats, anomaly_feats = [], []
140 | with torch.no_grad():
141 | # for img, gt, label, rgb_path, fore_mask in dataloader:
142 | for img, gt, label, rgb_path in dataloader:
143 | # break
144 | img = img.to(device)
145 | tic = time.time()
146 | torch.cuda.synchronize()
147 | # FS
148 | _, inputs, atten, delta = encoder_AT.encode_image(img, f_list=[4,8,12])
149 | # inputs = encoder_AT(img)
150 | # SS
151 | # outputs = decoder_SS(bn(inputs))
152 | _, outputs = decoder_SS(bn(inputs), f_list=[2,4,6])
153 | torch.cuda.synchronize()
154 | time_list.append(time.time() - tic)
155 |
156 | anomaly_map, _ = cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a', vis=label.bool())
157 | anomaly_map = gaussian_filter(anomaly_map, sigma=4)
158 | gt[gt > 0.5] = 1
159 | gt[gt <= 0.5] = 0
160 | # if label.item() == 0:
161 | # pdb.set_trace()
162 | # apply_ad_scoremap(rgb_path[0], anomaly_map, _class_)
163 | if label.item()!=0:
164 | # if rgb_path == '/data4/tch/AD_data/mvtec/zipper/test/fabric_interior/007.png':
165 | # pdb.set_trace()
166 | # pdb.set_trace()
167 | # aupro_list.append(compute_pro(gt.squeeze(0).cpu().numpy().astype(int),
168 | # anomaly_map[np.newaxis,:,:]))
169 | aupro_list.append(0.0)
170 |
171 | anomaly_map = anomaly_map
172 | gt_list_px.extend(gt.cpu().numpy().astype(int).ravel())
173 | pr_list_px.extend(anomaly_map.ravel())
174 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int)))
175 | pr_list_sp.append(np.max(anomaly_map))
176 |
177 | pr_list_px = np.array(pr_list_px)
178 | gt_list_px = np.array(gt_list_px)
179 | pr_list_sp = np.array(pr_list_sp)
180 | gt_list_sp = np.array(gt_list_sp)
181 | # np.save('results/saved_data/zipper/pr_list_sp_zipper', pr_list_sp)
182 | # np.save('results/saved_data/zipper/gt_list_sp_zipper', gt_list_sp)
183 |
184 | normal_sp = pr_list_sp[gt_list_sp==0]
185 | anomaly_sp = pr_list_sp[gt_list_sp==1]
186 |
187 | print(f'normal: {normal_sp.min():.4f} {normal_sp.max():.4f} anomaly: {anomaly_sp.min():.4f} {anomaly_sp.max():.4f}')
188 |
189 | precision, recall, thresholds = precision_recall_curve(gt_list_sp.astype(int), pr_list_sp)
190 | F1_scores = np.divide(
191 | 2 * precision * recall,
192 | precision + recall,
193 | out=np.zeros_like(precision),
194 | where=(precision + recall) != 0,
195 | )
196 | optimal_threshold = thresholds[np.argmax(F1_scores)]
197 | fpr_optim = np.mean(normal_sp > optimal_threshold)
198 | fnr_optim = np.mean(anomaly_sp < optimal_threshold)
199 | print('thresh:', optimal_threshold)
200 | print('fpr:', fpr_optim, 'fnr', fnr_optim)
201 |
202 | auroc_px = round(roc_auc_score(gt_list_px, pr_list_px), 3)
203 | auroc_sp = round(roc_auc_score(gt_list_sp, pr_list_sp), 3)
204 | return auroc_px, auroc_sp, round(np.mean(aupro_list),3), time_list, pr_list_sp, gt_list_sp
205 |
206 |
207 | def test(_class_, args):
208 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
209 | print(device)
210 | print(_class_)
211 |
212 | Stage1_ckp_path = 'vit_version/checkpoints_Stage1/' + 'wres50_'+_class_+'.pth'
213 | Stage2_ckp_path = 'vit_version/checkpoints_Stage2/' + 'wres50_'+ _class_+'.pth'
214 | image_size = 256
215 | test_data = TestDataset(class_name=_class_, img_size=args.image_size, dataset_path=args.data_root)
216 | test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
217 |
218 | encoder_AT, _ = clip_rar.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size)
219 | encoder_AT = encoder_AT.to(device)
220 | encoder_AT.load_state_dict(torch.load(Stage1_ckp_path)['encoder_AT']) # 加载Stage1
221 | encoder_AT.eval()
222 |
223 | # encoder_pre, _ = clip.load("ViT-L/14@336px", download_root='./clip', device=torch.device("cpu"), image_size=336)
224 | encoder_pre, _ = clip.load("ViT-B/16", download_root='vit_version/clip', device=torch.device("cpu"), image_size=args.image_size)
225 | encoder_pre = encoder_pre.to(device)
226 | encoder_pre.eval()
227 | # pdb.set_trace()
228 |
229 | bn = BN_layer(AttnBottleneck, 3)
230 | bn = bn.to(device)
231 | decoder_SS = de_VisionTransformer(input_resolution=256, patch_size=16, width=768, layers=6, heads=12, output_dim=512)
232 | decoder_SS = decoder_SS.to(device)
233 |
234 | SS_ckp = torch.load(Stage2_ckp_path)
235 | for k, v in list(SS_ckp['bn'].items()):
236 | if 'memory' in k:
237 | SS_ckp['bn'].pop(k)
238 | decoder_SS.load_state_dict(SS_ckp['decoder_ss'])
239 | bn.load_state_dict(SS_ckp['bn'])
240 |
241 | # total_params = compute_params([encoder_AT, bn, decoder_SS])
242 | # print(f'total params: {total_params/1e6} M') #{sum([x.nelement() for x in self.model.parameters()])/1000000.} M
243 |
244 | auroc_px, auroc_sp, aupro_px, time_list, preds, labels = evaluation_Stage2(encoder_AT, bn, decoder_SS, test_dataloader, device, _class_)
245 | print(_class_,':',auroc_px,',',auroc_sp,',',aupro_px)
246 | return auroc_px, auroc_sp, aupro_px, time_list, preds, labels
247 |
248 |
249 | def compute_params(nets):
250 | total_params = 0
251 | for net in nets:
252 | print(sum(p.numel() for p in net.parameters()))
253 | total_params += sum(p.numel() for p in net.parameters())
254 | return total_params
255 |
256 | def compute_pro(masks: ndarray, amaps: ndarray, num_th: int = 200) -> None:
257 |
258 | """Compute the area under the curve of per-region overlaping (PRO) and 0 to 0.3 FPR
259 | Args:
260 | category (str): Category of product
261 | masks (ndarray): All binary masks in test. masks.shape -> (num_test_data, h, w)
262 | amaps (ndarray): All anomaly maps in test. amaps.shape -> (num_test_data, h, w)
263 | num_th (int, optional): Number of thresholds
264 | """
265 |
266 | assert isinstance(amaps, ndarray), "type(amaps) must be ndarray"
267 | assert isinstance(masks, ndarray), "type(masks) must be ndarray"
268 | assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)"
269 | assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)"
270 | assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same"
271 | assert set(masks.flatten()) == {0, 1}, "set(masks.flatten()) must be {0, 1}"
272 | assert isinstance(num_th, int), "type(num_th) must be int"
273 |
274 | df = pd.DataFrame([], columns=["pro", "fpr", "threshold"])
275 | binary_amaps = np.zeros_like(amaps, dtype=np.bool_)
276 |
277 | min_th = amaps.min()
278 | max_th = amaps.max()
279 | delta = (max_th - min_th) / num_th
280 |
281 | for th in np.arange(min_th, max_th, delta):
282 | binary_amaps[amaps <= th] = 0
283 | binary_amaps[amaps > th] = 1
284 |
285 | pros = []
286 | for binary_amap, mask in zip(binary_amaps, masks):
287 | for region in measure.regionprops(measure.label(mask)):
288 | axes0_ids = region.coords[:, 0]
289 | axes1_ids = region.coords[:, 1]
290 | tp_pixels = binary_amap[axes0_ids, axes1_ids].sum()
291 | pros.append(tp_pixels / region.area)
292 |
293 | inverse_masks = 1 - masks
294 | fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
295 | fpr = fp_pixels / inverse_masks.sum()
296 |
297 | # df = df.append({"pro": mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True) # old pandas version
298 | df = pd.concat([df, pd.DataFrame([{"pro": mean(pros), "fpr": fpr, "threshold": th}])], ignore_index=True)
299 |
300 | # Normalize FPR from 0 ~ 1 to 0 ~ 0.3
301 | df = df[df["fpr"] < 0.3]
302 | df["fpr"] = df["fpr"] / df["fpr"].max()
303 |
304 | pro_auc = auc(df["fpr"], df["pro"])
305 | return pro_auc
306 |
307 |
308 | # if __name__ == '__main__':
309 | # from utils.utils import setup_seed
310 | # setup_seed(111)
311 | # item_list = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill',
312 | # 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']
313 | # # from MvTec3D import mvtec3d_classes
314 | # # item_list = mvtec3d_classes()
315 |
316 | # for i in item_list:
317 | # test(i)
318 |
319 | if __name__ == '__main__':
320 | # from main import setup_seed
321 | # setup_seed(111)
322 | parser = argparse.ArgumentParser()
323 | parser.add_argument('--data_root', type=str, default='/data4/tch/AD_data/mvtec3d')
324 | parser.add_argument('--image_size', type=int, default=256)
325 | args = parser.parse_args()
326 |
327 | dataset_name = get_dataset_name(args.data_root)
328 | if dataset_name == 'mvtec':
329 | from datasets.MvTec import TrainDataset, TestDataset, AnomalyDataset
330 | from datasets.MvTec import mvtec_classes
331 | item_list = mvtec_classes()
332 | elif dataset_name == 'VisA':
333 | from datasets.VisA import TrainDataset, TestDataset, AnomalyDataset
334 | from datasets.VisA import visa_classes
335 | item_list = visa_classes()
336 | elif dataset_name == 'mvtec3d':
337 | from datasets.MvTec3D import TrainDataset, TestDataset, AnomalyDataset
338 | from datasets.MvTec3D import mvtec3d_classes
339 | item_list = mvtec3d_classes()
340 |
341 | # all_time_list = []
342 | # for i in item_list:
343 | # time_list = test(i)
344 | # all_time_list += time_list
345 | # # print(len(all_time_list))
346 | # print('FPS: ', 1 / np.mean(all_time_list))
347 |
348 | p_auc_list, i_auc_list, p_pro_list = [], [], []
349 | label_list, pred_list = [], []
350 | for i in item_list:
351 | p_auc, i_auc, p_pro, _, preds, labels = test(i, args)
352 | p_auc_list.append(p_auc)
353 | i_auc_list.append(i_auc)
354 | p_pro_list.append(p_pro)
355 | label_list += list(labels)
356 | pred_list += list(preds)
357 | # all_time_list += time_list
358 | # print(len(all_time_list))
359 | # print('FPS: ', 1 / np.mean(all_time_list))
360 | i_auc_mean = np.mean(i_auc_list)
361 | p_auc_mean = np.mean(p_auc_list)
362 | p_pro_mean = np.mean(p_pro_list)
363 |
364 | normal = np.array(pred_list)[np.array(label_list)==0]
365 | anomaly = np.array(pred_list)[np.array(label_list)==1]
366 | os.makedirs(f'results/{dataset_name}',exist_ok=True)
367 | np.save(f'results/{dataset_name}/normal_scores_vit', normal)
368 | np.save(f'results/{dataset_name}/anomaly_scores_vit', anomaly)
369 | # np.save('normal_scores', normal)
370 | # np.save('anomaly_scores', anomaly)
371 | # normal = np.load('normal_scores.npy')
372 | # anomaly = np.load('anomaly_scores.npy')
373 | # pdb.set_trace()
374 | # print(anomaly)
375 |
376 | # width = 7
377 | # fig = plt.figure(figsize=(width, width*0.75))
378 | # print(f'normal: {normal.min():.4f} {normal.max():.4f} anomaly: {anomaly.min():.4f} {anomaly.max():.4f}')
379 | # plt.hist(normal,bins=20,label='normal sample',alpha=0.5)
380 | # plt.hist(anomaly,bins=20,label='anomalous sample',alpha=0.5)
381 | # plt.legend(fontsize=15)
382 | # plt.tick_params(labelsize=12.5)
383 | # plt.savefig(f'hist_SS.png')
384 |
385 |
386 | with open('results2.txt', 'a') as f:
387 | f.write(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))+'\n')
388 | f.write('i_auc: ')
389 | for item in i_auc_list:
390 | f.write(f"{item:.4f}")
391 | f.write(' ')
392 | f.write(f'avg: {i_auc_mean:.4f}\n')
393 |
394 | f.write('p_auc: ')
395 | for item in p_auc_list:
396 | f.write(f"{item:.4f}")
397 | f.write(' ')
398 | f.write(f'avg: {p_auc_mean:.4f}\n')
399 |
400 | f.write('p_pro: ')
401 | for item in p_pro_list:
402 | f.write(f"{item:.4f}")
403 | f.write(' ')
404 | f.write(f'avg: {p_pro_mean:.4f}\n')
--------------------------------------------------------------------------------
/models/de_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | try:
5 | from torch.hub import load_state_dict_from_url
6 | except ImportError:
7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
8 | from typing import Type, Any, Callable, Union, List, Optional
9 | import pdb
10 |
11 |
12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
14 | 'wide_resnet50_2', 'wide_resnet101_2']
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
27 | }
28 |
29 |
30 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
31 | """3x3 convolution with padding"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=dilation, groups=groups, bias=False, dilation=dilation)
34 |
35 |
36 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
37 | """1x1 convolution"""
38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
39 |
40 | def deconv2x2(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
41 | """1x1 convolution"""
42 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=stride,
43 | groups=groups, bias=False, dilation=dilation)
44 |
45 |
46 | class BasicBlock(nn.Module):
47 | expansion: int = 1
48 |
49 | def __init__(
50 | self,
51 | inplanes: int,
52 | planes: int,
53 | stride: int = 1,
54 | upsample: Optional[nn.Module] = None,
55 | groups: int = 1,
56 | base_width: int = 64,
57 | dilation: int = 1,
58 | norm_layer: Optional[Callable[..., nn.Module]] = None
59 | ) -> None:
60 | super(BasicBlock, self).__init__()
61 | if norm_layer is None:
62 | norm_layer = nn.BatchNorm2d
63 | if groups != 1 or base_width != 64:
64 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
65 | if dilation > 1:
66 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
67 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
68 | if stride == 2:
69 | self.conv1 = deconv2x2(inplanes, planes, stride)
70 | else:
71 | self.conv1 = conv3x3(inplanes, planes, stride)
72 | self.bn1 = norm_layer(planes)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.conv2 = conv3x3(planes, planes)
75 | self.bn2 = norm_layer(planes)
76 | self.upsample = upsample
77 | self.stride = stride
78 |
79 | def forward(self, x: Tensor) -> Tensor:
80 | identity = x
81 |
82 | out = self.conv1(x)
83 | out = self.bn1(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv2(out)
87 | out = self.bn2(out)
88 |
89 | if self.upsample is not None:
90 | identity = self.upsample(x)
91 |
92 | out += identity
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class Bottleneck(nn.Module):
99 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
100 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
101 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
102 | # This variant is also known as ResNet V1.5 and improves accuracy according to
103 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
104 |
105 | expansion: int = 4
106 |
107 | def __init__(
108 | self,
109 | inplanes: int,
110 | planes: int,
111 | stride: int = 1,
112 | upsample: Optional[nn.Module] = None,
113 | groups: int = 1,
114 | base_width: int = 64,
115 | dilation: int = 1,
116 | norm_layer: Optional[Callable[..., nn.Module]] = None
117 | ) -> None:
118 | super(Bottleneck, self).__init__()
119 | if norm_layer is None:
120 | norm_layer = nn.BatchNorm2d
121 | width = int(planes * (base_width / 64.)) * groups
122 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
123 | self.conv1 = conv1x1(inplanes, width)
124 | self.bn1 = norm_layer(width)
125 | if stride == 2:
126 | self.conv2 = deconv2x2(width, width, stride, groups, dilation)
127 | else:
128 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
129 | self.bn2 = norm_layer(width)
130 | self.conv3 = conv1x1(width, planes * self.expansion)
131 | self.bn3 = norm_layer(planes * self.expansion)
132 | self.relu = nn.ReLU(inplace=True)
133 | self.upsample = upsample
134 | self.stride = stride
135 |
136 | def forward(self, x: Tensor) -> Tensor:
137 | identity = x
138 |
139 | out = self.conv1(x)
140 | out = self.bn1(out)
141 | out = self.relu(out)
142 |
143 | out = self.conv2(out)
144 | out = self.bn2(out)
145 | out = self.relu(out)
146 |
147 | out = self.conv3(out)
148 | out = self.bn3(out)
149 |
150 | if self.upsample is not None:
151 | identity = self.upsample(x)
152 |
153 | out += identity
154 | out = self.relu(out)
155 |
156 | return out
157 |
158 |
159 | class ResNet(nn.Module):
160 |
161 | def __init__(
162 | self,
163 | block: Type[Union[BasicBlock, Bottleneck]],
164 | layers: List[int],
165 | num_classes: int = 1000,
166 | zero_init_residual: bool = False,
167 | groups: int = 1,
168 | width_per_group: int = 64,
169 | replace_stride_with_dilation: Optional[List[bool]] = None,
170 | norm_layer: Optional[Callable[..., nn.Module]] = None
171 | ) -> None:
172 | super(ResNet, self).__init__()
173 | if norm_layer is None:
174 | norm_layer = nn.BatchNorm2d
175 | self._norm_layer = norm_layer
176 |
177 | self.inplanes = 512 * block.expansion
178 | self.dilation = 1
179 | if replace_stride_with_dilation is None:
180 | # each element in the tuple indicates if we should replace
181 | # the 2x2 stride with a dilated convolution instead
182 | replace_stride_with_dilation = [False, False, False]
183 | if len(replace_stride_with_dilation) != 3:
184 | raise ValueError("replace_stride_with_dilation should be None "
185 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
186 | self.groups = groups
187 | self.base_width = width_per_group
188 | #self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
189 | # bias=False)
190 | #self.bn1 = norm_layer(self.inplanes)
191 | #self.relu = nn.ReLU(inplace=True)
192 | #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
193 | self.layer1 = self._make_layer(block, 256, layers[0], stride=2)
194 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
195 | dilate=replace_stride_with_dilation[0])
196 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2,
197 | dilate=replace_stride_with_dilation[1])
198 | #self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
199 | # dilate=replace_stride_with_dilation[2])
200 | #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
201 | #self.fc = nn.Linear(512 * block.expansion, num_classes)
202 |
203 | for m in self.modules():
204 | if isinstance(m, nn.Conv2d):
205 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
206 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
207 | nn.init.constant_(m.weight, 1)
208 | nn.init.constant_(m.bias, 0)
209 |
210 | # Zero-initialize the last BN in each residual branch,
211 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
212 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
213 | if zero_init_residual:
214 | for m in self.modules():
215 | if isinstance(m, Bottleneck):
216 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
217 | elif isinstance(m, BasicBlock):
218 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
219 |
220 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
221 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
222 | norm_layer = self._norm_layer
223 | upsample = None
224 | previous_dilation = self.dilation
225 | if dilate:
226 | self.dilation *= stride
227 | stride = 1
228 | if stride != 1 or self.inplanes != planes * block.expansion:
229 | upsample = nn.Sequential(
230 | deconv2x2(self.inplanes, planes * block.expansion, stride),
231 | norm_layer(planes * block.expansion),
232 | )
233 |
234 | layers = []
235 | layers.append(block(self.inplanes, planes, stride, upsample, self.groups,
236 | self.base_width, previous_dilation, norm_layer))
237 | self.inplanes = planes * block.expansion
238 | for _ in range(1, blocks):
239 | layers.append(block(self.inplanes, planes, groups=self.groups,
240 | base_width=self.base_width, dilation=self.dilation,
241 | norm_layer=norm_layer))
242 |
243 | return nn.Sequential(*layers)
244 |
245 | def _forward_impl(self, x: Tensor) -> Tensor:
246 | # See note [TorchScript super()]
247 | #x = self.conv1(x)
248 | #x = self.bn1(x)
249 | #x = self.relu(x)
250 | #x = self.maxpool(x)
251 | # pdb.set_trace()
252 | feature_a = self.layer1(x) # 512*8*8->256*16*16
253 | feature_b = self.layer2(feature_a) # 256*16*16->128*32*32
254 | feature_c = self.layer3(feature_b) # 128*32*32->64*64*64
255 | #feature_d = self.layer4(feature_c) # 64*64*64->128*32*32
256 |
257 | #x = self.avgpool(feature_d)
258 | #x = torch.flatten(x, 1)
259 | #x = self.fc(x)
260 |
261 | return [feature_c, feature_b, feature_a]
262 |
263 | def forward(self, x: Tensor) -> Tensor:
264 | return self._forward_impl(x)
265 |
266 |
267 | def _resnet(
268 | arch: str,
269 | block: Type[Union[BasicBlock, Bottleneck]],
270 | layers: List[int],
271 | pretrained: bool,
272 | progress: bool,
273 | **kwargs: Any
274 | ) -> ResNet:
275 | model = ResNet(block, layers, **kwargs)
276 | if pretrained:
277 | state_dict = load_state_dict_from_url(model_urls[arch],
278 | progress=progress)
279 | #for k,v in list(state_dict.items()):
280 | # if 'layer4' in k or 'fc' in k:
281 | # state_dict.pop(k)
282 | model.load_state_dict(state_dict)
283 | return model
284 |
285 |
286 | def de_resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
287 | r"""ResNet-18 model from
288 | `"Deep Residual Learning for Image Recognition" `_.
289 | Args:
290 | pretrained (bool): If True, returns a model pre-trained on ImageNet
291 | progress (bool): If True, displays a progress bar of the download to stderr
292 | """
293 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
294 | **kwargs)
295 |
296 |
297 | def de_resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
298 | r"""ResNet-34 model from
299 | `"Deep Residual Learning for Image Recognition" `_.
300 | Args:
301 | pretrained (bool): If True, returns a model pre-trained on ImageNet
302 | progress (bool): If True, displays a progress bar of the download to stderr
303 | """
304 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
305 | **kwargs)
306 |
307 |
308 | def de_resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
309 | r"""ResNet-50 model from
310 | `"Deep Residual Learning for Image Recognition" `_.
311 | Args:
312 | pretrained (bool): If True, returns a model pre-trained on ImageNet
313 | progress (bool): If True, displays a progress bar of the download to stderr
314 | """
315 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
316 | **kwargs)
317 |
318 |
319 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
320 | r"""ResNet-101 model from
321 | `"Deep Residual Learning for Image Recognition" `_.
322 | Args:
323 | pretrained (bool): If True, returns a model pre-trained on ImageNet
324 | progress (bool): If True, displays a progress bar of the download to stderr
325 | """
326 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
327 | **kwargs)
328 |
329 |
330 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
331 | r"""ResNet-152 model from
332 | `"Deep Residual Learning for Image Recognition" `_.
333 | Args:
334 | pretrained (bool): If True, returns a model pre-trained on ImageNet
335 | progress (bool): If True, displays a progress bar of the download to stderr
336 | """
337 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
338 | **kwargs)
339 |
340 |
341 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
342 | r"""ResNeXt-50 32x4d model from
343 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
344 | Args:
345 | pretrained (bool): If True, returns a model pre-trained on ImageNet
346 | progress (bool): If True, displays a progress bar of the download to stderr
347 | """
348 | kwargs['groups'] = 32
349 | kwargs['width_per_group'] = 4
350 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
351 | pretrained, progress, **kwargs)
352 |
353 |
354 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
355 | r"""ResNeXt-101 32x8d model from
356 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
357 | Args:
358 | pretrained (bool): If True, returns a model pre-trained on ImageNet
359 | progress (bool): If True, displays a progress bar of the download to stderr
360 | """
361 | kwargs['groups'] = 32
362 | kwargs['width_per_group'] = 8
363 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
364 | pretrained, progress, **kwargs)
365 |
366 |
367 | def de_wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
368 | r"""Wide ResNet-50-2 model from
369 | `"Wide Residual Networks" `_.
370 | The model is the same as ResNet except for the bottleneck number of channels
371 | which is twice larger in every block. The number of channels in outer 1x1
372 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
373 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
374 | Args:
375 | pretrained (bool): If True, returns a model pre-trained on ImageNet
376 | progress (bool): If True, displays a progress bar of the download to stderr
377 | """
378 | kwargs['width_per_group'] = 64 * 2
379 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
380 | pretrained, progress, **kwargs)
381 |
382 |
383 | def de_wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
384 | r"""Wide ResNet-101-2 model from
385 | `"Wide Residual Networks" `_.
386 | The model is the same as ResNet except for the bottleneck number of channels
387 | which is twice larger in every block. The number of channels in outer 1x1
388 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
389 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
390 | Args:
391 | pretrained (bool): If True, returns a model pre-trained on ImageNet
392 | progress (bool): If True, displays a progress bar of the download to stderr
393 | """
394 | kwargs['width_per_group'] = 64 * 2
395 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
396 | pretrained, progress, **kwargs)
--------------------------------------------------------------------------------
/vit_version/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 | from .utils import to_2tuple
9 | import math
10 | import pdb
11 | from models.recons_net import patch_to_tensor
12 |
13 |
14 | class Bottleneck(nn.Module):
15 | expansion = 4
16 |
17 | def __init__(self, inplanes, planes, stride=1):
18 | super().__init__()
19 |
20 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
21 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.relu1 = nn.ReLU(inplace=True)
24 |
25 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 | self.relu2 = nn.ReLU(inplace=True)
28 |
29 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
30 |
31 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
32 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
33 | self.relu3 = nn.ReLU(inplace=True)
34 |
35 | self.downsample = None
36 | self.stride = stride
37 |
38 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
39 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
40 | self.downsample = nn.Sequential(OrderedDict([
41 | ("-1", nn.AvgPool2d(stride)),
42 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
43 | ("1", nn.BatchNorm2d(planes * self.expansion))
44 | ]))
45 |
46 | def forward(self, x: torch.Tensor):
47 | identity = x
48 |
49 | out = self.relu1(self.bn1(self.conv1(x)))
50 | out = self.relu2(self.bn2(self.conv2(out)))
51 | out = self.avgpool(out)
52 | out = self.bn3(self.conv3(out))
53 |
54 | if self.downsample is not None:
55 | identity = self.downsample(x)
56 |
57 | out += identity
58 | out = self.relu3(out)
59 | return out
60 |
61 |
62 | class AttentionPool2d(nn.Module):
63 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
64 | super().__init__()
65 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
66 | self.k_proj = nn.Linear(embed_dim, embed_dim)
67 | self.q_proj = nn.Linear(embed_dim, embed_dim)
68 | self.v_proj = nn.Linear(embed_dim, embed_dim)
69 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
70 | self.num_heads = num_heads
71 |
72 | def forward(self, x):
73 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
74 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
75 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
76 | x, _ = F.multi_head_attention_forward(
77 | query=x, key=x, value=x,
78 | embed_dim_to_check=x.shape[-1],
79 | num_heads=self.num_heads,
80 | q_proj_weight=self.q_proj.weight,
81 | k_proj_weight=self.k_proj.weight,
82 | v_proj_weight=self.v_proj.weight,
83 | in_proj_weight=None,
84 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
85 | bias_k=None,
86 | bias_v=None,
87 | add_zero_attn=False,
88 | dropout_p=0,
89 | out_proj_weight=self.c_proj.weight,
90 | out_proj_bias=self.c_proj.bias,
91 | use_separate_proj_weight=True,
92 | training=self.training,
93 | need_weights=False
94 | )
95 |
96 | return x[0]
97 |
98 |
99 | class ModifiedResNet(nn.Module):
100 | """
101 | A ResNet class that is similar to torchvision's but contains the following changes:
102 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
103 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
104 | - The final pooling layer is a QKV attention instead of an average pool
105 | """
106 |
107 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
108 | super().__init__()
109 | self.output_dim = output_dim
110 | self.input_resolution = input_resolution
111 |
112 | # the 3-layer stem
113 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
114 | self.bn1 = nn.BatchNorm2d(width // 2)
115 | self.relu1 = nn.ReLU(inplace=True)
116 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
117 | self.bn2 = nn.BatchNorm2d(width // 2)
118 | self.relu2 = nn.ReLU(inplace=True)
119 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
120 | self.bn3 = nn.BatchNorm2d(width)
121 | self.relu3 = nn.ReLU(inplace=True)
122 | self.avgpool = nn.AvgPool2d(2)
123 |
124 | # residual layers
125 | self._inplanes = width # this is a *mutable* variable used during construction
126 | self.layer1 = self._make_layer(width, layers[0])
127 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
128 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
129 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
130 |
131 | embed_dim = width * 32 # the ResNet feature dimension
132 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
133 |
134 | def _make_layer(self, planes, blocks, stride=1):
135 | layers = [Bottleneck(self._inplanes, planes, stride)]
136 |
137 | self._inplanes = planes * Bottleneck.expansion
138 | for _ in range(1, blocks):
139 | layers.append(Bottleneck(self._inplanes, planes))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 | def stem(x):
145 | x = self.relu1(self.bn1(self.conv1(x)))
146 | x = self.relu2(self.bn2(self.conv2(x)))
147 | x = self.relu3(self.bn3(self.conv3(x)))
148 | x = self.avgpool(x)
149 | return x
150 |
151 | x = x.type(self.conv1.weight.dtype)
152 | x = stem(x)
153 | x1 = self.layer1(x)
154 | x2 = self.layer2(x1)
155 | x3 = self.layer3(x2)
156 | x4 = self.layer4(x3)
157 | y = self.attnpool(x4)
158 |
159 | return y, [x, x1, x2, x3, x4]
160 |
161 |
162 | class LayerNorm(nn.LayerNorm):
163 | """Subclass torch's LayerNorm to handle fp16."""
164 |
165 | def forward(self, x: torch.Tensor):
166 | orig_type = x.dtype
167 | ret = super().forward(x.type(torch.float32))
168 | return ret.type(orig_type)
169 |
170 |
171 | class QuickGELU(nn.Module):
172 | def forward(self, x: torch.Tensor):
173 | return x * torch.sigmoid(1.702 * x)
174 |
175 |
176 | class ResidualAttentionBlock(nn.Module):
177 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
178 | super().__init__()
179 |
180 | self.attn = nn.MultiheadAttention(d_model, n_head)
181 | self.ln_1 = LayerNorm(d_model)
182 | self.mlp = nn.Sequential(OrderedDict([
183 | ("c_fc", nn.Linear(d_model, d_model * 4)),
184 | ("gelu", QuickGELU()),
185 | ("c_proj", nn.Linear(d_model * 4, d_model))
186 | ]))
187 | self.ln_2 = LayerNorm(d_model)
188 | self.attn_mask = attn_mask
189 |
190 | def attention(self, x: torch.Tensor):
191 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
192 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
193 |
194 | def forward(self, x: torch.Tensor):
195 | # pdb.set_trace()
196 | x = x + self.attention(self.ln_1(x))
197 | x = x + self.mlp(self.ln_2(x))
198 | return x
199 |
200 |
201 | class Transformer(nn.Module):
202 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
203 | super().__init__()
204 | self.width = width
205 | self.layers = layers
206 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
207 |
208 | # def forward(self, x: torch.Tensor):
209 | # return self.resblocks(x)
210 |
211 | def forward(self, x, f_list=[]):
212 | # pdb.set_trace()
213 | out_tokens = []
214 | idx = 0
215 | for r in self.resblocks:
216 | # pdb.set_trace()
217 | idx+=1
218 | x = r(x)
219 | if idx in f_list:
220 | if len(x)==2:
221 | out_tokens.append(x[0])
222 | out_tokens.append(x[1])
223 | else:
224 | out_tokens.append(x.permute(1,0,2))
225 | return x, out_tokens #self.resblocks(x)
226 |
227 |
228 | class VisionTransformer(nn.Module):
229 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
230 | super().__init__()
231 | self.input_resolution = input_resolution
232 | self.grid_size = (input_resolution // patch_size) # modify
233 | self.output_dim = output_dim
234 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
235 |
236 | scale = width ** -0.5
237 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
238 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
239 | self.ln_pre = LayerNorm(width)
240 |
241 | self.transformer = Transformer(width, layers, heads)
242 |
243 | self.ln_post = LayerNorm(width)
244 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
245 |
246 | def forward(self, x: torch.Tensor, f_list: list):
247 | # pdb.set_trace()
248 | x = self.conv1(x) # shape = [*, width, grid, grid]
249 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
250 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
251 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
252 | x = x + self.positional_embedding.to(x.dtype)
253 | # pdb.set_trace()
254 | x = self.ln_pre(x)
255 |
256 | x = x.permute(1, 0, 2) # NLD -> LND
257 | x, patch_tokens = self.transformer(x, f_list)
258 | x = x.permute(1, 0, 2) # LND -> NLD
259 |
260 | x = self.ln_post(x[:, 0, :])
261 | # pdb.set_trace()
262 | if self.proj is not None:
263 | x = x @ self.proj
264 |
265 | assert len(patch_tokens) == 3
266 | patch_tokens = [patch_to_tensor(x[:,1:]) for x in patch_tokens]
267 |
268 | return x, patch_tokens
269 |
270 |
271 | class CLIP(nn.Module):
272 | def __init__(self,
273 | embed_dim: int,
274 | # vision
275 | image_resolution: int,
276 | vision_layers: Union[Tuple[int, int, int, int], int],
277 | vision_width: int,
278 | vision_patch_size: int,
279 | # text
280 | context_length: int,
281 | vocab_size: int,
282 | transformer_width: int,
283 | transformer_heads: int,
284 | transformer_layers: int
285 | ):
286 | super().__init__()
287 |
288 | self.context_length = context_length
289 |
290 | if isinstance(vision_layers, (tuple, list)):
291 | vision_heads = vision_width * 32 // 64
292 | self.visual = ModifiedResNet(
293 | layers=vision_layers,
294 | output_dim=embed_dim,
295 | heads=vision_heads,
296 | input_resolution=image_resolution,
297 | width=vision_width
298 | )
299 | else:
300 | vision_heads = vision_width // 64
301 | self.visual = VisionTransformer(
302 | input_resolution=image_resolution,
303 | patch_size=vision_patch_size,
304 | width=vision_width,
305 | layers=vision_layers,
306 | heads=vision_heads,
307 | output_dim=embed_dim
308 | )
309 |
310 | self.transformer = Transformer(
311 | width=transformer_width,
312 | layers=transformer_layers,
313 | heads=transformer_heads,
314 | attn_mask=self.build_attention_mask()
315 | )
316 |
317 | self.vocab_size = vocab_size
318 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
319 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
320 | self.ln_final = LayerNorm(transformer_width)
321 |
322 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
323 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
324 |
325 | self.initialize_parameters()
326 |
327 | def initialize_parameters(self):
328 | nn.init.normal_(self.token_embedding.weight, std=0.02)
329 | nn.init.normal_(self.positional_embedding, std=0.01)
330 |
331 | if isinstance(self.visual, ModifiedResNet):
332 | if self.visual.attnpool is not None:
333 | std = self.visual.attnpool.c_proj.in_features ** -0.5
334 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
335 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
336 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
337 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
338 |
339 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
340 | for name, param in resnet_block.named_parameters():
341 | if name.endswith("bn3.weight"):
342 | nn.init.zeros_(param)
343 |
344 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
345 | attn_std = self.transformer.width ** -0.5
346 | fc_std = (2 * self.transformer.width) ** -0.5
347 | for block in self.transformer.resblocks:
348 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
349 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
350 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
351 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
352 |
353 | if self.text_projection is not None:
354 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
355 |
356 | def build_attention_mask(self):
357 | # lazily create causal attention mask, with full attention between the vision tokens
358 | # pytorch uses additive attention mask; fill with -inf
359 | mask = torch.empty(self.context_length, self.context_length)
360 | mask.fill_(float("-inf"))
361 | mask.triu_(1) # zero out the lower diagonal
362 | return mask
363 |
364 | @property
365 | def dtype(self):
366 | return self.visual.conv1.weight.dtype
367 |
368 | def encode_image(self, image, f_list):
369 | # pdb.set_trace()
370 | return self.visual(image.type(self.dtype), f_list)
371 |
372 | # def encode_text(self, text):
373 | # # pdb.set_trace()
374 | # x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
375 |
376 | # x = x + self.positional_embedding.type(self.dtype)
377 | # x = x.permute(1, 0, 2) # NLD -> LND
378 | # x, _ = self.transformer(x)
379 | # x = x.permute(1, 0, 2) # LND -> NLD
380 | # x = self.ln_final(x).type(self.dtype)
381 |
382 | # # x.shape = [batch_size, n_ctx, transformer.width]
383 | # # take features from the eot embedding (eot_token is the highest number in each sequence)
384 | # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
385 |
386 | # return x
387 |
388 | def encode_text(self, prompts, tokenized_prompts):
389 |
390 | x = prompts + self.positional_embedding.type(self.dtype)
391 |
392 | x = x.permute(1, 0, 2) # NLD -> LND
393 | x, _ = self.transformer(x)
394 | x = x.permute(1, 0, 2) # LND -> NLD
395 | x = self.ln_final(x).type(self.dtype)
396 |
397 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
398 |
399 | return x
400 |
401 | def forward(self, image, text):
402 | image_features = self.encode_image(image)
403 | text_features = self.encode_text(text)
404 |
405 | # normalized features
406 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
407 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
408 |
409 | # cosine similarity as logits
410 | logit_scale = self.logit_scale.exp()
411 | logits_per_image = logit_scale * image_features @ text_features.t()
412 | logits_per_text = logits_per_image.t()
413 |
414 | # shape = [global_batch_size, global_batch_size]
415 | return logits_per_image, logits_per_text
416 |
417 |
418 | def convert_weights(model: nn.Module):
419 | """Convert applicable model parameters to fp16"""
420 |
421 | def _convert_weights_to_fp16(l):
422 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
423 | l.weight.data = l.weight.data.half()
424 | if l.bias is not None:
425 | l.bias.data = l.bias.data.half()
426 |
427 | if isinstance(l, nn.MultiheadAttention):
428 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
429 | tensor = getattr(l, attr)
430 | if tensor is not None:
431 | tensor.data = tensor.data.half()
432 |
433 | for name in ["text_projection", "proj"]:
434 | if hasattr(l, name):
435 | attr = getattr(l, name)
436 | if attr is not None:
437 | attr.data = attr.data.half()
438 |
439 | model.apply(_convert_weights_to_fp16)
440 |
441 |
442 | def build_model(state_dict: dict, image_size: int):
443 | vit = "visual.proj" in state_dict
444 | # pdb.set_trace()
445 | if vit:
446 | vision_width = state_dict["visual.conv1.weight"].shape[0]
447 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
448 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
449 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
450 | image_resolution = vision_patch_size * grid_size
451 | else:
452 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
453 | vision_layers = tuple(counts)
454 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
455 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
456 | vision_patch_size = None
457 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
458 | image_resolution = output_width * 32
459 |
460 | embed_dim = state_dict["text_projection"].shape[1]
461 | context_length = state_dict["positional_embedding"].shape[0]
462 | vocab_size = state_dict["token_embedding.weight"].shape[0]
463 | transformer_width = state_dict["ln_final.weight"].shape[0]
464 | transformer_heads = transformer_width // 64
465 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
466 |
467 | # model = CLIP(
468 | # embed_dim,
469 | # image_resolution, vision_layers, vision_width, vision_patch_size,
470 | # context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
471 | # )
472 |
473 | model = CLIP(
474 | embed_dim,
475 | image_size, vision_layers, vision_width, vision_patch_size,
476 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
477 | )
478 |
479 | for key in ["input_resolution", "context_length", "vocab_size"]:
480 | if key in state_dict:
481 | del state_dict[key]
482 | # pdb.set_trace()
483 | resize_pos_embed(state_dict, model) # modify
484 | convert_weights(model)
485 | model.load_state_dict(state_dict)
486 | return model.eval()
487 |
488 |
489 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
490 | # Rescale the grid of position embeddings when loading from state_dict
491 | flag = 1
492 | old_pos_embed = state_dict.get('visual.positional_embedding', None)
493 | if old_pos_embed is None:
494 | flag = 0
495 | old_pos_embed = state_dict.get('visual.attnpool.positional_embedding', None)
496 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
497 | return
498 | grid_size = to_2tuple(model.visual.grid_size)
499 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
500 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
501 | if new_seq_len == old_pos_embed.shape[0]:
502 | return
503 |
504 | if extra_tokens:
505 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
506 | else:
507 | pos_emb_tok, pos_emb_img = None, old_pos_embed
508 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
509 |
510 | print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
511 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
512 | pos_emb_img = F.interpolate(
513 | pos_emb_img,
514 | size=grid_size,
515 | mode=interpolation,
516 | # antialias=antialias,
517 | align_corners=False,
518 | )
519 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
520 | if pos_emb_tok is not None:
521 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
522 | else:
523 | new_pos_embed = pos_emb_img
524 | if flag:
525 | state_dict['visual.positional_embedding'] = new_pos_embed
526 | else:
527 | state_dict['visual.attnpool.positional_embedding'] = new_pos_embed
528 |
--------------------------------------------------------------------------------