├── models
├── __init__.py
├── cifar10_model.py
├── mvtec_base_model.py
├── shanghaitech_base_model.py
├── mvtec_model.py
└── shanghaitech_model.py
├── datasets
├── __init__.py
├── data_manager.py
├── base.py
├── cifar10.py
├── shanghaitech.py
├── mvtec.py
└── shanghaitech_test.py
├── trainers
├── __init__.py
├── trainer_shanghaitech.py
├── train_cifar10.py
└── trainer_mvtec.py
├── .gitignore
├── images
├── tb1.png
├── tb2.png
├── tb3.png
└── mocca.png
├── requirements.txt
├── README.md
├── test_multiple_models_cifar10.py
├── main_cifar10.py
├── utils.py
├── main_shanghaitech.py
└── main_mvtec.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bin/*
2 | *.sh
3 | output
4 | *.pyc
5 | *.log
6 | *__pycache__*
--------------------------------------------------------------------------------
/images/tb1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fvmassoli/mocca-anomaly-detection/HEAD/images/tb1.png
--------------------------------------------------------------------------------
/images/tb2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fvmassoli/mocca-anomaly-detection/HEAD/images/tb2.png
--------------------------------------------------------------------------------
/images/tb3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fvmassoli/mocca-anomaly-detection/HEAD/images/tb3.png
--------------------------------------------------------------------------------
/images/mocca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fvmassoli/mocca-anomaly-detection/HEAD/images/mocca.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.5
2 | opencv-python==4.5.1.48
3 | pandas==1.0.4
4 | Pillow==7.2.0
5 | pudb==2019.2
6 | python==3.6.9
7 | scikit-build==0.11.1
8 | scikit-image==0.14.2
9 | scikit-learn==0.23.2
10 | sklearn==0.0
11 | tensorboard==1.14.0
12 | tensorboardX==2.0
13 | torch==1.4.0
14 | torchvision==0.4.2
15 | tqdm==4.56.0
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MOCCA: Multi-Layer One-Class ClassificAtion for Anomaly Detection
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | This repository contains the code relative to the paper "[MOCCA: Multi-Layer One-Class ClassificAtion for Anomaly Detection](https://ieeexplore.ieee.org/document/9640579)" by Fabio Valerio Massoli (ISTI - CNR), Fabrizio Falchi (ISTI - CNR), Alperen Kantarci (ITU), Şeymanur Akti (ITU), Hazim Kemal Ekenel (ITU), Giuseppe Amato (ISTI - CNR).
11 |
12 | It reports a new technique to detect anomalies based on a layer-wise paradigm to exploit the features maps generated at different depths of a Deep Learning model.
13 |
14 | The paper has been accepted for publication in the [IEEE Transactions on Neural Networks and Learning Systems, Special Issue on Deep Learning for Anomaly Detection](https://ieeexplore.ieee.org/document/9640579).
15 |
16 | DOI: [10.1109/TNNLS.2021.3130074](https://doi.org/10.1109/TNNLS.2021.3130074).
17 |
18 | **Please note:**
19 | We are researchers, not a software company, and have no personnel devoted to documenting and maintaing this research code. Therefore this code is offered "AS IS". Exact reproduction of the numbers in the paper depends on exact reproduction of many factors, including the version of all software dependencies and the choice of underlying hardware (GPU model, etc). Therefore you should expect to need to re-tune your hyperparameters slightly for your new setup.
20 |
21 |
22 | ## How to run the code
23 |
24 | Before to run the code, make sure that your system has the proper packages installed. You can have a look at the [requirements.txt](https://github.com/fvmassoli/mocca-anomaly-detection/blob/main/requirements.txt) file.
25 |
26 |
27 | Minimal usage (CIFAR10):
28 |
29 | ```
30 | python main_cifar10.py -ptr -tr -tt -zl 128 -nc -dp
31 | ```
32 |
33 | Minimal usage (MVTec):
34 |
35 | ```
36 | python main_mvtec.py -ptr -tr -tt -zl 128 -nc -dp --use-selector
37 | ```
38 |
39 | Minimal usage (ShanghaiTech):
40 |
41 | ```
42 | python main_shanghaitech.py -dp -ee -tt -zl 1024 -ll -use
43 | ```
44 |
45 | ## Reference
46 | For all the details about the training procedure and the experimental results, please have a look at the [paper](https://arxiv.org/abs/2012.12111).
47 |
48 | To cite our work, please use the following form
49 |
50 | ```
51 | @article{massoli2021mocca,
52 | title={MOCCA: Multilayer One-Class Classification for Anomaly Detection},
53 | author={Massoli, Fabio Valerio and Falchi, Fabrizio and Kantarci, Alperen and Akti, {\c{S}}eymanur and Ekenel, Hazim Kemal and Amato, Giuseppe},
54 | journal={IEEE Transactions on Neural Networks and Learning Systems},
55 | year={2021},
56 | publisher={IEEE}
57 | }
58 | ```
59 |
60 | ## Contacts
61 | If you have any question about our work, please contact [Dr. Fabio Valerio Massoli](mailto:fabio.massoli@isti.cnr.it).
62 |
63 | Have fun! :-D
64 |
--------------------------------------------------------------------------------
/datasets/data_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 |
5 | from .mvtec import MVTec_DataHolder
6 | from .cifar10 import CIFAR10_DataHolder
7 | from .shanghaitech import ShanghaiTech_DataHolder
8 |
9 |
10 | AVAILABLE_DATASETS = ('cifar10', 'ShanghaiTech', 'MVTec_Anomaly')
11 |
12 |
13 | class DataManager(object):
14 | """"Init class to manage and load data
15 |
16 | """
17 | def __init__(self, dataset_name: str, data_path: str, normal_class: int, clip_length: int=16, only_test: bool=False):
18 | """Init the DataManager
19 |
20 | Parameters
21 | ----------
22 | dataset_name : str
23 | Name of the dataset
24 | data_path : str
25 | Path to the dataset
26 | normal_class : int
27 | Index of the normal class
28 | clip_length: int
29 | Number of video frames in each clip (ShanghaiTech only)
30 | only_test : bool
31 | True if we are in test model, False otherwise
32 |
33 | """
34 | self.dataset_name = dataset_name
35 | self.data_path = data_path
36 | self.normal_class = normal_class
37 | self.clip_length = clip_length
38 | self.only_test = only_test
39 |
40 | # Immediately check if the data are available
41 | self.__check_dataset()
42 |
43 | def __check_dataset(self) -> None:
44 | """Checks if the required dataset is available
45 |
46 | """
47 | assert self.dataset_name in AVAILABLE_DATASETS, f"{self.dataset_name} dataset is not available"
48 | assert os.path.exists(self.data_path), f"{self.dataset_name} dataset is available but not found at: \n{self.data_path}"
49 |
50 | def get_data_holder(self):
51 | """Returns the data holder for the required dataset
52 |
53 | Rerurns
54 | -------
55 | MVTec_DataHolder : MVTec_DataHolder
56 | Class to handle datasets
57 |
58 | """
59 | if self.dataset_name == 'cifar10':
60 | return CIFAR10_DataHolder(root=self.data_path, normal_class=self.normal_class)
61 |
62 | if self.dataset_name == 'ShanghaiTech':
63 | return ShanghaiTech_DataHolder(root=self.data_path,clip_length=self.clip_length)
64 |
65 | if self.dataset_name == 'MVTec_Anomaly':
66 | texture_classes = tuple(["carpet", "grid", "leather", "tile", "wood"])
67 | object_classes = tuple(["bottle", "hazelnut", "metal_nut", "screw"])
68 | # object_classes2 = tuple(["capsule", "toothbrush", "cable", "pill", "transistor", "zipper"])
69 |
70 | # check if the selected class is texture-type
71 | is_texture = self.normal_class in texture_classes
72 | if is_texture:
73 | image_size = 512
74 | patch_size = 64
75 | rotation_range = (0, 45)
76 | else:
77 | patch_size = 1
78 | image_size = 128
79 | # For some object-type classes, the anomalies are the rotations themselves
80 | # thus, we don't have to apply rotations as data augmentation
81 | rotation_range = (-45, 45) if self.normal_class in object_classes else (0, 0)
82 |
83 | return MVTec_DataHolder(
84 | data_path=self.data_path,
85 | category=self.normal_class,
86 | image_size=image_size,
87 | patch_size=patch_size,
88 | rotation_range=rotation_range,
89 | is_texture=is_texture
90 | )
91 |
--------------------------------------------------------------------------------
/test_multiple_models_cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import logging
5 | import argparse
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | import torch
10 |
11 | from trainer_svdd import test
12 | from datasets.main import load_dataset
13 | from models.deep_svdd.deep_svdd_mnist import MNIST_LeNet_Autoencoder, MNIST_LeNet
14 | from models.deep_svdd.deep_svdd_cifar10 import CIFAR10_LeNet_Autoencoder, CIFAR10_LeNet
15 |
16 |
17 | parser = argparse.ArgumentParser('AD')
18 | ## General config
19 | parser.add_argument('--n_jobs_dataloader', type=int, default=0, help='Number of workers for data loading. 0 means that the data will be loaded in the main process.')
20 | ## Model config
21 | parser.add_argument('-zl', '--code-length', default=32, type=int, help='Code length (default: 32)')
22 | parser.add_argument('-ck', '--model-ckp', help='Model checkpoint')
23 | ## Data
24 | parser.add_argument('-ax', '--aux-data-filename', default='/media/fabiovalerio/datasets/ti_500K_pseudo_labeled.pickle', help='Path to unlabelled data')
25 | parser.add_argument('-dn', '--dataset-name', choices=('mnist', 'cifar10'), default='mnist', help='Dataset (default: mnist)')
26 | parser.add_argument('-ul', '--unlabelled-data', action="store_true", help='Use unlabelled data (default: False)')
27 | parser.add_argument('-aux', '--unl-data-path', default="/media/fabiovalerio/datasets/ti_500K_pseudo_labeled.pickle", help='Path to unalbelled data')
28 | ## Training config
29 | parser.add_argument('-bs', '--batch-size', type=int, default=200, help='Batch size (default: 200)')
30 | parser.add_argument('-bd', '--boundary', choices=("hard", "soft"), default="soft", help='Boundary (default: soft)')
31 | parser.add_argument('-ile', '--idx-list-enc', type=int, nargs='+', default=[], help='List of indices of model encoder')
32 | args = parser.parse_args()
33 |
34 |
35 | # Get data base path
36 | _user = os.environ['USER']
37 | if _user == 'fabiovalerio':
38 | data_path = '/media/fabiovalerio/datasets'
39 | elif _user == 'fabiom':
40 | data_path = '/mnt/datone/datasets/'
41 | else:
42 | raise NotImplementedError('Username %s not configured' % _user)
43 |
44 |
45 | def main():
46 | cuda_available = torch.cuda.is_available()
47 | device = torch.device('cuda' if cuda_available else 'cpu')
48 |
49 | boundary = args.model_ckp.split('/')[-1].split('-')[-3].split('_')[-1]
50 | normal_class = int(args.model_ckp.split('/')[-1].split('-')[2].split('_')[-1])
51 | if len(args.idx_list_enc) == 0:
52 | idx_list_enc = [int(i) for i in args.model_ckp.split('/')[-1].split('-')[-1].split('_')[-1].split('.')]
53 | else:
54 | idx_list_enc = args.idx_list_enc
55 |
56 | # LOAD DATA
57 | dataset = load_dataset(args.dataset_name, data_path, normal_class, args.unlabelled_data, args.unl_data_path)
58 |
59 | print(
60 | f"Start test with params"
61 | f"\n\t\t\t\tCode length : {args.code_length}"
62 | f"\n\t\t\t\tEnc layer list : {idx_list_enc}"
63 | f"\n\t\t\t\tBoundary : {boundary}"
64 | f"\n\t\t\t\tNormal class : {normal_class}"
65 | )
66 |
67 | test_auc = []
68 | main_model_ckp_dir = args.model_ckp
69 | for m_ckp in tqdm(os.listdir(main_model_ckp_dir), total=len(os.listdir(main_model_ckp_dir)), leave=False):
70 | net_cehckpoint = os.path.join(main_model_ckp_dir, m_ckp)
71 |
72 | # Load model
73 | net = MNIST_LeNet(args.code_length) if args.dataset_name == 'mnist' else CIFAR10_LeNet(args.code_length)
74 | st_dict = torch.load(net_cehckpoint)
75 | net.load_state_dict(st_dict['net_state_dict'])
76 |
77 | # TEST
78 | test_auc_ = test(net, dataset, st_dict['R'], st_dict['c'], device, idx_list_enc, boundary, args)
79 | del net, st_dict
80 |
81 | test_auc.append(test_auc_)
82 |
83 | test_auc = np.asarray(test_auc)
84 | test_auc_m, test_auc_s = test_auc.mean(), test_auc.std()
85 | print("[")
86 | for tau in test_auc:
87 | print(tau, ", ")
88 | print("]")
89 | print(test_auc)
90 | print(f"{test_auc_m:.2f} $\pm$ {test_auc_s:.2f}")
91 |
92 |
93 | if __name__ == '__main__':
94 | main()
95 |
--------------------------------------------------------------------------------
/datasets/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta
2 | from abc import abstractmethod
3 |
4 | import torch
5 | import numpy as np
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class DatasetBase(Dataset):
10 | """
11 | Base class for all datasets.
12 | """
13 | __metaclass__ = ABCMeta
14 |
15 | @abstractmethod
16 | def test(self, *args):
17 | """
18 | Sets the dataset in test mode.
19 | """
20 | pass
21 |
22 | @property
23 | @abstractmethod
24 | def shape(self):
25 | """
26 | Returns the shape of examples.
27 | """
28 | pass
29 |
30 | @abstractmethod
31 | def __len__(self):
32 | """
33 | Returns the number of examples.
34 | """
35 | pass
36 |
37 | @abstractmethod
38 | def __getitem__(self, i):
39 | """
40 | Provides the i-th example.
41 | """
42 | pass
43 |
44 |
45 | class OneClassDataset(DatasetBase):
46 | """
47 | Base class for all one-class classification datasets.
48 | """
49 | __metaclass__ = ABCMeta
50 |
51 | @abstractmethod
52 | def val(self, *args):
53 | """
54 | Sets the dataset in validation mode.
55 | """
56 | pass
57 |
58 | @property
59 | @abstractmethod
60 | def test_classes(self):
61 | """
62 | Returns all test possible test classes.
63 | """
64 | pass
65 |
66 |
67 | class VideoAnomalyDetectionDataset(DatasetBase):
68 | """
69 | Base class for all video anomaly detection datasets.
70 | """
71 | __metaclass__ = ABCMeta
72 |
73 | @property
74 | @abstractmethod
75 | def test_videos(self):
76 | """
77 | Returns all test video ids.
78 | """
79 | pass
80 |
81 |
82 | @abstractmethod
83 | def __len__(self):
84 | """
85 | Returns the number of examples.
86 | """
87 | pass
88 |
89 | @property
90 | def raw_shape(self):
91 | """
92 | Workaround!
93 | """
94 | return self.shape
95 |
96 | @abstractmethod
97 | def __getitem__(self, i):
98 | """
99 | Provides the i-th example.
100 | """
101 | pass
102 |
103 | @abstractmethod
104 | def load_test_sequence_gt(self, video_id):
105 | # type: (str) -> np.ndarray
106 | """
107 | Loads the groundtruth of a test video in memory.
108 | :param video_id: the id of the test video for which the groundtruth has to be loaded.
109 | :return: the groundtruth of the video in a np.ndarray, with shape (n_frames,).
110 | """
111 | pass
112 |
113 | @property
114 | @abstractmethod
115 | def collate_fn(self):
116 | """
117 | Returns a function that decides how to merge a list of examples in a batch.
118 | """
119 | pass
120 |
121 |
122 | class ToFloatTensor3D(object):
123 | """ Convert videos to FloatTensors """
124 | def __init__(self, normalize=True):
125 | self._normalize = normalize
126 |
127 | def __call__(self, sample):
128 | if len(sample) == 3:
129 | X, Y, _ = sample
130 | else:
131 | X = sample
132 |
133 | # swap color axis because
134 | # numpy image: T x H x W x C
135 | X = X.transpose(3, 0, 1, 2)
136 | #Y = Y.transpose(3, 0, 1, 2)
137 |
138 | if self._normalize:
139 | X = X / 255.
140 |
141 | X = np.float32(X)
142 | return torch.from_numpy(X)
143 |
144 | class ToFloatTensor3DMask(object):
145 | """ Convert videos to FloatTensors """
146 | def __init__(self, normalize=True, has_x_mask=True, has_y_mask=True):
147 | self._normalize = normalize
148 | self.has_x_mask = has_x_mask
149 | self.has_y_mask = has_y_mask
150 |
151 | def __call__(self, sample):
152 | X = sample
153 | # swap color axis because
154 | # numpy image: T x H x W x C
155 | X = X.transpose(3, 0, 1, 2)
156 |
157 | X = np.float32(X)
158 |
159 | if self._normalize:
160 | if self.has_x_mask:
161 | X[:-1] = X[:-1] / 255.
162 | else:
163 | X = X / 255.
164 |
165 | return torch.from_numpy(X)
166 |
167 |
168 | class RemoveBackground:
169 |
170 | def __init__(self, threshold: float):
171 | self.threshold = threshold
172 |
173 | def __call__(self, sample: tuple) -> tuple:
174 | X, Y, background = sample
175 |
176 | mask = np.uint8(np.sum(np.abs(np.int32(X) - background), axis=-1) > self.threshold)
177 | mask = np.expand_dims(mask, axis=-1)
178 |
179 | mask = np.stack([binary_dilation(mask_frame, iterations=5) for mask_frame in mask])
180 |
181 | X *= mask
182 |
183 | return X, Y, background
--------------------------------------------------------------------------------
/models/cifar10_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def init_conv(out_channels: int, k_size: int = 5) -> nn.Module:
10 | """ Init convolutional layers.
11 |
12 | Parameters
13 | ----------
14 | k_size : int
15 | Kernel size
16 | out_channels : int
17 | Output features size
18 |
19 | Returns
20 | -------
21 | nn.Module :
22 | Conv2d layer
23 |
24 | """
25 | l = nn.Conv2d(
26 | in_channels=3 if out_channels==32 else out_channels//2,
27 | out_channels=out_channels,
28 | kernel_size=k_size,
29 | bias=False,
30 | padding=2
31 | )
32 | nn.init.xavier_uniform_(l.weight, gain=nn.init.calculate_gain('leaky_relu'))
33 | return l
34 |
35 |
36 | def init_deconv(out_channels: int, k_size: int = 5) -> nn.Module:
37 | """ Init deconv layers.
38 |
39 | Parameters
40 | ----------
41 | k_size : int
42 | Kernel size
43 | out_channels : int
44 | Input features size
45 |
46 | Returns
47 | -------
48 | nn.Module :
49 | ConvTranspose2d layer
50 |
51 | """
52 | l = nn.ConvTranspose2d(
53 | in_channels=out_channels,
54 | out_channels=3 if out_channels==32 else out_channels//2,
55 | kernel_size=k_size,
56 | bias=False,
57 | padding=2
58 | )
59 | nn.init.xavier_uniform_(l.weight, gain=nn.init.calculate_gain('leaky_relu'))
60 | return l
61 |
62 | def init_bn(num_features: int) -> nn.Module:
63 | """ Init BatchNorm layers.
64 |
65 | Parameters
66 | ----------
67 | num_features : int
68 | Number of input features
69 |
70 | """
71 | return nn.BatchNorm2d(num_features=num_features, eps=1e-04, affine=False)
72 |
73 |
74 | class BaseNet(nn.Module):
75 | """Base class for all neural networks.
76 |
77 | """
78 | def __init__(self):
79 | super(BaseNet, self).__init__()
80 |
81 | # init Logger to print model infos
82 | self.logger = logging.getLogger(self.__class__.__name__)
83 |
84 | # List of input/output features depths for the convolutional layers
85 | self.output_features_sizes = [32, 64, 128]
86 |
87 | def summary(self) -> None:
88 | """Network summary.
89 |
90 | """
91 | net_parameters = filter(lambda p: p.requires_grad, self.parameters())
92 | params = sum([np.prod(p.size()) for p in net_parameters])
93 | self.logger.info('Trainable parameters: {}'.format(params))
94 | self.logger.info(self)
95 |
96 |
97 | class CIFAR10_Encoder(BaseNet):
98 | """"Encoder network.
99 |
100 | """
101 | def __init__(self, code_length: int):
102 | """"Init encoder.
103 |
104 | Parameters
105 | ----------
106 | code_length : int
107 | Latent code size
108 |
109 | """
110 | super(CIFAR10_Encoder, self).__init__()
111 |
112 | # Init Conv layers
113 | self.conv1, self.conv2, self.conv3 = [init_conv(out_channels) for out_channels in self.output_features_sizes]
114 |
115 | # Init BN layers
116 | self.bnd1, self.bnd2, self.bnd3 = [init_bn(num_features) for num_features in self.output_features_sizes]
117 |
118 | # Init all other layers
119 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
120 | self.fc1 = nn.Linear(in_features=128 * 4 * 4, out_features=code_length, bias=False)
121 |
122 | def forward(self, x: torch.Tensor) -> torch.Tensor:
123 | x = self.conv1(x)
124 | x = self.pool(F.leaky_relu(self.bnd1(x)))
125 | x = self.conv2(x)
126 | x = self.pool(F.leaky_relu(self.bnd2(x)))
127 | x = self.conv3(x)
128 | x = self.pool(F.leaky_relu(self.bnd3(x)))
129 | x = self.fc1(x.view(x.size(0), -1))
130 | return x
131 |
132 |
133 | class CIFAR10_Decoder(BaseNet):
134 | """Full Decoder network.
135 |
136 | """
137 | def __init__(self, code_length: int):
138 | """Init decoder.
139 |
140 | Parameters
141 | ----------
142 | code_length : int
143 | Latent code size
144 |
145 | """
146 | super(CIFAR10_Decoder, self).__init__()
147 |
148 | self.rep_dim = code_length
149 |
150 | self.bn1d = nn.BatchNorm1d(self.rep_dim, eps=1e-04, affine=False)
151 |
152 | # Build the Decoder
153 | self.deconv1 = nn.ConvTranspose2d(int(self.rep_dim / (4 * 4)), 128, 5, bias=False, padding=2)
154 | self.deconv2, self.deconv3, self.deconv4 = [init_deconv(out_channels) for out_channels in self.output_features_sizes[::-1]]
155 |
156 | # Init BN layers
157 | self.bnd4, self.bnd5, self.bnd6 = [init_bn(num_features) for num_features in self.output_features_sizes[::-1]]
158 |
159 | def forward(self, x: torch.Tensor) -> torch.Tensor:
160 | x = self.bn1d(x)
161 | x = x.view(x.size(0), int(self.rep_dim / (4 * 4)), 4, 4)
162 | x = F.leaky_relu(x)
163 | x = self.deconv1(x)
164 | x = F.interpolate(F.leaky_relu(self.bnd4(x)), scale_factor=2)
165 | x = self.deconv2(x)
166 | x = F.interpolate(F.leaky_relu(self.bnd5(x)), scale_factor=2)
167 | x = self.deconv3(x)
168 | x = F.interpolate(F.leaky_relu(self.bnd6(x)), scale_factor=2)
169 | x = self.deconv4(x)
170 | x = torch.sigmoid(x)
171 | return x
172 |
173 |
174 | class CIFAR10_Autoencoder(BaseNet):
175 | """Full AutoEncoder network.
176 |
177 | """
178 | def __init__(self, code_length: int = 128):
179 | """Init the AutoEncoder
180 |
181 | Parameters
182 | ----------
183 | code_length : int
184 | Latent code size
185 |
186 | """
187 | super().__init__()
188 |
189 | # Build the Encoder
190 | self.encoder = CIFAR10_Encoder(code_length=code_length)
191 | self.bn1d = nn.BatchNorm1d(num_features=code_length, eps=1e-04, affine=False)
192 |
193 | # Build the Decoder
194 | self.decoder = CIFAR10_Decoder(code_length=code_length)
195 |
196 | def forward(self, x: torch.Tensor) -> torch.Tensor:
197 | z = self.encoder(x)
198 | z = self.bn1d(z)
199 | return self.decoder(z)
200 |
--------------------------------------------------------------------------------
/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | import torch
5 | from torch.utils.data import DataLoader, Subset
6 |
7 | from torchvision.datasets import CIFAR10
8 | import torchvision.transforms as transforms
9 |
10 |
11 | def get_target_label_idx(labels: np.array, targets: np.array):
12 | """Get the indices of labels that are included in targets.
13 |
14 | Parameters
15 | ----------
16 | labels : np.array
17 | Array of labels
18 | targets : np.array
19 | Array of target labels
20 |
21 | Returns
22 | ------
23 | List with indices of target labels
24 |
25 | """
26 | return np.argwhere(np.isin(labels, targets)).flatten().tolist()
27 |
28 |
29 | def global_contrast_normalization(x: torch.tensor, scale: str='l1') -> torch.Tensor:
30 | """Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale,
31 | which is either the standard deviation, L1- or L2-norm across features (pixels).
32 | Note this is a *per sample* normalization globally across features (and not across the dataset).
33 |
34 | Parameters
35 | ----------
36 | x : torch.tensor
37 | Data sample
38 | scale : str
39 | Scale
40 |
41 | Returns
42 | ------
43 | Normalized features
44 |
45 | """
46 | assert scale in ('l1', 'l2')
47 |
48 | n_features = int(np.prod(x.shape))
49 |
50 | # Evaluate the mean over all features (pixels) per sample
51 | mean = torch.mean(x)
52 | x -= mean
53 |
54 | x_scale = torch.mean(torch.abs(x)) if scale == 'l1' else torch.sqrt(torch.sum(x ** 2)) / n_features
55 |
56 | return x / x_scale
57 |
58 |
59 | class CIFAR10_DataHolder(object):
60 | """CIFAR10 data holder class
61 |
62 | """
63 | def __init__(self, root: str, normal_class=5):
64 | """Init CIFAR10 data holder class
65 |
66 | Parameters
67 | ----------
68 | root : str
69 | Path to root folder of the data
70 | normal_class :
71 | Index of the normal class
72 |
73 | """
74 | self.root = root
75 |
76 | # Total number of classes = 2, i.e., 0: normal, 1: anomalies
77 | self.n_classes = 2
78 |
79 | # Tuple containing the normal classes
80 | self.normal_classes = tuple([normal_class])
81 |
82 | # List of the anomalous classes
83 | self.anomaly_classes = list(range(0, 10))
84 | self.anomaly_classes.remove(normal_class)
85 |
86 | # Init the datasets
87 | self.__init_train_test_datasets(normal_class)
88 |
89 | def __init_train_test_datasets(self, normal_class: int) -> None:
90 | """Init the datasets.
91 |
92 | Parameters
93 | ----------
94 | normal_class : int
95 | The index of the non-anomalous class
96 |
97 | """
98 | # Pre-computed min and max values (after applying GCN) from train data per class
99 | min_max = [(-28.94083453598571, 13.802961825439636),
100 | (-6.681770233365245, 9.158067708230273),
101 | (-34.924463588638204, 14.419298165027628),
102 | (-10.599172931391799, 11.093187820377565),
103 | (-11.945022995801637, 10.628045447867583),
104 | (-9.691969487694928, 8.948326776180823),
105 | (-9.174940012342555, 13.847014686472365),
106 | (-6.876682005899029, 12.282371383343161),
107 | (-15.603507135507172, 15.2464923804279),
108 | (-6.132882973622672, 8.046098172351265)]
109 |
110 | # Define CIFAR-10 preprocessing operations
111 | # 1. GCN with L1 norm
112 | # 2. min-max feature scaling to [0,1]
113 | self.transform = transforms.Compose([
114 | transforms.ToTensor(),
115 | transforms.Lambda(lambda x: global_contrast_normalization(x, scale='l1')),
116 | transforms.Normalize(
117 | [min_max[normal_class][0]] * 3,
118 | [min_max[normal_class][1] - min_max[normal_class][0]] * 3
119 | )
120 | ])
121 |
122 | # Define CIFAR-10 preprocessing operations on the labels,
123 | # i.e., set to 0 all the labels that belong to the anomalous classes
124 | self.target_transform = transforms.Lambda(lambda x: int(x in self.anomaly_classes))
125 |
126 | # Init training set
127 | self.train_set = MyCIFAR10(
128 | root=self.root,
129 | train=True,
130 | download=True,
131 | transform=self.transform,
132 | target_transform=self.target_transform
133 | )
134 |
135 | # Subset the training set by considering normal class images only
136 | train_idx_normal = get_target_label_idx(labels=self.train_set.targets, targets=self.normal_classes)
137 | self.train_set = Subset(self.train_set, train_idx_normal)
138 |
139 | # Init test set
140 | self.test_set = MyCIFAR10(
141 | root=self.root,
142 | train=False,
143 | download=True,
144 | transform=self.transform,
145 | target_transform=self.target_transform
146 | )
147 |
148 | def get_loaders(self, batch_size: int, shuffle_train: bool=True, pin_memory: bool=False, num_workers: int = 0) -> [torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
149 | """Returns CIFAR10 dataloaders
150 |
151 | Parameters
152 | ----------
153 | batch_size : int
154 | Size of the batch to
155 | shuffle_train : bool
156 | If True, shuffles the training dataset
157 | pin_memory : bool
158 | If True, pin memeory
159 | num_workers : int
160 | Number of dataloader workers
161 |
162 | Retunrs
163 | -------
164 | loaders : DataLoader
165 | Train and test data loaders
166 |
167 | """
168 | train_loader = DataLoader(
169 | dataset=self.train_set,
170 | batch_size=batch_size,
171 | shuffle=shuffle_train,
172 | pin_memory=pin_memory,
173 | num_workers=num_workers
174 | )
175 | test_loader = DataLoader(
176 | dataset=self.test_set,
177 | batch_size=batch_size,
178 | pin_memory=pin_memory,
179 | num_workers=num_workers
180 | )
181 | return train_loader, test_loader
182 |
183 |
184 | class MyCIFAR10(CIFAR10):
185 | """Torchvision CIFAR10 class with patch of __getitem__ method to also return the index of a data sample.
186 |
187 | """
188 | def __init__(self, *args, **kwargs):
189 | super(MyCIFAR10, self).__init__(*args, **kwargs)
190 |
191 | def __getitem__(self, index):
192 | """Override the original method of the CIFAR10 class.
193 |
194 | Parameters
195 | ----------
196 | index : int
197 | Index
198 |
199 | Returns
200 | -------
201 | triple: (image, target, index) where target is the index of the target class.
202 |
203 | """
204 | img, target = self.data[index], self.targets[index]
205 |
206 | img = Image.fromarray(img)
207 |
208 | if self.transform is not None:
209 | img = self.transform(img)
210 |
211 | if self.target_transform is not None:
212 | target = self.target_transform(target)
213 |
214 | return img, target, index
215 |
--------------------------------------------------------------------------------
/datasets/shanghaitech.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from tqdm import tqdm
3 | from time import time
4 | from typing import List, Tuple
5 | from os.path import basename, isdir, join, splitext
6 |
7 | import cv2
8 | import numpy as np
9 | import skimage.io as io
10 |
11 | import torch
12 | from torchvision import transforms
13 | from skimage.transform import resize
14 | from torch.utils.data import Dataset, DataLoader
15 | from torch.utils.data.dataloader import default_collate
16 | from .shanghaitech_test import ShanghaiTechTestHandler
17 |
18 | class ShanghaiTech_DataHolder(object):
19 | """
20 | ShanghaiTech data holder class
21 |
22 | Parameters
23 | ----------
24 | root : str
25 | root folder of ShanghaiTech dataset
26 | clip_length : int
27 | number of frames that form a clip
28 | stride : int
29 | for creating a clip what should be the size of sliding window
30 | """
31 | def __init__(self, root: str, clip_length=16, stride=1):
32 | self.root = root
33 | self.clip_length = clip_length
34 | self.stride = stride
35 | self.shape = (3, clip_length, 256, 512)
36 | self.train_dir = join(root, 'training', 'nobackground_frames_resized')
37 | # Transform
38 | self.transform = transforms.Compose([ToFloatTensor3D(normalize=True)])
39 |
40 |
41 | def get_test_data(self) -> Dataset:
42 | """Load test dataset
43 |
44 | Returns
45 | -------
46 | ShanghaiTech : Dataset
47 | Custom dataset to handle ShanghaiTech data
48 |
49 | """
50 | return ShanghaiTechTestHandler(self.root)
51 |
52 | def get_train_data(self, return_dataset: bool=True):
53 | """Load train dataset
54 |
55 | Parameters
56 | ----------
57 | return_dataset : bool
58 | False for preprocessing purpose only
59 | """
60 |
61 | if return_dataset:
62 | # Load all ids
63 | self.train_ids = self.load_train_ids()
64 | # Create clips with given clip_length and stride
65 | self.train_clips = self.create_clips(self.train_dir, self.train_ids, clip_length=self.clip_length, stride=self.stride, read_target=False)
66 |
67 | return MySHANGHAI(self.train_clips, self.transform, clip_length=self.clip_length)
68 | else:
69 | return
70 |
71 | def get_loaders(self, batch_size: int, shuffle_train: bool=True, pin_memory: bool=False, num_workers: int = 0) -> [DataLoader, DataLoader]:
72 | """Returns MVtec dataloaders
73 |
74 | Parameters
75 | ----------
76 | batch_size : int
77 | Size of the batch to
78 | shuffle_train : bool
79 | If True, shuffles the training dataset
80 | pin_memory : bool
81 | If True, pin memeory
82 | num_workers : int
83 | Number of dataloader workers
84 |
85 | Returns
86 | -------
87 | loaders : DataLoader
88 | Train and test data loaders
89 |
90 | """
91 | train_loader = DataLoader(
92 | dataset=self.get_train_data(return_dataset=True),
93 | batch_size=batch_size,
94 | shuffle=shuffle_train,
95 | pin_memory=pin_memory,
96 | num_workers=num_workers
97 | )
98 | test_loader = DataLoader(
99 | dataset=self.get_test_data(),
100 | batch_size=batch_size,
101 | pin_memory=pin_memory,
102 | num_workers=num_workers
103 | )
104 | return train_loader, test_loader
105 |
106 | def load_train_ids(self):
107 | # type: () -> List[str]
108 | """
109 | Loads the set of all train video ids.
110 | :return: The list of train ids.
111 | """
112 | return sorted([basename(d) for d in glob(join(self.train_dir, '**')) if isdir(d)])
113 |
114 | def create_clips(self, dir_path, ids, clip_length=16, stride=1, read_target=False):
115 | # type: (str, int, int, bool)
116 | """
117 | Gets frame directory and ids of the directories in the frame dir
118 | Creates clips which consist of number of clip_length at each clip.
119 | Clips are created in a sliding window fashion. Default window slide is 1
120 | but stride controls the window slide
121 | Example: for default parameters first clip is [001.jpg, 002.jpg, ...,016.jpg]
122 | second clip would be [002.jpg, 003.jpg, ..., 017.jpg]
123 | If read_target is True then it will try to read from test directory
124 | If read_target is False then it will populate the array with all zeros
125 | :return: clips:: numpy array with (num_clips,clip_length) shape
126 | ground_truths:: numpy array with (num_clips,clip_length) shape
127 | """
128 | clips = []
129 | print(f"Creating clips for {dir_path} dataset with length {clip_length}...")
130 | for idx in tqdm(ids):
131 | frames = sorted([x for x in glob(join(dir_path, idx, "*.jpg"))])
132 | num_frames = len(frames)
133 | # Slide the window with stride to collect clips
134 | for window in range(0, num_frames-clip_length+1, stride):
135 | clips.append(frames[window:window+clip_length])
136 | return np.array(clips)
137 |
138 | class MySHANGHAI(Dataset):
139 | def __init__(self, clips, transform=None, clip_length=16):
140 | self.clips = clips
141 | self.transform = transform
142 | self.shape = (3, clip_length, 256, 512)
143 |
144 | def __len__(self):
145 | return 10000 # len(self.clips)
146 |
147 | def __getitem__(self, index):
148 | """
149 | Args:
150 | index (int): Index
151 | Returns:
152 | triple: (image, target, index) where target is index of the target class.
153 | targets are all 0 target
154 | """
155 | index_ = torch.randint(0, len(self.clips), size=(1,)).item()
156 | sample = np.stack([np.uint8(io.imread(img_path)) for img_path in self.clips[index_]])
157 | sample = self.transform(sample) if self.transform else sample
158 | return sample, index_
159 |
160 | from scipy.ndimage.morphology import binary_dilation
161 |
162 |
163 | def get_target_label_idx(labels, targets):
164 | """
165 | Get the indices of labels that are included in targets.
166 | :param labels: array of labels
167 | :param targets: list/tuple of target labels
168 | :return: list with indices of target labels
169 | """
170 | return np.argwhere(np.isin(labels, targets)).flatten().tolist()
171 |
172 |
173 | def global_contrast_normalization(x: torch.tensor, scale='l2'):
174 | """
175 | Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale,
176 | which is either the standard deviation, L1- or L2-norm across features (pixels).
177 | Note this is a *per sample* normalization globally across features (and not across the dataset).
178 | """
179 |
180 | assert scale in ('l1', 'l2')
181 |
182 | n_features = int(np.prod(x.shape))
183 |
184 | mean = torch.mean(x) # mean over all features (pixels) per sample
185 | x -= mean
186 |
187 | if scale == 'l1':
188 | x_scale = torch.mean(torch.abs(x))
189 |
190 | if scale == 'l2':
191 | x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features
192 |
193 | x /= x_scale
194 |
195 | return x
196 |
197 | class ToFloatTensor3D(object):
198 | """ Convert videos to FloatTensors """
199 | def __init__(self, normalize=True):
200 | self._normalize = normalize
201 |
202 | def __call__(self, sample):
203 | if len(sample) == 3:
204 | X, Y, _ = sample
205 | else:
206 | X = sample
207 |
208 | # swap color axis because
209 | # numpy image: T x H x W x C
210 | X = X.transpose(3, 0, 1, 2)
211 | #Y = Y.transpose(3, 0, 1, 2)
212 |
213 | if self._normalize:
214 | X = X / 255.
215 |
216 | X = np.float32(X)
217 | return torch.from_numpy(X)
218 |
219 | class ToFloatTensor3DMask(object):
220 | """ Convert videos to FloatTensors """
221 | def __init__(self, normalize=True, has_x_mask=True, has_y_mask=True):
222 | self._normalize = normalize
223 | self.has_x_mask = has_x_mask
224 | self.has_y_mask = has_y_mask
225 |
226 | def __call__(self, sample):
227 | X = sample
228 | # swap color axis because
229 | # numpy image: T x H x W x C
230 | X = X.transpose(3, 0, 1, 2)
231 |
232 | X = np.float32(X)
233 |
234 | if self._normalize:
235 | if self.has_x_mask:
236 | X[:-1] = X[:-1] / 255.
237 | else:
238 | X = X / 255.
239 |
240 | return torch.from_numpy(X)
241 |
242 |
243 | class RemoveBackground:
244 |
245 | def __init__(self, threshold: float):
246 | self.threshold = threshold
247 |
248 | def __call__(self, sample: tuple) -> tuple:
249 | X, Y, background = sample
250 |
251 | mask = np.uint8(np.sum(np.abs(np.int32(X) - background), axis=-1) > self.threshold)
252 | mask = np.expand_dims(mask, axis=-1)
253 |
254 | mask = np.stack([binary_dilation(mask_frame, iterations=5) for mask_frame in mask])
255 |
256 | X *= mask
257 |
258 | return X, Y, background
--------------------------------------------------------------------------------
/models/mvtec_base_model.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from operator import mul
3 | from typing import Optional
4 | from functools import reduce
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import Module
9 |
10 |
11 | class BaseModule(nn.Module):
12 | """
13 | Implements the basic module.
14 | All other modules inherit from this one
15 | """
16 | def load_w(self, checkpoint_path):
17 | # type: (str) -> None
18 | """
19 | Loads a checkpoint into the state_dict.
20 | :param checkpoint_path: the checkpoint file to be loaded.
21 | """
22 | self.load_state_dict(torch.load(checkpoint_path))
23 |
24 | def __repr__(self):
25 | # type: () -> str
26 | """
27 | String representation
28 | """
29 | good_old = super(BaseModule, self).__repr__()
30 | addition = 'Total number of parameters: {:,}'.format(self.n_parameters)
31 |
32 | return good_old + '\n' + addition
33 |
34 | def __call__(self, *args, **kwargs):
35 | return super(BaseModule, self).__call__(*args, **kwargs)
36 |
37 | @property
38 | def n_parameters(self):
39 | # type: () -> int
40 | """
41 | Number of parameters of the model.
42 | """
43 | n_parameters = 0
44 | for p in self.parameters():
45 | if hasattr(p, 'mask'):
46 | n_parameters += torch.sum(p.mask).item()
47 | else:
48 | n_parameters += reduce(mul, p.shape)
49 | return int(n_parameters)
50 |
51 |
52 | def residual_op(x, functions, bns, activation_fn):
53 | # type: (torch.Tensor, List[Module, Module, Module], List[Module, Module, Module], Module) -> torch.Tensor
54 | """
55 | Implements a global residual operation.
56 | :param x: the input tensor.
57 | :param functions: a list of functions (nn.Modules).
58 | :param bns: a list of optional batch-norm layers.
59 | :param activation_fn: the activation to be applied.
60 | :return: the output of the residual operation.
61 | """
62 | f1, f2, f3 = functions
63 | bn1, bn2, bn3 = bns
64 |
65 | assert len(functions) == len(bns) == 3
66 | assert f1 is not None and f2 is not None
67 | assert not (f3 is None and bn3 is not None)
68 |
69 | # A-branch
70 | ha = x
71 | ha = f1(ha)
72 | if bn1 is not None:
73 | ha = bn1(ha)
74 | ha = activation_fn(ha)
75 |
76 | ha = f2(ha)
77 | if bn2 is not None:
78 | ha = bn2(ha)
79 |
80 | # B-branch
81 | hb = x
82 | if f3 is not None:
83 | hb = f3(hb)
84 | if bn3 is not None:
85 | hb = bn3(hb)
86 |
87 | # Residual connection
88 | out = ha + hb
89 | return activation_fn(out)
90 |
91 |
92 | class BaseBlock(BaseModule):
93 | """ Base class for all blocks. """
94 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=False):
95 | # type: (int, int, Module, bool, bool) -> None
96 | """
97 | Class constructor.
98 | :param channel_in: number of input channels.
99 | :param channel_out: number of output channels.
100 | :param activation_fn: activation to be employed.
101 | :param use_bn: whether or not to use batch-norm.
102 | :param use_bias: whether or not to use bias.
103 | """
104 | super(BaseBlock, self).__init__()
105 |
106 | assert not (use_bn and use_bias), 'Using bias=True with batch_normalization is forbidden.'
107 |
108 | self._channel_in = channel_in
109 | self._channel_out = channel_out
110 | self._activation_fn = activation_fn
111 | self._use_bn = use_bn
112 | self._bias = use_bias
113 |
114 | def get_bn(self):
115 | # type: () -> Optional[Module]
116 | """
117 | Returns batch norm layers, if needed.
118 | :return: batch norm layers or None
119 | """
120 | return nn.BatchNorm2d(num_features=self._channel_out) if self._use_bn else None
121 |
122 | def forward(self, x):
123 | """
124 | Abstract forward function. Not implemented.
125 | """
126 | raise NotImplementedError
127 |
128 |
129 | class DownsampleBlock(BaseBlock):
130 | """ Implements a Downsampling block for images (Fig. 1ii). """
131 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=False):
132 | # type: (int, int, Module, bool, bool) -> None
133 | """
134 | Class constructor.
135 | :param channel_in: number of input channels.
136 | :param channel_out: number of output channels.
137 | :param activation_fn: activation to be employed.
138 | :param use_bn: whether or not to use batch-norm.
139 | :param use_bias: whether or not to use bias.
140 | """
141 | super(DownsampleBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias)
142 |
143 | # Convolutions
144 | self.conv1a = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=3,
145 | padding=1, stride=2, bias=use_bias)
146 | self.conv1b = nn.Conv2d(in_channels=channel_out, out_channels=channel_out, kernel_size=3,
147 | padding=1, stride=1, bias=use_bias)
148 | self.conv2a = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=1,
149 | padding=0, stride=2, bias=use_bias)
150 |
151 | # Batch Normalization layers
152 | self.bn1a = self.get_bn()
153 | self.bn1b = self.get_bn()
154 | self.bn2a = self.get_bn()
155 |
156 | def forward(self, x):
157 | # type: (torch.Tensor) -> torch.Tensor
158 | """
159 | Forward propagation.
160 | :param x: the input tensor
161 | :return: the output tensor
162 | """
163 | return residual_op(
164 | x,
165 | functions=[self.conv1a, self.conv1b, self.conv2a],
166 | bns=[self.bn1a, self.bn1b, self.bn2a],
167 | activation_fn=self._activation_fn
168 | )
169 |
170 |
171 | class UpsampleBlock(BaseBlock):
172 | """ Implements a Upsampling block for images (Fig. 1ii). """
173 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=False):
174 | # type: (int, int, Module, bool, bool) -> None
175 | """
176 | Class constructor.
177 | :param channel_in: number of input channels.
178 | :param channel_out: number of output channels.
179 | :param activation_fn: activation to be employed.
180 | :param use_bn: whether or not to use batch-norm.
181 | :param use_bias: whether or not to use bias.
182 | """
183 | super(UpsampleBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias)
184 |
185 | # Convolutions
186 | self.conv1a = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=5,
187 | padding=2, stride=2, output_padding=1, bias=use_bias)
188 | self.conv1b = nn.Conv2d(in_channels=channel_out, out_channels=channel_out, kernel_size=3,
189 | padding=1, stride=1, bias=use_bias)
190 | self.conv2a = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=1,
191 | padding=0, stride=2, output_padding=1, bias=use_bias)
192 |
193 | # Batch Normalization layers
194 | self.bn1a = self.get_bn()
195 | self.bn1b = self.get_bn()
196 | self.bn2a = self.get_bn()
197 |
198 | def forward(self, x):
199 | # type: (torch.Tensor) -> torch.Tensor
200 | """
201 | Forward propagation.
202 | :param x: the input tensor
203 | :return: the output tensor
204 | """
205 | return residual_op(
206 | x,
207 | functions=[self.conv1a, self.conv1b, self.conv2a],
208 | bns=[self.bn1a, self.bn1b, self.bn2a],
209 | activation_fn=self._activation_fn
210 | )
211 |
212 |
213 | class ResidualBlock(BaseBlock):
214 | """ Implements a Residual block for images (Fig. 1ii). """
215 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=False):
216 | # type: (int, int, Module, bool, bool) -> None
217 | """
218 | Class constructor.
219 | :param channel_in: number of input channels.
220 | :param channel_out: number of output channels.
221 | :param activation_fn: activation to be employed.
222 | :param use_bn: whether or not to use batch-norm.
223 | :param use_bias: whether or not to use bias.
224 | """
225 | super(ResidualBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias)
226 |
227 | # Convolutions
228 | self.conv1 = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=3,
229 | padding=1, stride=1, bias=use_bias)
230 | self.conv2 = nn.Conv2d(in_channels=channel_out, out_channels=channel_out, kernel_size=3,
231 | padding=1, stride=1, bias=use_bias)
232 |
233 | # Batch Normalization layers
234 | self.bn1 = self.get_bn()
235 | self.bn2 = self.get_bn()
236 |
237 | def forward(self, x):
238 | # type: (torch.Tensor) -> torch.Tensor
239 | """
240 | Forward propagation.
241 | :param x: the input tensor
242 | :return: the output tensor
243 | """
244 | return residual_op(
245 | x,
246 | functions=[self.conv1, self.conv2, None],
247 | bns=[self.bn1, self.bn2, None],
248 | activation_fn=self._activation_fn
249 | )
--------------------------------------------------------------------------------
/trainers/trainer_shanghaitech.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import itertools
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.nn import DataParallel
12 | from torch.optim import Adam, SGD
13 | from torch.optim.lr_scheduler import MultiStepLR
14 | from torch.utils.data.dataloader import DataLoader
15 |
16 | from sklearn.metrics import roc_auc_score
17 |
18 |
19 | def pretrain(ae_net, train_loader, out_dir, tb_writer, device, args):
20 | logger = logging.getLogger()
21 |
22 | ae_net = ae_net.train().to(device)
23 |
24 |
25 | # Set optimizer
26 | if args.optimizer == 'adam':
27 | optimizer = Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
28 | else:
29 | optimizer = SGD(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9)
30 |
31 | scheduler = MultiStepLR(optimizer, milestones=args.ae_lr_milestones, gamma=0.1)
32 |
33 | ae_epochs = 1 if args.debug else args.ae_epochs
34 | it_t = 0
35 | logger.info("Start Pretraining the autoencoder...")
36 | for epoch in range(ae_epochs):
37 |
38 | recon_loss = 0.0
39 | n_batches = 0
40 | for idx, (data, _) in enumerate(tqdm(train_loader), 1):
41 | if args.debug and idx == 2: break
42 |
43 | data = data.to(device)
44 | optimizer.zero_grad()
45 | x_r = ae_net(data)[0]
46 | recon_loss_ = torch.mean(torch.sum((x_r - data) ** 2, dim=tuple(range(1, x_r.dim()))))
47 | recon_loss_.backward()
48 | optimizer.step()
49 |
50 | recon_loss += recon_loss_.item()
51 | n_batches += 1
52 |
53 | if idx % (len(train_loader)//args.log_frequency) == 0:
54 | logger.info(f"PreTrain at epoch: {epoch+1} [{idx}]/[{len(train_loader)}] ==> Recon Loss: {recon_loss/idx:.4f}")
55 | tb_writer.add_scalar('pretrain/recon_loss', recon_loss/idx, it_t)
56 | it_t += 1
57 |
58 | scheduler.step()
59 | if epoch in args.ae_lr_milestones:
60 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))
61 |
62 | ae_net_checkpoint = os.path.join(out_dir, f'ae_ckp_epoch_{epoch}_{time.time()}.pth')
63 | torch.save({'ae_state_dict': ae_net.state_dict()}, ae_net_checkpoint)
64 |
65 | logger.info('Finished pretraining.')
66 | logger.info(f'Saved autoencoder at: {ae_net_checkpoint}')
67 |
68 | return ae_net_checkpoint
69 |
70 |
71 | def train(net, train_loader, out_dir, tb_writer, device, ae_net_checkpoint, args):
72 | logger = logging.getLogger()
73 |
74 | idx_list_enc = {int(i): 1 for i in args.idx_list_enc}
75 |
76 | # Set device for network
77 | net = net.to(device)
78 |
79 | # Set optimizer
80 | if args.optimizer == 'adam':
81 | optimizer = Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
82 | else:
83 | optimizer = SGD(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9)
84 |
85 | # Set learning rate scheduler
86 | scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1)
87 |
88 | # Initialize hypersphere center c
89 | logger.info('Evaluating hypersphere centers...')
90 | c, keys = init_center_c(train_loader, net, idx_list_enc, device, args.end_to_end_training, args.debug)
91 | logger.info(f'Keys: {keys}')
92 | logger.info('Done!')
93 |
94 | R = {k: torch.tensor(0.0, device=device) for k in keys}
95 |
96 | # Training
97 | logger.info('Starting training...')
98 | warm_up_n_epochs = args.warm_up_n_epochs
99 | net.train()
100 | it_t = 0
101 |
102 | best_loss = 1e12
103 | epochs = 1 if args.debug else args.epochs
104 | for epoch in range(epochs):
105 | one_class_loss = 0.0
106 | recon_loss = 0.0
107 | objective_loss = 0.0
108 | n_batches = 0
109 | d_from_c = {k: 0 for k in keys}
110 | epoch_start_time = time.time()
111 |
112 | for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), desc=f"Training epoch: {epoch+1}"), 1):
113 | if args.debug and idx == 2: break
114 |
115 | n_batches += 1
116 | data = data.to(device)
117 |
118 | # Zero the network parameter gradients
119 | optimizer.zero_grad()
120 |
121 | # Update network parameters via backpropagation: forward + backward + optimize
122 | if args.end_to_end_training:
123 | x_r, _, d_lstms = net(data)
124 | recon_loss_ = torch.mean(torch.sum((x_r - data) ** 2, dim=tuple(range(1, x_r.dim()))))
125 | else:
126 | _, d_lstms = net(data)
127 | recon_loss_ = torch.tensor([0.0], device=device)
128 |
129 | dist, one_class_loss_ = eval_ad_loss(d_lstms, c, R, args.nu, args.boundary)
130 | objective_loss_ = one_class_loss_ + recon_loss_
131 |
132 | for k in keys:
133 | d_from_c[k] += torch.mean(dist[k]).item()
134 |
135 | objective_loss_.backward()
136 | optimizer.step()
137 |
138 | one_class_loss += one_class_loss_.item()
139 | recon_loss += recon_loss_.item()
140 | objective_loss += objective_loss_.item()
141 |
142 | if idx % (len(train_loader)//args.log_frequency) == 0:
143 | logger.info(
144 | f"TRAIN at epoch: {epoch} [{idx}]/[{len(train_loader)}] ==> "
145 | f"\n\t\t\t\tReconstr Loss : {recon_loss/n_batches:.4f}"
146 | f"\n\t\t\t\tOne class Loss: {one_class_loss/n_batches:.4f}"
147 | f"\n\t\t\t\tObjective Loss: {objective_loss/n_batches:.4f}"
148 | )
149 | tb_writer.add_scalar('train/recon_loss', recon_loss/n_batches, it_t)
150 | tb_writer.add_scalar('train/one_class_loss', one_class_loss/n_batches, it_t)
151 | tb_writer.add_scalar('train/objective_loss', objective_loss/n_batches, it_t)
152 | for k in keys:
153 | logger.info(
154 | f"[{k}] -- Radius: {R[k]:.4f} - "
155 | f"Dist from sphere centr: {d_from_c[k]/n_batches:.4f}"
156 | )
157 | tb_writer.add_scalar(f'train/radius_{k}', R[k], it_t)
158 | tb_writer.add_scalar(f'train/distance_c_sphere_{k}', d_from_c[k]/n_batches, it_t)
159 | it_t += 1
160 |
161 | # Update hypersphere radius R on mini-batch distances
162 | if (args.boundary == 'soft') and (epoch >= warm_up_n_epochs):
163 | for k in R.keys():
164 | R[k].data = torch.tensor(
165 | np.quantile(np.sqrt(dist[k].clone().data.cpu().numpy()), 1 - args.nu),
166 | device=device
167 | )
168 |
169 | scheduler.step()
170 | if epoch in args.lr_milestones:
171 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))
172 |
173 | time_ = time.time() if ae_net_checkpoint is None else ae_net_checkpoint.split('_')[-1].split('.p')[0]
174 | net_checkpoint = os.path.join(out_dir, f'net_ckp_{epoch}_{time_}.pth')
175 | torch.save({
176 | 'net_state_dict': net.state_dict(),
177 | 'R': R,
178 | 'c': c
179 | },
180 | net_checkpoint
181 | )
182 | logger.info(f'Saved model at: {net_checkpoint}')
183 | if objective_loss < best_loss or epoch==0:
184 | best_loss = objective_loss
185 | best_model_checkpoint = os.path.join(out_dir, f'net_ckp_best_model_{time_}.pth')
186 | torch.save({
187 | 'net_state_dict': net.state_dict(),
188 | 'R': R,
189 | 'c': c
190 | },
191 | best_model_checkpoint
192 | )
193 |
194 | logger.info('Finished training.')
195 |
196 | return best_model_checkpoint #net_checkpoint
197 |
198 |
199 | @torch.no_grad()
200 | def init_center_c(train_loader, net, idx_list_enc, device, end_to_end_training, debug, eps=0.1):
201 | """Initialize hypersphere center c as the mean from an initial forward pass on the data."""
202 | n_samples = 0
203 | net.eval()
204 |
205 | data, _ = iter(train_loader).next()
206 | d_lstms = net(data.to(device))[-1]
207 |
208 | keys = []
209 | c = {}
210 | for en, k in enumerate(list(d_lstms.keys())):
211 | if en in idx_list_enc:
212 | keys.append(k)
213 | c[k] = torch.zeros_like(d_lstms[k][-1], device=device)
214 |
215 | for idx, (data, _) in enumerate(tqdm(train_loader, desc='init hyperspheres centeres', total=len(train_loader), leave=False)):
216 | if debug and idx == 2: break
217 | # get the inputs of the batch
218 | n_samples += data.shape[0]
219 | d_lstms = net(data.to(device))[-1]
220 | for k in keys:
221 | c[k] += torch.sum(d_lstms[k], dim=0)
222 |
223 | for k in keys:
224 | c[k] = c[k] / n_samples
225 | # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
226 | c[k][(abs(c[k]) < eps) & (c[k] < 0)] = -eps
227 | c[k][(abs(c[k]) < eps) & (c[k] > 0)] = eps
228 |
229 | return c, keys
230 |
231 |
232 | def eval_ad_loss(d_lstms, c, R, nu, boundary):
233 | dist = {}
234 | loss = 1
235 |
236 | for k in c.keys():
237 | dist[k] = torch.sum((d_lstms[k] - c[k].unsqueeze(0)) ** 2, dim=-1)
238 |
239 | if boundary == 'soft':
240 | scores = dist[k] - R[k] ** 2
241 | loss += R[k] ** 2 + (1 / nu) * torch.mean(torch.max(torch.zeros_like(scores), scores))
242 | else:
243 | loss += torch.mean(dist[k])
244 |
245 | return dist, loss
246 |
--------------------------------------------------------------------------------
/datasets/mvtec.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import random
5 | import numpy as np
6 | from tqdm import tqdm
7 | from PIL import Image
8 | from os.path import join
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils.data import Dataset, DataLoader, TensorDataset
13 |
14 | import torchvision.transforms as T
15 | from torchvision.datasets import ImageFolder
16 |
17 |
18 | class MVtecDataset(ImageFolder):
19 | """Torchvision ImageFolder class with patch of __getitem__ method to targets according to the task.
20 |
21 | """
22 | def __init__(self, root: str, transform):
23 | super(MVtecDataset, self).__init__(root=root, transform=transform)
24 |
25 | # Index of the class that corresponds to the folder named 'good'
26 | self.normal_class_idx = self.class_to_idx['good']
27 |
28 | def __getitem__(self, index: int):
29 | data, target = self.samples[index]
30 | def read_image(path):
31 | """Returns the image in RGB
32 |
33 | """
34 | with open(path, 'rb') as f:
35 | img = Image.open(f)
36 | return img.convert('RGB')
37 |
38 | # Convert the target to the 0/1 case
39 | target = 0 if target == self.normal_class_idx else 1
40 | data = self.transform(read_image(data))
41 |
42 | return data, target
43 |
44 |
45 | class CustomTensorDataset(TensorDataset):
46 | """Custom dataset for preprocessed images.
47 |
48 | """
49 | def __init__(self, root: str):
50 | """Init the dataset.
51 |
52 | Parameters
53 | ----------
54 | root : str
55 | Path to data file
56 |
57 | """
58 | # Load data
59 | self.data = torch.from_numpy(np.load(root))
60 |
61 | # Load TensorDataset
62 | super(CustomTensorDataset, self).__init__(self.data)
63 |
64 | def __len__(self):
65 | return self.data.shape[0]
66 |
67 | def __getitem__(self, index):
68 | return self.data[index], 0
69 |
70 |
71 | class MVTec_DataHolder(object):
72 | """MVTec data holder class
73 |
74 | """
75 | def __init__(self, data_path: str, category: str, image_size: int, patch_size: int, rotation_range: tuple, is_texture: bool):
76 | """Init MVTec data holder class
77 |
78 | Parameters
79 | ----------
80 | category : str
81 | Normal class
82 | image_size : int
83 | Side size of the input images
84 | patch_size : int
85 | Side size of the patches (for textures only)
86 | rotation_range : tuple
87 | Min and max angle to rotate images
88 | is_texture : bool
89 | True if the category is texture-type class
90 |
91 | """
92 | self.data_path = data_path
93 | self.category = category
94 | self.image_size = image_size
95 | self.patch_size = patch_size
96 | self.rotation_range = rotation_range
97 | self.is_texture = is_texture
98 |
99 | def get_test_data(self) -> Dataset:
100 | """Load test dataset
101 |
102 | Returns
103 | -------
104 | MVtecDataset : Dataset
105 | Custom dataset to handle MVTec data
106 |
107 | """
108 | return MVtecDataset(
109 | root=join(self.data_path, f'{self.category}/test'),
110 | transform=T.Compose([
111 | T.Resize(self.image_size, interpolation=Image.BILINEAR),
112 | T.ToTensor(),
113 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
114 | ])
115 | )
116 |
117 | def get_train_data(self, return_dataset: bool=True):
118 | """Load train dataset
119 |
120 | Parameters
121 | ----------
122 | return_dataset : bool
123 | False for preprocessing purpose only
124 |
125 | """
126 | train_data_dir = join(self.data_path, f'{self.category}/train/')
127 |
128 | # Preprocessed output data path
129 | cache_main_dir = join(self.data_path, f'processed/{self.category}')
130 | os.makedirs(cache_main_dir, exist_ok=True)
131 | cache_file = f'{cache_main_dir}/{self.category}_train_dataset_i-{self.image_size}_p-{self.patch_size}_r-{self.rotation_range[0]}--{self.rotation_range[1]}.npy'
132 |
133 | # Check if preprocessed file already exists
134 | if not os.path.exists(cache_file):
135 |
136 | # Apply random rotation
137 | def augmentation():
138 | """Returns transforms to apply to the data
139 |
140 | """
141 | # For textures rotate and crop without edges
142 | if self.is_texture:
143 | return T.Compose([
144 | T.Resize(self.image_size, interpolation=Image.BILINEAR),
145 | T.Pad(padding=self.image_size//4, padding_mode="reflect"),
146 | T.RandomRotation((self.rotation_range[0], self.rotation_range[1])),
147 | T.CenterCrop(self.image_size),
148 | T.RandomCrop(self.patch_size),
149 | T.ToTensor(),
150 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
151 | ])
152 | else:
153 | return T.Compose([
154 | T.Resize(self.image_size, interpolation=Image.BILINEAR),
155 | T.Pad(padding=self.image_size//4, padding_mode="reflect"),
156 | T.RandomRotation((self.rotation_range[0], self.rotation_range[1])),
157 | T.CenterCrop(self.image_size),
158 | T.ToTensor(),
159 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
160 | ])
161 |
162 | # Load data and apply transformations
163 | train_dataset = ImageFolder(root=train_data_dir, transform=augmentation())
164 | print(f"Creating cache for dataset: \n{cache_file}")
165 | # To simulate a larger datasegt, replicate images with transformations
166 | nb_epochs = 50000 // len(train_dataset.imgs)
167 | data_loader = DataLoader(dataset=train_dataset, batch_size=1024, pin_memory=True)
168 |
169 | for epoch in tqdm(range(nb_epochs), total=nb_epochs, desc=f"Creating cache for: {self.category}"):
170 | if epoch == 0:
171 | cache_np = [x.numpy() for x, _ in tqdm(data_loader, total=len(data_loader), desc=f'Caching epoch: {epoch+1}/{nb_epochs+1}', leave=False)]
172 | else:
173 | cache_np.extend([x.numpy() for x, _ in tqdm(data_loader, total=len(data_loader), desc=f'Caching epoch: {epoch+1}/{nb_epochs+1}', leave=False)])
174 |
175 | cache_np = np.vstack(cache_np)
176 | np.save(cache_file, cache_np)
177 | print(f"Preprocessed images has been saved at: \n{cache_file}")
178 |
179 | if return_dataset:
180 | print(f"Loading dataset from cache: \n{cache_file}")
181 | return CustomTensorDataset(cache_file)
182 | else:
183 | return
184 |
185 | def get_loaders(self, batch_size: int, shuffle_train: bool=True, pin_memory: bool=False, num_workers: int = 0) -> [DataLoader, DataLoader]:
186 | """Returns MVtec dataloaders
187 |
188 | Parameters
189 | ----------
190 | batch_size : int
191 | Size of the batch to
192 | shuffle_train : bool
193 | If True, shuffles the training dataset
194 | pin_memory : bool
195 | If True, pin memeory
196 | num_workers : int
197 | Number of dataloader workers
198 |
199 | Returns
200 | -------
201 | loaders : DataLoader
202 | Train and test data loaders
203 |
204 | """
205 | train_loader = DataLoader(
206 | dataset=self.get_train_data(return_dataset=True),
207 | batch_size=batch_size,
208 | shuffle=shuffle_train,
209 | pin_memory=pin_memory,
210 | num_workers=num_workers
211 | )
212 | test_loader = DataLoader(
213 | dataset=self.get_test_data(),
214 | batch_size=batch_size,
215 | pin_memory=pin_memory,
216 | num_workers=num_workers
217 | )
218 | return train_loader, test_loader
219 |
220 |
221 | if __name__ == '__main__':
222 | """To speed up the train phase we can preprocess the training images and save them as numpy array.
223 |
224 | """
225 | textures = tuple(['carpet', 'grid', 'leather', 'tile', 'wood'])
226 | objects_1 = tuple(['bottle', 'hazelnut', 'metal_nut', 'screw'])
227 | objects_2 = tuple(['capsule', 'toothbrush', 'cable', 'pill', 'transistor', 'zipper'])
228 |
229 | classes = list(textures)
230 | classes.extends(list(objects_1))
231 | classes.extends(list(objects_2))
232 |
233 | for category in classes:
234 | if category in textures:
235 | args = dict(
236 | category=category,
237 | image_size=512,
238 | patch_size=64,
239 | rotation_range=(0, 45),
240 | texture=True
241 | )
242 | elif category in objects_1:
243 | args = dict(
244 | category=category,
245 | image_size=128,
246 | patch_size=-1,
247 | rotation_range=(-45, 45),
248 | texture=True
249 | )
250 | else:
251 | args = dict(
252 | category=category,
253 | image_size=128,
254 | patch_size=-1,
255 | rotation_range=(0, 0),
256 | texture=False
257 | )
258 |
259 | MVTec_DataHolder(*args).get_train_data(return_dataset=False)
260 |
--------------------------------------------------------------------------------
/models/shanghaitech_base_model.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | from operator import mul
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | class BaseModule(nn.Module):
8 | """
9 | Implements the basic module.
10 | All other modules inherit from this one
11 | """
12 | def load_w(self, checkpoint_path):
13 | # type: (str) -> None
14 | """
15 | Loads a checkpoint into the state_dict.
16 | :param checkpoint_path: the checkpoint file to be loaded.
17 | """
18 | self.load_state_dict(torch.load(checkpoint_path))
19 |
20 | def __repr__(self):
21 | # type: () -> str
22 | """
23 | String representation
24 | """
25 | good_old = super(BaseModule, self).__repr__()
26 | addition = 'Total number of parameters: {:,}'.format(self.n_parameters)
27 |
28 | # return good_old + '\n' + addition
29 | return good_old
30 |
31 | def __call__(self, *args, **kwargs):
32 | return super(BaseModule, self).__call__(*args, **kwargs)
33 |
34 | @property
35 | def n_parameters(self):
36 | # type: () -> int
37 | """
38 | Number of parameters of the model.
39 | """
40 | n_parameters = 0
41 | for p in self.parameters():
42 | if hasattr(p, 'mask'):
43 | n_parameters += torch.sum(p.mask).item()
44 | else:
45 | n_parameters += reduce(mul, p.shape)
46 | return int(n_parameters)
47 |
48 | class MaskedConv3d(BaseModule, nn.Conv3d):
49 | """
50 | Implements a Masked Convolution 3D.
51 | This is a 3D Convolution that cannot access future frames.
52 | """
53 | def __init__(self, *args, **kwargs):
54 | super(MaskedConv3d, self).__init__(*args, **kwargs)
55 |
56 | self.register_buffer('mask', self.weight.data.clone())
57 | _, _, kT, kH, kW = self.weight.size()
58 | self.mask.fill_(1)
59 | self.mask[:, :, kT // 2 + 1:] = 0
60 |
61 | def forward(self, x):
62 | # type: (torch.Tensor) -> torch.Tensor
63 | """
64 | Performs the forward pass.
65 | :param x: the input tensor.
66 | :return: the output tensor as result of the convolution.
67 | """
68 | self.weight.data *= self.mask
69 | return super(MaskedConv3d, self).forward(x)
70 |
71 | class TemporallySharedFullyConnection(BaseModule):
72 | """
73 | Implements a temporally-shared fully connection.
74 | Processes a time series of feature vectors and performs
75 | the same linear projection to all of them.
76 | """
77 | def __init__(self, in_features, out_features, bias=True):
78 | # type: (int, int, bool) -> None
79 | """
80 | Class constructor.
81 | :param in_features: number of input features.
82 | :param out_features: number of output features.
83 | :param bias: whether or not to add bias.
84 | """
85 | super(TemporallySharedFullyConnection, self).__init__()
86 |
87 | self.in_features = in_features
88 | self.out_features = out_features
89 | self.bias = bias
90 |
91 | # the layer to be applied at each timestep
92 | self.linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
93 |
94 | def forward(self, x):
95 | # type: (torch.Tensor) -> torch.Tensor
96 | """
97 | Forward function.
98 | :param x: layer input. Has shape=(batchsize, seq_len, in_features).
99 | :return: layer output. Has shape=(batchsize, seq_len, out_features)
100 | """
101 | b, t, d = x.size()
102 |
103 | output = []
104 | for i in range(0, t):
105 | # apply dense layer
106 | output.append(self.linear(x[:, i, :]))
107 | output = torch.stack(output, 1)
108 |
109 | return output
110 |
111 | def residual_op(x, functions, bns, activation_fn):
112 | # type: (torch.Tensor, List[Module, Module, Module], List[Module, Module, Module], Module) -> torch.Tensor
113 | """
114 | Implements a global residual operation.
115 | :param x: the input tensor.
116 | :param functions: a list of functions (nn.Modules).
117 | :param bns: a list of optional batch-norm layers.
118 | :param activation_fn: the activation to be applied.
119 | :return: the output of the residual operation.
120 | """
121 | f1, f2, f3 = functions
122 | bn1, bn2, bn3 = bns
123 |
124 | assert len(functions) == len(bns) == 3
125 | assert f1 is not None and f2 is not None
126 | assert not (f3 is None and bn3 is not None)
127 |
128 | # A-branch
129 | ha = x
130 | ha = f1(ha)
131 | if bn1 is not None:
132 | ha = bn1(ha)
133 | ha = activation_fn(ha)
134 |
135 | ha = f2(ha)
136 | if bn2 is not None:
137 | ha = bn2(ha)
138 |
139 | # B-branch
140 | hb = x
141 | if f3 is not None:
142 | hb = f3(hb)
143 | if bn3 is not None:
144 | hb = bn3(hb)
145 |
146 | # Residual connection
147 | out = ha + hb
148 | return activation_fn(out)
149 |
150 |
151 | class BaseBlock(BaseModule):
152 | """ Base class for all blocks. """
153 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=True):
154 | # type: (int, int, Module, bool, bool) -> None
155 | """
156 | Class constructor.
157 | :param channel_in: number of input channels.
158 | :param channel_out: number of output channels.
159 | :param activation_fn: activation to be employed.
160 | :param use_bn: whether or not to use batch-norm.
161 | :param use_bias: whether or not to use bias.
162 | """
163 | super(BaseBlock, self).__init__()
164 |
165 | assert not (use_bn and use_bias), 'Using bias=True with batch_normalization is forbidden.'
166 |
167 | self._channel_in = channel_in
168 | self._channel_out = channel_out
169 | self._activation_fn = activation_fn
170 | self._use_bn = use_bn
171 | self._bias = use_bias
172 |
173 | def get_bn(self):
174 | # type: () -> Optional[Module]
175 | """
176 | Returns batch norm layers, if needed.
177 | :return: batch norm layers or None
178 | """
179 | return nn.BatchNorm3d(num_features=self._channel_out) if self._use_bn else None
180 |
181 | def forward(self, x):
182 | """
183 | Abstract forward function. Not implemented.
184 | """
185 | raise NotImplementedError
186 |
187 |
188 | class DownsampleBlock(BaseBlock):
189 | """ Implements a Downsampling block for videos (Fig. 1ii). """
190 | def __init__(self, channel_in, channel_out, activation_fn, stride, use_bn=True, use_bias=False):
191 | # type: (int, int, Module, Tuple[int, int, int], bool, bool) -> None
192 | """
193 | Class constructor.
194 | :param channel_in: number of input channels.
195 | :param channel_out: number of output channels.
196 | :param activation_fn: activation to be employed.
197 | :param stride: the stride to be applied to downsample feature maps.
198 | :param use_bn: whether or not to use batch-norm.
199 | :param use_bias: whether or not to use bias.
200 | """
201 | super(DownsampleBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias)
202 | self.stride = stride
203 |
204 | # Convolutions
205 | self.conv1a = MaskedConv3d(in_channels=channel_in, out_channels=channel_out, kernel_size=3,
206 | padding=1, stride=stride, bias=use_bias)
207 | self.conv1b = MaskedConv3d(in_channels=channel_out, out_channels=channel_out, kernel_size=3,
208 | padding=1, stride=1, bias=use_bias)
209 | self.conv2a = nn.Conv3d(in_channels=channel_in, out_channels=channel_out, kernel_size=1,
210 | padding=0, stride=stride, bias=use_bias)
211 |
212 | # Batch Normalization layers
213 | self.bn1a = self.get_bn()
214 | self.bn1b = self.get_bn()
215 | self.bn2a = self.get_bn()
216 |
217 | def forward(self, x):
218 | # type: (torch.Tensor) -> torch.Tensor
219 | """
220 | Forward propagation.
221 | :param x: the input tensor
222 | :return: the output tensor
223 | """
224 | return residual_op(
225 | x,
226 | functions=[self.conv1a, self.conv1b, self.conv2a],
227 | bns=[self.bn1a, self.bn1b, self.bn2a],
228 | activation_fn=self._activation_fn
229 | )
230 |
231 |
232 | class UpsampleBlock(BaseBlock):
233 | """ Implements a Upsampling block for videos (Fig. 1ii). """
234 | def __init__(self, channel_in, channel_out, activation_fn, stride, output_padding, use_bn=True, use_bias=False):
235 | # type: (int, int, Module, Tuple[int, int, int], Tuple[int, int, int], bool, bool) -> None
236 | """
237 | Class constructor.
238 | :param channel_in: number of input channels.
239 | :param channel_out: number of output channels.
240 | :param activation_fn: activation to be employed.
241 | :param stride: the stride to be applied to upsample feature maps.
242 | :param output_padding: the padding to be added applied output feature maps.
243 | :param use_bn: whether or not to use batch-norm.
244 | :param use_bias: whether or not to use bias.
245 | """
246 | super(UpsampleBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias)
247 | self.stride = stride
248 | self.output_padding = output_padding
249 |
250 | # Convolutions
251 | self.conv1a = nn.ConvTranspose3d(channel_in, channel_out, kernel_size=5,
252 | padding=2, stride=stride, output_padding=output_padding, bias=use_bias)
253 | self.conv1b = nn.Conv3d(in_channels=channel_out, out_channels=channel_out, kernel_size=3,
254 | padding=1, stride=1, bias=use_bias)
255 | self.conv2a = nn.ConvTranspose3d(channel_in, channel_out, kernel_size=5,
256 | padding=2, stride=stride, output_padding=output_padding, bias=use_bias)
257 |
258 | # Batch Normalization layers
259 | self.bn1a = self.get_bn()
260 | self.bn1b = self.get_bn()
261 | self.bn2a = self.get_bn()
262 |
263 | def forward(self, x):
264 | # type: (torch.Tensor) -> torch.Tensor
265 | """
266 | Forward propagation.
267 | :param x: the input tensor
268 | :return: the output tensor
269 | """
270 | return residual_op(
271 | x,
272 | functions=[self.conv1a, self.conv1b, self.conv2a],
273 | bns=[self.bn1a, self.bn1b, self.bn2a],
274 | activation_fn=self._activation_fn
275 | )
276 |
277 |
--------------------------------------------------------------------------------
/main_cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import logging
5 | import argparse
6 | import numpy as np
7 |
8 | import torch
9 |
10 | from tensorboardX import SummaryWriter
11 |
12 | from datasets.data_manager import DataManager
13 | from trainers.train_cifar10 import pretrain, train, test
14 | from utils import set_seeds, get_out_dir, purge_ae_params
15 | from models.cifar10_model import CIFAR10_Autoencoder, CIFAR10_Encoder
16 |
17 |
18 | def main(args):
19 |
20 | # If the layer list is not specified, them use only the last layer to detect anomalies
21 | if len(args.idx_list_enc) == 0 and args.train:
22 | args.idx_list_enc = [3]
23 |
24 | ## Init logger & print training/warm-up summary
25 | logging.basicConfig(
26 | level=logging.INFO,
27 | format="%(asctime)s | %(message)s",
28 | handlers=[
29 | logging.FileHandler('./training.log'),
30 | logging.StreamHandler()
31 | ])
32 |
33 | logger = logging.getLogger()
34 |
35 | if args.train or args.pretrain:
36 | logger.info(
37 | "Start run with params:\n"
38 | f"\n\t\t\t\tPretrain model : {args.pretrain}"
39 | f"\n\t\t\t\tTrain model : {args.train}"
40 | f"\n\t\t\t\tTest model : {args.test}"
41 | f"\n\t\t\t\tBoundary : {args.boundary}"
42 | f"\n\t\t\t\tNormal class : {args.normal_class}"
43 | f"\n\t\t\t\tBatch size : {args.batch_size}\n"
44 | f"\n\t\t\t\tPretrain epochs : {args.ae_epochs}"
45 | f"\n\t\t\t\tAE-Learning rate : {args.ae_learning_rate}"
46 | f"\n\t\t\t\tAE-milestones : {args.ae_lr_milestones}"
47 | f"\n\t\t\t\tAE-Weight decay : {args.ae_weight_decay}\n"
48 | f"\n\t\t\t\tTrain epochs : {args.epochs}"
49 | f"\n\t\t\t\tLearning rate : {args.learning_rate}"
50 | f"\n\t\t\t\tMilestones : {args.lr_milestones}"
51 | f"\n\t\t\t\tWeight decay : {args.weight_decay}\n"
52 | f"\n\t\t\t\tCode length : {args.code_length}"
53 | f"\n\t\t\t\tNu : {args.nu}"
54 | f"\n\t\t\t\tEncoder list : {args.idx_list_enc}\n"
55 | )
56 | else:
57 | if args.model_ckp is None:
58 | logger.info("CANNOT TEST MODEL WITHOUT A VALID CHECKPOINT")
59 | sys.exit(0)
60 |
61 | args.normal_class = int(args.model_ckp.split('/')[-2].split('-')[2].split('_')[-1])
62 |
63 | # Set seed
64 | set_seeds(args.seed)
65 |
66 | # Get the device
67 | device = "cuda" if torch.cuda.is_available() else "cpu"
68 |
69 | # Init DataHolder class
70 | data_holder = DataManager(
71 | dataset_name='cifar10',
72 | data_path=args.data_path,
73 | normal_class=args.normal_class,
74 | only_test=args.test
75 | ).get_data_holder()
76 |
77 | # Load data
78 | train_loader, test_loader = data_holder.get_loaders(
79 | batch_size=args.batch_size,
80 | shuffle_train=True,
81 | pin_memory=device=="cuda",
82 | num_workers=args.n_workers
83 | )
84 |
85 | ### PRETRAIN the full AutoEncoder
86 | ae_net_cehckpoint = None
87 | if args.pretrain:
88 | out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='cifar10')
89 | tb_writer = SummaryWriter(os.path.join(args.output_path, 'cifar10', str(args.normal_class), 'tb_runs/pretrain', tmp))
90 |
91 | # Init AutoEncoder
92 | ae_net = CIFAR10_Autoencoder(args.code_length)
93 |
94 | # Start pretraining
95 | logging.info('Start training the full AutoEcnoder')
96 | ae_net_cehckpoint = pretrain(
97 | ae_net=ae_net,
98 | train_loader=train_loader,
99 | out_dir=out_dir,
100 | tb_writer=tb_writer,
101 | device=device,
102 | ae_learning_rate=args.ae_learning_rate,
103 | ae_weight_decay=args.ae_weight_decay,
104 | ae_lr_milestones=args.ae_lr_milestones,
105 | ae_epochs=args.ae_epochs
106 | )
107 | logging.info('AutoEncoder trained!!!')
108 |
109 | tb_writer.close()
110 |
111 | ### TRAIN the Encoder
112 | net_cehckpoint = None
113 | if args.train:
114 | if ae_net_cehckpoint is None:
115 | if args.model_ckp is None:
116 | logger.info("CANNOT TRAIN MODEL WITHOUT A VALID CHECKPOINT")
117 | sys.exit(0)
118 | ae_net_cehckpoint = args.model_ckp
119 | aelr = float(ae_net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1])
120 | out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr)
121 |
122 | tb_writer = SummaryWriter(os.path.join(args.output_path, 'cifar10', str(args.normal_class), 'tb_runs/train', tmp))
123 |
124 | # Init Encoder
125 | encoder_net = CIFAR10_Encoder(args.code_length)
126 |
127 | # Load Encoder parameters from pretrianed full AutoEncoder
128 | purge_ae_params(encoder_net=encoder_net, ae_net_cehckpoint=ae_net_cehckpoint)
129 |
130 | # Start training
131 | net_cehckpoint = train(
132 | net=encoder_net,
133 | train_loader=train_loader,
134 | out_dir=out_dir,
135 | tb_writer=tb_writer,
136 | device=device,
137 | ae_net_cehckpoint=ae_net_cehckpoint,
138 | idx_list_enc=args.idx_list_enc,
139 | learning_rate=args.learning_rate,
140 | weight_decay=args.weight_decay,
141 | lr_milestones=args.lr_milestones,
142 | epochs=args.epochs,
143 | nu=args.nu,
144 | boundary=args.boundary,
145 | debug=args.debug
146 | )
147 |
148 | tb_writer.close()
149 |
150 | ### TEST the Encoder
151 | if args.test:
152 | if net_cehckpoint is None:
153 | net_cehckpoint = args.model_ckp
154 | # Init Encoder
155 | net = CIFAR10_Encoder(args.code_length)
156 | st_dict = torch.load(net_cehckpoint)
157 | net.load_state_dict(st_dict['net_state_dict'])
158 |
159 | logger.info(f"Loaded model from: {net_cehckpoint}")
160 |
161 | if args.debug:
162 | idx_list_enc = args.idx_list_enc
163 | boundary = args.boundary
164 | else:
165 | idx_list_enc = [int(i) for i in net_cehckpoint.split('/')[-2].split('-')[-1].split('_')[-1].split('.')]
166 | boundary = net_cehckpoint.split('/')[-2].split('-')[-3].split('_')[-1]
167 |
168 | logger.info(
169 | f"Start test with params"
170 | f"\n\t\t\t\tCode length : {args.code_length}"
171 | f"\n\t\t\t\tEnc layer list : {idx_list_enc}"
172 | f"\n\t\t\t\tBoundary : {boundary}"
173 | f"\n\t\t\t\tNormal class : {args.normal_class}"
174 | )
175 |
176 | # Start test
177 | test(net=net, test_loader=test_loader, R=st_dict['R'], c=st_dict['c'], device=device, idx_list_enc=idx_list_enc, boundary=boundary)
178 |
179 |
180 | if __name__ == '__main__':
181 | parser = argparse.ArgumentParser('AD')
182 | ## General config
183 | parser.add_argument('-s', '--seed', type=int, default=-1, help='Random seed (default: -1)')
184 | parser.add_argument('--n_workers', type=int, default=8, help='Number of workers for data loading. 0 means that the data will be loaded in the main process. (default: 8)')
185 | parser.add_argument('--output_path', default='./output')
186 | ## Model config
187 | parser.add_argument('-zl', '--code-length', default=32, type=int, help='Code length (default: 32)')
188 | parser.add_argument('-ck', '--model-ckp', help='Model checkpoint')
189 | ## Optimizer config
190 | parser.add_argument('-alr', '--ae-learning-rate', type=float, default=1.e-4, help='Warm up learning rate (default: 1.e-4)')
191 | parser.add_argument('-lr', '--learning-rate', type=float, default=1.e-4, help='Learning rate (default: 1.e-4)')
192 | parser.add_argument('-awd', '--ae-weight-decay', type=float, default=0.5e-6, help='Warm up learning rate (default: 0.5e-4)')
193 | parser.add_argument('-wd', '--weight-decay', type=float, default=0.5e-6, help='Learning rate (default: 0.5e-6)')
194 | parser.add_argument('-aml', '--ae-lr-milestones', type=int, nargs='+', default=[], help='Pretrain milestone')
195 | parser.add_argument('-ml', '--lr-milestones', type=int, nargs='+', default=[], help='Training milestone')
196 | ## Data
197 | parser.add_argument('-dp', '--data-path', default='./cifar10', help='Dataset main path')
198 | parser.add_argument('-nc', '--normal-class', type=int, default=5, help='Normal Class (default: 5)')
199 | ## Training config
200 | parser.add_argument('-we', '--warm_up_n_epochs', type=int, default=10, help='Warm up epochs (default: 10)')
201 | parser.add_argument('--use-selectors', action="store_true", help='Use features selector (default: False)')
202 | parser.add_argument('-tbc', '--train-best-conf', action="store_true", help='Train best configurations (default: False)')
203 | parser.add_argument('-db', '--debug', action="store_true", help='Debug (default: False)')
204 | parser.add_argument('-bs', '--batch-size', type=int, default=256, help='Batch size (default: 256)')
205 | parser.add_argument('-bd', '--boundary', choices=("hard", "soft"), default="soft", help='Boundary (default: soft)')
206 | parser.add_argument('-ptr', '--pretrain', action="store_true", help='Pretrain model (default: False)')
207 | parser.add_argument('-tr', '--train', action="store_true", help='Train model (default: False)')
208 | parser.add_argument('-tt', '--test', action="store_true", help='Test model (default: False)')
209 | parser.add_argument('-ile', '--idx-list-enc', type=int, nargs='+', default=[], help='List of indices of model encoder')
210 | parser.add_argument('-e', '--epochs', type=int, default=1, help='Training epochs (default: 1)')
211 | parser.add_argument('-ae', '--ae-epochs', type=int, default=1, help='Warmp up epochs (default: 1)')
212 | parser.add_argument('-nu', '--nu', type=float, default=0.1)
213 | args = parser.parse_args()
214 |
215 | main(args)
216 |
--------------------------------------------------------------------------------
/models/mvtec_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from operator import mul
3 | from typing import Tuple
4 | from functools import reduce
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from .mvtec_base_model import BaseModule, DownsampleBlock, ResidualBlock, UpsampleBlock
11 |
12 |
13 | CHANNELS = [32, 64, 128]
14 |
15 |
16 | def init_conv_blocks(channel_in: int, channel_out: int, activation_fn: nn) -> nn.Module:
17 | """ Init convolutional layers.
18 |
19 | Parameters
20 | ----------
21 | k_size : int
22 | Kernel size
23 | out_channels : int
24 | Output features size
25 |
26 | """
27 | return DownsampleBlock(channel_in=channel_in, channel_out=channel_out, activation_fn=activation_fn)
28 |
29 |
30 | class Selector(nn.Module):
31 | """Selector module
32 |
33 | """
34 | def __init__(self, code_length: int, idx: int):
35 | super().__init__()
36 | """Init Selector architeture
37 |
38 | Parameters
39 | ----------
40 | code_length : int
41 | Latent code size
42 | idx : int
43 | Layer idx
44 |
45 | """
46 | # List of depths of features maps
47 | sizes = [CHANNELS[0], CHANNELS[0], CHANNELS[1], CHANNELS[2], CHANNELS[2]*2, CHANNELS[2]*2, code_length]
48 |
49 | # Hidden FC output size
50 | mid_features_size = 256
51 |
52 | # Last FC output size
53 | out_features = 128
54 |
55 | # Choose a different Selector architecture
56 | # depending on which layer it attaches
57 | if idx < 5:
58 | self.fc = nn.Sequential(
59 | nn.AdaptiveMaxPool2d(output_size=8),
60 | nn.Conv2d(in_channels=sizes[idx], out_channels=1, kernel_size=1),
61 | nn.Flatten(),
62 | nn.Linear(in_features=8**2, out_features=mid_features_size, bias=True),
63 | nn.BatchNorm1d(mid_features_size),
64 | nn.ReLU(),
65 | nn.Linear(in_features=mid_features_size, out_features=out_features, bias=True)
66 | )
67 | else:
68 | self.fc = nn.Sequential(
69 | nn.Flatten(),
70 | nn.Linear(in_features=sizes[idx], out_features=mid_features_size, bias=True),
71 | nn.BatchNorm1d(mid_features_size),
72 | nn.ReLU(),
73 | nn.Linear(in_features=mid_features_size, out_features=out_features, bias=True)
74 | )
75 |
76 | def forward(self, x: torch.Tensor) -> torch.Tensor:
77 | return self.fc(x)
78 |
79 |
80 | class MVTec_Encoder(BaseModule):
81 | """MVtec Encoder network
82 |
83 | """
84 | def __init__(self, input_shape: torch.Tensor, code_length: int, idx_list_enc: list, use_selectors: bool):
85 | """Init Encoder network
86 |
87 | Parameters
88 | ----------
89 | input_shape : torch.Tensor
90 | Input data shape
91 | code_length : int
92 | Latent code size
93 | idx_list_enc : list
94 | List of layers' idx to use for the AD task
95 | use_selectors : bool
96 | True (False) if the model has (not) to use Selectors modules
97 |
98 | """
99 | super().__init__()
100 |
101 | self.idx_list_enc = idx_list_enc
102 | self.use_selectors = use_selectors
103 |
104 | # Single input data shape
105 | c, h, w = input_shape
106 |
107 | # Activation function
108 | self.activation_fn = nn.LeakyReLU()
109 |
110 | # Init convolutional blocks
111 | self.conv = nn.Conv2d(in_channels=c, out_channels=32, kernel_size=3, bias=False)
112 | self.res = ResidualBlock(channel_in=32, channel_out=32, activation_fn=self.activation_fn)
113 | self.dwn1, self.dwn2, self.dwn3 = [init_conv_blocks(channel_in=ch, channel_out=ch*2, activation_fn=self.activation_fn) for ch in CHANNELS]
114 |
115 | # Depth of the last features map
116 | self.last_depth = CHANNELS[2]*2
117 |
118 | # Shape of the last features map
119 | self.deepest_shape = (self.last_depth, h // 8, w // 8)
120 |
121 | # init FC layers
122 | self.fc1 = nn.Linear(in_features=reduce(mul, self.deepest_shape), out_features=self.last_depth)
123 | self.bn = nn.BatchNorm1d(num_features=self.last_depth)
124 | self.fc2 = nn.Linear(in_features=self.last_depth, out_features=code_length)
125 |
126 | ## Init features selector models
127 | if self.use_selectors:
128 | self.selectors = nn.ModuleList([Selector(code_length=code_length, idx=idx) for idx in range(7)])
129 | self.selectors.append(Selector(code_length=code_length, idx=6))
130 |
131 | def get_depths_info(self) -> [int, int]:
132 | """
133 | Returns
134 | ------
135 | self.last_depth : int
136 | Depth of the last features map
137 | self.deepest_shape : int
138 | Shape of the last features map
139 |
140 | """
141 | return self.last_depth, self.deepest_shape
142 |
143 | def set_idx_list_enc(self, idx_list_enc: list) -> None:
144 | """Set the list of layers from wchich extract the features.
145 | It is used to initialize the hyperspheres centers so that
146 | independently from which layers we are considering, the first
147 | time that we create the centroids, we do it for all the layers.
148 |
149 | Parameters
150 | ----------
151 | idx_list_enc : list
152 | List of layers indices
153 |
154 | """
155 | self.idx_list_enc = idx_list_enc
156 |
157 | def forward(self, x: torch.Tensor) -> torch.Tensor:
158 | o1 = self.conv(x)
159 | o2 = self.res(self.activation_fn(o1))
160 | o3 = self.dwn1(o2)
161 | o4 = self.dwn2(o3)
162 | o5 = self.dwn3(o4)
163 | o7 = self.activation_fn(
164 | self.bn(
165 | self.fc1(
166 | o5.view(len(o5), -1)
167 | )
168 | )
169 | ) # FC -> BN -> LeakyReLU
170 | o8 = self.fc2(o7)
171 | z = nn.Sigmoid()(o8)
172 |
173 | outputs = [o1, o2, o3, o4, o5, o7, o8, z]
174 |
175 | if len(self.idx_list_enc) != 0:
176 | # If we are pretraining the full AutoEncoder we don't need any of this and we set self.idx_list_enc = []
177 |
178 | if self.use_selectors:
179 | tuple_o = [self.selectors[idx](tt) for idx, tt in enumerate(outputs) if idx in self.idx_list_enc]
180 |
181 | else:
182 | # If we don't use selector, apply simple transformations to reduce the size of the feature maps
183 | tuple_o = []
184 |
185 | for idx, tt in enumerate(outputs):
186 | if idx not in self.idx_list_enc: continue
187 |
188 | if tt.ndimension() > 2:
189 | tuple_o.append(F.avg_pool2d(tt, tt.shape[-2:]).squeeze())
190 |
191 | else:
192 | tuple_o.append(tt.squeeze())
193 |
194 | return list(zip([f'0{idx}' for idx in self.idx_list_enc], tuple_o))
195 |
196 | else: # It means that we are pretraining the full AutoEncoder
197 | return z
198 |
199 |
200 | class MVTec_Decoder(BaseModule):
201 | """MVTec Decoder network
202 |
203 | """
204 | def __init__(self, code_length: int, deepest_shape: int, last_depth: int, output_shape: torch.Tensor):
205 | """Init MVtec Decoder network
206 |
207 | Parameters
208 | ----------
209 | code_length : int
210 | Latent code size
211 | deepest_shape : int
212 | Depth of the last encoder features map
213 | output_shape : torch.Tensor
214 | Input Data shape
215 |
216 | """
217 | super().__init__()
218 |
219 | self.code_length = code_length
220 | self.deepest_shape = deepest_shape
221 | self.output_shape = output_shape
222 |
223 | # Decoder activation function
224 | activation_fn = nn.LeakyReLU()
225 |
226 | # FC network
227 | self.fc = nn.Sequential(
228 | nn.Linear(in_features=code_length, out_features=last_depth),
229 | nn.BatchNorm1d(num_features=last_depth),
230 | activation_fn,
231 | nn.Linear(in_features=last_depth, out_features=reduce(mul, deepest_shape)),
232 | nn.BatchNorm1d(num_features=reduce(mul, deepest_shape)),
233 | activation_fn
234 | )
235 |
236 | # (Transposed) Convolutional network
237 | self.conv = nn.Sequential(
238 | UpsampleBlock(channel_in=CHANNELS[2]*2, channel_out=CHANNELS[2], activation_fn=activation_fn),
239 | UpsampleBlock(channel_in=CHANNELS[1]*2, channel_out=CHANNELS[1], activation_fn=activation_fn),
240 | UpsampleBlock(channel_in=CHANNELS[0]*2, channel_out=CHANNELS[0], activation_fn=activation_fn),
241 | ResidualBlock(channel_in=CHANNELS[0], channel_out=CHANNELS[0], activation_fn=activation_fn),
242 | nn.Conv2d(in_channels=CHANNELS[0], out_channels=3, kernel_size=1, bias=False)
243 | )
244 |
245 | def forward(self, x: torch.Tensor) -> torch.Tensor:
246 | h = self.fc(x)
247 | h = h.view(len(h), *self.deepest_shape)
248 | return self.conv(h)
249 |
250 |
251 | class MVTecNet_AutoEncoder(BaseModule):
252 | """Full MVTecNet_AutoEncoder network
253 |
254 | """
255 | def __init__(self, input_shape: int, code_length: int, use_selectors: bool):
256 | """Init Full AutoEncoder
257 |
258 | Parameters
259 | ----------
260 | input_shape : Tensor
261 | Shape of input data
262 | code_length : int
263 | Latent code size
264 | use_selectors : bool
265 | True (False) if the model has (not) to use Selectors modules
266 |
267 | """
268 | super().__init__()
269 |
270 | # Shape of input data needed by the Decoder
271 | self.input_shape = input_shape
272 |
273 | # Build Encoder
274 | self.encoder = MVTec_Encoder(
275 | input_shape=input_shape,
276 | code_length=code_length,
277 | idx_list_enc=[],
278 | use_selectors=use_selectors
279 | )
280 |
281 | last_depth, deepest_shape = self.encoder.get_depths_info()
282 |
283 | # Build Decoder
284 | self.decoder = MVTec_Decoder(
285 | code_length=code_length,
286 | deepest_shape=deepest_shape,
287 | last_depth=last_depth,
288 | output_shape=input_shape
289 | )
290 |
291 | def forward(self, x: torch.Tensor) -> torch.Tensor:
292 | z = self.encoder(x)
293 | x_r = self.decoder(z)
294 | x_r = x_r.view(-1, *self.input_shape)
295 | return x_r
296 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import logging
4 | import numpy as np
5 | from tqdm import tqdm
6 | from PIL import Image
7 | from os.path import join
8 |
9 | import torch
10 | import torchvision.transforms as T
11 | from torch.utils.data import DataLoader
12 | from torchvision.datasets import ImageFolder
13 |
14 | from models.mvtec_model import MVTec_Encoder
15 |
16 |
17 | def get_out_dir(args, pretrain: bool, aelr: float, dset_name: str="cifar10", training_strategy: str=None) -> [str, str]:
18 | """Creates training output dir
19 |
20 | Parameters
21 | ----------
22 |
23 | args :
24 | Arguments
25 | pretrain : bool
26 | True if pretrain the model
27 | aelr : float
28 | Full AutoEncoder learning rate
29 | dset_name : str
30 | Dataset name
31 | training_strategy : str
32 | ................................................................
33 |
34 | Returns
35 | -------
36 | out_dir : str
37 | Path to output folder
38 | tmp : str
39 | String containing infos about the current experiment setup
40 |
41 | """
42 | if dset_name == "ShanghaiTech":
43 | if pretrain:
44 | tmp = (f"pretrain-mn_{dset_name}-cl_{args.code_length}-lr_{args.ae_learning_rate}")
45 | out_dir = os.path.join(args.output_path, dset_name, 'pretrain', tmp)
46 | else:
47 | tmp = (
48 | f"train-mn_{dset_name}-cl_{args.code_length}-bs_{args.batch_size}-nu_{args.nu}-lr_{args.learning_rate}-"
49 | f"bd_{args.boundary}-sl_{args.use_selectors}-ile_{'.'.join(map(str, args.idx_list_enc))}-lstm_{args.load_lstm}-"
50 | f"bidir_{args.bidirectional}-hs_{args.hidden_size}-nl_{args.num_layers}-dp_{args.dropout}"
51 | )
52 | out_dir = os.path.join(args.output_path, dset_name, 'train', tmp)
53 | if args.end_to_end_training:
54 | out_dir = os.path.join(args.output_path, dset_name, 'train_end_to_end', tmp)
55 | else:
56 | if pretrain:
57 | tmp = (f"pretrain-mn_{dset_name}-nc_{args.normal_class}-cl_{args.code_length}-lr_{args.ae_learning_rate}-awd_{args.ae_weight_decay}")
58 | out_dir = os.path.join(args.output_path, dset_name, str(args.normal_class), 'pretrain', tmp)
59 |
60 | else:
61 | tmp = (
62 | f"train-mn_{dset_name}-nc_{args.normal_class}-cl_{args.code_length}-bs_{args.batch_size}-nu_{args.nu}-lr_{args.learning_rate}-"
63 | f"wd_{args.weight_decay}-bd_{args.boundary}-alr_{aelr}-sl_{args.use_selectors}-ep_{args.epochs}-ile_{'.'.join(map(str, args.idx_list_enc))}"
64 | )
65 | out_dir = os.path.join(args.output_path, dset_name, str(args.normal_class), 'train', tmp)
66 |
67 | if not os.path.exists(out_dir):
68 | os.makedirs(out_dir)
69 |
70 | return out_dir, tmp
71 |
72 |
73 | def set_seeds(seed: int) -> None:
74 | """Set all seeds.
75 |
76 | Parameters
77 | ----------
78 | seed : int
79 | Seed
80 |
81 | """
82 | # Set the seed only if the user specified it
83 | if seed != -1:
84 | random.seed(seed)
85 | np.random.seed(seed)
86 | torch.manual_seed(seed)
87 |
88 |
89 | def purge_params(encoder_net, ae_net_cehckpoint: str) -> None:
90 | """Load Encoder preatrained weights from the full AutoEncoder.
91 | After the pretraining phase, we don't need the full AutoEncoder parameters, we only need the Encoder
92 |
93 | Parameters
94 | ----------
95 | encoder_net :
96 | The Encoder network
97 | ae_net_cehckpoint : str
98 | Path to full AutoEncoder checkpoint
99 |
100 | """
101 | # Load the full AutoEncoder checkpoint dict
102 | ae_net_dict = torch.load(ae_net_cehckpoint, map_location=lambda storage, loc: storage)['ae_state_dict']
103 |
104 | # Load encoder weight from autoencoder
105 | net_dict = encoder_net.state_dict()
106 |
107 | # Filter out decoder network keys
108 | st_dict = {k: v for k, v in ae_net_dict.items() if k in net_dict}
109 |
110 | # Overwrite values in the existing state_dict
111 | net_dict.update(st_dict)
112 |
113 | # Load the new state_dict
114 | encoder_net.load_state_dict(net_dict)
115 |
116 |
117 | def load_mvtec_model_from_checkpoint(input_shape: tuple, code_length: int, idx_list_enc: list, use_selectors: bool, net_cehckpoint: str, purge_ae_params: bool = False) -> torch.nn.Module:
118 | """Load AutoEncoder checkpoint.
119 |
120 | Parameters
121 | ----------
122 | input_shape : tuple
123 | Input data shape
124 | code_length : int
125 | Latent code size
126 | idx_list_enc : list
127 | List of indexes of layers from which extract features
128 | use_selectors : bool
129 | True if the model has to use Selector modules
130 | net_cehckpoint : str
131 | Path to model checkpoint
132 | purge_ae_params : bool
133 | True if the checkpoint is relative to an AutoEncoder
134 |
135 | Returns
136 | -------
137 | encoder_net : torch.nn.Module
138 | The Encoder network
139 |
140 | """
141 | logger = logging.getLogger()
142 |
143 | encoder_net = MVTec_Encoder(
144 | input_shape=input_shape,
145 | code_length=code_length,
146 | idx_list_enc=idx_list_enc,
147 | use_selectors=use_selectors
148 | )
149 |
150 | if purge_ae_params:
151 |
152 | # Load Encoder parameters from pretrianed full AutoEncoder
153 | logger.info(f"Loading encoder from: {net_cehckpoint}")
154 | purge_params(encoder_net=encoder_net, ae_net_cehckpoint=net_cehckpoint)
155 | else:
156 |
157 | st_dict = torch.load(net_cehckpoint)
158 | encoder_net.load_state_dict(st_dict['net_state_dict'])
159 | logger.info(f"Loaded model from: {net_cehckpoint}")
160 |
161 | return encoder_net
162 |
163 | def extract_arguments_from_checkpoint(net_checkpoint: str):
164 | """Takes file path of the checkpoint and parse the checkpoint name to extract training parameters and
165 | architectural specifications of the model.
166 |
167 | Parameters
168 | ----------
169 | net_checkpoint : file path of the checkpoint (str)
170 |
171 | Returns
172 | -------
173 | code_length = latent code size (int)
174 | batch_size = batch_size (int)
175 | boundary = soft or hard boundary (str)
176 | use_selectors = if selectors used it is true, otherwise false (bool)
177 | idx_list_enc = indexes of the exploited layers (list of integers)
178 | load_lstm = boolean to show whether lstm used (bool)
179 | hidden_size = hidden size of the lstm (int)
180 | num_layers = number of layers of the lstm (int)
181 | dropout = dropout probability (float)
182 | bidirectional = is lstm bi-directional or not (bool)
183 | dataset_name = name of the dataset (str)
184 | train_type = is it end-to-end, train, or pretrain (str)
185 | """
186 |
187 | code_length = int(net_checkpoint.split(os.sep)[-2].split('-')[2].split('_')[-1])
188 | batch_size = int(net_checkpoint.split(os.sep)[-2].split('-')[3].split('_')[-1])
189 | boundary = net_checkpoint.split(os.sep)[-2].split('-')[6].split('_')[-1]
190 | use_selectors = net_checkpoint.split(os.sep)[-2].split('-')[7].split('_')[-1] == "True"
191 | idx_list_enc = [int(i) for i in net_checkpoint.split(os.sep)[-2].split('-')[8].split('_')[-1].split('.')]
192 | load_lstm = net_checkpoint.split(os.sep)[-2].split('-')[9].split('_')[-1] == "True"
193 | hidden_size = int(net_checkpoint.split(os.sep)[-2].split('-')[11].split('_')[-1])
194 | num_layers = int(net_checkpoint.split(os.sep)[-2].split('-')[12].split('_')[-1])
195 | dropout = float(net_checkpoint.split(os.sep)[-2].split('-')[13].split('_')[-1])
196 | bidirectional = net_checkpoint.split(os.sep)[-2].split('-')[10].split('_')[-1] == "True"
197 | dataset_name = net_checkpoint.split(os.sep)[-4]
198 | train_type = net_checkpoint.split(os.sep)[-3]
199 | return code_length, batch_size, boundary, use_selectors, idx_list_enc, load_lstm, hidden_size, num_layers, dropout, bidirectional, dataset_name, train_type
200 |
201 | def eval_spheres_centers(train_loader: DataLoader, encoder_net: torch.nn.Module, ae_net_cehckpoint: str, use_selectors: bool, device:str, debug: bool) -> dict:
202 | """Eval the centers of the hyperspheres at each chosen layer.
203 |
204 | Parameters
205 | ----------
206 | train_loader : DataLoader
207 | DataLoader for trainin data
208 | encoder_net : torch.nn.Module
209 | Encoder network
210 | ae_net_cehckpoint : str
211 | Checkpoint of the full AutoEncoder
212 | use_selectors : bool
213 | True if we want to use selector models
214 | device : str
215 | Device on which run the computations
216 | debug : bool
217 | Activate debug mode
218 |
219 | Returns
220 | -------
221 | dict : dictionary
222 | Dictionary with k='layer name'; v='features vector representing hypersphere center'
223 |
224 | """
225 | logger = logging.getLogger()
226 |
227 | centers_files = ae_net_cehckpoint[:-4]+f'_w_centers_{use_selectors}.pth'
228 |
229 | # If centers are found, then load and return
230 | if os.path.exists(centers_files):
231 |
232 | logger.info("Found hyperspheres centers")
233 | ae_net_ckp = torch.load(centers_files, map_location=lambda storage, loc: storage)
234 |
235 | centers = {k: v.to(device) for k, v in ae_net_ckp['centers'].items()}
236 | else:
237 |
238 | logger.info("Hyperspheres centers not found... evaluating...")
239 | centers_ = init_center_c(train_loader=train_loader, encoder_net=encoder_net, device=device, debug=debug)
240 |
241 | logger.info("Hyperspheres centers evaluated!!!")
242 | new_ckp = ae_net_cehckpoint.split('.pth')[0]+f'_w_centers_{use_selectors}.pth'
243 |
244 | logger.info(f"New AE dict saved at: {new_ckp}!!!")
245 | centers = {k: v for k, v in centers_.items()}
246 |
247 | torch.save({
248 | 'ae_state_dict': torch.load(ae_net_cehckpoint)['ae_state_dict'],
249 | 'centers': centers
250 | }, new_ckp)
251 |
252 | return centers
253 |
254 |
255 | @torch.no_grad()
256 | def init_center_c(train_loader: DataLoader, encoder_net: torch.nn.Module, device: str, debug: bool, eps: float=0.1) -> dict:
257 | """Initialize hypersphere center as the mean from an initial forward pass on the data.
258 |
259 | Parameters
260 | ----------
261 | train_loader :
262 | encoder_net :
263 | debug :
264 | eps:
265 |
266 | Returns
267 | -------
268 | dictionary : dict
269 | Dictionary with k='layer name'; v='center featrues'
270 |
271 | """
272 | n_samples = 0
273 |
274 | encoder_net.eval().to(device)
275 |
276 | for idx, (data, _) in enumerate(tqdm(train_loader, desc='Init hyperspheres centeres', total=len(train_loader), leave=False)):
277 | if debug and idx == 5: break
278 |
279 | data = data.to(device)
280 | n_samples += data.shape[0]
281 |
282 | zipped = encoder_net(data)
283 |
284 | if idx == 0:
285 | c = {item[0]: torch.zeros_like(item[1][-1], device=device) for item in zipped}
286 |
287 | for item in zipped:
288 | c[item[0]] += torch.sum(item[1], dim=0)
289 |
290 | for k in c.keys():
291 | c[k] = c[k] / n_samples
292 |
293 | # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
294 | c[k][(abs(c[k]) < eps) & (c[k] < 0)] = -eps
295 | c[k][(abs(c[k]) < eps) & (c[k] > 0)] = eps
296 |
297 | return c
298 |
--------------------------------------------------------------------------------
/models/shanghaitech_model.py:
--------------------------------------------------------------------------------
1 | from .shanghaitech_base_model import BaseModule, DownsampleBlock, UpsampleBlock, TemporallySharedFullyConnection, MaskedConv3d
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class Selector(BaseModule):
7 | def __init__(self, code_length, idx):
8 | super(Selector, self).__init__()
9 | """
10 | sizes = [[ch, time , h, w], ...]
11 | """
12 | self.idx = idx
13 | self.sizes = [
14 | [8, 16, 128, 256],
15 | [16, 16, 64, 128],
16 | [32, 8, 32, 64],
17 | [64, 8, 16, 32],
18 | [64, 4, 8, 16]
19 | ]
20 | mid_features_size = 256
21 | #self.adaptive = nn.AdaptiveMaxPool3d(output_size=(None,16,16))
22 | self.cv1 = nn.Sequential(
23 | nn.Conv3d(in_channels=self.sizes[idx][0], out_channels=self.sizes[idx][0]*2, kernel_size=3, padding=1,stride=(1,2,2)),
24 | nn.BatchNorm3d(num_features=self.sizes[idx][0]*2),
25 | nn.ReLU(),
26 | nn.Conv3d(in_channels=self.sizes[idx][0]*2, out_channels=self.sizes[idx][0]*4, kernel_size=3, padding=1),
27 | nn.BatchNorm3d(num_features=self.sizes[idx][0]*4),
28 | nn.ReLU(),
29 | nn.Conv3d(in_channels=self.sizes[idx][0]*4, out_channels=self.sizes[idx][0]*4, kernel_size=1)
30 | )
31 |
32 | self.fc = nn.Sequential(
33 | TemporallySharedFullyConnection(in_features=(self.sizes[self.idx][0]*4 * self.sizes[self.idx][2]//2 * self.sizes[self.idx][3]//2), out_features=mid_features_size, bias=True),
34 | nn.BatchNorm1d((self.sizes[self.idx][1])),
35 | nn.ReLU(),
36 | TemporallySharedFullyConnection(in_features=mid_features_size, out_features=self.sizes[self.idx][0], bias=True)
37 | )
38 |
39 |
40 | def forward(self, x):
41 | x = self.cv1(x)
42 | _, t, _, _ = self.sizes[self.idx]
43 | x = torch.transpose(x, 1, 2).contiguous()
44 | x = x.view(-1, t, (self.sizes[self.idx][0]*4 * self.sizes[self.idx][2]//2 * self.sizes[self.idx][3]//2))
45 | x = self.fc(x)
46 | return x
47 |
48 | def build_lstm(input_size, hidden_size, num_layers, dropout, bidirectional):
49 | return nn.LSTM(
50 | input_size=input_size,
51 | hidden_size=hidden_size,
52 | num_layers=num_layers,
53 | bias=True,
54 | batch_first=True,
55 | dropout=dropout,
56 | bidirectional=bidirectional
57 | )
58 |
59 |
60 | class ShanghaiTechEncoder(BaseModule):
61 | """
62 | ShanghaiTech model encoder.
63 | """
64 | def __init__(self, input_shape, code_length, load_lstm, hidden_size, num_layers, dropout, bidirectional, use_selectors):
65 | # type: (Tuple[int, int, int, int], int) -> None
66 | """
67 | Class constructor:
68 | :param input_shape: the shape of UCSD Ped2 samples.
69 | :param code_length: the dimensionality of latent vectors.
70 | """
71 | super(ShanghaiTechEncoder, self).__init__()
72 |
73 | self.input_shape = input_shape
74 | self.code_length = code_length
75 | self.load_lstm = load_lstm
76 | self.use_selectors = use_selectors
77 |
78 | c, t, h, w = input_shape
79 |
80 | activation_fn = nn.LeakyReLU()
81 |
82 | # Convolutional network
83 | #self.conv = nn.Sequential(
84 | self.conv_1 = DownsampleBlock(channel_in=c, channel_out=8, activation_fn=activation_fn, stride=(1, 2, 2))
85 | self.conv_2 = DownsampleBlock(channel_in=8, channel_out=16, activation_fn=activation_fn, stride=(1, 2, 2))
86 | self.conv_3 = DownsampleBlock(channel_in=16, channel_out=32, activation_fn=activation_fn, stride=(2, 2, 2))
87 | self.conv_4 = DownsampleBlock(channel_in=32, channel_out=64, activation_fn=activation_fn, stride=(1, 2, 2))
88 | self.conv_5 = DownsampleBlock(channel_in=64, channel_out=64, activation_fn=activation_fn, stride=(2, 2, 2))
89 | if load_lstm:
90 | self.lstm_1 = build_lstm(8, hidden_size, num_layers, dropout, bidirectional)
91 | self.lstm_2 = build_lstm(16, hidden_size, num_layers, dropout, bidirectional)
92 | self.lstm_3 = build_lstm(32, hidden_size, num_layers, dropout, bidirectional)
93 | self.lstm_4 = build_lstm(64, hidden_size, num_layers, dropout, bidirectional)
94 | self.lstm_5 = build_lstm(64, hidden_size, num_layers, dropout, bidirectional)
95 | #)
96 |
97 | ## Features selector models (MLPs)
98 | self.sel1 = Selector(self.code_length, 0)
99 | self.sel2 = Selector(self.code_length, 1)
100 | self.sel3 = Selector(self.code_length, 2)
101 | self.sel4 = Selector(self.code_length, 3)
102 | self.sel5 = Selector(self.code_length, 4)
103 |
104 |
105 | self.deepest_shape = (64, t // 4, h // 32, w // 32)
106 |
107 | # FC network
108 | dc, dt, dh, dw = self.deepest_shape
109 | #self.tdl = nn.Sequential(
110 | self.tdl_1 = TemporallySharedFullyConnection(in_features=(dc * dh * dw), out_features=512)
111 | self.tanh = nn.Tanh()
112 | self.tdl_2 = TemporallySharedFullyConnection(in_features=512, out_features=code_length)
113 | self.sigmoid = nn.Sigmoid()
114 | if load_lstm:
115 | self.lstm_tdl_1 = build_lstm(512, hidden_size, num_layers, dropout, bidirectional)
116 | self.lstm_tdl_2 = build_lstm(code_length, hidden_size, num_layers, dropout, bidirectional)
117 | #)
118 |
119 | def forward(self, x):
120 | # types: (torch.Tensor) -> torch.Tensor
121 | """
122 | Forward propagation.
123 | :param x: the input batch of patches.
124 | :return: the batch of latent vectors.
125 | """
126 | h = x
127 | #h = self.conv(h)
128 | o1 = self.conv_1(h)
129 | o2 = self.conv_2(o1)
130 | o3 = self.conv_3(o2)
131 | o4 = self.conv_4(o3)
132 | o5 = self.conv_5(o4)
133 |
134 | # Reshape for fully connected sub-network (flatten)
135 | c, t, height, width = self.deepest_shape
136 | h = torch.transpose(o5, 1, 2).contiguous()
137 | h = h.view(-1, t, (c * height * width))
138 | #o = self.tdl(h)
139 | o_tdl_1 = self.tdl_1(h)
140 | o_tdl_1_t = self.tanh(o_tdl_1)
141 | o_tdl_2 = self.tdl_2(o_tdl_1_t)
142 | o_tdl_2_s = self.sigmoid(o_tdl_2)
143 |
144 | if self.load_lstm:
145 |
146 | def shape_lstm_input(o):
147 | # batch, channel, height, width
148 | o = o.permute(0, 2, 1, 3, 4)
149 | kernel_size = (1, o.shape[-2], o.shape[-1])
150 | o = F.avg_pool3d(o, kernel_size).squeeze() if o.ndimension() > 3 else o
151 | # batch, time, channel
152 | return o if o.ndim > 2 else o.unsqueeze(0)
153 | if self.use_selectors:
154 | o1_lstm, _ = self.lstm_1(self.sel1(o1))
155 | o2_lstm, _ = self.lstm_2(self.sel2(o2))
156 | o3_lstm, _ = self.lstm_3(self.sel3(o3))
157 | o4_lstm, _ = self.lstm_4(self.sel4(o4))
158 | o5_lstm, _ = self.lstm_5(self.sel5(o5))
159 | else:
160 | o1_lstm, _ = self.lstm_1(shape_lstm_input(o1))
161 | o2_lstm, _ = self.lstm_2(shape_lstm_input(o2))
162 | o3_lstm, _ = self.lstm_3(shape_lstm_input(o3))
163 | o4_lstm, _ = self.lstm_4(shape_lstm_input(o4))
164 | o5_lstm, _ = self.lstm_5(shape_lstm_input(o5))
165 |
166 | o1_tdl_lstm, _ = self.lstm_tdl_1(o_tdl_1_t)
167 | o2_tdl_lstm, _ = self.lstm_tdl_2(o_tdl_2_s)
168 |
169 | conv_lstms = [o1_lstm[:, -1], o2_lstm[:, -1], o3_lstm[:, -1], o4_lstm[:, -1], o5_lstm[:, -1]]
170 | tdl_lstms = [o1_tdl_lstm[:, -1], o2_tdl_lstm[:, -1]]
171 |
172 | d_lstms = dict(zip([f"conv_lstm_o_{i}" for i in range(len(conv_lstms))], conv_lstms))
173 | d_lstms.update(dict(zip([f"tdl_lstm_o_{i}" for i in range(len(tdl_lstms))], tdl_lstms)))
174 | return o_tdl_2_s, d_lstms
175 |
176 | else:
177 | return o_tdl_2_s
178 |
179 |
180 | class ShanghaiTechDecoder(BaseModule):
181 | """
182 | ShanghaiTech model decoder.
183 | """
184 | def __init__(self, code_length, deepest_shape, output_shape):
185 | # type: (int, Tuple[int, int, int, int], Tuple[int, int, int, int]) -> None
186 | """
187 | Class constructor.
188 | :param code_length: the dimensionality of latent vectors.
189 | :param deepest_shape: the dimensionality of the encoder's deepest convolutional map.
190 | :param output_shape: the shape of UCSD Ped2 samples.
191 | """
192 | super(ShanghaiTechDecoder, self).__init__()
193 |
194 | self.code_length = code_length
195 | self.deepest_shape = deepest_shape
196 | self.output_shape = output_shape
197 |
198 | dc, dt, dh, dw = deepest_shape
199 |
200 | activation_fn = nn.LeakyReLU()
201 |
202 | # FC network
203 | self.tdl = nn.Sequential(
204 | TemporallySharedFullyConnection(in_features=code_length, out_features=512),
205 | nn.Tanh(),
206 | TemporallySharedFullyConnection(in_features=512, out_features=(dc * dh * dw)),
207 | activation_fn
208 | )
209 |
210 | # Convolutional network
211 | self.conv = nn.Sequential(
212 | UpsampleBlock(channel_in=dc, channel_out=64,
213 | activation_fn=activation_fn, stride=(2, 2, 2), output_padding=(1, 1, 1)),
214 | UpsampleBlock(channel_in=64, channel_out=32,
215 | activation_fn=activation_fn, stride=(1, 2, 2), output_padding=(0, 1, 1)),
216 | UpsampleBlock(channel_in=32, channel_out=16,
217 | activation_fn=activation_fn, stride=(2, 2, 2), output_padding=(1, 1, 1)),
218 | UpsampleBlock(channel_in=16, channel_out=8,
219 | activation_fn=activation_fn, stride=(1, 2, 2), output_padding=(0, 1, 1)),
220 | UpsampleBlock(channel_in=8, channel_out=8,
221 | activation_fn=activation_fn, stride=(1, 2, 2), output_padding=(0, 1, 1)),
222 | nn.Conv3d(in_channels=8, out_channels=output_shape[0], kernel_size=1)
223 | )
224 |
225 | def forward(self, x):
226 | # types: (torch.Tensor) -> torch.Tensor
227 | """
228 | Forward propagation.
229 | :param x: the batch of latent vectors.
230 | :return: the batch of reconstructions.
231 | """
232 | h = x
233 | h = self.tdl(h)
234 |
235 | # Reshape to encoder's deepest convolutional shape
236 | h = torch.transpose(h, 1, 2).contiguous()
237 | h = h.view(len(h), *self.deepest_shape)
238 |
239 | h = self.conv(h)
240 | o = h
241 |
242 | return o
243 |
244 |
245 | class ShanghaiTech(BaseModule):
246 | """
247 | Model for ShanghaiTech video anomaly detection.
248 | """
249 | def __init__(self, input_shape, code_length, load_lstm=False, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False, use_selectors=False):
250 | # type: (Tuple[int, int, int, int], int, int) -> None
251 | """
252 | Class constructor.
253 | :param input_shape: the shape of UCSD Ped2 samples.
254 | :param code_length: the dimensionality of latent vectors.
255 | :param cpd_channels: number of bins in which the multinomial works.
256 | """
257 | super(ShanghaiTech, self).__init__()
258 |
259 | self.input_shape = input_shape
260 | self.code_length = code_length
261 | self.load_lstm = load_lstm
262 | # Build encoder
263 | self.encoder = ShanghaiTechEncoder(
264 | input_shape=input_shape,
265 | code_length=code_length,
266 | load_lstm=load_lstm,
267 | hidden_size=hidden_size,
268 | num_layers=num_layers,
269 | dropout=dropout,
270 | bidirectional=bidirectional,
271 | use_selectors=use_selectors
272 | )
273 |
274 | # Build decoder
275 | self.decoder = ShanghaiTechDecoder(
276 | code_length=code_length,
277 | deepest_shape=self.encoder.deepest_shape,
278 | output_shape=input_shape
279 | )
280 |
281 | def forward(self, x):
282 | # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
283 | """
284 | Forward propagation.
285 | :param x: the input batch of patches.
286 | :return: a tuple of torch.Tensors holding reconstructions, latent vectors and CPD estimates.
287 | """
288 | h = x
289 |
290 | # Produce representations
291 | if self.load_lstm:
292 | z, d_lstms = self.encoder(h)
293 | else:
294 | z = self.encoder(h)
295 |
296 | # Reconstruct x
297 | x_r = self.decoder(z)
298 | x_r = x_r.view(-1, *self.input_shape)
299 | if self.load_lstm:
300 | return x_r, z, d_lstms
301 | else:
302 | return x_r, z
--------------------------------------------------------------------------------
/trainers/train_cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.optim import Adam, SGD
10 | from torch.optim.lr_scheduler import MultiStepLR
11 | from torch.utils.data.dataloader import DataLoader
12 |
13 | from tensorboardX import SummaryWriter
14 |
15 | from sklearn.metrics import roc_auc_score
16 |
17 |
18 | def pretrain(ae_net: torch.nn.Module, train_loader: DataLoader, out_dir: str, tb_writer: SummaryWriter, device: str, ae_learning_rate: float, ae_weight_decay: float, ae_lr_milestones: list, ae_epochs: int) -> str:
19 | """Train the full AutoEncoder network.
20 |
21 | Parameters
22 | ----------
23 | ae_net : torch.nn.Module
24 | AutoEncoder network
25 | train_loader : DataLoader
26 | Data laoder
27 | out_dir : str
28 | Path to checkpoint dir
29 | tb_writer : SummaryWriter
30 | Writer on tensorboard
31 | device : str
32 | Device
33 | ae_learning_rate : float
34 | AutoEncoder learning rate
35 | ae_weight_decay : float
36 | Weight decay
37 | ae_lr_milestones : list
38 | Epochs at which drop the learning rate
39 | ae_epochs: int
40 | Number of training epochs
41 |
42 | Returns
43 | -------
44 | ae_net_cehckpoint : str
45 | Path to model checkpoint
46 |
47 | """
48 | logger = logging.getLogger()
49 |
50 | ae_net = ae_net.train().to(device)
51 |
52 | optimizer = Adam(ae_net.parameters(), lr=ae_learning_rate, weight_decay=ae_weight_decay)
53 | scheduler = MultiStepLR(optimizer, milestones=ae_lr_milestones, gamma=0.1)
54 |
55 | for epoch in range(ae_epochs):
56 | loss_epoch = 0.0
57 | n_batches = 0
58 |
59 | for (data, _, _) in train_loader:
60 | data = data.to(device)
61 |
62 | optimizer.zero_grad()
63 |
64 | outputs = ae_net(data)
65 |
66 | scores = torch.sum((outputs - data) ** 2, dim=tuple(range(1, outputs.dim())))
67 |
68 | loss = torch.mean(scores)
69 | loss.backward()
70 |
71 | optimizer.step()
72 |
73 | loss_epoch += loss.item()
74 | n_batches += 1
75 |
76 | scheduler.step()
77 | if epoch in ae_lr_milestones:
78 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))
79 |
80 | logger.info(f"PreTrain at epoch: {epoch+1} ==> Recon Loss: {loss_epoch/len(train_loader):.4f}")
81 | tb_writer.add_scalar('pretrain/recon_loss', loss_epoch/len(train_loader), epoch+1)
82 |
83 | logger.info('Finished pretraining.')
84 |
85 | ae_net_cehckpoint = os.path.join(out_dir, f'ae_ckp_{time.time()}.pth')
86 | torch.save({'ae_state_dict': ae_net.state_dict()}, ae_net_cehckpoint)
87 | logger.info(f'Saved autoencoder at: {ae_net_cehckpoint}')
88 |
89 | return ae_net_cehckpoint
90 |
91 |
92 | def train(net: torch.nn.Module, train_loader: DataLoader, out_dir: str, tb_writer: SummaryWriter, device: str, ae_net_cehckpoint: str, idx_list_enc: list, learning_rate: float, weight_decay: float, lr_milestones: list, epochs: int, nu: float, boundary: str, debug: bool) -> str:
93 | """Train the Encoder network on the one class task.
94 |
95 | Parameters
96 | ----------
97 | net : torch.nn.Module
98 | Encoder network
99 | train_loader : DataLoader
100 | Data laoder
101 | out_dir : str
102 | Path to checkpoint dir
103 | tb_writer : SummaryWriter
104 | Writer on tensorboard
105 | device : str
106 | Device
107 | ae_net_cehckpoint : str
108 | Path to autoencoder checkpoint
109 | idx_list_enc : list
110 | List of indexes of layers from which extract features
111 | learning_rate : float
112 | AutoEncoder learning rate
113 | weight_decay : float
114 | Weight decay
115 | lr_milestones : list
116 | Epochs at which drop the learning rate
117 | epochs: int
118 | Number of training epochs
119 | nu : float
120 | Value of the trade-off parameter
121 | boundary : str
122 | Type of boundary
123 | debug: bool
124 | If True, enable debug mode
125 |
126 | Returns
127 | -------
128 | net_cehckpoint : str
129 | Path to model checkpoint
130 |
131 | """
132 | logger = logging.getLogger()
133 |
134 | net.train().to(device)
135 |
136 | # Hook model's layers
137 | feat_d = {}
138 | hooks = hook_model(idx_list_enc=idx_list_enc, model=net, dataset_name="cifar10", feat_d=feat_d)
139 |
140 | optimizer = Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
141 | scheduler = MultiStepLR(optimizer, milestones=lr_milestones, gamma=0.1)
142 |
143 | # Initialize hypersphere center c
144 | logger.info('Initializing center c...')
145 | c = init_center_c(feat_d=feat_d, train_loader=train_loader, net=net, device=device)
146 | logger.info('Center c initialized.')
147 |
148 | R = {k: torch.tensor(0.0, device=device) for k in c.keys()}
149 |
150 | logger.info('Start training...')
151 | warm_up_n_epochs = 10
152 |
153 | for epoch in range(epochs):
154 | loss_epoch = 0.0
155 | n_batches = 0
156 | d_from_c = {}
157 |
158 | for (data, _, _) in train_loader:
159 | data = data.to(device)
160 |
161 | # Update network parameters via backpropagation: forward + backward + optimize
162 | _ = net(data)
163 |
164 | dist, loss = eval_ad_loss(feat_d=feat_d, c=c, R=R, nu=nu, boundary=boundary)
165 |
166 | for k in dist.keys():
167 | if k not in d_from_c:
168 | d_from_c[k] = 0
169 | d_from_c[k] += torch.mean(dist[k]).item()
170 |
171 | optimizer.zero_grad()
172 | loss.backward()
173 | optimizer.step()
174 |
175 | # Update hypersphere radius R on mini-batch distances
176 | # only after the warm up epochs
177 | if (boundary == 'soft') and (epoch >= warm_up_n_epochs):
178 | for k in R.keys():
179 | R[k].data = torch.tensor(
180 | np.quantile(np.sqrt(dist[k].clone().data.cpu().numpy()), 1 - nu),
181 | device=device
182 | )
183 |
184 | loss_epoch += loss.item()
185 | n_batches += 1
186 |
187 | scheduler.step()
188 | if epoch in lr_milestones:
189 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))
190 |
191 | # log epoch statistics
192 | logger.info(f"TRAIN at epoch: {epoch} ==> Objective Loss: {loss_epoch/n_batches:.4f}")
193 | tb_writer.add_scalar('train/objective_loss', loss_epoch/n_batches, epoch)
194 | for en, k in enumerate(d_from_c.keys()):
195 | logger.info(
196 | f"[{k}] -- Radius: {R[k]:.4f} - "
197 | f"Dist from sphere centr: {d_from_c[k]/n_batches:.4f}"
198 | )
199 | tb_writer.add_scalar(f'train/radius_{idx_list_enc[en]}', R[k], epoch)
200 | tb_writer.add_scalar(f'train/distance_c_sphere_{idx_list_enc[en]}', d_from_c[k]/n_batches, epoch)
201 |
202 | logger.info('Finished training!!')
203 |
204 | [h.remove() for h in hooks]
205 |
206 | time_ = ae_net_cehckpoint.split('_')[-1].split('.p')[0]
207 | net_cehckpoint = os.path.join(out_dir, f'net_ckp_{time_}.pth')
208 | if debug:
209 | net_cehckpoint = './test_net_ckp.pth'
210 | torch.save({
211 | 'net_state_dict': net.state_dict(),
212 | 'R': R,
213 | 'c': c
214 | },
215 | net_cehckpoint
216 | )
217 | logger.info(f'Saved model at: {net_cehckpoint}')
218 |
219 | return net_cehckpoint
220 |
221 |
222 | def test(net: torch.nn.Module, test_loader: DataLoader, R: dict, c: dict, device: str, idx_list_enc: list, boundary: str) -> float:
223 | """Test the Encoder network.
224 |
225 | Parameters
226 | ----------
227 | net : torch.nn.Module
228 | Encoder network
229 | test_loader : DataLoader
230 | Data laoder
231 | R : dict
232 | Dictionary containing the values of the radiuses for each layer
233 | c : dict
234 | Dictionary containing the values of the hyperspheres' center for each layer
235 | device : str
236 | Device
237 | idx_list_enc : list
238 | List of indexes of layers from which extract features
239 | boundary : str
240 | Type of boundary
241 | debug: bool
242 | If True, enable debug mode
243 |
244 | Returns
245 | -------
246 | test_auc : float
247 | AUC
248 |
249 | """
250 | logger = logging.getLogger()
251 |
252 | # Hook model's layers
253 | feat_d = {}
254 | hooks = hook_model(idx_list_enc=idx_list_enc, model=net, dataset_name="cifar10", feat_d=feat_d)
255 |
256 | # Testing
257 | logger.info('Starti testing...')
258 | idx_label_score = []
259 | net.eval().to(device)
260 | with torch.no_grad():
261 | for data in test_loader:
262 | inputs, labels, idx = data
263 | inputs = inputs.to(device)
264 |
265 | _ = net(inputs)
266 |
267 | scores = get_scores(feat_d=feat_d, c=c, R=R, device=device, boundary=boundary)
268 |
269 | # Save triples of (idx, label, score) in a list
270 | idx_label_score += list(zip(idx.cpu().data.numpy().tolist(),
271 | labels.cpu().data.numpy().tolist(),
272 | scores.cpu().data.numpy().tolist()))
273 |
274 | [h.remove() for h in hooks]
275 |
276 | # Compute AUC
277 | _, labels, scores = zip(*idx_label_score)
278 | labels = np.array(labels)
279 | scores = np.array(scores)
280 | test_auc = roc_auc_score(labels, scores)
281 | logger.info('Test set AUC: {:.2f}%'.format(100. * test_auc))
282 |
283 | logger.info('Finished testing!!')
284 |
285 | return 100. * test_auc
286 |
287 |
288 | def hook_model(idx_list_enc: list, model: torch.nn.Module, dataset_name: str, feat_d: dict) -> None:
289 | """Create hooks for model's layers.
290 |
291 | Parameters
292 | ----------
293 | idx_list_enc : list
294 | List of indexes of layers from which extract features
295 | model : torch.nn.Module
296 | Encoder network
297 | dataset_name : str
298 | Name of the dataset
299 | feat_d : dict
300 | Dictionary containing features
301 |
302 | Returns
303 | -------
304 | registered hooks
305 |
306 | """
307 | if dataset_name == 'mnist':
308 |
309 | blocks_ = [model.conv1, model.conv2, model.fc1]
310 | else:
311 |
312 | blocks_ = [model.conv1, model.conv2, model.conv3, model.fc1]
313 |
314 | if isinstance(idx_list_enc, list) and len(idx_list_enc) != 0:
315 | assert len(idx_list_enc) <= len(blocks_), f"Too many indices for decoder: {idx_list_enc} - for {len(blocks_)} blocks"
316 | blocks = [blocks_[idx] for idx in idx_list_enc]
317 |
318 | blocks_idx = dict(zip(blocks, map('{:02d}'.format, range(len(blocks)))))
319 |
320 | def hook_func(module, input, output):
321 | block_num = blocks_idx[module]
322 | extracted = output
323 | if extracted.ndimension() > 2:
324 | extracted = F.avg_pool2d(extracted, extracted.shape[-2:])
325 | feat_d[block_num] = extracted.squeeze()
326 |
327 | return [b.register_forward_hook(hook_func) for b in blocks_idx]
328 |
329 |
330 | @torch.no_grad()
331 | def init_center_c(feat_d: dict, train_loader: DataLoader, net: torch.nn.Module, device: str, eps: float = 0.1) -> dict:
332 | """Initialize hyperspheres' center c as the mean from an initial forward pass on the data.
333 |
334 | Parameters
335 | ----------
336 | feat_d : dict
337 | Dictionary containing features
338 | train_loader : DataLoader
339 | Training data loader
340 | net : torch.nn.Module
341 | Encoder network
342 | device : str
343 | Device
344 | eps : float = 0.1
345 | If a center is too close to 0, set to +-eps
346 | Returns
347 | -------
348 | c : dict
349 | hyperspheres' center
350 |
351 | """
352 | n_samples = 0
353 |
354 | net.eval()
355 |
356 | for idx, (data, _, _) in enumerate(tqdm(train_loader, desc='init hyperspheres centeres', total=len(train_loader), leave=False)):
357 | data = data.to(device)
358 |
359 | outputs = net(data)
360 | n_samples += outputs.shape[0]
361 |
362 | if idx == 0:
363 | c = {k: torch.zeros_like(feat_d[k][-1], device=device) for k in feat_d.keys()}
364 |
365 | for k in feat_d.keys():
366 | c[k] += torch.sum(feat_d[k], dim=0)
367 |
368 | for k in c.keys():
369 | c[k] = c[k] / n_samples
370 | # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
371 | c[k][(abs(c[k]) < eps) & (c[k] < 0)] = -eps
372 | c[k][(abs(c[k]) < eps) & (c[k] > 0)] = eps
373 |
374 | return c
375 |
376 |
377 | def eval_ad_loss(feat_d: dict, c: dict, R: dict, nu: float, boundary: str) -> [dict, torch.Tensor]:
378 | """Eval training loss.
379 |
380 | Parameters
381 | ----------
382 | feat_d : dict
383 | Dictionary containing features
384 | c : dict
385 | Dictionary hyperspheres' center
386 | R : dict
387 | Dictionary hyperspheres' radius
388 | nu : float
389 | Value of the trade-off parameter
390 | boundary : str
391 | Type of boundary
392 |
393 | Returns
394 | -------
395 | dist : dict
396 | Dictionary containing the average distance of the features vectors from the hypersphere's center for each layer
397 | loss : torch.Tensor
398 | Loss value
399 | """
400 | dist = {}
401 | loss = 1
402 |
403 | for k in feat_d.keys():
404 |
405 | dist[k] = torch.sum((feat_d[k] - c[k].unsqueeze(0)) ** 2, dim=1)
406 | if boundary == 'soft':
407 |
408 | scores = dist[k] - R[k] ** 2
409 | loss += R[k] ** 2 + (1 / nu) * torch.mean(torch.max(torch.zeros_like(scores), scores))
410 | else:
411 |
412 | loss += torch.mean(dist[k])
413 |
414 | return dist, loss
415 |
416 |
417 | def get_scores(feat_d: dict, c: dict, R: dict, device: str, boundary: str) -> float:
418 | """Eval anomaly score.
419 |
420 | Parameters
421 | ----------
422 | feat_d : dict
423 | Dictionary containing features
424 | c : dict
425 | Dictionary hyperspheres' center
426 | R : dict
427 | Dictionary hyperspheres' radius
428 | device : str
429 | Device
430 | boundary : str
431 | Type of boundary
432 |
433 | Returns
434 | -------
435 | scores : float
436 | Anomaly score
437 |
438 | """
439 | dist, _ = eval_ad_loss(feat_d, c, R, 1, boundary)
440 | shape = dist[list(dist.keys())[0]].shape[0]
441 | scores = torch.zeros((shape,), device=device)
442 |
443 | for k in dist.keys():
444 | if boundary == 'soft':
445 |
446 | scores += dist[k] - R[k] ** 2
447 | else:
448 |
449 | scores += dist[k]
450 |
451 | return scores/len(list(dist.keys()))
452 |
--------------------------------------------------------------------------------
/trainers/trainer_mvtec.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import logging
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.optim import Adam, SGD
12 | from torch.utils.data import DataLoader
13 | from torch.optim.lr_scheduler import MultiStepLR
14 |
15 | from tensorboardX import SummaryWriter
16 | from sklearn.metrics import roc_curve, roc_auc_score, auc
17 |
18 |
19 | def pretrain(ae_net: nn.Module, train_loader: DataLoader, out_dir: str, tb_writer: SummaryWriter, device: str, ae_learning_rate: float, ae_weight_decay: float, ae_lr_milestones: list, ae_epochs: int, log_frequency: int, batch_accumulation: int, debug: bool) -> str:
20 | """Train the full AutoEncoder network.
21 |
22 | Parameters
23 | ----------
24 | ae_net : nn.Module
25 | AutoEncoder network
26 | train_loader : DataLoader
27 | Data loader
28 | out_dir : str
29 | Path to checkpoint dir
30 | tb_writer : SummaryWriter
31 | Writer on tensorboard
32 | device : str
33 | Device
34 | ae_learning_rate : float
35 | AutoEncoder learning rate
36 | ae_weight_decay : float
37 | Weight decay
38 | ae_lr_milestones : list
39 | Epochs at which drop the learning rate
40 | ae_epochs: int
41 | Number of training epochs
42 | log_frequency : int
43 | Number of iteration after which show logs
44 | batch_accumulation : int
45 | Number of iteration among which accumulate gradients
46 | debug : bool
47 | Only use the first 10 batches
48 |
49 | Returns
50 | -------
51 | ae_net_cehckpoint : str
52 | Path to model checkpoint
53 |
54 | """
55 | logger = logging.getLogger()
56 |
57 | ae_net = ae_net.train().to(device)
58 |
59 | optimizer = Adam(ae_net.parameters(), lr=ae_learning_rate, weight_decay=ae_weight_decay)
60 | scheduler = MultiStepLR(optimizer, milestones=ae_lr_milestones, gamma=0.1)
61 |
62 | # Independent index to save stats on tensorboard
63 | kk = 1
64 |
65 | # Counter for the batch accumulation steps
66 | j_ba_steps = 0
67 |
68 | for epoch in range(ae_epochs):
69 | loss_epoch = 0.0
70 | n_batches = 0
71 | optimizer.zero_grad()
72 |
73 | for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)):
74 | if debug and idx == 5: break
75 |
76 | data = data.to(device)
77 |
78 | x_r = ae_net(data)
79 |
80 | scores = torch.sum((x_r - data) ** 2, dim=tuple(range(1, x_r.dim())))
81 |
82 | loss = torch.mean(scores)
83 | loss.backward()
84 |
85 | j_ba_steps += 1
86 | if batch_accumulation != -1:
87 | if j_ba_steps % batch_accumulation == 0:
88 | optimizer.step()
89 | optimizer.zero_grad()
90 | j_ba_steps = 0
91 | else:
92 | optimizer.step()
93 | optimizer.zero_grad()
94 |
95 | # Sanity check
96 | if np.isnan(loss.item()):
97 | logger.info("Found nan values into loss")
98 | sys.exit(0)
99 |
100 | loss_epoch += loss.item()
101 | n_batches += 1
102 |
103 | if idx != 0 and idx % ((len(train_loader)//log_frequency)+1) == 0:
104 | logger.info(f"PreTrain at epoch: {epoch+1} ([{idx}]/[{len(train_loader)}]) ==> Recon Loss: {loss_epoch/idx:.4f}")
105 | tb_writer.add_scalar('pretrain/recon_loss', loss_epoch/idx, kk)
106 | kk += 1
107 |
108 | scheduler.step()
109 | if epoch in ae_lr_milestones:
110 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))
111 |
112 | ae_net_cehckpoint = os.path.join(out_dir, f'best_ae_ckp.pth')
113 | torch.save({'ae_state_dict': ae_net.state_dict()}, ae_net_cehckpoint)
114 | logger.info(f'Saved best autoencoder so far at: {ae_net_cehckpoint}')
115 |
116 | logger.info('Finished pretraining.')
117 |
118 | return ae_net_cehckpoint
119 |
120 |
121 | def train(net: torch.nn.Module, train_loader: DataLoader, centers: dict, out_dir: str, tb_writer: SummaryWriter, device: str, learning_rate: float, weight_decay: float, lr_milestones: list, epochs: int, nu: float, boundary: str, batch_accumulation: int, warm_up_n_epochs: int, log_frequency: int, debug: bool) -> str:
122 | """Train the Encoder network on the one class task.
123 |
124 | Parameters
125 | ----------
126 | net : nn.Module
127 | Encoder network
128 | train_loader : DataLoader
129 | Data laoder
130 | centers : dict
131 | Dictionary containing hyperspheres' center at each layer
132 | out_dir : str
133 | Path to checkpoint dir
134 | tb_writer : SummaryWriter
135 | Writer on tensorboard
136 | device : str
137 | Device
138 | learning_rate : float
139 | AutoEncoder learning rate
140 | weight_decay : float
141 | Weight decay
142 | lr_milestones : list
143 | Epochs at which drop the learning rate
144 | epochs: int
145 | Number of training epochs
146 | nu : float
147 | Value of the trade-off parameter
148 | boundary : str
149 | Type of boundary
150 | batch_accumulation : int
151 |
152 | warm_up_n_epochs : int
153 |
154 | log_frequency: int
155 |
156 | debug : bool
157 | Only use the first 10 batches
158 |
159 | Returns
160 | -------
161 | net_cehckpoint : str
162 | Path to model checkpoint
163 |
164 | """
165 | logger = logging.getLogger()
166 |
167 | optimizer = Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
168 | scheduler = MultiStepLR(optimizer, milestones=lr_milestones, gamma=0.1)
169 |
170 | # Init spheres' radius
171 | R = {k: torch.tensor(0.0, device=device) for k in centers.keys()}
172 |
173 | # Training
174 | logger.info('Start training...')
175 | kk = 1
176 | net.train().to(device)
177 | best_loss = 1.e6
178 |
179 | for epoch in range(epochs):
180 | j = 0
181 | loss_epoch = 0.0
182 | n_batches = 0
183 | d_from_c = {}
184 | optimizer.zero_grad()
185 |
186 | for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)):
187 | if debug and idx == 5: break
188 |
189 | data = data.to(device)
190 |
191 | zipped = net(data)
192 |
193 | dist, loss = eval_ad_loss(zipped=zipped, c=centers, R=R, nu=nu, boundary=boundary)
194 |
195 | for k in dist.keys():
196 | if k not in d_from_c:
197 | d_from_c[k] = 0
198 | d_from_c[k] += torch.mean(dist[k]).item()
199 |
200 | loss.backward()
201 | j += 1
202 | if batch_accumulation != -1:
203 | if j == batch_accumulation:
204 | j = 0
205 | optimizer.step()
206 | optimizer.zero_grad()
207 | else:
208 | optimizer.step()
209 | optimizer.zero_grad()
210 |
211 | # Update hypersphere radius R on mini-batch distances
212 | if (boundary == 'soft') and (epoch >= warm_up_n_epochs):
213 | # R.data = torch.tensor(get_radius(dist, nu), device=device)
214 | for k in R.keys():
215 | R[k].data = torch.tensor(
216 | np.quantile(np.sqrt(dist[k].clone().data.cpu().numpy()), 1 - nu),
217 | device=device
218 | )
219 |
220 | loss_epoch += loss.item()
221 | n_batches += 1
222 |
223 | if np.isnan(loss.item()):
224 | logger.info("Found nan values into loss")
225 | sys.exit(0)
226 |
227 | if idx != 0 and idx % ((len(train_loader)//log_frequency)+1) == 0:
228 | # log epoch statistics
229 | logger.info(f"TRAIN at epoch: {epoch+1} ([{idx}]/[{len(train_loader)}]) ==> Objective Loss: {loss_epoch/idx:.4f}")
230 | tb_writer.add_scalar('train/objective_loss', loss_epoch/idx, kk)
231 | for _, k in enumerate(d_from_c.keys()):
232 | logger.info(
233 | f"[{k}] -- Radius: {R[k]:.4f} - "
234 | f"Dist from sphere centr: {d_from_c[k]/idx:.4f}"
235 | )
236 | tb_writer.add_scalar(f'train/radius_{k}', R[k], kk)
237 | tb_writer.add_scalar(f'train/distance_c_sphere_{k}', d_from_c[k]/idx, kk)
238 | kk += 1
239 |
240 | scheduler.step()
241 | if epoch in lr_milestones:
242 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))
243 |
244 | if (loss_epoch/len(train_loader)) <= best_loss:
245 | net_cehckpoint = os.path.join(out_dir, f'best_oc_model.pth')
246 | best_loss = (loss_epoch/len(train_loader))
247 | torch.save({
248 | 'net_state_dict': net.state_dict(),
249 | 'R': R,
250 | 'c': centers
251 | },
252 | net_cehckpoint
253 | )
254 | logger.info(f'Saved best model so far at: {net_cehckpoint}')
255 |
256 | logger.info('Finished training!!')
257 |
258 | return net_cehckpoint
259 |
260 |
261 | def test(normal_class: str, is_texture: bool, net: nn.Module, test_loader: DataLoader, R: dict, c: dict, device: str, boundary: str, debug: bool) -> [float, float]:
262 | """Test the Encoder network.
263 |
264 | Parameters
265 | ----------
266 | normal_class : str
267 | Name of the class under test
268 | is_texture : bool
269 | True if the input data belong to a texture-type class
270 | net : nn.Module
271 | Encoder network
272 | test_loader : DataLoader
273 | Data laoder
274 | R : dict
275 | Dictionary containing the values of the radiuses for each layer
276 | c : dict
277 | Dictionary containing the values of the hyperspheres' center for each layer
278 | device : str
279 | Device
280 | boundary : str
281 | Type of boundary
282 | debug : bool
283 | Only use the first 10 batches
284 |
285 | Returns
286 | -------
287 | test_auc : float
288 | AUC
289 | balanced_accuracy : float
290 | Maximum Balanced Accuracy
291 |
292 | """
293 | logger = logging.getLogger()
294 |
295 | # Testing
296 | logger.info('Start testing...')
297 |
298 | idx_label_score = []
299 |
300 | net.eval().to(device)
301 |
302 | with torch.no_grad():
303 | for idx, (data, labels) in enumerate(tqdm(test_loader, total=len(test_loader), desc=f"Testing class: {normal_class}", leave=False)):
304 | if debug and idx == 5: break
305 |
306 | data = data.to(device)
307 |
308 | if is_texture:
309 | ## Get 8 patches from each texture image ==> the anomaly score is max{score(patches)}
310 | _, _, h, w = data.shape
311 | assert h == w, "Height and Width are different!!!"
312 | patch_size = 64
313 |
314 | patches = [
315 | data[:, :, h_:h_+patch_size, w_:w_+patch_size]
316 | for h_ in range(0, h, patch_size)
317 | for w_ in range(0, w, patch_size)
318 | ]
319 |
320 | patches = torch.stack(patches, dim=1) # shape = (b_size, nb_patches, ch, h, w)
321 |
322 | scores = torch.stack([
323 | get_scores(
324 | zipped=net(batch),
325 | c=c, R=R,
326 | device=device,
327 | boundary=boundary,
328 | is_texture=is_texture)
329 | for batch in patches
330 | ]) # batch.shape = (nb_patches, ch, h, w)
331 |
332 | else:
333 | scores = get_scores(zipped=net(data), c=c, R=R, device=device, boundary=boundary, is_texture=is_texture)
334 |
335 | idx_label_score += list(
336 | zip(
337 | labels.cpu().data.numpy().tolist(),
338 | scores.cpu().data.numpy().tolist()
339 | )
340 | )
341 |
342 | # Compute AUC
343 | labels, scores = zip(*idx_label_score)
344 | labels = np.array(labels)
345 | scores = np.array(scores)
346 |
347 | #test_auc = roc_auc_score(labels, scores)
348 | fpr, tpr, _ = roc_curve(labels, scores)
349 | balanced_accuracy = np.max((tpr + (1 - fpr)) / 2)
350 | auroc = auc(fpr, tpr)
351 | logger.info(f'Test set results ===> AUC: {auroc:.4f} --- maxB: {balanced_accuracy:.4f}')
352 | logger.info('Finished testing!!')
353 |
354 | return auroc, balanced_accuracy
355 |
356 |
357 | def eval_ad_loss(zipped: dict, c: dict, R: dict, nu: float, boundary: str) -> [dict, torch.Tensor]:
358 | """Evaluate ancoder loss in the one class setting.
359 |
360 | Parameters
361 | ----------
362 | zipped : dict
363 | Dictionary containing output features
364 | c : dict
365 | Dictionary of layers centroids
366 | R : dict
367 | Dictionary of layers radiuses
368 | nu : float
369 | Trade-off parameters
370 | boundary: str
371 | Type of boundary
372 |
373 | Returns
374 | -------
375 | dist : dict
376 | Dictionary containing the sum of the distances of the features from the hyperspheres center at each layer
377 | loss : torch.Tensor
378 | Trainign loss
379 |
380 | """
381 | dist = {}
382 |
383 | loss = 1
384 |
385 | for (k, v) in zipped:
386 | dist[k] = torch.sum((v - c[k].unsqueeze(0)) ** 2, dim=1)
387 |
388 | if boundary == 'soft':
389 | scores = dist[k] - R[k] ** 2
390 | loss += R[k] ** 2 + (1 / nu) * torch.mean(torch.max(torch.zeros_like(scores), scores))
391 |
392 | else:
393 | loss += torch.mean(dist[k])
394 |
395 | return dist, loss
396 |
397 |
398 | def get_scores(zipped: dict, c: dict, R: dict, device: str, boundary: str, is_texture: bool) -> float:
399 | """Evaluate anomaly score.
400 |
401 | Parameters
402 | ----------
403 | zipped : dict
404 | Dictionary containing output features
405 | c : dict
406 | Dictionary of layers centroids
407 | R : dict
408 | Dictionary of layers radiuses
409 | device : str
410 | Device on which run the computation
411 | boundary: str
412 | Type of boundary
413 | is_texture : bool
414 | True if images belong to texture-type classes
415 |
416 | Returns
417 | -------
418 | scores : float
419 | Anomlay score for each image
420 |
421 | """
422 | dist = {item[0]: torch.norm(item[1] - c[item[0]].unsqueeze(0), dim=1) for item in zipped}
423 |
424 | shape = dist[list(dist.keys())[0]].shape[0]
425 | scores = torch.zeros((shape,), device=device)
426 |
427 | for k in dist.keys():
428 |
429 | if boundary == 'soft':
430 | scores += dist[k] - R[k] # R[k] is a number not a vector
431 |
432 | else:
433 | scores += dist[k]
434 |
435 | return scores.max()/len(list(dist.keys())) if is_texture else scores/len(list(dist.keys()))
436 |
--------------------------------------------------------------------------------
/main_shanghaitech.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import logging
5 | import argparse
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from torchvision.utils import make_grid
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | from tensorboardX import SummaryWriter
15 |
16 | from models.shanghaitech_model import ShanghaiTech, ShanghaiTechEncoder, ShanghaiTechDecoder
17 |
18 | from datasets.data_manager import DataManager
19 | from datasets.shanghaitech_test import VideoAnomalyDetectionResultHelper
20 |
21 | from trainers.trainer_shanghaitech import pretrain, train
22 | from utils import set_seeds, get_out_dir, eval_spheres_centers, load_mvtec_model_from_checkpoint, extract_arguments_from_checkpoint
23 |
24 |
25 | def main(args):
26 | # Set seed
27 | set_seeds(args.seed)
28 |
29 | # Get the device
30 | device = "cuda" if torch.cuda.is_available() else "cpu"
31 |
32 | if args.disable_logging:
33 | logging.disable(level=logging.INFO)
34 |
35 |
36 | ## Init logger & print training/warm-up summary
37 | logging.basicConfig(
38 | level=logging.INFO,
39 | format="%(asctime)s | %(message)s",
40 | handlers=[
41 | logging.FileHandler('./training.log'),
42 | logging.StreamHandler()
43 | ])
44 | logger = logging.getLogger()
45 |
46 |
47 | if args.train or args.pretrain or args.end_to_end_training:
48 | # If the list of layers from which extract the features is empty, then use the last one (after the sigmoid)
49 | if len(args.idx_list_enc) == 0: args.idx_list_enc = [6]
50 |
51 | logger.info(
52 | "Start run with params:\n"
53 | f"\n\t\t\t\tEnd to end training : {args.end_to_end_training}"
54 | f"\n\t\t\t\tPretrain model : {args.pretrain}"
55 | f"\n\t\t\t\tTrain model : {args.train}"
56 | f"\n\t\t\t\tTest model : {args.test}"
57 | f"\n\t\t\t\tBatch size : {args.batch_size}\n"
58 | f"\n\t\t\t\tAutoEncoder Pretraining"
59 | f"\n\t\t\t\tPretrain epochs : {args.ae_epochs}"
60 | f"\n\t\t\t\tAE-Learning rate : {args.ae_learning_rate}"
61 | f"\n\t\t\t\tAE-milestones : {args.ae_lr_milestones}"
62 | f"\n\t\t\t\tAE-Weight decay : {args.ae_weight_decay}\n"
63 | f"\n\t\t\t\tEncoder Training"
64 | f"\n\t\t\t\tClip length : {args.clip_length}"
65 | f"\n\t\t\t\tBoundary : {args.boundary}"
66 | f"\n\t\t\t\tTrain epochs : {args.epochs}"
67 | f"\n\t\t\t\tLearning rate : {args.learning_rate}"
68 | f"\n\t\t\t\tMilestones : {args.lr_milestones}"
69 | f"\n\t\t\t\tUse selectors : {args.use_selectors}"
70 | f"\n\t\t\t\tWeight decay : {args.weight_decay}"
71 | f"\n\t\t\t\tCode length : {args.code_length}"
72 | f"\n\t\t\t\tNu : {args.nu}"
73 | f"\n\t\t\t\tEncoder list : {args.idx_list_enc}\n"
74 | f"\n\t\t\t\tLSTMs"
75 | f"\n\t\t\t\tLoad LSTMs : {args.load_lstm}"
76 | f"\n\t\t\t\tBidirectional : {args.bidirectional}"
77 | f"\n\t\t\t\tHidden size : {args.hidden_size}"
78 | f"\n\t\t\t\tNumber of layers : {args.num_layers}"
79 | f"\n\t\t\t\tDropout prob : {args.dropout}\n"
80 | )
81 | else:
82 | if args.model_ckp is None:
83 | logger.info("CANNOT TEST MODEL WITHOUT A VALID CHECKPOINT")
84 | sys.exit(0)
85 |
86 |
87 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
88 |
89 | # Init DataHolder class
90 | data_holder = DataManager(
91 | dataset_name='ShanghaiTech',
92 | data_path=args.data_path,
93 | normal_class=None,
94 | only_test=args.test
95 | ).get_data_holder()
96 |
97 | # Load data
98 | train_loader, _ = data_holder.get_loaders(
99 | batch_size=args.batch_size,
100 | shuffle_train=True,
101 | pin_memory=device=="cuda",
102 | num_workers=args.n_workers
103 | )
104 | # Print data infos
105 | only_test = args.test and not args.train and not args.pretrain
106 | logger.info("Dataset info:")
107 | logger.info(
108 | "\n"
109 | f"\n\t\t\t\tBatch size : {args.batch_size}"
110 | )
111 | if not only_test:
112 | logger.info(
113 | f"TRAIN:"
114 | f"\n\t\t\t\tNumber of clips : {len(train_loader.dataset)}"
115 | f"\n\t\t\t\tNumber of batches : {len(train_loader.dataset)//args.batch_size}"
116 | )
117 |
118 |
119 | ########################################################################################
120 | ####### Train the AUTOENCODER on the RECONSTRUCTION task and then train only the #######
121 | ########################## ENCODER on the ONE CLASS OBJECTIVE ##########################
122 | ########################################################################################
123 | ae_net_checkpoint = None
124 | if args.pretrain and not args.end_to_end_training:
125 | out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='ShanghaiTech')
126 |
127 | tb_writer = SummaryWriter(os.path.join(args.output_path, "ShanghaiTech", 'tb_runs_pretrain', tmp))
128 | # Init AutoEncoder
129 | ae_net = ShanghaiTech(data_holder.shape, args.code_length,use_selectors=args.use_selectors)
130 | ### PRETRAIN
131 | ae_net_checkpoint = pretrain(ae_net, train_loader, out_dir, tb_writer, device, args)
132 | tb_writer.close()
133 |
134 | net_checkpoint = None
135 |
136 | if args.train and not args.end_to_end_training:
137 | if ae_net_checkpoint is None:
138 | if args.model_ckp is None:
139 | logger.info("CANNOT TRAIN MODEL WITHOUT A VALID CHECKPOINT")
140 | sys.exit(0)
141 | ae_net_checkpoint = args.model_ckp
142 |
143 | aelr = float(ae_net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1])
144 |
145 | out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr, dset_name='ShanghaiTech')
146 | tb_writer = SummaryWriter(os.path.join(args.output_path, "ShanghaiTech", 'tb_runs_train', tmp))
147 |
148 | # Init Encoder
149 | net = ShanghaiTechEncoder(data_holder.shape, args.code_length, args.load_lstm, args.hidden_size, args.num_layers, args.dropout, args.bidirectional, args.use_selectors)
150 |
151 | # Load encoder weight from autoencoder
152 | net_dict = net.state_dict()
153 | logger.info(f"Loading encoder from: {ae_net_checkpoint}")
154 | ae_net_dict = torch.load(ae_net_checkpoint, map_location=lambda storage, loc: storage)['ae_state_dict']
155 |
156 | # Filter out decoder network keys
157 | st_dict = {k: v for k, v in ae_net_dict.items() if k in net_dict}
158 | # Overwrite values in the existing state_dict
159 | net_dict.update(st_dict)
160 | # Load the new state_dict
161 | net.load_state_dict(net_dict)
162 |
163 | ### TRAIN
164 | net_checkpoint = train(net, train_loader, out_dir, tb_writer, device, ae_net_checkpoint, args)
165 | tb_writer.close()
166 |
167 | ########################################################################################
168 | ########################################################################################
169 |
170 | ########################################################################################
171 | ################### Train the AUTOENCODER on the combined objective: ###################
172 | ############################## RECONSTRUCTION + ONE CLASS ##############################
173 | ########################################################################################
174 | if args.end_to_end_training:
175 |
176 | out_dir, tmp = get_out_dir(args, pretrain=False, aelr=int(args.learning_rate), dset_name='ShanghaiTech')
177 |
178 |
179 | tb_writer = SummaryWriter(os.path.join(args.output_path, "ShanghaiTech", 'tb_runs_train_end_to_end', tmp))
180 | # Init AutoEncoder
181 | ae_net = ShanghaiTech(data_holder.shape, args.code_length, args.load_lstm, args.hidden_size, args.num_layers, args.dropout, args.bidirectional, args.use_selectors)
182 | ### End to end TRAIN
183 | net_checkpoint = train(ae_net, train_loader, out_dir, tb_writer, device, None, args)
184 | tb_writer.close()
185 | ########################################################################################
186 | ########################################################################################
187 |
188 | ########################################################################################
189 | ###################################### Model test ######################################
190 | ########################################################################################
191 | if args.test:
192 | if net_checkpoint is None:
193 | net_checkpoint = args.model_ckp
194 |
195 | code_length, batch_size, boundary, use_selectors, idx_list_enc, \
196 | load_lstm, hidden_size, num_layers, dropout, bidirectional, \
197 | dataset_name, train_type = extract_arguments_from_checkpoint(net_checkpoint)
198 |
199 | # Init dataset
200 | dataset = data_holder.get_test_data()
201 | if train_type == "train_end_to_end":
202 | # Init Autoencoder
203 | net = ShanghaiTech(data_holder.shape, args.code_length, load_lstm, hidden_size, num_layers, dropout, bidirectional, use_selectors)
204 | else:
205 | # Init Encoder ONLY
206 | net = ShanghaiTechEncoder(dataset.shape, code_length, load_lstm, hidden_size, num_layers, dropout, bidirectional, use_selectors)
207 | st_dict = torch.load(net_checkpoint)
208 |
209 | net.load_state_dict(st_dict['net_state_dict'])
210 | logger.info(f"Loaded model from: {net_checkpoint}")
211 | logger.info(
212 | f"Start test with params:"
213 | f"\n\t\t\t\tDataset : {dataset_name}"
214 | f"\n\t\t\t\tCode length : {code_length}"
215 | f"\n\t\t\t\tEnc layer list : {idx_list_enc}"
216 | f"\n\t\t\t\tBoundary : {boundary}"
217 | f"\n\t\t\t\tUse Selectors : {use_selectors}"
218 | f"\n\t\t\t\tBatch size : {batch_size}"
219 | f"\n\t\t\t\tN workers : {args.n_workers}"
220 | f"\n\t\t\t\tLoad LSTMs : {load_lstm}"
221 | f"\n\t\t\t\tHidden size : {hidden_size}"
222 | f"\n\t\t\t\tNum layers : {num_layers}"
223 | f"\n\t\t\t\tBidirectional : {bidirectional}"
224 | f"\n\t\t\t\tDropout prob : {dropout}"
225 | )
226 |
227 | # Initialize test helper for processing each video seperately
228 | # It prints the result to the loaded checkpoint directory
229 | helper = VideoAnomalyDetectionResultHelper(
230 | dataset=dataset,
231 | model=net,
232 | c=st_dict['c'],
233 | R=st_dict['R'],
234 | boundary=boundary,
235 | device=device,
236 | end_to_end_training= True if train_type == "train_end_to_end" else False,
237 | debug=args.debug,
238 | output_file=os.path.join("".join(net_checkpoint.split(os.sep)[:-1]),"shanghaitech_test_results.txt")
239 | )
240 | ### TEST
241 | helper.test_video_anomaly_detection()
242 | print("Test finished")
243 | ########################################################################################
244 | ########################################################################################
245 |
246 |
247 | if __name__ == '__main__':
248 |
249 | parser = argparse.ArgumentParser('AD')
250 | ## General config
251 | parser.add_argument('-s', '--seed', type=int, default=-1, help='Random seed (default: -1)')
252 | parser.add_argument('--n_workers', type=int, default=8, help='Number of workers for data loading. 0 means that the data will be loaded in the main process. (default=8)')
253 | parser.add_argument('--output_path', default='./output')
254 | parser.add_argument('-lf', '--log-frequency', type=int, default=5, help='Log frequency (default: 5)')
255 | parser.add_argument('-dl', '--disable-logging', action="store_true", help='Disabel logging (default: False)')
256 | parser.add_argument('-db', '--debug', action='store_true', help='Debug mode (default: False)')
257 | ## Model config
258 | parser.add_argument('-zl', '--code-length', default=2048, type=int, help='Code length (default: 2048)')
259 | parser.add_argument('-ck', '--model-ckp', help='Model checkpoint')
260 | ## Optimizer config
261 | parser.add_argument('-opt', '--optimizer', choices=('adam', 'sgd'), default='adam', help='Optimizer (default: adam)')
262 | parser.add_argument('-alr', '--ae-learning-rate', type=float, default=1.e-4, help='Warm up learning rate (default: 1.e-4)')
263 | parser.add_argument('-lr', '--learning-rate', type=float, default=1.e-4, help='Learning rate (default: 1.e-4)')
264 | parser.add_argument('-awd', '--ae-weight-decay', type=float, default=0.5e-6, help='Warm up learning rate (default: 1.e-4)')
265 | parser.add_argument('-wd', '--weight-decay', type=float, default=0.5e-6, help='Learning rate (default: 1.e-4)')
266 | parser.add_argument('-aml', '--ae-lr-milestones', type=int, nargs='+', default=[], help='Pretrain milestone')
267 | parser.add_argument('-ml', '--lr-milestones', type=int, nargs='+', default=[], help='Training milestone')
268 | ## Data
269 | parser.add_argument('-dp', '--data-path', default='./ShanghaiTech', help='Dataset main path')
270 | parser.add_argument('-cl', '--clip-length', type=int, default=16, help='Clip length (default: 16)')
271 | ## Training config
272 | # LSTMs
273 | parser.add_argument('-ll', '--load-lstm', action="store_true", help='Load LSTMs (default: False)')
274 | parser.add_argument('-bdl', '--bidirectional', action="store_true", help='Bidirectional LSTMs (default: False)')
275 | parser.add_argument('-hs', '--hidden-size', type=int, default=100, help='Hidden size (default: 100)')
276 | parser.add_argument('-nl', '--num-layers', type=int, default=1, help='Number of LSTMs layers (default: 1)')
277 | parser.add_argument('-drp', '--dropout', type=float, default=0.0, help='Dropout probability (default: 0.0)')
278 | # Autoencoder
279 | parser.add_argument('-ee', '--end-to-end-training', action="store_true", help='End-to-End training of the autoencoder (default: False)')
280 | parser.add_argument('-we', '--warm_up_n_epochs', type=int, default=5, help='Warm up epochs (default: 5)')
281 | parser.add_argument('-use','--use-selectors', action="store_true", help='Use features selector (default: False)')
282 | parser.add_argument('-ba', '--batch-accumulation', type=int, default=-1, help='Batch accumulation (default: -1, i.e., None)')
283 | parser.add_argument('-ptr', '--pretrain', action="store_true", help='Pretrain model (default: False)')
284 | parser.add_argument('-tr', '--train', action="store_true", help='Train model (default: False)')
285 | parser.add_argument('-tt', '--test', action="store_true", help='Test model (default: False)')
286 | parser.add_argument('-tbc', '--train-best-conf', action="store_true", help='Train best configurations (default: False)')
287 | parser.add_argument('-bs', '--batch-size', type=int, default=4, help='Batch size (default: 4)')
288 | parser.add_argument('-bd', '--boundary', choices=("hard", "soft"), default="soft", help='Boundary (default: soft)')
289 | parser.add_argument('-ile', '--idx-list-enc', type=int, nargs='+', default=[], help='List of indices of model encoder')
290 | parser.add_argument('-e', '--epochs', type=int, default=1, help='Training epochs (default: 1)')
291 | parser.add_argument('-ae', '--ae-epochs', type=int, default=1, help='Warmp up epochs (default: 1)')
292 | parser.add_argument('-nu', '--nu', type=float, default=0.1)
293 |
294 | args = parser.parse_args()
295 | main(args)
--------------------------------------------------------------------------------
/datasets/shanghaitech_test.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from os.path import basename
3 | from os.path import isdir
4 | from os.path import join
5 | from typing import List
6 | from typing import Tuple
7 |
8 |
9 |
10 |
11 | import cv2
12 | import torch
13 | import numpy as np
14 | import skimage.io as io
15 | from tqdm import tqdm
16 |
17 | from sklearn.metrics import roc_auc_score
18 | from torch.utils.data import DataLoader
19 | from prettytable import PrettyTable
20 | from skimage.transform import resize
21 | from torchvision import transforms
22 | from torch.utils.data.dataloader import default_collate
23 | from .base import VideoAnomalyDetectionDataset, ToFloatTensor3D, ToFloatTensor3DMask
24 |
25 | class ShanghaiTechTestHandler(VideoAnomalyDetectionDataset):
26 |
27 | def __init__(self, path):
28 | # type: (str) -> None
29 | """
30 | Class constructor.
31 | :param path: The folder in which ShanghaiTech is stored.
32 | """
33 | super(ShanghaiTechTestHandler, self).__init__()
34 | self.path = path
35 | # Test directory
36 | self.test_dir = join(path, 'testing')
37 | # Transform
38 | self.transform = transforms.Compose([ToFloatTensor3D(normalize=True)])
39 | # Load all test ids
40 | self.test_ids = self.load_test_ids()
41 | # Other utilities
42 | self.cur_len = 0
43 | self.cur_video_id = None
44 | self.cur_video_frames = None
45 | self.cur_video_gt = None
46 |
47 | def load_test_ids(self):
48 | # type: () -> List[str]
49 | """
50 | Loads the set of all test video ids.
51 | :return: The list of test ids.
52 | """
53 | return sorted([basename(d) for d in glob(join(self.test_dir, 'nobackground_frames_resized', '**')) if isdir(d)])
54 |
55 | def load_test_sequence_frames(self, video_id):
56 | # type: (str) -> np.ndarray
57 | """
58 | Loads a test video in memory.
59 | :param video_id: the id of the test video to be loaded
60 | :return: the video in a np.ndarray, with shape (n_frames, h, w, c).
61 | """
62 | c, t, h, w = self.shape
63 | sequence_dir = join(self.test_dir, 'nobackground_frames_resized', video_id)
64 | img_list = sorted(glob(join(sequence_dir, '*.jpg')))
65 | #print(f"Creating clips for {sequence_dir} dataset with length {t}...")
66 | return np.stack([np.uint8(io.imread(img_path)) for img_path in img_list])
67 |
68 | def load_test_sequence_gt(self, video_id):
69 | # type: (str) -> np.ndarray
70 | """
71 | Loads the groundtruth of a test video in memory.
72 | :param video_id: the id of the test video for which the groundtruth has to be loaded.
73 | :return: the groundtruth of the video in a np.ndarray, with shape (n_frames,).
74 | """
75 | clip_gt = np.load(join(self.test_dir, 'test_frame_mask', f'{video_id}.npy'))
76 | return clip_gt
77 |
78 | def test(self, video_id):
79 | # type: (str) -> None
80 | """
81 | Sets the dataset in test mode.
82 | :param video_id: the id of the video to test.
83 | """
84 | c, t, h, w = self.shape
85 | self.cur_video_id = video_id
86 | self.cur_video_frames = self.load_test_sequence_frames(video_id)
87 | self.cur_video_gt = self.load_test_sequence_gt(video_id)
88 | self.cur_len = len(self.cur_video_frames) - t + 1
89 |
90 | @property
91 | def shape(self):
92 | # type: () -> Tuple[int, int, int, int]
93 | """
94 | Returns the shape of examples being fed to the model.
95 | """
96 | return 3, 16, 256, 512
97 |
98 | @property
99 | def test_videos(self):
100 | # type: () -> List[str]
101 | """
102 | Returns all available test videos.
103 | """
104 | return self.test_ids
105 |
106 | def __len__(self):
107 | # type: () -> int
108 | """
109 | Returns the number of examples.
110 | """
111 | return self.cur_len
112 |
113 | def __getitem__(self, i):
114 | # type: (int) -> Tuple[torch.Tensor, torch.Tensor]
115 | """
116 | Provides the i-th example.
117 | """
118 | c, t, h, w = self.shape
119 | clip = self.cur_video_frames[i:i+t]
120 | sample = clip
121 | # Apply transform
122 | if self.transform:
123 | sample = self.transform(sample)
124 | return sample
125 |
126 | @property
127 | def collate_fn(self):
128 | """
129 | Returns a function that decides how to merge a list of examples in a batch.
130 | """
131 | return default_collate
132 |
133 | def __repr__(self):
134 | return f'ShanghaiTech (video id = {self.cur_video_id})'
135 |
136 | def get_target_label_idx(labels, targets):
137 | """
138 | Get the indices of labels that are included in targets.
139 | :param labels: array of labels
140 | :param targets: list/tuple of target labels
141 | :return: list with indices of target labels
142 | """
143 | return np.argwhere(np.isin(labels, targets)).flatten().tolist()
144 |
145 |
146 | def global_contrast_normalization(x: torch.tensor, scale='l2'):
147 | """
148 | Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale,
149 | which is either the standard deviation, L1- or L2-norm across features (pixels).
150 | Note this is a *per sample* normalization globally across features (and not across the dataset).
151 | """
152 |
153 | assert scale in ('l1', 'l2')
154 |
155 | n_features = int(np.prod(x.shape))
156 |
157 | mean = torch.mean(x) # mean over all features (pixels) per sample
158 | x -= mean
159 |
160 | if scale == 'l1':
161 | x_scale = torch.mean(torch.abs(x))
162 |
163 | if scale == 'l2':
164 | x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features
165 |
166 | x /= x_scale
167 |
168 | return x
169 |
170 |
171 | class ResultsAccumulator:
172 | """
173 | Accumulates results in a buffer for a sliding window
174 | results computation. Employed to get frame-level scores
175 | from clip-level scores.
176 | ` In order to recover the anomaly score of each
177 | frame, we compute the mean score of all clips in which it
178 | appears`
179 | """
180 | def __init__(self, nb_frames_per_clip):
181 | # type: (int) -> None
182 | """
183 | Class constructor.
184 | :param nb_frames_per_clip: the number of frames each clip holds.
185 | """
186 |
187 | # This buffers rotate.
188 | self._buffer = np.zeros(shape=(nb_frames_per_clip,), dtype=np.float32)
189 | self._counts = np.zeros(shape=(nb_frames_per_clip,))
190 |
191 | def push(self, score):
192 | # type: (float) -> None
193 | """
194 | Pushes the score of a clip into the buffer.
195 | :param score: the score of a clip
196 | """
197 |
198 | # Update buffer and counts
199 | self._buffer += score
200 | self._counts += 1
201 |
202 | def get_next(self):
203 | # type: () -> float
204 | """
205 | Gets the next frame (the first in the buffer) score,
206 | computed as the mean of the clips in which it appeared,
207 | and rolls the buffers.
208 | :return: the averaged score of the frame exiting the buffer.
209 | """
210 |
211 | # Return first in buffer
212 | ret = self._buffer[0] / self._counts[0]
213 |
214 | # Roll time backwards
215 | self._buffer = np.roll(self._buffer, shift=-1)
216 | self._counts = np.roll(self._counts, shift=-1)
217 |
218 | # Zero out final frame (next to be filled)
219 | self._buffer[-1] = 0
220 | self._counts[-1] = 0
221 |
222 | return ret
223 |
224 | @property
225 | def results_left(self):
226 | # type: () -> np.int32
227 | """
228 | Returns the number of frames still in the buffer.
229 | """
230 | return np.sum(self._counts != 0).astype(np.int32)
231 |
232 |
233 | class VideoAnomalyDetectionResultHelper(object):
234 | """
235 | Performs tests for video anomaly detection datasets (UCSD Ped2 or Shanghaitech).
236 | """
237 |
238 | def __init__(self, dataset, model, c, R, boundary, device, end_to_end_training, debug, output_file):
239 | # type: (VideoAnomalyDetectionDataset, BaseModule, str, str) -> None
240 | """
241 | Class constructor.
242 | :param dataset: dataset class.
243 | :param model: pytorch model to evaluate.
244 | :param output_file: text file where to save results.
245 | """
246 | self.dataset = dataset
247 | self.model = model
248 | self.hc = c
249 | self.keys = list(c.keys())
250 | self.R = R
251 | self.boundary = boundary
252 | self.device = device
253 | self.end_to_end_training = end_to_end_training
254 | self.debug = debug
255 | self.output_file = output_file
256 |
257 | def _get_scores(self, d_lstm):
258 | # Eval novelty scores
259 | dist = {k: torch.sum((d_lstm[k] - self.hc[k].unsqueeze(0)) ** 2, dim=1) for k in self.keys}
260 | scores = {k: torch.zeros((dist[k].shape[0],), device=self.device) for k in self.keys}
261 | overall_score = torch.zeros((dist[self.keys[0]].shape[0],), device=self.device)
262 | for k in self.keys:
263 | if self.boundary == 'soft':
264 | scores[k] += dist[k] - self.R[k] ** 2
265 | overall_score += dist[k] - self.R[k] ** 2
266 | else:
267 | scores[k] += dist[k]
268 | overall_score += dist[k]
269 | scores = {k: scores[k]/len(self.keys) for k in self.keys}
270 | return scores, overall_score/len(self.keys)
271 |
272 | @torch.no_grad()
273 | def test_video_anomaly_detection(self):
274 | # type: () -> None
275 | """
276 | Actually performs tests.
277 | """
278 | self.model.eval().to(self.device)
279 |
280 | c, t, h, w = self.dataset.raw_shape
281 |
282 | # Prepare a table to show results
283 | vad_table = self.empty_table
284 |
285 | # Set up container for anomaly scores from all test videos
286 | ## oc: one class
287 | ## rc: reconstruction
288 | ## as: overall anomaly score
289 | global_oc = []
290 | global_rc = []
291 | global_as = []
292 | global_as_by_layer = {k: [] for k in self.keys}
293 | global_y = []
294 | global_y_by_layer = {k: [] for k in self.keys}
295 |
296 | # Get accumulators
297 | results_accumulator_rc = ResultsAccumulator(nb_frames_per_clip=t)
298 | results_accumulator_oc = ResultsAccumulator(nb_frames_per_clip=t)
299 | results_accumulator_oc_by_layer = {k: ResultsAccumulator(nb_frames_per_clip=t) for k in self.keys}
300 | print(self.dataset.test_videos)
301 | # Start iteration over test videos
302 | for cl_idx, video_id in tqdm(enumerate(self.dataset.test_videos), total=len(self.dataset.test_videos), desc="Test on Video"):
303 | # Run the test
304 | self.dataset.test(video_id)
305 | loader = DataLoader(self.dataset, collate_fn=self.dataset.collate_fn)
306 |
307 | # Build score containers
308 | sample_rc = np.zeros(shape=(len(loader) + t - 1,))
309 | sample_oc = np.zeros(shape=(len(loader) + t - 1,))
310 | sample_oc_by_layer = {k: np.zeros(shape=(len(loader) + t - 1,)) for k in self.keys}
311 | sample_y = self.dataset.load_test_sequence_gt(video_id)
312 |
313 | for i, x in tqdm(enumerate(loader), total=len(loader), desc=f'Computing scores for {self.dataset}', leave=False):
314 | # x.shape = [1, 3, 16, 256, 512]
315 | x = x.to(self.device)
316 |
317 | if self.end_to_end_training:
318 | x_r, _, d_lstm = self.model(x)
319 | recon_loss = torch.mean(torch.sum((x_r - x) ** 2, dim=tuple(range(1, x_r.dim()))))
320 | else:
321 | _, d_lstm = self.model(x)
322 | recon_loss = torch.tensor([0.0])
323 |
324 | # Eval one class score for current clip
325 | oc_loss_by_layer, oc_overall_loss = self._get_scores(d_lstm)
326 |
327 | # Feed results accumulators
328 | results_accumulator_rc.push(recon_loss.item())
329 | sample_rc[i] = results_accumulator_rc.get_next()
330 | results_accumulator_oc.push(oc_overall_loss.item())
331 | sample_oc[i] = results_accumulator_oc.get_next()
332 |
333 | for k in self.keys:
334 | if k != "tdl_lstm_o_0" and k != "tdl_lstm_o_1":
335 | results_accumulator_oc_by_layer[k].push(oc_loss_by_layer[k].item())
336 | sample_oc_by_layer[k][i] = results_accumulator_oc_by_layer[k].get_next()
337 |
338 | # Get last results layer by layer
339 | for k in self.keys:
340 | if k != "tdl_lstm_o_0" and k != "tdl_lstm_o_1":
341 | while results_accumulator_oc_by_layer[k].results_left != 0:
342 | index = (- results_accumulator_oc_by_layer[k].results_left)
343 | sample_oc_by_layer[k][index] = results_accumulator_oc_by_layer[k].get_next()
344 |
345 | min_, max_ = sample_oc_by_layer[k].min(), sample_oc_by_layer[k].max()
346 |
347 | # Computes the normalized novelty score given likelihood scores, reconstruction scores
348 | # and normalization coefficients (Eq. 9-10).
349 | sample_ns = (sample_oc_by_layer[k] - min_) / (max_ - min_)
350 |
351 | # Update global scores (used for global metrics)
352 | global_as_by_layer[k].append(sample_ns)
353 | global_y_by_layer[k].append(sample_y)
354 |
355 | try:
356 | # Compute AUROC for this video
357 | this_video_metrics = [
358 | roc_auc_score(sample_y, sample_ns), # anomaly score == one class metric
359 | 0.,
360 | 0.
361 | ]
362 | #vad_table.add_row([k] + [video_id] + this_video_metrics)
363 | except ValueError:
364 | # This happens for sequences in which all frames are abnormal
365 | # Skipping this row in the table (the sequence will still count for global metrics)
366 | continue
367 |
368 | # Get last results
369 | while results_accumulator_oc.results_left != 0:
370 | index = (- results_accumulator_oc.results_left)
371 | sample_oc[index] = results_accumulator_oc.get_next()
372 | sample_rc[index] = results_accumulator_rc.get_next()
373 |
374 | min_oc, max_oc, min_rc, max_rc = sample_oc.min(), sample_oc.max(), sample_rc.min(), sample_rc.max()
375 |
376 | # Computes the normalized novelty score given likelihood scores, reconstruction scores
377 | # and normalization coefficients (Eq. 9-10).
378 | sample_oc = (sample_oc - min_oc) / (max_oc - min_oc)
379 | sample_rc = (sample_rc - min_rc) / (max_rc - min_rc) if (max_rc - min_rc) > 0 else np.zeros_like(sample_rc)
380 | sample_as = sample_oc + sample_rc
381 |
382 | # Update global scores (used for global metrics)
383 | global_oc.append(sample_oc)
384 | global_rc.append(sample_rc)
385 | global_as.append(sample_as)
386 | global_y.append(sample_y)
387 |
388 | try:
389 | # Compute AUROC for this video
390 | this_video_metrics = [
391 | roc_auc_score(sample_y, sample_oc), # one class metric
392 | roc_auc_score(sample_y, sample_rc), # reconstruction metric
393 | roc_auc_score(sample_y, sample_as) # anomaly score
394 | ]
395 | #vad_table.add_row(['Overall'] + [video_id] + this_video_metrics)
396 | except ValueError:
397 | # This happens for sequences in which all frames are abnormal
398 | # Skipping this row in the table (the sequence will still count for global metrics)
399 | continue
400 |
401 | if self.debug: break
402 |
403 | # Compute global AUROC and print table
404 | for k in self.keys:
405 | if k != "tdl_lstm_o_0" and k != "tdl_lstm_o_1":
406 | global_as_by_layer[k] = np.concatenate(global_as_by_layer[k])
407 | global_y_by_layer[k] = np.concatenate(global_y_by_layer[k])
408 | global_metrics = [
409 | roc_auc_score(global_y_by_layer[k], global_as_by_layer[k]), # anomaly score == one class metric
410 | 0.,
411 | 0.
412 | ]
413 | vad_table.add_row([k] + ['avg'] + global_metrics)
414 |
415 | # Compute global AUROC and print table
416 | global_oc = np.concatenate(global_oc)
417 | global_rc = np.concatenate(global_rc)
418 | global_as = np.concatenate(global_as)
419 | global_y = np.concatenate(global_y)
420 | global_metrics = [
421 | roc_auc_score(global_y, global_oc), # one class metric
422 | roc_auc_score(global_y, global_rc), # reconstruction metric
423 | roc_auc_score(global_y, global_as) # anomaly score
424 | ]
425 |
426 | vad_table.add_row(['Overall'] + ['avg'] + global_metrics)
427 | print(vad_table)
428 |
429 | # Save table
430 | with open(self.output_file, mode='w') as f:
431 | f.write(str(vad_table))
432 |
433 | @property
434 | def empty_table(self):
435 | # type: () -> PrettyTable
436 | """
437 | Sets up a nice ascii-art table to hold results.
438 | This table is suitable for the video anomaly detection setting.
439 | :return: table to be filled with auroc metrics.
440 | """
441 | table = PrettyTable()
442 | table.field_names = ['Layer key', 'VIDEO-ID', 'OC metric', 'Recon metric', 'AUROC-AS']
443 | table.float_format = '0.3'
444 | return table
445 |
--------------------------------------------------------------------------------
/main_mvtec.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import random
5 | import logging
6 | import argparse
7 | import numpy as np
8 | import pandas as pd
9 | from tqdm import tqdm
10 | from os import makedirs
11 | from os.path import exists
12 | from prettytable import PrettyTable
13 |
14 | import torch
15 | from torch.utils.data import DataLoader
16 |
17 | from tensorboardX import SummaryWriter
18 |
19 | from datasets.data_manager import DataManager
20 | from models.mvtec_model import MVTecNet_AutoEncoder
21 | from trainers.trainer_mvtec import pretrain, train, test
22 | from utils import set_seeds, get_out_dir, eval_spheres_centers, load_mvtec_model_from_checkpoint
23 |
24 |
25 | def test_models(test_loader: DataLoader, net_cehckpoint: str, tables: tuple, out_df: pd.DataFrame, is_texture: bool, input_shape: tuple, idx_list_enc: list, boundary: str, normal_class: str, use_selectors: bool, device: str, debug: bool):
26 | """Test a single model.
27 |
28 | Parameters
29 | ----------
30 | test_loader : DataLoader
31 | Test data loader
32 | net_cehckpoint : str
33 | Path to model checkpoint
34 | tables : tuple
35 | Tuple containing PrettyTabels for soft and hard boundary
36 | out_df : DataFrame
37 | Output dataframe
38 | is_texture : bool
39 | True if we are dealing with texture-type class
40 | input_shape : tuple
41 | Shape of the input data
42 | idx_list_enc : list
43 | List containing the index of layers from which extract features
44 | boundary : str
45 | Type of boundary
46 | normal_class : str
47 | Name of the normal class
48 | use_selectors : bool
49 | True if we want to use Selector modules
50 | device : str
51 | Device to be used
52 | debug : bool
53 | Activate debug mode
54 |
55 | Returns
56 | -------
57 | out_df : DataFrame
58 | Dataframe containing the test results
59 |
60 | """
61 | logger = logging.getLogger()
62 |
63 | if not os.path.exists(net_cehckpoint):
64 | print(f"File not found at: {net_cehckpoint}")
65 | return out_df
66 |
67 | # Get latent code size from checkpoint name
68 | code_length = int(net_cehckpoint.split('/')[-2].split('-')[3].split('_')[-1])
69 |
70 | if net_cehckpoint.split('/')[-2].split('-')[-1].split('_')[-1].split('.')[0] == '':
71 | idx_list_enc = [7]
72 |
73 | idx_list_enc = [int(i) for i in net_cehckpoint.split('/')[-2].split('-')[-1].split('_')[-1].split('.')]
74 | boundary = net_cehckpoint.split('/')[-2].split('-')[9].split('_')[-1]
75 | normal_class = net_cehckpoint.split('/')[-2].split('-')[2].split('_')[-1]
76 |
77 | logger.info(
78 | f"Start test with params"
79 | f"\n\t\t\t\tCode length : {code_length}"
80 | f"\n\t\t\t\tEnc layer list : {idx_list_enc}"
81 | f"\n\t\t\t\tBoundary : {boundary}"
82 | f"\n\t\t\t\tObject class : {normal_class}"
83 | )
84 |
85 | # Init Encoder
86 | net = load_mvtec_model_from_checkpoint(
87 | input_shape=input_shape,
88 | code_length=code_length,
89 | idx_list_enc=idx_list_enc,
90 | use_selectors=use_selectors,
91 | net_cehckpoint=net_cehckpoint
92 | )
93 |
94 | st_dict = torch.load(net_cehckpoint)
95 | net.load_state_dict(st_dict['net_state_dict'])
96 |
97 | ### TEST
98 | test_auc, test_b_acc = test(
99 | normal_class=normal_class,
100 | is_texture=is_texture,
101 | net=net,
102 | test_loader=test_loader,
103 | R=st_dict['R'],
104 | c=st_dict['c'],
105 | device=device,
106 | boundary=boundary,
107 | debug=debug
108 | )
109 |
110 | table = tables[0] if boundary == 'soft' else tables[1]
111 | table.add_row([
112 | net_cehckpoint.split('/')[-2],
113 | code_length,
114 | idx_list_enc,
115 | net_cehckpoint.split('/')[-2].split('-')[7].split('_')[-1]+'-'+net_cehckpoint.split('/')[-2].split('-')[8],
116 | normal_class,
117 | boundary,
118 | net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1],
119 | net_cehckpoint.split('/')[-2].split('-')[5].split('_')[-1],
120 | test_auc,
121 | test_b_acc
122 | ])
123 |
124 | out_df = out_df.append(dict(
125 | path=net_cehckpoint.split('/')[-2],
126 | code_length=code_length,
127 | enc_l_list=idx_list_enc,
128 | weight_decay=net_cehckpoint.split('/')[-2].split('-')[7].split('_')[-1]+'-'+net_cehckpoint.split('/')[-2].split('-')[8],
129 | object_class=normal_class,
130 | boundary=boundary,
131 | batch_size=net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1],
132 | nu=net_cehckpoint.split('/')[-2].split('-')[5].split('_')[-1],
133 | auc=test_auc,
134 | balanced_acc=test_b_acc
135 | ),
136 | ignore_index=True
137 | )
138 |
139 | return out_df
140 |
141 |
142 | def main(args):
143 | # Set seed
144 | set_seeds(args.seed)
145 |
146 | # Get the device
147 | device = "cuda" if torch.cuda.is_available() else "cpu"
148 |
149 | if args.disable_logging:
150 | logging.disable(level=logging.INFO)
151 |
152 | ## Init logger & print training/warm-up summary
153 | logging.basicConfig(
154 | level=logging.INFO,
155 | format="%(asctime)s | %(message)s",
156 | handlers=[
157 | logging.FileHandler('./training.log'),
158 | logging.StreamHandler()
159 | ])
160 | logger = logging.getLogger()
161 |
162 | if args.train or args.pretrain:
163 | # If the list of layers from which extract the features is empty, then use the last one (after the sigmoid)
164 | if len(args.idx_list_enc) == 0: args.idx_list_enc = [7]
165 |
166 | logger.info(
167 | "Start run with params:"
168 | f"\n\t\t\t\tPretrain model : {args.pretrain}"
169 | f"\n\t\t\t\tTrain model : {args.train}"
170 | f"\n\t\t\t\tTest model : {args.test}"
171 | f"\n\t\t\t\tBoundary : {args.boundary}"
172 | f"\n\t\t\t\tPretrain epochs : {args.ae_epochs}"
173 | f"\n\t\t\t\tAE-Learning rate : {args.ae_learning_rate}"
174 | f"\n\t\t\t\tAE-milestones : {args.ae_lr_milestones}"
175 | f"\n\t\t\t\tAE-Weight decay : {args.ae_weight_decay}\n"
176 | f"\n\t\t\t\tTrain epochs : {args.epochs}"
177 | f"\n\t\t\t\tBatch size: : {args.batch_size}"
178 | f"\n\t\t\t\tBatch acc. : {args.batch_accumulation}"
179 | f"\n\t\t\t\tWarm up epochs : {args.warm_up_n_epochs}"
180 | f"\n\t\t\t\tLearning rate : {args.learning_rate}"
181 | f"\n\t\t\t\tMilestones : {args.lr_milestones}"
182 | f"\n\t\t\t\tUse selectors : {args.use_selectors}"
183 | f"\n\t\t\t\tWeight decay : {args.weight_decay}\n"
184 | f"\n\t\t\t\tCode length : {args.code_length}"
185 | f"\n\t\t\t\tNu : {args.nu}"
186 | f"\n\t\t\t\tEncoder list : {args.idx_list_enc}\n"
187 | f"\n\t\t\t\tTest metric : {args.metric}"
188 | )
189 | else:
190 | if args.model_ckp is None:
191 | logger.info("CANNOT TEST MODEL WITHOUT A VALID CHECKPOINT")
192 | sys.exit(0)
193 |
194 | if args.debug:
195 | args.normal_class = 'carpet'
196 |
197 | else:
198 |
199 | if os.path.isfile(args.model_ckp):
200 | args.normal_class = args.model_ckp.split('/')[-2].split('-')[2].split('_')[-1]
201 |
202 | else:
203 | args.normal_class = args.model_ckp.split('/')[-3]
204 |
205 | # Init DataHolder class
206 | data_holder = DataManager(
207 | dataset_name='MVTec_Anomaly',
208 | data_path=args.data_path,
209 | normal_class=args.normal_class,
210 | only_test=args.test
211 | ).get_data_holder()
212 |
213 | # Load data
214 | train_loader, test_loader = data_holder.get_loaders(
215 | batch_size=args.batch_size,
216 | shuffle_train=True,
217 | pin_memory=device=="cuda",
218 | num_workers=args.n_workers
219 | )
220 |
221 | # Print data infos
222 | only_test = args.test and not args.train and not args.pretrain
223 | logger.info("Dataset info:")
224 | logger.info(
225 | "\n"
226 | f"\n\t\t\t\tNormal class : {args.normal_class}"
227 | f"\n\t\t\t\tBatch size : {args.batch_size}"
228 | )
229 | if not only_test:
230 | logger.info(
231 | f"TRAIN:"
232 | f"\n\t\t\t\tNumber of images : {len(train_loader.dataset)}"
233 | f"\n\t\t\t\tNumber of batches : {len(train_loader.dataset)//args.batch_size}"
234 | )
235 | logger.info(
236 | f"TEST:"
237 | f"\n\t\t\t\tNumber of images : {len(test_loader.dataset)}"
238 | )
239 |
240 | is_texture = args.normal_class in tuple(["carpet", "grid", "leather", "tile", "wood"])
241 | input_shape = (3, 64, 64) if is_texture else (3, 128, 128)
242 |
243 | ### PRETRAIN the full AutoEncoder
244 | ae_net_cehckpoint = None
245 | if args.pretrain:
246 |
247 | pretrain_out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='mvtec')
248 | pretrain_tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/pretrain', tmp))
249 |
250 | # Init AutoEncoder
251 | ae_net = MVTecNet_AutoEncoder(input_shape=input_shape, code_length=args.code_length, use_selectors=args.use_selectors)
252 |
253 | # Start pretraining
254 | ae_net_cehckpoint = pretrain(
255 | ae_net=ae_net,
256 | train_loader=train_loader,
257 | out_dir=pretrain_out_dir,
258 | tb_writer=pretrain_tb_writer,
259 | device=device,
260 | ae_learning_rate=args.ae_learning_rate,
261 | ae_weight_decay=args.ae_weight_decay,
262 | ae_lr_milestones=args.ae_lr_milestones,
263 | ae_epochs=args.ae_epochs,
264 | log_frequency=args.log_frequency,
265 | batch_accumulation=args.batch_accumulation,
266 | debug=args.debug
267 | )
268 |
269 | pretrain_tb_writer.close()
270 |
271 | ### TRAIN the Encoder
272 | net_cehckpoint = None
273 | if args.train:
274 | if ae_net_cehckpoint is None:
275 | if args.model_ckp is None:
276 | logger.info("CANNOT TRAIN MODEL WITHOUT A VALID CHECKPOINT")
277 | sys.exit(0)
278 |
279 | ae_net_cehckpoint = args.model_ckp
280 |
281 | aelr = float(ae_net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1])
282 |
283 | train_out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr, dset_name='mvtec')
284 | train_tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/train', tmp))
285 |
286 | # Init the Encoder network
287 | encoder_net = load_mvtec_model_from_checkpoint(
288 | input_shape=input_shape,
289 | code_length=args.code_length,
290 | idx_list_enc=args.idx_list_enc,
291 | use_selectors=args.use_selectors,
292 | net_cehckpoint=ae_net_cehckpoint,
293 | purge_ae_params=True
294 | )
295 |
296 | ## Eval/Load hyperspeheres centers
297 | encoder_net.set_idx_list_enc(range(8))
298 | centers = eval_spheres_centers(train_loader=train_loader, encoder_net=encoder_net, ae_net_cehckpoint=ae_net_cehckpoint, use_selectors=args.use_selectors, device=device, debug=args.debug)
299 | encoder_net.set_idx_list_enc(args.idx_list_enc)
300 |
301 | # Start training
302 | net_cehckpoint = train(
303 | net=encoder_net,
304 | train_loader=train_loader,
305 | centers=centers,
306 | out_dir=train_out_dir,
307 | tb_writer=train_tb_writer,
308 | device=device,
309 | learning_rate=args.learning_rate,
310 | weight_decay=args.weight_decay,
311 | lr_milestones=args.lr_milestones,
312 | epochs=args.epochs,
313 | nu=args.nu,
314 | boundary=args.boundary,
315 | batch_accumulation=args.batch_accumulation,
316 | warm_up_n_epochs=args.warm_up_n_epochs,
317 | log_frequency=args.log_frequency,
318 | debug=args.debug
319 | )
320 |
321 | train_tb_writer.close()
322 |
323 | ### TEST the Encoder
324 | if args.test:
325 | if net_cehckpoint is None:
326 | net_cehckpoint = args.model_ckp
327 |
328 | # Init table to print resutls on shell
329 | # If we only test one model at a time, on the two tables will be empty
330 | # If all the model checkpoints are in one folder then the two tables will be automatically filled
331 | table_s = PrettyTable()
332 | table_s.field_names = ['Path', 'Code length', 'Enc layer list', 'weight decay', 'Object class', 'Boundary', 'batch size', 'nu', 'AUC', 'Balanced acc']
333 | table_s.float_format = '0.3'
334 | table_h = PrettyTable()
335 | table_h.field_names = ['Path', 'Code length', 'Enc layer list', 'weight decay', 'Object class', 'Boundary', 'batch size', 'nu', 'AUC', 'Balanced acc']
336 | table_h.float_format = '0.3'
337 |
338 | # Init dataframe to store results
339 | out_df = pd.DataFrame()
340 |
341 | is_file = os.path.isfile(net_cehckpoint)
342 | if is_file:
343 | out_df = test_models(
344 | test_loader=test_loader,
345 | net_cehckpoint=net_cehckpoint,
346 | tables=(table_s, table_h),
347 | out_df=out_df,
348 | is_texture=is_texture,
349 | input_shape=input_shape,
350 | idx_list_enc=args.idx_list_enc,
351 | boundary=args.boundary,
352 | normal_class=args.normal_class,
353 | use_selectors=args.use_selectors,
354 | device=device,
355 | debug=args.debug
356 | )
357 | else:
358 | for model_ckp in tqdm(os.listdir(net_cehckpoint), total=len(os.listdir(net_cehckpoint)), desc="Running on models"):
359 | out_df = test_models(
360 | test_loader=test_loader,
361 | net_cehckpoint=os.path.join(net_cehckpoint, model_ckp, 'best_oc_model_model.pth'),
362 | tables=(table_s, table_h),
363 | out_df=out_df,
364 | is_texture=is_texture,
365 | idx_list_enc=args.idx_list_enc,
366 | boundary=args.boundary,
367 | normal_class=args.normal_class,
368 | use_selectors=args.use_selectors,
369 | device=device,
370 | debug=args.debug
371 | )
372 |
373 | print(table_s)
374 | print(table_h)
375 |
376 | b_path = "./output/mvtec_test_results/test_csv"
377 | if not exists(b_path):
378 | makedirs(b_path)
379 |
380 | normal_class = net_cehckpoint.split('/')[-4]
381 | ff = glob.glob(os.path.join(b_path, f'*{normal_class}*'))
382 | if len(ff) == 0:
383 | csv_out_name = os.path.join(b_path, f"test-results-{normal_class}_0.csv")
384 |
385 | else:
386 | ff.sort()
387 | version = int(ff[-1].split('_')[-1].split('.')[0]) + 1
388 | logger.info(f"Already found csv file for {normal_class} with latest version: {version-1} ==> creaing new csv file with version: {version}")
389 | csv_out_name = os.path.join(b_path, f"test-results-{normal_class}_{version}.csv")
390 |
391 | out_df.to_csv(csv_out_name)
392 |
393 |
394 | if __name__ == '__main__':
395 | parser = argparse.ArgumentParser('AD')
396 | ## General config
397 | parser.add_argument('-s', '--seed', type=int, default=-1, help='Random seed (default: -1)')
398 | parser.add_argument('--n_workers', type=int, default=8, help='Number of workers for data loading. 0 means that the data will be loaded in the main process. (default: 8)')
399 | parser.add_argument('--output_path', default='./output')
400 | parser.add_argument('-lf', '--log-frequency', type=int, default=5, help='Log frequency (default: 5)')
401 | parser.add_argument('-dl', '--disable-logging', action="store_true", help='Disabel logging (default: False)')
402 | parser.add_argument('-db', '--debug', action="store_true", help='Activate debug mode, i.e., only use the first three batches (default: False)')
403 | ## Model config
404 | parser.add_argument('-zl', '--code-length', default=64, type=int, help='Code length (default: 64)')
405 | parser.add_argument('-ck', '--model-ckp', help='Model checkpoint')
406 | ## Optimizer config
407 | parser.add_argument('-alr', '--ae-learning-rate', type=float, default=1.e-4, help='Warm up learning rate (default: 1.e-4)')
408 | parser.add_argument('-lr', '--learning-rate', type=float, default=1.e-4, help='Learning rate (default: 1.e-4)')
409 | parser.add_argument('-awd', '--ae-weight-decay', type=float, default=0.5e-6, help='Warm up learning rate (default: 0.5e-4)')
410 | parser.add_argument('-wd', '--weight-decay', type=float, default=0.5e-6, help='Learning rate (default: 0.5e-6)')
411 | parser.add_argument('-aml', '--ae-lr-milestones', type=int, nargs='+', default=[], help='Pretrain milestone')
412 | parser.add_argument('-ml', '--lr-milestones', type=int, nargs='+', default=[], help='Training milestone')
413 | ## Data
414 | parser.add_argument('-dp', '--data-path', default='./MVTec_Anomaly', help='Dataset main path')
415 | parser.add_argument('-nc', '--normal-class', choices=('bottle', 'capsule', 'grid', 'leather', 'metal_nut', 'screw', 'toothbrush', 'wood', 'cable', 'carpet', 'hazelnut', 'pill', 'tile', 'transistor', 'zipper'), default='cable', help='Category (default: cable)')
416 | ## Training config
417 | parser.add_argument('-we', '--warm_up_n_epochs', type=int, default=5, help='Warm up epochs (default: 5)')
418 | parser.add_argument('--use-selectors', action="store_true", help='Use features selector (default: False)')
419 | parser.add_argument('-ba', '--batch-accumulation', type=int, default=-1, help='Batch accumulation (default: -1, i.e., None)')
420 | parser.add_argument('-ptr', '--pretrain', action="store_true", help='Pretrain model (default: False)')
421 | parser.add_argument('-tr', '--train', action="store_true", help='Train model (default: False)')
422 | parser.add_argument('-tt', '--test', action="store_true", help='Test model (default: False)')
423 | parser.add_argument('-tbc', '--train-best-conf', action="store_true", help='Train best configurations (default: False)')
424 | parser.add_argument('-bs', '--batch-size', type=int, default=128, help='Batch size (default: 128)')
425 | parser.add_argument('-bd', '--boundary', choices=("hard", "soft"), default="soft", help='Boundary (default: soft)')
426 | parser.add_argument('-ile', '--idx-list-enc', type=int, nargs='+', default=[], help='List of indices of model encoder')
427 | parser.add_argument('-e', '--epochs', type=int, default=1, help='Training epochs (default: 1)')
428 | parser.add_argument('-ae', '--ae-epochs', type=int, default=1, help='Warmp up epochs (default: 1)')
429 | parser.add_argument('-nu', '--nu', type=float, default=0.1)
430 | ## Test config
431 | parser.add_argument('-mt', '--metric', choices=(1, 2), type=int, default=2, help="Metric to evaluate norms (default: 2, i.e., L2)")
432 | args = parser.parse_args()
433 |
434 | main(args)
435 |
--------------------------------------------------------------------------------