├── models ├── __init__.py ├── cifar10_model.py ├── mvtec_base_model.py ├── shanghaitech_base_model.py ├── mvtec_model.py └── shanghaitech_model.py ├── datasets ├── __init__.py ├── data_manager.py ├── base.py ├── cifar10.py ├── shanghaitech.py ├── mvtec.py └── shanghaitech_test.py ├── trainers ├── __init__.py ├── trainer_shanghaitech.py ├── train_cifar10.py └── trainer_mvtec.py ├── .gitignore ├── images ├── tb1.png ├── tb2.png ├── tb3.png └── mocca.png ├── requirements.txt ├── README.md ├── test_multiple_models_cifar10.py ├── main_cifar10.py ├── utils.py ├── main_shanghaitech.py └── main_mvtec.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/* 2 | *.sh 3 | output 4 | *.pyc 5 | *.log 6 | *__pycache__* -------------------------------------------------------------------------------- /images/tb1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fvmassoli/mocca-anomaly-detection/HEAD/images/tb1.png -------------------------------------------------------------------------------- /images/tb2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fvmassoli/mocca-anomaly-detection/HEAD/images/tb2.png -------------------------------------------------------------------------------- /images/tb3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fvmassoli/mocca-anomaly-detection/HEAD/images/tb3.png -------------------------------------------------------------------------------- /images/mocca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fvmassoli/mocca-anomaly-detection/HEAD/images/mocca.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.5 2 | opencv-python==4.5.1.48 3 | pandas==1.0.4 4 | Pillow==7.2.0 5 | pudb==2019.2 6 | python==3.6.9 7 | scikit-build==0.11.1 8 | scikit-image==0.14.2 9 | scikit-learn==0.23.2 10 | sklearn==0.0 11 | tensorboard==1.14.0 12 | tensorboardX==2.0 13 | torch==1.4.0 14 | torchvision==0.4.2 15 | tqdm==4.56.0 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MOCCA: Multi-Layer One-Class ClassificAtion for Anomaly Detection 2 | 3 | 4 |

5 | MOCCA 6 |

7 | 8 | 9 | 10 | This repository contains the code relative to the paper "[MOCCA: Multi-Layer One-Class ClassificAtion for Anomaly Detection](https://ieeexplore.ieee.org/document/9640579)" by Fabio Valerio Massoli (ISTI - CNR), Fabrizio Falchi (ISTI - CNR), Alperen Kantarci (ITU), Şeymanur Akti (ITU), Hazim Kemal Ekenel (ITU), Giuseppe Amato (ISTI - CNR). 11 | 12 | It reports a new technique to detect anomalies based on a layer-wise paradigm to exploit the features maps generated at different depths of a Deep Learning model. 13 | 14 | The paper has been accepted for publication in the [IEEE Transactions on Neural Networks and Learning Systems, Special Issue on Deep Learning for Anomaly Detection](https://ieeexplore.ieee.org/document/9640579). 15 | 16 | DOI: [10.1109/TNNLS.2021.3130074](https://doi.org/10.1109/TNNLS.2021.3130074). 17 | 18 | **Please note:** 19 | We are researchers, not a software company, and have no personnel devoted to documenting and maintaing this research code. Therefore this code is offered "AS IS". Exact reproduction of the numbers in the paper depends on exact reproduction of many factors, including the version of all software dependencies and the choice of underlying hardware (GPU model, etc). Therefore you should expect to need to re-tune your hyperparameters slightly for your new setup. 20 | 21 | 22 | ## How to run the code 23 | 24 | Before to run the code, make sure that your system has the proper packages installed. You can have a look at the [requirements.txt](https://github.com/fvmassoli/mocca-anomaly-detection/blob/main/requirements.txt) file. 25 | 26 | 27 | Minimal usage (CIFAR10): 28 | 29 | ``` 30 | python main_cifar10.py -ptr -tr -tt -zl 128 -nc -dp 31 | ``` 32 | 33 | Minimal usage (MVTec): 34 | 35 | ``` 36 | python main_mvtec.py -ptr -tr -tt -zl 128 -nc -dp --use-selector 37 | ``` 38 | 39 | Minimal usage (ShanghaiTech): 40 | 41 | ``` 42 | python main_shanghaitech.py -dp -ee -tt -zl 1024 -ll -use 43 | ``` 44 | 45 | ## Reference 46 | For all the details about the training procedure and the experimental results, please have a look at the [paper](https://arxiv.org/abs/2012.12111). 47 | 48 | To cite our work, please use the following form 49 | 50 | ``` 51 | @article{massoli2021mocca, 52 | title={MOCCA: Multilayer One-Class Classification for Anomaly Detection}, 53 | author={Massoli, Fabio Valerio and Falchi, Fabrizio and Kantarci, Alperen and Akti, {\c{S}}eymanur and Ekenel, Hazim Kemal and Amato, Giuseppe}, 54 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 55 | year={2021}, 56 | publisher={IEEE} 57 | } 58 | ``` 59 | 60 | ## Contacts 61 | If you have any question about our work, please contact [Dr. Fabio Valerio Massoli](mailto:fabio.massoli@isti.cnr.it). 62 | 63 | Have fun! :-D 64 | -------------------------------------------------------------------------------- /datasets/data_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | 5 | from .mvtec import MVTec_DataHolder 6 | from .cifar10 import CIFAR10_DataHolder 7 | from .shanghaitech import ShanghaiTech_DataHolder 8 | 9 | 10 | AVAILABLE_DATASETS = ('cifar10', 'ShanghaiTech', 'MVTec_Anomaly') 11 | 12 | 13 | class DataManager(object): 14 | """"Init class to manage and load data 15 | 16 | """ 17 | def __init__(self, dataset_name: str, data_path: str, normal_class: int, clip_length: int=16, only_test: bool=False): 18 | """Init the DataManager 19 | 20 | Parameters 21 | ---------- 22 | dataset_name : str 23 | Name of the dataset 24 | data_path : str 25 | Path to the dataset 26 | normal_class : int 27 | Index of the normal class 28 | clip_length: int 29 | Number of video frames in each clip (ShanghaiTech only) 30 | only_test : bool 31 | True if we are in test model, False otherwise 32 | 33 | """ 34 | self.dataset_name = dataset_name 35 | self.data_path = data_path 36 | self.normal_class = normal_class 37 | self.clip_length = clip_length 38 | self.only_test = only_test 39 | 40 | # Immediately check if the data are available 41 | self.__check_dataset() 42 | 43 | def __check_dataset(self) -> None: 44 | """Checks if the required dataset is available 45 | 46 | """ 47 | assert self.dataset_name in AVAILABLE_DATASETS, f"{self.dataset_name} dataset is not available" 48 | assert os.path.exists(self.data_path), f"{self.dataset_name} dataset is available but not found at: \n{self.data_path}" 49 | 50 | def get_data_holder(self): 51 | """Returns the data holder for the required dataset 52 | 53 | Rerurns 54 | ------- 55 | MVTec_DataHolder : MVTec_DataHolder 56 | Class to handle datasets 57 | 58 | """ 59 | if self.dataset_name == 'cifar10': 60 | return CIFAR10_DataHolder(root=self.data_path, normal_class=self.normal_class) 61 | 62 | if self.dataset_name == 'ShanghaiTech': 63 | return ShanghaiTech_DataHolder(root=self.data_path,clip_length=self.clip_length) 64 | 65 | if self.dataset_name == 'MVTec_Anomaly': 66 | texture_classes = tuple(["carpet", "grid", "leather", "tile", "wood"]) 67 | object_classes = tuple(["bottle", "hazelnut", "metal_nut", "screw"]) 68 | # object_classes2 = tuple(["capsule", "toothbrush", "cable", "pill", "transistor", "zipper"]) 69 | 70 | # check if the selected class is texture-type 71 | is_texture = self.normal_class in texture_classes 72 | if is_texture: 73 | image_size = 512 74 | patch_size = 64 75 | rotation_range = (0, 45) 76 | else: 77 | patch_size = 1 78 | image_size = 128 79 | # For some object-type classes, the anomalies are the rotations themselves 80 | # thus, we don't have to apply rotations as data augmentation 81 | rotation_range = (-45, 45) if self.normal_class in object_classes else (0, 0) 82 | 83 | return MVTec_DataHolder( 84 | data_path=self.data_path, 85 | category=self.normal_class, 86 | image_size=image_size, 87 | patch_size=patch_size, 88 | rotation_range=rotation_range, 89 | is_texture=is_texture 90 | ) 91 | -------------------------------------------------------------------------------- /test_multiple_models_cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import logging 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | import torch 10 | 11 | from trainer_svdd import test 12 | from datasets.main import load_dataset 13 | from models.deep_svdd.deep_svdd_mnist import MNIST_LeNet_Autoencoder, MNIST_LeNet 14 | from models.deep_svdd.deep_svdd_cifar10 import CIFAR10_LeNet_Autoencoder, CIFAR10_LeNet 15 | 16 | 17 | parser = argparse.ArgumentParser('AD') 18 | ## General config 19 | parser.add_argument('--n_jobs_dataloader', type=int, default=0, help='Number of workers for data loading. 0 means that the data will be loaded in the main process.') 20 | ## Model config 21 | parser.add_argument('-zl', '--code-length', default=32, type=int, help='Code length (default: 32)') 22 | parser.add_argument('-ck', '--model-ckp', help='Model checkpoint') 23 | ## Data 24 | parser.add_argument('-ax', '--aux-data-filename', default='/media/fabiovalerio/datasets/ti_500K_pseudo_labeled.pickle', help='Path to unlabelled data') 25 | parser.add_argument('-dn', '--dataset-name', choices=('mnist', 'cifar10'), default='mnist', help='Dataset (default: mnist)') 26 | parser.add_argument('-ul', '--unlabelled-data', action="store_true", help='Use unlabelled data (default: False)') 27 | parser.add_argument('-aux', '--unl-data-path', default="/media/fabiovalerio/datasets/ti_500K_pseudo_labeled.pickle", help='Path to unalbelled data') 28 | ## Training config 29 | parser.add_argument('-bs', '--batch-size', type=int, default=200, help='Batch size (default: 200)') 30 | parser.add_argument('-bd', '--boundary', choices=("hard", "soft"), default="soft", help='Boundary (default: soft)') 31 | parser.add_argument('-ile', '--idx-list-enc', type=int, nargs='+', default=[], help='List of indices of model encoder') 32 | args = parser.parse_args() 33 | 34 | 35 | # Get data base path 36 | _user = os.environ['USER'] 37 | if _user == 'fabiovalerio': 38 | data_path = '/media/fabiovalerio/datasets' 39 | elif _user == 'fabiom': 40 | data_path = '/mnt/datone/datasets/' 41 | else: 42 | raise NotImplementedError('Username %s not configured' % _user) 43 | 44 | 45 | def main(): 46 | cuda_available = torch.cuda.is_available() 47 | device = torch.device('cuda' if cuda_available else 'cpu') 48 | 49 | boundary = args.model_ckp.split('/')[-1].split('-')[-3].split('_')[-1] 50 | normal_class = int(args.model_ckp.split('/')[-1].split('-')[2].split('_')[-1]) 51 | if len(args.idx_list_enc) == 0: 52 | idx_list_enc = [int(i) for i in args.model_ckp.split('/')[-1].split('-')[-1].split('_')[-1].split('.')] 53 | else: 54 | idx_list_enc = args.idx_list_enc 55 | 56 | # LOAD DATA 57 | dataset = load_dataset(args.dataset_name, data_path, normal_class, args.unlabelled_data, args.unl_data_path) 58 | 59 | print( 60 | f"Start test with params" 61 | f"\n\t\t\t\tCode length : {args.code_length}" 62 | f"\n\t\t\t\tEnc layer list : {idx_list_enc}" 63 | f"\n\t\t\t\tBoundary : {boundary}" 64 | f"\n\t\t\t\tNormal class : {normal_class}" 65 | ) 66 | 67 | test_auc = [] 68 | main_model_ckp_dir = args.model_ckp 69 | for m_ckp in tqdm(os.listdir(main_model_ckp_dir), total=len(os.listdir(main_model_ckp_dir)), leave=False): 70 | net_cehckpoint = os.path.join(main_model_ckp_dir, m_ckp) 71 | 72 | # Load model 73 | net = MNIST_LeNet(args.code_length) if args.dataset_name == 'mnist' else CIFAR10_LeNet(args.code_length) 74 | st_dict = torch.load(net_cehckpoint) 75 | net.load_state_dict(st_dict['net_state_dict']) 76 | 77 | # TEST 78 | test_auc_ = test(net, dataset, st_dict['R'], st_dict['c'], device, idx_list_enc, boundary, args) 79 | del net, st_dict 80 | 81 | test_auc.append(test_auc_) 82 | 83 | test_auc = np.asarray(test_auc) 84 | test_auc_m, test_auc_s = test_auc.mean(), test_auc.std() 85 | print("[") 86 | for tau in test_auc: 87 | print(tau, ", ") 88 | print("]") 89 | print(test_auc) 90 | print(f"{test_auc_m:.2f} $\pm$ {test_auc_s:.2f}") 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from abc import abstractmethod 3 | 4 | import torch 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class DatasetBase(Dataset): 10 | """ 11 | Base class for all datasets. 12 | """ 13 | __metaclass__ = ABCMeta 14 | 15 | @abstractmethod 16 | def test(self, *args): 17 | """ 18 | Sets the dataset in test mode. 19 | """ 20 | pass 21 | 22 | @property 23 | @abstractmethod 24 | def shape(self): 25 | """ 26 | Returns the shape of examples. 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def __len__(self): 32 | """ 33 | Returns the number of examples. 34 | """ 35 | pass 36 | 37 | @abstractmethod 38 | def __getitem__(self, i): 39 | """ 40 | Provides the i-th example. 41 | """ 42 | pass 43 | 44 | 45 | class OneClassDataset(DatasetBase): 46 | """ 47 | Base class for all one-class classification datasets. 48 | """ 49 | __metaclass__ = ABCMeta 50 | 51 | @abstractmethod 52 | def val(self, *args): 53 | """ 54 | Sets the dataset in validation mode. 55 | """ 56 | pass 57 | 58 | @property 59 | @abstractmethod 60 | def test_classes(self): 61 | """ 62 | Returns all test possible test classes. 63 | """ 64 | pass 65 | 66 | 67 | class VideoAnomalyDetectionDataset(DatasetBase): 68 | """ 69 | Base class for all video anomaly detection datasets. 70 | """ 71 | __metaclass__ = ABCMeta 72 | 73 | @property 74 | @abstractmethod 75 | def test_videos(self): 76 | """ 77 | Returns all test video ids. 78 | """ 79 | pass 80 | 81 | 82 | @abstractmethod 83 | def __len__(self): 84 | """ 85 | Returns the number of examples. 86 | """ 87 | pass 88 | 89 | @property 90 | def raw_shape(self): 91 | """ 92 | Workaround! 93 | """ 94 | return self.shape 95 | 96 | @abstractmethod 97 | def __getitem__(self, i): 98 | """ 99 | Provides the i-th example. 100 | """ 101 | pass 102 | 103 | @abstractmethod 104 | def load_test_sequence_gt(self, video_id): 105 | # type: (str) -> np.ndarray 106 | """ 107 | Loads the groundtruth of a test video in memory. 108 | :param video_id: the id of the test video for which the groundtruth has to be loaded. 109 | :return: the groundtruth of the video in a np.ndarray, with shape (n_frames,). 110 | """ 111 | pass 112 | 113 | @property 114 | @abstractmethod 115 | def collate_fn(self): 116 | """ 117 | Returns a function that decides how to merge a list of examples in a batch. 118 | """ 119 | pass 120 | 121 | 122 | class ToFloatTensor3D(object): 123 | """ Convert videos to FloatTensors """ 124 | def __init__(self, normalize=True): 125 | self._normalize = normalize 126 | 127 | def __call__(self, sample): 128 | if len(sample) == 3: 129 | X, Y, _ = sample 130 | else: 131 | X = sample 132 | 133 | # swap color axis because 134 | # numpy image: T x H x W x C 135 | X = X.transpose(3, 0, 1, 2) 136 | #Y = Y.transpose(3, 0, 1, 2) 137 | 138 | if self._normalize: 139 | X = X / 255. 140 | 141 | X = np.float32(X) 142 | return torch.from_numpy(X) 143 | 144 | class ToFloatTensor3DMask(object): 145 | """ Convert videos to FloatTensors """ 146 | def __init__(self, normalize=True, has_x_mask=True, has_y_mask=True): 147 | self._normalize = normalize 148 | self.has_x_mask = has_x_mask 149 | self.has_y_mask = has_y_mask 150 | 151 | def __call__(self, sample): 152 | X = sample 153 | # swap color axis because 154 | # numpy image: T x H x W x C 155 | X = X.transpose(3, 0, 1, 2) 156 | 157 | X = np.float32(X) 158 | 159 | if self._normalize: 160 | if self.has_x_mask: 161 | X[:-1] = X[:-1] / 255. 162 | else: 163 | X = X / 255. 164 | 165 | return torch.from_numpy(X) 166 | 167 | 168 | class RemoveBackground: 169 | 170 | def __init__(self, threshold: float): 171 | self.threshold = threshold 172 | 173 | def __call__(self, sample: tuple) -> tuple: 174 | X, Y, background = sample 175 | 176 | mask = np.uint8(np.sum(np.abs(np.int32(X) - background), axis=-1) > self.threshold) 177 | mask = np.expand_dims(mask, axis=-1) 178 | 179 | mask = np.stack([binary_dilation(mask_frame, iterations=5) for mask_frame in mask]) 180 | 181 | X *= mask 182 | 183 | return X, Y, background -------------------------------------------------------------------------------- /models/cifar10_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def init_conv(out_channels: int, k_size: int = 5) -> nn.Module: 10 | """ Init convolutional layers. 11 | 12 | Parameters 13 | ---------- 14 | k_size : int 15 | Kernel size 16 | out_channels : int 17 | Output features size 18 | 19 | Returns 20 | ------- 21 | nn.Module : 22 | Conv2d layer 23 | 24 | """ 25 | l = nn.Conv2d( 26 | in_channels=3 if out_channels==32 else out_channels//2, 27 | out_channels=out_channels, 28 | kernel_size=k_size, 29 | bias=False, 30 | padding=2 31 | ) 32 | nn.init.xavier_uniform_(l.weight, gain=nn.init.calculate_gain('leaky_relu')) 33 | return l 34 | 35 | 36 | def init_deconv(out_channels: int, k_size: int = 5) -> nn.Module: 37 | """ Init deconv layers. 38 | 39 | Parameters 40 | ---------- 41 | k_size : int 42 | Kernel size 43 | out_channels : int 44 | Input features size 45 | 46 | Returns 47 | ------- 48 | nn.Module : 49 | ConvTranspose2d layer 50 | 51 | """ 52 | l = nn.ConvTranspose2d( 53 | in_channels=out_channels, 54 | out_channels=3 if out_channels==32 else out_channels//2, 55 | kernel_size=k_size, 56 | bias=False, 57 | padding=2 58 | ) 59 | nn.init.xavier_uniform_(l.weight, gain=nn.init.calculate_gain('leaky_relu')) 60 | return l 61 | 62 | def init_bn(num_features: int) -> nn.Module: 63 | """ Init BatchNorm layers. 64 | 65 | Parameters 66 | ---------- 67 | num_features : int 68 | Number of input features 69 | 70 | """ 71 | return nn.BatchNorm2d(num_features=num_features, eps=1e-04, affine=False) 72 | 73 | 74 | class BaseNet(nn.Module): 75 | """Base class for all neural networks. 76 | 77 | """ 78 | def __init__(self): 79 | super(BaseNet, self).__init__() 80 | 81 | # init Logger to print model infos 82 | self.logger = logging.getLogger(self.__class__.__name__) 83 | 84 | # List of input/output features depths for the convolutional layers 85 | self.output_features_sizes = [32, 64, 128] 86 | 87 | def summary(self) -> None: 88 | """Network summary. 89 | 90 | """ 91 | net_parameters = filter(lambda p: p.requires_grad, self.parameters()) 92 | params = sum([np.prod(p.size()) for p in net_parameters]) 93 | self.logger.info('Trainable parameters: {}'.format(params)) 94 | self.logger.info(self) 95 | 96 | 97 | class CIFAR10_Encoder(BaseNet): 98 | """"Encoder network. 99 | 100 | """ 101 | def __init__(self, code_length: int): 102 | """"Init encoder. 103 | 104 | Parameters 105 | ---------- 106 | code_length : int 107 | Latent code size 108 | 109 | """ 110 | super(CIFAR10_Encoder, self).__init__() 111 | 112 | # Init Conv layers 113 | self.conv1, self.conv2, self.conv3 = [init_conv(out_channels) for out_channels in self.output_features_sizes] 114 | 115 | # Init BN layers 116 | self.bnd1, self.bnd2, self.bnd3 = [init_bn(num_features) for num_features in self.output_features_sizes] 117 | 118 | # Init all other layers 119 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 120 | self.fc1 = nn.Linear(in_features=128 * 4 * 4, out_features=code_length, bias=False) 121 | 122 | def forward(self, x: torch.Tensor) -> torch.Tensor: 123 | x = self.conv1(x) 124 | x = self.pool(F.leaky_relu(self.bnd1(x))) 125 | x = self.conv2(x) 126 | x = self.pool(F.leaky_relu(self.bnd2(x))) 127 | x = self.conv3(x) 128 | x = self.pool(F.leaky_relu(self.bnd3(x))) 129 | x = self.fc1(x.view(x.size(0), -1)) 130 | return x 131 | 132 | 133 | class CIFAR10_Decoder(BaseNet): 134 | """Full Decoder network. 135 | 136 | """ 137 | def __init__(self, code_length: int): 138 | """Init decoder. 139 | 140 | Parameters 141 | ---------- 142 | code_length : int 143 | Latent code size 144 | 145 | """ 146 | super(CIFAR10_Decoder, self).__init__() 147 | 148 | self.rep_dim = code_length 149 | 150 | self.bn1d = nn.BatchNorm1d(self.rep_dim, eps=1e-04, affine=False) 151 | 152 | # Build the Decoder 153 | self.deconv1 = nn.ConvTranspose2d(int(self.rep_dim / (4 * 4)), 128, 5, bias=False, padding=2) 154 | self.deconv2, self.deconv3, self.deconv4 = [init_deconv(out_channels) for out_channels in self.output_features_sizes[::-1]] 155 | 156 | # Init BN layers 157 | self.bnd4, self.bnd5, self.bnd6 = [init_bn(num_features) for num_features in self.output_features_sizes[::-1]] 158 | 159 | def forward(self, x: torch.Tensor) -> torch.Tensor: 160 | x = self.bn1d(x) 161 | x = x.view(x.size(0), int(self.rep_dim / (4 * 4)), 4, 4) 162 | x = F.leaky_relu(x) 163 | x = self.deconv1(x) 164 | x = F.interpolate(F.leaky_relu(self.bnd4(x)), scale_factor=2) 165 | x = self.deconv2(x) 166 | x = F.interpolate(F.leaky_relu(self.bnd5(x)), scale_factor=2) 167 | x = self.deconv3(x) 168 | x = F.interpolate(F.leaky_relu(self.bnd6(x)), scale_factor=2) 169 | x = self.deconv4(x) 170 | x = torch.sigmoid(x) 171 | return x 172 | 173 | 174 | class CIFAR10_Autoencoder(BaseNet): 175 | """Full AutoEncoder network. 176 | 177 | """ 178 | def __init__(self, code_length: int = 128): 179 | """Init the AutoEncoder 180 | 181 | Parameters 182 | ---------- 183 | code_length : int 184 | Latent code size 185 | 186 | """ 187 | super().__init__() 188 | 189 | # Build the Encoder 190 | self.encoder = CIFAR10_Encoder(code_length=code_length) 191 | self.bn1d = nn.BatchNorm1d(num_features=code_length, eps=1e-04, affine=False) 192 | 193 | # Build the Decoder 194 | self.decoder = CIFAR10_Decoder(code_length=code_length) 195 | 196 | def forward(self, x: torch.Tensor) -> torch.Tensor: 197 | z = self.encoder(x) 198 | z = self.bn1d(z) 199 | return self.decoder(z) 200 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, Subset 6 | 7 | from torchvision.datasets import CIFAR10 8 | import torchvision.transforms as transforms 9 | 10 | 11 | def get_target_label_idx(labels: np.array, targets: np.array): 12 | """Get the indices of labels that are included in targets. 13 | 14 | Parameters 15 | ---------- 16 | labels : np.array 17 | Array of labels 18 | targets : np.array 19 | Array of target labels 20 | 21 | Returns 22 | ------ 23 | List with indices of target labels 24 | 25 | """ 26 | return np.argwhere(np.isin(labels, targets)).flatten().tolist() 27 | 28 | 29 | def global_contrast_normalization(x: torch.tensor, scale: str='l1') -> torch.Tensor: 30 | """Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale, 31 | which is either the standard deviation, L1- or L2-norm across features (pixels). 32 | Note this is a *per sample* normalization globally across features (and not across the dataset). 33 | 34 | Parameters 35 | ---------- 36 | x : torch.tensor 37 | Data sample 38 | scale : str 39 | Scale 40 | 41 | Returns 42 | ------ 43 | Normalized features 44 | 45 | """ 46 | assert scale in ('l1', 'l2') 47 | 48 | n_features = int(np.prod(x.shape)) 49 | 50 | # Evaluate the mean over all features (pixels) per sample 51 | mean = torch.mean(x) 52 | x -= mean 53 | 54 | x_scale = torch.mean(torch.abs(x)) if scale == 'l1' else torch.sqrt(torch.sum(x ** 2)) / n_features 55 | 56 | return x / x_scale 57 | 58 | 59 | class CIFAR10_DataHolder(object): 60 | """CIFAR10 data holder class 61 | 62 | """ 63 | def __init__(self, root: str, normal_class=5): 64 | """Init CIFAR10 data holder class 65 | 66 | Parameters 67 | ---------- 68 | root : str 69 | Path to root folder of the data 70 | normal_class : 71 | Index of the normal class 72 | 73 | """ 74 | self.root = root 75 | 76 | # Total number of classes = 2, i.e., 0: normal, 1: anomalies 77 | self.n_classes = 2 78 | 79 | # Tuple containing the normal classes 80 | self.normal_classes = tuple([normal_class]) 81 | 82 | # List of the anomalous classes 83 | self.anomaly_classes = list(range(0, 10)) 84 | self.anomaly_classes.remove(normal_class) 85 | 86 | # Init the datasets 87 | self.__init_train_test_datasets(normal_class) 88 | 89 | def __init_train_test_datasets(self, normal_class: int) -> None: 90 | """Init the datasets. 91 | 92 | Parameters 93 | ---------- 94 | normal_class : int 95 | The index of the non-anomalous class 96 | 97 | """ 98 | # Pre-computed min and max values (after applying GCN) from train data per class 99 | min_max = [(-28.94083453598571, 13.802961825439636), 100 | (-6.681770233365245, 9.158067708230273), 101 | (-34.924463588638204, 14.419298165027628), 102 | (-10.599172931391799, 11.093187820377565), 103 | (-11.945022995801637, 10.628045447867583), 104 | (-9.691969487694928, 8.948326776180823), 105 | (-9.174940012342555, 13.847014686472365), 106 | (-6.876682005899029, 12.282371383343161), 107 | (-15.603507135507172, 15.2464923804279), 108 | (-6.132882973622672, 8.046098172351265)] 109 | 110 | # Define CIFAR-10 preprocessing operations 111 | # 1. GCN with L1 norm 112 | # 2. min-max feature scaling to [0,1] 113 | self.transform = transforms.Compose([ 114 | transforms.ToTensor(), 115 | transforms.Lambda(lambda x: global_contrast_normalization(x, scale='l1')), 116 | transforms.Normalize( 117 | [min_max[normal_class][0]] * 3, 118 | [min_max[normal_class][1] - min_max[normal_class][0]] * 3 119 | ) 120 | ]) 121 | 122 | # Define CIFAR-10 preprocessing operations on the labels, 123 | # i.e., set to 0 all the labels that belong to the anomalous classes 124 | self.target_transform = transforms.Lambda(lambda x: int(x in self.anomaly_classes)) 125 | 126 | # Init training set 127 | self.train_set = MyCIFAR10( 128 | root=self.root, 129 | train=True, 130 | download=True, 131 | transform=self.transform, 132 | target_transform=self.target_transform 133 | ) 134 | 135 | # Subset the training set by considering normal class images only 136 | train_idx_normal = get_target_label_idx(labels=self.train_set.targets, targets=self.normal_classes) 137 | self.train_set = Subset(self.train_set, train_idx_normal) 138 | 139 | # Init test set 140 | self.test_set = MyCIFAR10( 141 | root=self.root, 142 | train=False, 143 | download=True, 144 | transform=self.transform, 145 | target_transform=self.target_transform 146 | ) 147 | 148 | def get_loaders(self, batch_size: int, shuffle_train: bool=True, pin_memory: bool=False, num_workers: int = 0) -> [torch.utils.data.DataLoader, torch.utils.data.DataLoader]: 149 | """Returns CIFAR10 dataloaders 150 | 151 | Parameters 152 | ---------- 153 | batch_size : int 154 | Size of the batch to 155 | shuffle_train : bool 156 | If True, shuffles the training dataset 157 | pin_memory : bool 158 | If True, pin memeory 159 | num_workers : int 160 | Number of dataloader workers 161 | 162 | Retunrs 163 | ------- 164 | loaders : DataLoader 165 | Train and test data loaders 166 | 167 | """ 168 | train_loader = DataLoader( 169 | dataset=self.train_set, 170 | batch_size=batch_size, 171 | shuffle=shuffle_train, 172 | pin_memory=pin_memory, 173 | num_workers=num_workers 174 | ) 175 | test_loader = DataLoader( 176 | dataset=self.test_set, 177 | batch_size=batch_size, 178 | pin_memory=pin_memory, 179 | num_workers=num_workers 180 | ) 181 | return train_loader, test_loader 182 | 183 | 184 | class MyCIFAR10(CIFAR10): 185 | """Torchvision CIFAR10 class with patch of __getitem__ method to also return the index of a data sample. 186 | 187 | """ 188 | def __init__(self, *args, **kwargs): 189 | super(MyCIFAR10, self).__init__(*args, **kwargs) 190 | 191 | def __getitem__(self, index): 192 | """Override the original method of the CIFAR10 class. 193 | 194 | Parameters 195 | ---------- 196 | index : int 197 | Index 198 | 199 | Returns 200 | ------- 201 | triple: (image, target, index) where target is the index of the target class. 202 | 203 | """ 204 | img, target = self.data[index], self.targets[index] 205 | 206 | img = Image.fromarray(img) 207 | 208 | if self.transform is not None: 209 | img = self.transform(img) 210 | 211 | if self.target_transform is not None: 212 | target = self.target_transform(target) 213 | 214 | return img, target, index 215 | -------------------------------------------------------------------------------- /datasets/shanghaitech.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from tqdm import tqdm 3 | from time import time 4 | from typing import List, Tuple 5 | from os.path import basename, isdir, join, splitext 6 | 7 | import cv2 8 | import numpy as np 9 | import skimage.io as io 10 | 11 | import torch 12 | from torchvision import transforms 13 | from skimage.transform import resize 14 | from torch.utils.data import Dataset, DataLoader 15 | from torch.utils.data.dataloader import default_collate 16 | from .shanghaitech_test import ShanghaiTechTestHandler 17 | 18 | class ShanghaiTech_DataHolder(object): 19 | """ 20 | ShanghaiTech data holder class 21 | 22 | Parameters 23 | ---------- 24 | root : str 25 | root folder of ShanghaiTech dataset 26 | clip_length : int 27 | number of frames that form a clip 28 | stride : int 29 | for creating a clip what should be the size of sliding window 30 | """ 31 | def __init__(self, root: str, clip_length=16, stride=1): 32 | self.root = root 33 | self.clip_length = clip_length 34 | self.stride = stride 35 | self.shape = (3, clip_length, 256, 512) 36 | self.train_dir = join(root, 'training', 'nobackground_frames_resized') 37 | # Transform 38 | self.transform = transforms.Compose([ToFloatTensor3D(normalize=True)]) 39 | 40 | 41 | def get_test_data(self) -> Dataset: 42 | """Load test dataset 43 | 44 | Returns 45 | ------- 46 | ShanghaiTech : Dataset 47 | Custom dataset to handle ShanghaiTech data 48 | 49 | """ 50 | return ShanghaiTechTestHandler(self.root) 51 | 52 | def get_train_data(self, return_dataset: bool=True): 53 | """Load train dataset 54 | 55 | Parameters 56 | ---------- 57 | return_dataset : bool 58 | False for preprocessing purpose only 59 | """ 60 | 61 | if return_dataset: 62 | # Load all ids 63 | self.train_ids = self.load_train_ids() 64 | # Create clips with given clip_length and stride 65 | self.train_clips = self.create_clips(self.train_dir, self.train_ids, clip_length=self.clip_length, stride=self.stride, read_target=False) 66 | 67 | return MySHANGHAI(self.train_clips, self.transform, clip_length=self.clip_length) 68 | else: 69 | return 70 | 71 | def get_loaders(self, batch_size: int, shuffle_train: bool=True, pin_memory: bool=False, num_workers: int = 0) -> [DataLoader, DataLoader]: 72 | """Returns MVtec dataloaders 73 | 74 | Parameters 75 | ---------- 76 | batch_size : int 77 | Size of the batch to 78 | shuffle_train : bool 79 | If True, shuffles the training dataset 80 | pin_memory : bool 81 | If True, pin memeory 82 | num_workers : int 83 | Number of dataloader workers 84 | 85 | Returns 86 | ------- 87 | loaders : DataLoader 88 | Train and test data loaders 89 | 90 | """ 91 | train_loader = DataLoader( 92 | dataset=self.get_train_data(return_dataset=True), 93 | batch_size=batch_size, 94 | shuffle=shuffle_train, 95 | pin_memory=pin_memory, 96 | num_workers=num_workers 97 | ) 98 | test_loader = DataLoader( 99 | dataset=self.get_test_data(), 100 | batch_size=batch_size, 101 | pin_memory=pin_memory, 102 | num_workers=num_workers 103 | ) 104 | return train_loader, test_loader 105 | 106 | def load_train_ids(self): 107 | # type: () -> List[str] 108 | """ 109 | Loads the set of all train video ids. 110 | :return: The list of train ids. 111 | """ 112 | return sorted([basename(d) for d in glob(join(self.train_dir, '**')) if isdir(d)]) 113 | 114 | def create_clips(self, dir_path, ids, clip_length=16, stride=1, read_target=False): 115 | # type: (str, int, int, bool) 116 | """ 117 | Gets frame directory and ids of the directories in the frame dir 118 | Creates clips which consist of number of clip_length at each clip. 119 | Clips are created in a sliding window fashion. Default window slide is 1 120 | but stride controls the window slide 121 | Example: for default parameters first clip is [001.jpg, 002.jpg, ...,016.jpg] 122 | second clip would be [002.jpg, 003.jpg, ..., 017.jpg] 123 | If read_target is True then it will try to read from test directory 124 | If read_target is False then it will populate the array with all zeros 125 | :return: clips:: numpy array with (num_clips,clip_length) shape 126 | ground_truths:: numpy array with (num_clips,clip_length) shape 127 | """ 128 | clips = [] 129 | print(f"Creating clips for {dir_path} dataset with length {clip_length}...") 130 | for idx in tqdm(ids): 131 | frames = sorted([x for x in glob(join(dir_path, idx, "*.jpg"))]) 132 | num_frames = len(frames) 133 | # Slide the window with stride to collect clips 134 | for window in range(0, num_frames-clip_length+1, stride): 135 | clips.append(frames[window:window+clip_length]) 136 | return np.array(clips) 137 | 138 | class MySHANGHAI(Dataset): 139 | def __init__(self, clips, transform=None, clip_length=16): 140 | self.clips = clips 141 | self.transform = transform 142 | self.shape = (3, clip_length, 256, 512) 143 | 144 | def __len__(self): 145 | return 10000 # len(self.clips) 146 | 147 | def __getitem__(self, index): 148 | """ 149 | Args: 150 | index (int): Index 151 | Returns: 152 | triple: (image, target, index) where target is index of the target class. 153 | targets are all 0 target 154 | """ 155 | index_ = torch.randint(0, len(self.clips), size=(1,)).item() 156 | sample = np.stack([np.uint8(io.imread(img_path)) for img_path in self.clips[index_]]) 157 | sample = self.transform(sample) if self.transform else sample 158 | return sample, index_ 159 | 160 | from scipy.ndimage.morphology import binary_dilation 161 | 162 | 163 | def get_target_label_idx(labels, targets): 164 | """ 165 | Get the indices of labels that are included in targets. 166 | :param labels: array of labels 167 | :param targets: list/tuple of target labels 168 | :return: list with indices of target labels 169 | """ 170 | return np.argwhere(np.isin(labels, targets)).flatten().tolist() 171 | 172 | 173 | def global_contrast_normalization(x: torch.tensor, scale='l2'): 174 | """ 175 | Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale, 176 | which is either the standard deviation, L1- or L2-norm across features (pixels). 177 | Note this is a *per sample* normalization globally across features (and not across the dataset). 178 | """ 179 | 180 | assert scale in ('l1', 'l2') 181 | 182 | n_features = int(np.prod(x.shape)) 183 | 184 | mean = torch.mean(x) # mean over all features (pixels) per sample 185 | x -= mean 186 | 187 | if scale == 'l1': 188 | x_scale = torch.mean(torch.abs(x)) 189 | 190 | if scale == 'l2': 191 | x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features 192 | 193 | x /= x_scale 194 | 195 | return x 196 | 197 | class ToFloatTensor3D(object): 198 | """ Convert videos to FloatTensors """ 199 | def __init__(self, normalize=True): 200 | self._normalize = normalize 201 | 202 | def __call__(self, sample): 203 | if len(sample) == 3: 204 | X, Y, _ = sample 205 | else: 206 | X = sample 207 | 208 | # swap color axis because 209 | # numpy image: T x H x W x C 210 | X = X.transpose(3, 0, 1, 2) 211 | #Y = Y.transpose(3, 0, 1, 2) 212 | 213 | if self._normalize: 214 | X = X / 255. 215 | 216 | X = np.float32(X) 217 | return torch.from_numpy(X) 218 | 219 | class ToFloatTensor3DMask(object): 220 | """ Convert videos to FloatTensors """ 221 | def __init__(self, normalize=True, has_x_mask=True, has_y_mask=True): 222 | self._normalize = normalize 223 | self.has_x_mask = has_x_mask 224 | self.has_y_mask = has_y_mask 225 | 226 | def __call__(self, sample): 227 | X = sample 228 | # swap color axis because 229 | # numpy image: T x H x W x C 230 | X = X.transpose(3, 0, 1, 2) 231 | 232 | X = np.float32(X) 233 | 234 | if self._normalize: 235 | if self.has_x_mask: 236 | X[:-1] = X[:-1] / 255. 237 | else: 238 | X = X / 255. 239 | 240 | return torch.from_numpy(X) 241 | 242 | 243 | class RemoveBackground: 244 | 245 | def __init__(self, threshold: float): 246 | self.threshold = threshold 247 | 248 | def __call__(self, sample: tuple) -> tuple: 249 | X, Y, background = sample 250 | 251 | mask = np.uint8(np.sum(np.abs(np.int32(X) - background), axis=-1) > self.threshold) 252 | mask = np.expand_dims(mask, axis=-1) 253 | 254 | mask = np.stack([binary_dilation(mask_frame, iterations=5) for mask_frame in mask]) 255 | 256 | X *= mask 257 | 258 | return X, Y, background -------------------------------------------------------------------------------- /models/mvtec_base_model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from operator import mul 3 | from typing import Optional 4 | from functools import reduce 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import Module 9 | 10 | 11 | class BaseModule(nn.Module): 12 | """ 13 | Implements the basic module. 14 | All other modules inherit from this one 15 | """ 16 | def load_w(self, checkpoint_path): 17 | # type: (str) -> None 18 | """ 19 | Loads a checkpoint into the state_dict. 20 | :param checkpoint_path: the checkpoint file to be loaded. 21 | """ 22 | self.load_state_dict(torch.load(checkpoint_path)) 23 | 24 | def __repr__(self): 25 | # type: () -> str 26 | """ 27 | String representation 28 | """ 29 | good_old = super(BaseModule, self).__repr__() 30 | addition = 'Total number of parameters: {:,}'.format(self.n_parameters) 31 | 32 | return good_old + '\n' + addition 33 | 34 | def __call__(self, *args, **kwargs): 35 | return super(BaseModule, self).__call__(*args, **kwargs) 36 | 37 | @property 38 | def n_parameters(self): 39 | # type: () -> int 40 | """ 41 | Number of parameters of the model. 42 | """ 43 | n_parameters = 0 44 | for p in self.parameters(): 45 | if hasattr(p, 'mask'): 46 | n_parameters += torch.sum(p.mask).item() 47 | else: 48 | n_parameters += reduce(mul, p.shape) 49 | return int(n_parameters) 50 | 51 | 52 | def residual_op(x, functions, bns, activation_fn): 53 | # type: (torch.Tensor, List[Module, Module, Module], List[Module, Module, Module], Module) -> torch.Tensor 54 | """ 55 | Implements a global residual operation. 56 | :param x: the input tensor. 57 | :param functions: a list of functions (nn.Modules). 58 | :param bns: a list of optional batch-norm layers. 59 | :param activation_fn: the activation to be applied. 60 | :return: the output of the residual operation. 61 | """ 62 | f1, f2, f3 = functions 63 | bn1, bn2, bn3 = bns 64 | 65 | assert len(functions) == len(bns) == 3 66 | assert f1 is not None and f2 is not None 67 | assert not (f3 is None and bn3 is not None) 68 | 69 | # A-branch 70 | ha = x 71 | ha = f1(ha) 72 | if bn1 is not None: 73 | ha = bn1(ha) 74 | ha = activation_fn(ha) 75 | 76 | ha = f2(ha) 77 | if bn2 is not None: 78 | ha = bn2(ha) 79 | 80 | # B-branch 81 | hb = x 82 | if f3 is not None: 83 | hb = f3(hb) 84 | if bn3 is not None: 85 | hb = bn3(hb) 86 | 87 | # Residual connection 88 | out = ha + hb 89 | return activation_fn(out) 90 | 91 | 92 | class BaseBlock(BaseModule): 93 | """ Base class for all blocks. """ 94 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=False): 95 | # type: (int, int, Module, bool, bool) -> None 96 | """ 97 | Class constructor. 98 | :param channel_in: number of input channels. 99 | :param channel_out: number of output channels. 100 | :param activation_fn: activation to be employed. 101 | :param use_bn: whether or not to use batch-norm. 102 | :param use_bias: whether or not to use bias. 103 | """ 104 | super(BaseBlock, self).__init__() 105 | 106 | assert not (use_bn and use_bias), 'Using bias=True with batch_normalization is forbidden.' 107 | 108 | self._channel_in = channel_in 109 | self._channel_out = channel_out 110 | self._activation_fn = activation_fn 111 | self._use_bn = use_bn 112 | self._bias = use_bias 113 | 114 | def get_bn(self): 115 | # type: () -> Optional[Module] 116 | """ 117 | Returns batch norm layers, if needed. 118 | :return: batch norm layers or None 119 | """ 120 | return nn.BatchNorm2d(num_features=self._channel_out) if self._use_bn else None 121 | 122 | def forward(self, x): 123 | """ 124 | Abstract forward function. Not implemented. 125 | """ 126 | raise NotImplementedError 127 | 128 | 129 | class DownsampleBlock(BaseBlock): 130 | """ Implements a Downsampling block for images (Fig. 1ii). """ 131 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=False): 132 | # type: (int, int, Module, bool, bool) -> None 133 | """ 134 | Class constructor. 135 | :param channel_in: number of input channels. 136 | :param channel_out: number of output channels. 137 | :param activation_fn: activation to be employed. 138 | :param use_bn: whether or not to use batch-norm. 139 | :param use_bias: whether or not to use bias. 140 | """ 141 | super(DownsampleBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias) 142 | 143 | # Convolutions 144 | self.conv1a = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=3, 145 | padding=1, stride=2, bias=use_bias) 146 | self.conv1b = nn.Conv2d(in_channels=channel_out, out_channels=channel_out, kernel_size=3, 147 | padding=1, stride=1, bias=use_bias) 148 | self.conv2a = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=1, 149 | padding=0, stride=2, bias=use_bias) 150 | 151 | # Batch Normalization layers 152 | self.bn1a = self.get_bn() 153 | self.bn1b = self.get_bn() 154 | self.bn2a = self.get_bn() 155 | 156 | def forward(self, x): 157 | # type: (torch.Tensor) -> torch.Tensor 158 | """ 159 | Forward propagation. 160 | :param x: the input tensor 161 | :return: the output tensor 162 | """ 163 | return residual_op( 164 | x, 165 | functions=[self.conv1a, self.conv1b, self.conv2a], 166 | bns=[self.bn1a, self.bn1b, self.bn2a], 167 | activation_fn=self._activation_fn 168 | ) 169 | 170 | 171 | class UpsampleBlock(BaseBlock): 172 | """ Implements a Upsampling block for images (Fig. 1ii). """ 173 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=False): 174 | # type: (int, int, Module, bool, bool) -> None 175 | """ 176 | Class constructor. 177 | :param channel_in: number of input channels. 178 | :param channel_out: number of output channels. 179 | :param activation_fn: activation to be employed. 180 | :param use_bn: whether or not to use batch-norm. 181 | :param use_bias: whether or not to use bias. 182 | """ 183 | super(UpsampleBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias) 184 | 185 | # Convolutions 186 | self.conv1a = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=5, 187 | padding=2, stride=2, output_padding=1, bias=use_bias) 188 | self.conv1b = nn.Conv2d(in_channels=channel_out, out_channels=channel_out, kernel_size=3, 189 | padding=1, stride=1, bias=use_bias) 190 | self.conv2a = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=1, 191 | padding=0, stride=2, output_padding=1, bias=use_bias) 192 | 193 | # Batch Normalization layers 194 | self.bn1a = self.get_bn() 195 | self.bn1b = self.get_bn() 196 | self.bn2a = self.get_bn() 197 | 198 | def forward(self, x): 199 | # type: (torch.Tensor) -> torch.Tensor 200 | """ 201 | Forward propagation. 202 | :param x: the input tensor 203 | :return: the output tensor 204 | """ 205 | return residual_op( 206 | x, 207 | functions=[self.conv1a, self.conv1b, self.conv2a], 208 | bns=[self.bn1a, self.bn1b, self.bn2a], 209 | activation_fn=self._activation_fn 210 | ) 211 | 212 | 213 | class ResidualBlock(BaseBlock): 214 | """ Implements a Residual block for images (Fig. 1ii). """ 215 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=False): 216 | # type: (int, int, Module, bool, bool) -> None 217 | """ 218 | Class constructor. 219 | :param channel_in: number of input channels. 220 | :param channel_out: number of output channels. 221 | :param activation_fn: activation to be employed. 222 | :param use_bn: whether or not to use batch-norm. 223 | :param use_bias: whether or not to use bias. 224 | """ 225 | super(ResidualBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias) 226 | 227 | # Convolutions 228 | self.conv1 = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=3, 229 | padding=1, stride=1, bias=use_bias) 230 | self.conv2 = nn.Conv2d(in_channels=channel_out, out_channels=channel_out, kernel_size=3, 231 | padding=1, stride=1, bias=use_bias) 232 | 233 | # Batch Normalization layers 234 | self.bn1 = self.get_bn() 235 | self.bn2 = self.get_bn() 236 | 237 | def forward(self, x): 238 | # type: (torch.Tensor) -> torch.Tensor 239 | """ 240 | Forward propagation. 241 | :param x: the input tensor 242 | :return: the output tensor 243 | """ 244 | return residual_op( 245 | x, 246 | functions=[self.conv1, self.conv2, None], 247 | bns=[self.bn1, self.bn2, None], 248 | activation_fn=self._activation_fn 249 | ) -------------------------------------------------------------------------------- /trainers/trainer_shanghaitech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import itertools 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam, SGD 13 | from torch.optim.lr_scheduler import MultiStepLR 14 | from torch.utils.data.dataloader import DataLoader 15 | 16 | from sklearn.metrics import roc_auc_score 17 | 18 | 19 | def pretrain(ae_net, train_loader, out_dir, tb_writer, device, args): 20 | logger = logging.getLogger() 21 | 22 | ae_net = ae_net.train().to(device) 23 | 24 | 25 | # Set optimizer 26 | if args.optimizer == 'adam': 27 | optimizer = Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 28 | else: 29 | optimizer = SGD(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9) 30 | 31 | scheduler = MultiStepLR(optimizer, milestones=args.ae_lr_milestones, gamma=0.1) 32 | 33 | ae_epochs = 1 if args.debug else args.ae_epochs 34 | it_t = 0 35 | logger.info("Start Pretraining the autoencoder...") 36 | for epoch in range(ae_epochs): 37 | 38 | recon_loss = 0.0 39 | n_batches = 0 40 | for idx, (data, _) in enumerate(tqdm(train_loader), 1): 41 | if args.debug and idx == 2: break 42 | 43 | data = data.to(device) 44 | optimizer.zero_grad() 45 | x_r = ae_net(data)[0] 46 | recon_loss_ = torch.mean(torch.sum((x_r - data) ** 2, dim=tuple(range(1, x_r.dim())))) 47 | recon_loss_.backward() 48 | optimizer.step() 49 | 50 | recon_loss += recon_loss_.item() 51 | n_batches += 1 52 | 53 | if idx % (len(train_loader)//args.log_frequency) == 0: 54 | logger.info(f"PreTrain at epoch: {epoch+1} [{idx}]/[{len(train_loader)}] ==> Recon Loss: {recon_loss/idx:.4f}") 55 | tb_writer.add_scalar('pretrain/recon_loss', recon_loss/idx, it_t) 56 | it_t += 1 57 | 58 | scheduler.step() 59 | if epoch in args.ae_lr_milestones: 60 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0])) 61 | 62 | ae_net_checkpoint = os.path.join(out_dir, f'ae_ckp_epoch_{epoch}_{time.time()}.pth') 63 | torch.save({'ae_state_dict': ae_net.state_dict()}, ae_net_checkpoint) 64 | 65 | logger.info('Finished pretraining.') 66 | logger.info(f'Saved autoencoder at: {ae_net_checkpoint}') 67 | 68 | return ae_net_checkpoint 69 | 70 | 71 | def train(net, train_loader, out_dir, tb_writer, device, ae_net_checkpoint, args): 72 | logger = logging.getLogger() 73 | 74 | idx_list_enc = {int(i): 1 for i in args.idx_list_enc} 75 | 76 | # Set device for network 77 | net = net.to(device) 78 | 79 | # Set optimizer 80 | if args.optimizer == 'adam': 81 | optimizer = Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 82 | else: 83 | optimizer = SGD(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9) 84 | 85 | # Set learning rate scheduler 86 | scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) 87 | 88 | # Initialize hypersphere center c 89 | logger.info('Evaluating hypersphere centers...') 90 | c, keys = init_center_c(train_loader, net, idx_list_enc, device, args.end_to_end_training, args.debug) 91 | logger.info(f'Keys: {keys}') 92 | logger.info('Done!') 93 | 94 | R = {k: torch.tensor(0.0, device=device) for k in keys} 95 | 96 | # Training 97 | logger.info('Starting training...') 98 | warm_up_n_epochs = args.warm_up_n_epochs 99 | net.train() 100 | it_t = 0 101 | 102 | best_loss = 1e12 103 | epochs = 1 if args.debug else args.epochs 104 | for epoch in range(epochs): 105 | one_class_loss = 0.0 106 | recon_loss = 0.0 107 | objective_loss = 0.0 108 | n_batches = 0 109 | d_from_c = {k: 0 for k in keys} 110 | epoch_start_time = time.time() 111 | 112 | for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), desc=f"Training epoch: {epoch+1}"), 1): 113 | if args.debug and idx == 2: break 114 | 115 | n_batches += 1 116 | data = data.to(device) 117 | 118 | # Zero the network parameter gradients 119 | optimizer.zero_grad() 120 | 121 | # Update network parameters via backpropagation: forward + backward + optimize 122 | if args.end_to_end_training: 123 | x_r, _, d_lstms = net(data) 124 | recon_loss_ = torch.mean(torch.sum((x_r - data) ** 2, dim=tuple(range(1, x_r.dim())))) 125 | else: 126 | _, d_lstms = net(data) 127 | recon_loss_ = torch.tensor([0.0], device=device) 128 | 129 | dist, one_class_loss_ = eval_ad_loss(d_lstms, c, R, args.nu, args.boundary) 130 | objective_loss_ = one_class_loss_ + recon_loss_ 131 | 132 | for k in keys: 133 | d_from_c[k] += torch.mean(dist[k]).item() 134 | 135 | objective_loss_.backward() 136 | optimizer.step() 137 | 138 | one_class_loss += one_class_loss_.item() 139 | recon_loss += recon_loss_.item() 140 | objective_loss += objective_loss_.item() 141 | 142 | if idx % (len(train_loader)//args.log_frequency) == 0: 143 | logger.info( 144 | f"TRAIN at epoch: {epoch} [{idx}]/[{len(train_loader)}] ==> " 145 | f"\n\t\t\t\tReconstr Loss : {recon_loss/n_batches:.4f}" 146 | f"\n\t\t\t\tOne class Loss: {one_class_loss/n_batches:.4f}" 147 | f"\n\t\t\t\tObjective Loss: {objective_loss/n_batches:.4f}" 148 | ) 149 | tb_writer.add_scalar('train/recon_loss', recon_loss/n_batches, it_t) 150 | tb_writer.add_scalar('train/one_class_loss', one_class_loss/n_batches, it_t) 151 | tb_writer.add_scalar('train/objective_loss', objective_loss/n_batches, it_t) 152 | for k in keys: 153 | logger.info( 154 | f"[{k}] -- Radius: {R[k]:.4f} - " 155 | f"Dist from sphere centr: {d_from_c[k]/n_batches:.4f}" 156 | ) 157 | tb_writer.add_scalar(f'train/radius_{k}', R[k], it_t) 158 | tb_writer.add_scalar(f'train/distance_c_sphere_{k}', d_from_c[k]/n_batches, it_t) 159 | it_t += 1 160 | 161 | # Update hypersphere radius R on mini-batch distances 162 | if (args.boundary == 'soft') and (epoch >= warm_up_n_epochs): 163 | for k in R.keys(): 164 | R[k].data = torch.tensor( 165 | np.quantile(np.sqrt(dist[k].clone().data.cpu().numpy()), 1 - args.nu), 166 | device=device 167 | ) 168 | 169 | scheduler.step() 170 | if epoch in args.lr_milestones: 171 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0])) 172 | 173 | time_ = time.time() if ae_net_checkpoint is None else ae_net_checkpoint.split('_')[-1].split('.p')[0] 174 | net_checkpoint = os.path.join(out_dir, f'net_ckp_{epoch}_{time_}.pth') 175 | torch.save({ 176 | 'net_state_dict': net.state_dict(), 177 | 'R': R, 178 | 'c': c 179 | }, 180 | net_checkpoint 181 | ) 182 | logger.info(f'Saved model at: {net_checkpoint}') 183 | if objective_loss < best_loss or epoch==0: 184 | best_loss = objective_loss 185 | best_model_checkpoint = os.path.join(out_dir, f'net_ckp_best_model_{time_}.pth') 186 | torch.save({ 187 | 'net_state_dict': net.state_dict(), 188 | 'R': R, 189 | 'c': c 190 | }, 191 | best_model_checkpoint 192 | ) 193 | 194 | logger.info('Finished training.') 195 | 196 | return best_model_checkpoint #net_checkpoint 197 | 198 | 199 | @torch.no_grad() 200 | def init_center_c(train_loader, net, idx_list_enc, device, end_to_end_training, debug, eps=0.1): 201 | """Initialize hypersphere center c as the mean from an initial forward pass on the data.""" 202 | n_samples = 0 203 | net.eval() 204 | 205 | data, _ = iter(train_loader).next() 206 | d_lstms = net(data.to(device))[-1] 207 | 208 | keys = [] 209 | c = {} 210 | for en, k in enumerate(list(d_lstms.keys())): 211 | if en in idx_list_enc: 212 | keys.append(k) 213 | c[k] = torch.zeros_like(d_lstms[k][-1], device=device) 214 | 215 | for idx, (data, _) in enumerate(tqdm(train_loader, desc='init hyperspheres centeres', total=len(train_loader), leave=False)): 216 | if debug and idx == 2: break 217 | # get the inputs of the batch 218 | n_samples += data.shape[0] 219 | d_lstms = net(data.to(device))[-1] 220 | for k in keys: 221 | c[k] += torch.sum(d_lstms[k], dim=0) 222 | 223 | for k in keys: 224 | c[k] = c[k] / n_samples 225 | # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights. 226 | c[k][(abs(c[k]) < eps) & (c[k] < 0)] = -eps 227 | c[k][(abs(c[k]) < eps) & (c[k] > 0)] = eps 228 | 229 | return c, keys 230 | 231 | 232 | def eval_ad_loss(d_lstms, c, R, nu, boundary): 233 | dist = {} 234 | loss = 1 235 | 236 | for k in c.keys(): 237 | dist[k] = torch.sum((d_lstms[k] - c[k].unsqueeze(0)) ** 2, dim=-1) 238 | 239 | if boundary == 'soft': 240 | scores = dist[k] - R[k] ** 2 241 | loss += R[k] ** 2 + (1 / nu) * torch.mean(torch.max(torch.zeros_like(scores), scores)) 242 | else: 243 | loss += torch.mean(dist[k]) 244 | 245 | return dist, loss 246 | -------------------------------------------------------------------------------- /datasets/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import random 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | from os.path import join 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import Dataset, DataLoader, TensorDataset 13 | 14 | import torchvision.transforms as T 15 | from torchvision.datasets import ImageFolder 16 | 17 | 18 | class MVtecDataset(ImageFolder): 19 | """Torchvision ImageFolder class with patch of __getitem__ method to targets according to the task. 20 | 21 | """ 22 | def __init__(self, root: str, transform): 23 | super(MVtecDataset, self).__init__(root=root, transform=transform) 24 | 25 | # Index of the class that corresponds to the folder named 'good' 26 | self.normal_class_idx = self.class_to_idx['good'] 27 | 28 | def __getitem__(self, index: int): 29 | data, target = self.samples[index] 30 | def read_image(path): 31 | """Returns the image in RGB 32 | 33 | """ 34 | with open(path, 'rb') as f: 35 | img = Image.open(f) 36 | return img.convert('RGB') 37 | 38 | # Convert the target to the 0/1 case 39 | target = 0 if target == self.normal_class_idx else 1 40 | data = self.transform(read_image(data)) 41 | 42 | return data, target 43 | 44 | 45 | class CustomTensorDataset(TensorDataset): 46 | """Custom dataset for preprocessed images. 47 | 48 | """ 49 | def __init__(self, root: str): 50 | """Init the dataset. 51 | 52 | Parameters 53 | ---------- 54 | root : str 55 | Path to data file 56 | 57 | """ 58 | # Load data 59 | self.data = torch.from_numpy(np.load(root)) 60 | 61 | # Load TensorDataset 62 | super(CustomTensorDataset, self).__init__(self.data) 63 | 64 | def __len__(self): 65 | return self.data.shape[0] 66 | 67 | def __getitem__(self, index): 68 | return self.data[index], 0 69 | 70 | 71 | class MVTec_DataHolder(object): 72 | """MVTec data holder class 73 | 74 | """ 75 | def __init__(self, data_path: str, category: str, image_size: int, patch_size: int, rotation_range: tuple, is_texture: bool): 76 | """Init MVTec data holder class 77 | 78 | Parameters 79 | ---------- 80 | category : str 81 | Normal class 82 | image_size : int 83 | Side size of the input images 84 | patch_size : int 85 | Side size of the patches (for textures only) 86 | rotation_range : tuple 87 | Min and max angle to rotate images 88 | is_texture : bool 89 | True if the category is texture-type class 90 | 91 | """ 92 | self.data_path = data_path 93 | self.category = category 94 | self.image_size = image_size 95 | self.patch_size = patch_size 96 | self.rotation_range = rotation_range 97 | self.is_texture = is_texture 98 | 99 | def get_test_data(self) -> Dataset: 100 | """Load test dataset 101 | 102 | Returns 103 | ------- 104 | MVtecDataset : Dataset 105 | Custom dataset to handle MVTec data 106 | 107 | """ 108 | return MVtecDataset( 109 | root=join(self.data_path, f'{self.category}/test'), 110 | transform=T.Compose([ 111 | T.Resize(self.image_size, interpolation=Image.BILINEAR), 112 | T.ToTensor(), 113 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 114 | ]) 115 | ) 116 | 117 | def get_train_data(self, return_dataset: bool=True): 118 | """Load train dataset 119 | 120 | Parameters 121 | ---------- 122 | return_dataset : bool 123 | False for preprocessing purpose only 124 | 125 | """ 126 | train_data_dir = join(self.data_path, f'{self.category}/train/') 127 | 128 | # Preprocessed output data path 129 | cache_main_dir = join(self.data_path, f'processed/{self.category}') 130 | os.makedirs(cache_main_dir, exist_ok=True) 131 | cache_file = f'{cache_main_dir}/{self.category}_train_dataset_i-{self.image_size}_p-{self.patch_size}_r-{self.rotation_range[0]}--{self.rotation_range[1]}.npy' 132 | 133 | # Check if preprocessed file already exists 134 | if not os.path.exists(cache_file): 135 | 136 | # Apply random rotation 137 | def augmentation(): 138 | """Returns transforms to apply to the data 139 | 140 | """ 141 | # For textures rotate and crop without edges 142 | if self.is_texture: 143 | return T.Compose([ 144 | T.Resize(self.image_size, interpolation=Image.BILINEAR), 145 | T.Pad(padding=self.image_size//4, padding_mode="reflect"), 146 | T.RandomRotation((self.rotation_range[0], self.rotation_range[1])), 147 | T.CenterCrop(self.image_size), 148 | T.RandomCrop(self.patch_size), 149 | T.ToTensor(), 150 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 151 | ]) 152 | else: 153 | return T.Compose([ 154 | T.Resize(self.image_size, interpolation=Image.BILINEAR), 155 | T.Pad(padding=self.image_size//4, padding_mode="reflect"), 156 | T.RandomRotation((self.rotation_range[0], self.rotation_range[1])), 157 | T.CenterCrop(self.image_size), 158 | T.ToTensor(), 159 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 160 | ]) 161 | 162 | # Load data and apply transformations 163 | train_dataset = ImageFolder(root=train_data_dir, transform=augmentation()) 164 | print(f"Creating cache for dataset: \n{cache_file}") 165 | # To simulate a larger datasegt, replicate images with transformations 166 | nb_epochs = 50000 // len(train_dataset.imgs) 167 | data_loader = DataLoader(dataset=train_dataset, batch_size=1024, pin_memory=True) 168 | 169 | for epoch in tqdm(range(nb_epochs), total=nb_epochs, desc=f"Creating cache for: {self.category}"): 170 | if epoch == 0: 171 | cache_np = [x.numpy() for x, _ in tqdm(data_loader, total=len(data_loader), desc=f'Caching epoch: {epoch+1}/{nb_epochs+1}', leave=False)] 172 | else: 173 | cache_np.extend([x.numpy() for x, _ in tqdm(data_loader, total=len(data_loader), desc=f'Caching epoch: {epoch+1}/{nb_epochs+1}', leave=False)]) 174 | 175 | cache_np = np.vstack(cache_np) 176 | np.save(cache_file, cache_np) 177 | print(f"Preprocessed images has been saved at: \n{cache_file}") 178 | 179 | if return_dataset: 180 | print(f"Loading dataset from cache: \n{cache_file}") 181 | return CustomTensorDataset(cache_file) 182 | else: 183 | return 184 | 185 | def get_loaders(self, batch_size: int, shuffle_train: bool=True, pin_memory: bool=False, num_workers: int = 0) -> [DataLoader, DataLoader]: 186 | """Returns MVtec dataloaders 187 | 188 | Parameters 189 | ---------- 190 | batch_size : int 191 | Size of the batch to 192 | shuffle_train : bool 193 | If True, shuffles the training dataset 194 | pin_memory : bool 195 | If True, pin memeory 196 | num_workers : int 197 | Number of dataloader workers 198 | 199 | Returns 200 | ------- 201 | loaders : DataLoader 202 | Train and test data loaders 203 | 204 | """ 205 | train_loader = DataLoader( 206 | dataset=self.get_train_data(return_dataset=True), 207 | batch_size=batch_size, 208 | shuffle=shuffle_train, 209 | pin_memory=pin_memory, 210 | num_workers=num_workers 211 | ) 212 | test_loader = DataLoader( 213 | dataset=self.get_test_data(), 214 | batch_size=batch_size, 215 | pin_memory=pin_memory, 216 | num_workers=num_workers 217 | ) 218 | return train_loader, test_loader 219 | 220 | 221 | if __name__ == '__main__': 222 | """To speed up the train phase we can preprocess the training images and save them as numpy array. 223 | 224 | """ 225 | textures = tuple(['carpet', 'grid', 'leather', 'tile', 'wood']) 226 | objects_1 = tuple(['bottle', 'hazelnut', 'metal_nut', 'screw']) 227 | objects_2 = tuple(['capsule', 'toothbrush', 'cable', 'pill', 'transistor', 'zipper']) 228 | 229 | classes = list(textures) 230 | classes.extends(list(objects_1)) 231 | classes.extends(list(objects_2)) 232 | 233 | for category in classes: 234 | if category in textures: 235 | args = dict( 236 | category=category, 237 | image_size=512, 238 | patch_size=64, 239 | rotation_range=(0, 45), 240 | texture=True 241 | ) 242 | elif category in objects_1: 243 | args = dict( 244 | category=category, 245 | image_size=128, 246 | patch_size=-1, 247 | rotation_range=(-45, 45), 248 | texture=True 249 | ) 250 | else: 251 | args = dict( 252 | category=category, 253 | image_size=128, 254 | patch_size=-1, 255 | rotation_range=(0, 0), 256 | texture=False 257 | ) 258 | 259 | MVTec_DataHolder(*args).get_train_data(return_dataset=False) 260 | -------------------------------------------------------------------------------- /models/shanghaitech_base_model.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from operator import mul 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class BaseModule(nn.Module): 8 | """ 9 | Implements the basic module. 10 | All other modules inherit from this one 11 | """ 12 | def load_w(self, checkpoint_path): 13 | # type: (str) -> None 14 | """ 15 | Loads a checkpoint into the state_dict. 16 | :param checkpoint_path: the checkpoint file to be loaded. 17 | """ 18 | self.load_state_dict(torch.load(checkpoint_path)) 19 | 20 | def __repr__(self): 21 | # type: () -> str 22 | """ 23 | String representation 24 | """ 25 | good_old = super(BaseModule, self).__repr__() 26 | addition = 'Total number of parameters: {:,}'.format(self.n_parameters) 27 | 28 | # return good_old + '\n' + addition 29 | return good_old 30 | 31 | def __call__(self, *args, **kwargs): 32 | return super(BaseModule, self).__call__(*args, **kwargs) 33 | 34 | @property 35 | def n_parameters(self): 36 | # type: () -> int 37 | """ 38 | Number of parameters of the model. 39 | """ 40 | n_parameters = 0 41 | for p in self.parameters(): 42 | if hasattr(p, 'mask'): 43 | n_parameters += torch.sum(p.mask).item() 44 | else: 45 | n_parameters += reduce(mul, p.shape) 46 | return int(n_parameters) 47 | 48 | class MaskedConv3d(BaseModule, nn.Conv3d): 49 | """ 50 | Implements a Masked Convolution 3D. 51 | This is a 3D Convolution that cannot access future frames. 52 | """ 53 | def __init__(self, *args, **kwargs): 54 | super(MaskedConv3d, self).__init__(*args, **kwargs) 55 | 56 | self.register_buffer('mask', self.weight.data.clone()) 57 | _, _, kT, kH, kW = self.weight.size() 58 | self.mask.fill_(1) 59 | self.mask[:, :, kT // 2 + 1:] = 0 60 | 61 | def forward(self, x): 62 | # type: (torch.Tensor) -> torch.Tensor 63 | """ 64 | Performs the forward pass. 65 | :param x: the input tensor. 66 | :return: the output tensor as result of the convolution. 67 | """ 68 | self.weight.data *= self.mask 69 | return super(MaskedConv3d, self).forward(x) 70 | 71 | class TemporallySharedFullyConnection(BaseModule): 72 | """ 73 | Implements a temporally-shared fully connection. 74 | Processes a time series of feature vectors and performs 75 | the same linear projection to all of them. 76 | """ 77 | def __init__(self, in_features, out_features, bias=True): 78 | # type: (int, int, bool) -> None 79 | """ 80 | Class constructor. 81 | :param in_features: number of input features. 82 | :param out_features: number of output features. 83 | :param bias: whether or not to add bias. 84 | """ 85 | super(TemporallySharedFullyConnection, self).__init__() 86 | 87 | self.in_features = in_features 88 | self.out_features = out_features 89 | self.bias = bias 90 | 91 | # the layer to be applied at each timestep 92 | self.linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias) 93 | 94 | def forward(self, x): 95 | # type: (torch.Tensor) -> torch.Tensor 96 | """ 97 | Forward function. 98 | :param x: layer input. Has shape=(batchsize, seq_len, in_features). 99 | :return: layer output. Has shape=(batchsize, seq_len, out_features) 100 | """ 101 | b, t, d = x.size() 102 | 103 | output = [] 104 | for i in range(0, t): 105 | # apply dense layer 106 | output.append(self.linear(x[:, i, :])) 107 | output = torch.stack(output, 1) 108 | 109 | return output 110 | 111 | def residual_op(x, functions, bns, activation_fn): 112 | # type: (torch.Tensor, List[Module, Module, Module], List[Module, Module, Module], Module) -> torch.Tensor 113 | """ 114 | Implements a global residual operation. 115 | :param x: the input tensor. 116 | :param functions: a list of functions (nn.Modules). 117 | :param bns: a list of optional batch-norm layers. 118 | :param activation_fn: the activation to be applied. 119 | :return: the output of the residual operation. 120 | """ 121 | f1, f2, f3 = functions 122 | bn1, bn2, bn3 = bns 123 | 124 | assert len(functions) == len(bns) == 3 125 | assert f1 is not None and f2 is not None 126 | assert not (f3 is None and bn3 is not None) 127 | 128 | # A-branch 129 | ha = x 130 | ha = f1(ha) 131 | if bn1 is not None: 132 | ha = bn1(ha) 133 | ha = activation_fn(ha) 134 | 135 | ha = f2(ha) 136 | if bn2 is not None: 137 | ha = bn2(ha) 138 | 139 | # B-branch 140 | hb = x 141 | if f3 is not None: 142 | hb = f3(hb) 143 | if bn3 is not None: 144 | hb = bn3(hb) 145 | 146 | # Residual connection 147 | out = ha + hb 148 | return activation_fn(out) 149 | 150 | 151 | class BaseBlock(BaseModule): 152 | """ Base class for all blocks. """ 153 | def __init__(self, channel_in, channel_out, activation_fn, use_bn=True, use_bias=True): 154 | # type: (int, int, Module, bool, bool) -> None 155 | """ 156 | Class constructor. 157 | :param channel_in: number of input channels. 158 | :param channel_out: number of output channels. 159 | :param activation_fn: activation to be employed. 160 | :param use_bn: whether or not to use batch-norm. 161 | :param use_bias: whether or not to use bias. 162 | """ 163 | super(BaseBlock, self).__init__() 164 | 165 | assert not (use_bn and use_bias), 'Using bias=True with batch_normalization is forbidden.' 166 | 167 | self._channel_in = channel_in 168 | self._channel_out = channel_out 169 | self._activation_fn = activation_fn 170 | self._use_bn = use_bn 171 | self._bias = use_bias 172 | 173 | def get_bn(self): 174 | # type: () -> Optional[Module] 175 | """ 176 | Returns batch norm layers, if needed. 177 | :return: batch norm layers or None 178 | """ 179 | return nn.BatchNorm3d(num_features=self._channel_out) if self._use_bn else None 180 | 181 | def forward(self, x): 182 | """ 183 | Abstract forward function. Not implemented. 184 | """ 185 | raise NotImplementedError 186 | 187 | 188 | class DownsampleBlock(BaseBlock): 189 | """ Implements a Downsampling block for videos (Fig. 1ii). """ 190 | def __init__(self, channel_in, channel_out, activation_fn, stride, use_bn=True, use_bias=False): 191 | # type: (int, int, Module, Tuple[int, int, int], bool, bool) -> None 192 | """ 193 | Class constructor. 194 | :param channel_in: number of input channels. 195 | :param channel_out: number of output channels. 196 | :param activation_fn: activation to be employed. 197 | :param stride: the stride to be applied to downsample feature maps. 198 | :param use_bn: whether or not to use batch-norm. 199 | :param use_bias: whether or not to use bias. 200 | """ 201 | super(DownsampleBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias) 202 | self.stride = stride 203 | 204 | # Convolutions 205 | self.conv1a = MaskedConv3d(in_channels=channel_in, out_channels=channel_out, kernel_size=3, 206 | padding=1, stride=stride, bias=use_bias) 207 | self.conv1b = MaskedConv3d(in_channels=channel_out, out_channels=channel_out, kernel_size=3, 208 | padding=1, stride=1, bias=use_bias) 209 | self.conv2a = nn.Conv3d(in_channels=channel_in, out_channels=channel_out, kernel_size=1, 210 | padding=0, stride=stride, bias=use_bias) 211 | 212 | # Batch Normalization layers 213 | self.bn1a = self.get_bn() 214 | self.bn1b = self.get_bn() 215 | self.bn2a = self.get_bn() 216 | 217 | def forward(self, x): 218 | # type: (torch.Tensor) -> torch.Tensor 219 | """ 220 | Forward propagation. 221 | :param x: the input tensor 222 | :return: the output tensor 223 | """ 224 | return residual_op( 225 | x, 226 | functions=[self.conv1a, self.conv1b, self.conv2a], 227 | bns=[self.bn1a, self.bn1b, self.bn2a], 228 | activation_fn=self._activation_fn 229 | ) 230 | 231 | 232 | class UpsampleBlock(BaseBlock): 233 | """ Implements a Upsampling block for videos (Fig. 1ii). """ 234 | def __init__(self, channel_in, channel_out, activation_fn, stride, output_padding, use_bn=True, use_bias=False): 235 | # type: (int, int, Module, Tuple[int, int, int], Tuple[int, int, int], bool, bool) -> None 236 | """ 237 | Class constructor. 238 | :param channel_in: number of input channels. 239 | :param channel_out: number of output channels. 240 | :param activation_fn: activation to be employed. 241 | :param stride: the stride to be applied to upsample feature maps. 242 | :param output_padding: the padding to be added applied output feature maps. 243 | :param use_bn: whether or not to use batch-norm. 244 | :param use_bias: whether or not to use bias. 245 | """ 246 | super(UpsampleBlock, self).__init__(channel_in, channel_out, activation_fn, use_bn, use_bias) 247 | self.stride = stride 248 | self.output_padding = output_padding 249 | 250 | # Convolutions 251 | self.conv1a = nn.ConvTranspose3d(channel_in, channel_out, kernel_size=5, 252 | padding=2, stride=stride, output_padding=output_padding, bias=use_bias) 253 | self.conv1b = nn.Conv3d(in_channels=channel_out, out_channels=channel_out, kernel_size=3, 254 | padding=1, stride=1, bias=use_bias) 255 | self.conv2a = nn.ConvTranspose3d(channel_in, channel_out, kernel_size=5, 256 | padding=2, stride=stride, output_padding=output_padding, bias=use_bias) 257 | 258 | # Batch Normalization layers 259 | self.bn1a = self.get_bn() 260 | self.bn1b = self.get_bn() 261 | self.bn2a = self.get_bn() 262 | 263 | def forward(self, x): 264 | # type: (torch.Tensor) -> torch.Tensor 265 | """ 266 | Forward propagation. 267 | :param x: the input tensor 268 | :return: the output tensor 269 | """ 270 | return residual_op( 271 | x, 272 | functions=[self.conv1a, self.conv1b, self.conv2a], 273 | bns=[self.bn1a, self.bn1b, self.bn2a], 274 | activation_fn=self._activation_fn 275 | ) 276 | 277 | -------------------------------------------------------------------------------- /main_cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import logging 5 | import argparse 6 | import numpy as np 7 | 8 | import torch 9 | 10 | from tensorboardX import SummaryWriter 11 | 12 | from datasets.data_manager import DataManager 13 | from trainers.train_cifar10 import pretrain, train, test 14 | from utils import set_seeds, get_out_dir, purge_ae_params 15 | from models.cifar10_model import CIFAR10_Autoencoder, CIFAR10_Encoder 16 | 17 | 18 | def main(args): 19 | 20 | # If the layer list is not specified, them use only the last layer to detect anomalies 21 | if len(args.idx_list_enc) == 0 and args.train: 22 | args.idx_list_enc = [3] 23 | 24 | ## Init logger & print training/warm-up summary 25 | logging.basicConfig( 26 | level=logging.INFO, 27 | format="%(asctime)s | %(message)s", 28 | handlers=[ 29 | logging.FileHandler('./training.log'), 30 | logging.StreamHandler() 31 | ]) 32 | 33 | logger = logging.getLogger() 34 | 35 | if args.train or args.pretrain: 36 | logger.info( 37 | "Start run with params:\n" 38 | f"\n\t\t\t\tPretrain model : {args.pretrain}" 39 | f"\n\t\t\t\tTrain model : {args.train}" 40 | f"\n\t\t\t\tTest model : {args.test}" 41 | f"\n\t\t\t\tBoundary : {args.boundary}" 42 | f"\n\t\t\t\tNormal class : {args.normal_class}" 43 | f"\n\t\t\t\tBatch size : {args.batch_size}\n" 44 | f"\n\t\t\t\tPretrain epochs : {args.ae_epochs}" 45 | f"\n\t\t\t\tAE-Learning rate : {args.ae_learning_rate}" 46 | f"\n\t\t\t\tAE-milestones : {args.ae_lr_milestones}" 47 | f"\n\t\t\t\tAE-Weight decay : {args.ae_weight_decay}\n" 48 | f"\n\t\t\t\tTrain epochs : {args.epochs}" 49 | f"\n\t\t\t\tLearning rate : {args.learning_rate}" 50 | f"\n\t\t\t\tMilestones : {args.lr_milestones}" 51 | f"\n\t\t\t\tWeight decay : {args.weight_decay}\n" 52 | f"\n\t\t\t\tCode length : {args.code_length}" 53 | f"\n\t\t\t\tNu : {args.nu}" 54 | f"\n\t\t\t\tEncoder list : {args.idx_list_enc}\n" 55 | ) 56 | else: 57 | if args.model_ckp is None: 58 | logger.info("CANNOT TEST MODEL WITHOUT A VALID CHECKPOINT") 59 | sys.exit(0) 60 | 61 | args.normal_class = int(args.model_ckp.split('/')[-2].split('-')[2].split('_')[-1]) 62 | 63 | # Set seed 64 | set_seeds(args.seed) 65 | 66 | # Get the device 67 | device = "cuda" if torch.cuda.is_available() else "cpu" 68 | 69 | # Init DataHolder class 70 | data_holder = DataManager( 71 | dataset_name='cifar10', 72 | data_path=args.data_path, 73 | normal_class=args.normal_class, 74 | only_test=args.test 75 | ).get_data_holder() 76 | 77 | # Load data 78 | train_loader, test_loader = data_holder.get_loaders( 79 | batch_size=args.batch_size, 80 | shuffle_train=True, 81 | pin_memory=device=="cuda", 82 | num_workers=args.n_workers 83 | ) 84 | 85 | ### PRETRAIN the full AutoEncoder 86 | ae_net_cehckpoint = None 87 | if args.pretrain: 88 | out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='cifar10') 89 | tb_writer = SummaryWriter(os.path.join(args.output_path, 'cifar10', str(args.normal_class), 'tb_runs/pretrain', tmp)) 90 | 91 | # Init AutoEncoder 92 | ae_net = CIFAR10_Autoencoder(args.code_length) 93 | 94 | # Start pretraining 95 | logging.info('Start training the full AutoEcnoder') 96 | ae_net_cehckpoint = pretrain( 97 | ae_net=ae_net, 98 | train_loader=train_loader, 99 | out_dir=out_dir, 100 | tb_writer=tb_writer, 101 | device=device, 102 | ae_learning_rate=args.ae_learning_rate, 103 | ae_weight_decay=args.ae_weight_decay, 104 | ae_lr_milestones=args.ae_lr_milestones, 105 | ae_epochs=args.ae_epochs 106 | ) 107 | logging.info('AutoEncoder trained!!!') 108 | 109 | tb_writer.close() 110 | 111 | ### TRAIN the Encoder 112 | net_cehckpoint = None 113 | if args.train: 114 | if ae_net_cehckpoint is None: 115 | if args.model_ckp is None: 116 | logger.info("CANNOT TRAIN MODEL WITHOUT A VALID CHECKPOINT") 117 | sys.exit(0) 118 | ae_net_cehckpoint = args.model_ckp 119 | aelr = float(ae_net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1]) 120 | out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr) 121 | 122 | tb_writer = SummaryWriter(os.path.join(args.output_path, 'cifar10', str(args.normal_class), 'tb_runs/train', tmp)) 123 | 124 | # Init Encoder 125 | encoder_net = CIFAR10_Encoder(args.code_length) 126 | 127 | # Load Encoder parameters from pretrianed full AutoEncoder 128 | purge_ae_params(encoder_net=encoder_net, ae_net_cehckpoint=ae_net_cehckpoint) 129 | 130 | # Start training 131 | net_cehckpoint = train( 132 | net=encoder_net, 133 | train_loader=train_loader, 134 | out_dir=out_dir, 135 | tb_writer=tb_writer, 136 | device=device, 137 | ae_net_cehckpoint=ae_net_cehckpoint, 138 | idx_list_enc=args.idx_list_enc, 139 | learning_rate=args.learning_rate, 140 | weight_decay=args.weight_decay, 141 | lr_milestones=args.lr_milestones, 142 | epochs=args.epochs, 143 | nu=args.nu, 144 | boundary=args.boundary, 145 | debug=args.debug 146 | ) 147 | 148 | tb_writer.close() 149 | 150 | ### TEST the Encoder 151 | if args.test: 152 | if net_cehckpoint is None: 153 | net_cehckpoint = args.model_ckp 154 | # Init Encoder 155 | net = CIFAR10_Encoder(args.code_length) 156 | st_dict = torch.load(net_cehckpoint) 157 | net.load_state_dict(st_dict['net_state_dict']) 158 | 159 | logger.info(f"Loaded model from: {net_cehckpoint}") 160 | 161 | if args.debug: 162 | idx_list_enc = args.idx_list_enc 163 | boundary = args.boundary 164 | else: 165 | idx_list_enc = [int(i) for i in net_cehckpoint.split('/')[-2].split('-')[-1].split('_')[-1].split('.')] 166 | boundary = net_cehckpoint.split('/')[-2].split('-')[-3].split('_')[-1] 167 | 168 | logger.info( 169 | f"Start test with params" 170 | f"\n\t\t\t\tCode length : {args.code_length}" 171 | f"\n\t\t\t\tEnc layer list : {idx_list_enc}" 172 | f"\n\t\t\t\tBoundary : {boundary}" 173 | f"\n\t\t\t\tNormal class : {args.normal_class}" 174 | ) 175 | 176 | # Start test 177 | test(net=net, test_loader=test_loader, R=st_dict['R'], c=st_dict['c'], device=device, idx_list_enc=idx_list_enc, boundary=boundary) 178 | 179 | 180 | if __name__ == '__main__': 181 | parser = argparse.ArgumentParser('AD') 182 | ## General config 183 | parser.add_argument('-s', '--seed', type=int, default=-1, help='Random seed (default: -1)') 184 | parser.add_argument('--n_workers', type=int, default=8, help='Number of workers for data loading. 0 means that the data will be loaded in the main process. (default: 8)') 185 | parser.add_argument('--output_path', default='./output') 186 | ## Model config 187 | parser.add_argument('-zl', '--code-length', default=32, type=int, help='Code length (default: 32)') 188 | parser.add_argument('-ck', '--model-ckp', help='Model checkpoint') 189 | ## Optimizer config 190 | parser.add_argument('-alr', '--ae-learning-rate', type=float, default=1.e-4, help='Warm up learning rate (default: 1.e-4)') 191 | parser.add_argument('-lr', '--learning-rate', type=float, default=1.e-4, help='Learning rate (default: 1.e-4)') 192 | parser.add_argument('-awd', '--ae-weight-decay', type=float, default=0.5e-6, help='Warm up learning rate (default: 0.5e-4)') 193 | parser.add_argument('-wd', '--weight-decay', type=float, default=0.5e-6, help='Learning rate (default: 0.5e-6)') 194 | parser.add_argument('-aml', '--ae-lr-milestones', type=int, nargs='+', default=[], help='Pretrain milestone') 195 | parser.add_argument('-ml', '--lr-milestones', type=int, nargs='+', default=[], help='Training milestone') 196 | ## Data 197 | parser.add_argument('-dp', '--data-path', default='./cifar10', help='Dataset main path') 198 | parser.add_argument('-nc', '--normal-class', type=int, default=5, help='Normal Class (default: 5)') 199 | ## Training config 200 | parser.add_argument('-we', '--warm_up_n_epochs', type=int, default=10, help='Warm up epochs (default: 10)') 201 | parser.add_argument('--use-selectors', action="store_true", help='Use features selector (default: False)') 202 | parser.add_argument('-tbc', '--train-best-conf', action="store_true", help='Train best configurations (default: False)') 203 | parser.add_argument('-db', '--debug', action="store_true", help='Debug (default: False)') 204 | parser.add_argument('-bs', '--batch-size', type=int, default=256, help='Batch size (default: 256)') 205 | parser.add_argument('-bd', '--boundary', choices=("hard", "soft"), default="soft", help='Boundary (default: soft)') 206 | parser.add_argument('-ptr', '--pretrain', action="store_true", help='Pretrain model (default: False)') 207 | parser.add_argument('-tr', '--train', action="store_true", help='Train model (default: False)') 208 | parser.add_argument('-tt', '--test', action="store_true", help='Test model (default: False)') 209 | parser.add_argument('-ile', '--idx-list-enc', type=int, nargs='+', default=[], help='List of indices of model encoder') 210 | parser.add_argument('-e', '--epochs', type=int, default=1, help='Training epochs (default: 1)') 211 | parser.add_argument('-ae', '--ae-epochs', type=int, default=1, help='Warmp up epochs (default: 1)') 212 | parser.add_argument('-nu', '--nu', type=float, default=0.1) 213 | args = parser.parse_args() 214 | 215 | main(args) 216 | -------------------------------------------------------------------------------- /models/mvtec_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from operator import mul 3 | from typing import Tuple 4 | from functools import reduce 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .mvtec_base_model import BaseModule, DownsampleBlock, ResidualBlock, UpsampleBlock 11 | 12 | 13 | CHANNELS = [32, 64, 128] 14 | 15 | 16 | def init_conv_blocks(channel_in: int, channel_out: int, activation_fn: nn) -> nn.Module: 17 | """ Init convolutional layers. 18 | 19 | Parameters 20 | ---------- 21 | k_size : int 22 | Kernel size 23 | out_channels : int 24 | Output features size 25 | 26 | """ 27 | return DownsampleBlock(channel_in=channel_in, channel_out=channel_out, activation_fn=activation_fn) 28 | 29 | 30 | class Selector(nn.Module): 31 | """Selector module 32 | 33 | """ 34 | def __init__(self, code_length: int, idx: int): 35 | super().__init__() 36 | """Init Selector architeture 37 | 38 | Parameters 39 | ---------- 40 | code_length : int 41 | Latent code size 42 | idx : int 43 | Layer idx 44 | 45 | """ 46 | # List of depths of features maps 47 | sizes = [CHANNELS[0], CHANNELS[0], CHANNELS[1], CHANNELS[2], CHANNELS[2]*2, CHANNELS[2]*2, code_length] 48 | 49 | # Hidden FC output size 50 | mid_features_size = 256 51 | 52 | # Last FC output size 53 | out_features = 128 54 | 55 | # Choose a different Selector architecture 56 | # depending on which layer it attaches 57 | if idx < 5: 58 | self.fc = nn.Sequential( 59 | nn.AdaptiveMaxPool2d(output_size=8), 60 | nn.Conv2d(in_channels=sizes[idx], out_channels=1, kernel_size=1), 61 | nn.Flatten(), 62 | nn.Linear(in_features=8**2, out_features=mid_features_size, bias=True), 63 | nn.BatchNorm1d(mid_features_size), 64 | nn.ReLU(), 65 | nn.Linear(in_features=mid_features_size, out_features=out_features, bias=True) 66 | ) 67 | else: 68 | self.fc = nn.Sequential( 69 | nn.Flatten(), 70 | nn.Linear(in_features=sizes[idx], out_features=mid_features_size, bias=True), 71 | nn.BatchNorm1d(mid_features_size), 72 | nn.ReLU(), 73 | nn.Linear(in_features=mid_features_size, out_features=out_features, bias=True) 74 | ) 75 | 76 | def forward(self, x: torch.Tensor) -> torch.Tensor: 77 | return self.fc(x) 78 | 79 | 80 | class MVTec_Encoder(BaseModule): 81 | """MVtec Encoder network 82 | 83 | """ 84 | def __init__(self, input_shape: torch.Tensor, code_length: int, idx_list_enc: list, use_selectors: bool): 85 | """Init Encoder network 86 | 87 | Parameters 88 | ---------- 89 | input_shape : torch.Tensor 90 | Input data shape 91 | code_length : int 92 | Latent code size 93 | idx_list_enc : list 94 | List of layers' idx to use for the AD task 95 | use_selectors : bool 96 | True (False) if the model has (not) to use Selectors modules 97 | 98 | """ 99 | super().__init__() 100 | 101 | self.idx_list_enc = idx_list_enc 102 | self.use_selectors = use_selectors 103 | 104 | # Single input data shape 105 | c, h, w = input_shape 106 | 107 | # Activation function 108 | self.activation_fn = nn.LeakyReLU() 109 | 110 | # Init convolutional blocks 111 | self.conv = nn.Conv2d(in_channels=c, out_channels=32, kernel_size=3, bias=False) 112 | self.res = ResidualBlock(channel_in=32, channel_out=32, activation_fn=self.activation_fn) 113 | self.dwn1, self.dwn2, self.dwn3 = [init_conv_blocks(channel_in=ch, channel_out=ch*2, activation_fn=self.activation_fn) for ch in CHANNELS] 114 | 115 | # Depth of the last features map 116 | self.last_depth = CHANNELS[2]*2 117 | 118 | # Shape of the last features map 119 | self.deepest_shape = (self.last_depth, h // 8, w // 8) 120 | 121 | # init FC layers 122 | self.fc1 = nn.Linear(in_features=reduce(mul, self.deepest_shape), out_features=self.last_depth) 123 | self.bn = nn.BatchNorm1d(num_features=self.last_depth) 124 | self.fc2 = nn.Linear(in_features=self.last_depth, out_features=code_length) 125 | 126 | ## Init features selector models 127 | if self.use_selectors: 128 | self.selectors = nn.ModuleList([Selector(code_length=code_length, idx=idx) for idx in range(7)]) 129 | self.selectors.append(Selector(code_length=code_length, idx=6)) 130 | 131 | def get_depths_info(self) -> [int, int]: 132 | """ 133 | Returns 134 | ------ 135 | self.last_depth : int 136 | Depth of the last features map 137 | self.deepest_shape : int 138 | Shape of the last features map 139 | 140 | """ 141 | return self.last_depth, self.deepest_shape 142 | 143 | def set_idx_list_enc(self, idx_list_enc: list) -> None: 144 | """Set the list of layers from wchich extract the features. 145 | It is used to initialize the hyperspheres centers so that 146 | independently from which layers we are considering, the first 147 | time that we create the centroids, we do it for all the layers. 148 | 149 | Parameters 150 | ---------- 151 | idx_list_enc : list 152 | List of layers indices 153 | 154 | """ 155 | self.idx_list_enc = idx_list_enc 156 | 157 | def forward(self, x: torch.Tensor) -> torch.Tensor: 158 | o1 = self.conv(x) 159 | o2 = self.res(self.activation_fn(o1)) 160 | o3 = self.dwn1(o2) 161 | o4 = self.dwn2(o3) 162 | o5 = self.dwn3(o4) 163 | o7 = self.activation_fn( 164 | self.bn( 165 | self.fc1( 166 | o5.view(len(o5), -1) 167 | ) 168 | ) 169 | ) # FC -> BN -> LeakyReLU 170 | o8 = self.fc2(o7) 171 | z = nn.Sigmoid()(o8) 172 | 173 | outputs = [o1, o2, o3, o4, o5, o7, o8, z] 174 | 175 | if len(self.idx_list_enc) != 0: 176 | # If we are pretraining the full AutoEncoder we don't need any of this and we set self.idx_list_enc = [] 177 | 178 | if self.use_selectors: 179 | tuple_o = [self.selectors[idx](tt) for idx, tt in enumerate(outputs) if idx in self.idx_list_enc] 180 | 181 | else: 182 | # If we don't use selector, apply simple transformations to reduce the size of the feature maps 183 | tuple_o = [] 184 | 185 | for idx, tt in enumerate(outputs): 186 | if idx not in self.idx_list_enc: continue 187 | 188 | if tt.ndimension() > 2: 189 | tuple_o.append(F.avg_pool2d(tt, tt.shape[-2:]).squeeze()) 190 | 191 | else: 192 | tuple_o.append(tt.squeeze()) 193 | 194 | return list(zip([f'0{idx}' for idx in self.idx_list_enc], tuple_o)) 195 | 196 | else: # It means that we are pretraining the full AutoEncoder 197 | return z 198 | 199 | 200 | class MVTec_Decoder(BaseModule): 201 | """MVTec Decoder network 202 | 203 | """ 204 | def __init__(self, code_length: int, deepest_shape: int, last_depth: int, output_shape: torch.Tensor): 205 | """Init MVtec Decoder network 206 | 207 | Parameters 208 | ---------- 209 | code_length : int 210 | Latent code size 211 | deepest_shape : int 212 | Depth of the last encoder features map 213 | output_shape : torch.Tensor 214 | Input Data shape 215 | 216 | """ 217 | super().__init__() 218 | 219 | self.code_length = code_length 220 | self.deepest_shape = deepest_shape 221 | self.output_shape = output_shape 222 | 223 | # Decoder activation function 224 | activation_fn = nn.LeakyReLU() 225 | 226 | # FC network 227 | self.fc = nn.Sequential( 228 | nn.Linear(in_features=code_length, out_features=last_depth), 229 | nn.BatchNorm1d(num_features=last_depth), 230 | activation_fn, 231 | nn.Linear(in_features=last_depth, out_features=reduce(mul, deepest_shape)), 232 | nn.BatchNorm1d(num_features=reduce(mul, deepest_shape)), 233 | activation_fn 234 | ) 235 | 236 | # (Transposed) Convolutional network 237 | self.conv = nn.Sequential( 238 | UpsampleBlock(channel_in=CHANNELS[2]*2, channel_out=CHANNELS[2], activation_fn=activation_fn), 239 | UpsampleBlock(channel_in=CHANNELS[1]*2, channel_out=CHANNELS[1], activation_fn=activation_fn), 240 | UpsampleBlock(channel_in=CHANNELS[0]*2, channel_out=CHANNELS[0], activation_fn=activation_fn), 241 | ResidualBlock(channel_in=CHANNELS[0], channel_out=CHANNELS[0], activation_fn=activation_fn), 242 | nn.Conv2d(in_channels=CHANNELS[0], out_channels=3, kernel_size=1, bias=False) 243 | ) 244 | 245 | def forward(self, x: torch.Tensor) -> torch.Tensor: 246 | h = self.fc(x) 247 | h = h.view(len(h), *self.deepest_shape) 248 | return self.conv(h) 249 | 250 | 251 | class MVTecNet_AutoEncoder(BaseModule): 252 | """Full MVTecNet_AutoEncoder network 253 | 254 | """ 255 | def __init__(self, input_shape: int, code_length: int, use_selectors: bool): 256 | """Init Full AutoEncoder 257 | 258 | Parameters 259 | ---------- 260 | input_shape : Tensor 261 | Shape of input data 262 | code_length : int 263 | Latent code size 264 | use_selectors : bool 265 | True (False) if the model has (not) to use Selectors modules 266 | 267 | """ 268 | super().__init__() 269 | 270 | # Shape of input data needed by the Decoder 271 | self.input_shape = input_shape 272 | 273 | # Build Encoder 274 | self.encoder = MVTec_Encoder( 275 | input_shape=input_shape, 276 | code_length=code_length, 277 | idx_list_enc=[], 278 | use_selectors=use_selectors 279 | ) 280 | 281 | last_depth, deepest_shape = self.encoder.get_depths_info() 282 | 283 | # Build Decoder 284 | self.decoder = MVTec_Decoder( 285 | code_length=code_length, 286 | deepest_shape=deepest_shape, 287 | last_depth=last_depth, 288 | output_shape=input_shape 289 | ) 290 | 291 | def forward(self, x: torch.Tensor) -> torch.Tensor: 292 | z = self.encoder(x) 293 | x_r = self.decoder(z) 294 | x_r = x_r.view(-1, *self.input_shape) 295 | return x_r 296 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import numpy as np 5 | from tqdm import tqdm 6 | from PIL import Image 7 | from os.path import join 8 | 9 | import torch 10 | import torchvision.transforms as T 11 | from torch.utils.data import DataLoader 12 | from torchvision.datasets import ImageFolder 13 | 14 | from models.mvtec_model import MVTec_Encoder 15 | 16 | 17 | def get_out_dir(args, pretrain: bool, aelr: float, dset_name: str="cifar10", training_strategy: str=None) -> [str, str]: 18 | """Creates training output dir 19 | 20 | Parameters 21 | ---------- 22 | 23 | args : 24 | Arguments 25 | pretrain : bool 26 | True if pretrain the model 27 | aelr : float 28 | Full AutoEncoder learning rate 29 | dset_name : str 30 | Dataset name 31 | training_strategy : str 32 | ................................................................ 33 | 34 | Returns 35 | ------- 36 | out_dir : str 37 | Path to output folder 38 | tmp : str 39 | String containing infos about the current experiment setup 40 | 41 | """ 42 | if dset_name == "ShanghaiTech": 43 | if pretrain: 44 | tmp = (f"pretrain-mn_{dset_name}-cl_{args.code_length}-lr_{args.ae_learning_rate}") 45 | out_dir = os.path.join(args.output_path, dset_name, 'pretrain', tmp) 46 | else: 47 | tmp = ( 48 | f"train-mn_{dset_name}-cl_{args.code_length}-bs_{args.batch_size}-nu_{args.nu}-lr_{args.learning_rate}-" 49 | f"bd_{args.boundary}-sl_{args.use_selectors}-ile_{'.'.join(map(str, args.idx_list_enc))}-lstm_{args.load_lstm}-" 50 | f"bidir_{args.bidirectional}-hs_{args.hidden_size}-nl_{args.num_layers}-dp_{args.dropout}" 51 | ) 52 | out_dir = os.path.join(args.output_path, dset_name, 'train', tmp) 53 | if args.end_to_end_training: 54 | out_dir = os.path.join(args.output_path, dset_name, 'train_end_to_end', tmp) 55 | else: 56 | if pretrain: 57 | tmp = (f"pretrain-mn_{dset_name}-nc_{args.normal_class}-cl_{args.code_length}-lr_{args.ae_learning_rate}-awd_{args.ae_weight_decay}") 58 | out_dir = os.path.join(args.output_path, dset_name, str(args.normal_class), 'pretrain', tmp) 59 | 60 | else: 61 | tmp = ( 62 | f"train-mn_{dset_name}-nc_{args.normal_class}-cl_{args.code_length}-bs_{args.batch_size}-nu_{args.nu}-lr_{args.learning_rate}-" 63 | f"wd_{args.weight_decay}-bd_{args.boundary}-alr_{aelr}-sl_{args.use_selectors}-ep_{args.epochs}-ile_{'.'.join(map(str, args.idx_list_enc))}" 64 | ) 65 | out_dir = os.path.join(args.output_path, dset_name, str(args.normal_class), 'train', tmp) 66 | 67 | if not os.path.exists(out_dir): 68 | os.makedirs(out_dir) 69 | 70 | return out_dir, tmp 71 | 72 | 73 | def set_seeds(seed: int) -> None: 74 | """Set all seeds. 75 | 76 | Parameters 77 | ---------- 78 | seed : int 79 | Seed 80 | 81 | """ 82 | # Set the seed only if the user specified it 83 | if seed != -1: 84 | random.seed(seed) 85 | np.random.seed(seed) 86 | torch.manual_seed(seed) 87 | 88 | 89 | def purge_params(encoder_net, ae_net_cehckpoint: str) -> None: 90 | """Load Encoder preatrained weights from the full AutoEncoder. 91 | After the pretraining phase, we don't need the full AutoEncoder parameters, we only need the Encoder 92 | 93 | Parameters 94 | ---------- 95 | encoder_net : 96 | The Encoder network 97 | ae_net_cehckpoint : str 98 | Path to full AutoEncoder checkpoint 99 | 100 | """ 101 | # Load the full AutoEncoder checkpoint dict 102 | ae_net_dict = torch.load(ae_net_cehckpoint, map_location=lambda storage, loc: storage)['ae_state_dict'] 103 | 104 | # Load encoder weight from autoencoder 105 | net_dict = encoder_net.state_dict() 106 | 107 | # Filter out decoder network keys 108 | st_dict = {k: v for k, v in ae_net_dict.items() if k in net_dict} 109 | 110 | # Overwrite values in the existing state_dict 111 | net_dict.update(st_dict) 112 | 113 | # Load the new state_dict 114 | encoder_net.load_state_dict(net_dict) 115 | 116 | 117 | def load_mvtec_model_from_checkpoint(input_shape: tuple, code_length: int, idx_list_enc: list, use_selectors: bool, net_cehckpoint: str, purge_ae_params: bool = False) -> torch.nn.Module: 118 | """Load AutoEncoder checkpoint. 119 | 120 | Parameters 121 | ---------- 122 | input_shape : tuple 123 | Input data shape 124 | code_length : int 125 | Latent code size 126 | idx_list_enc : list 127 | List of indexes of layers from which extract features 128 | use_selectors : bool 129 | True if the model has to use Selector modules 130 | net_cehckpoint : str 131 | Path to model checkpoint 132 | purge_ae_params : bool 133 | True if the checkpoint is relative to an AutoEncoder 134 | 135 | Returns 136 | ------- 137 | encoder_net : torch.nn.Module 138 | The Encoder network 139 | 140 | """ 141 | logger = logging.getLogger() 142 | 143 | encoder_net = MVTec_Encoder( 144 | input_shape=input_shape, 145 | code_length=code_length, 146 | idx_list_enc=idx_list_enc, 147 | use_selectors=use_selectors 148 | ) 149 | 150 | if purge_ae_params: 151 | 152 | # Load Encoder parameters from pretrianed full AutoEncoder 153 | logger.info(f"Loading encoder from: {net_cehckpoint}") 154 | purge_params(encoder_net=encoder_net, ae_net_cehckpoint=net_cehckpoint) 155 | else: 156 | 157 | st_dict = torch.load(net_cehckpoint) 158 | encoder_net.load_state_dict(st_dict['net_state_dict']) 159 | logger.info(f"Loaded model from: {net_cehckpoint}") 160 | 161 | return encoder_net 162 | 163 | def extract_arguments_from_checkpoint(net_checkpoint: str): 164 | """Takes file path of the checkpoint and parse the checkpoint name to extract training parameters and 165 | architectural specifications of the model. 166 | 167 | Parameters 168 | ---------- 169 | net_checkpoint : file path of the checkpoint (str) 170 | 171 | Returns 172 | ------- 173 | code_length = latent code size (int) 174 | batch_size = batch_size (int) 175 | boundary = soft or hard boundary (str) 176 | use_selectors = if selectors used it is true, otherwise false (bool) 177 | idx_list_enc = indexes of the exploited layers (list of integers) 178 | load_lstm = boolean to show whether lstm used (bool) 179 | hidden_size = hidden size of the lstm (int) 180 | num_layers = number of layers of the lstm (int) 181 | dropout = dropout probability (float) 182 | bidirectional = is lstm bi-directional or not (bool) 183 | dataset_name = name of the dataset (str) 184 | train_type = is it end-to-end, train, or pretrain (str) 185 | """ 186 | 187 | code_length = int(net_checkpoint.split(os.sep)[-2].split('-')[2].split('_')[-1]) 188 | batch_size = int(net_checkpoint.split(os.sep)[-2].split('-')[3].split('_')[-1]) 189 | boundary = net_checkpoint.split(os.sep)[-2].split('-')[6].split('_')[-1] 190 | use_selectors = net_checkpoint.split(os.sep)[-2].split('-')[7].split('_')[-1] == "True" 191 | idx_list_enc = [int(i) for i in net_checkpoint.split(os.sep)[-2].split('-')[8].split('_')[-1].split('.')] 192 | load_lstm = net_checkpoint.split(os.sep)[-2].split('-')[9].split('_')[-1] == "True" 193 | hidden_size = int(net_checkpoint.split(os.sep)[-2].split('-')[11].split('_')[-1]) 194 | num_layers = int(net_checkpoint.split(os.sep)[-2].split('-')[12].split('_')[-1]) 195 | dropout = float(net_checkpoint.split(os.sep)[-2].split('-')[13].split('_')[-1]) 196 | bidirectional = net_checkpoint.split(os.sep)[-2].split('-')[10].split('_')[-1] == "True" 197 | dataset_name = net_checkpoint.split(os.sep)[-4] 198 | train_type = net_checkpoint.split(os.sep)[-3] 199 | return code_length, batch_size, boundary, use_selectors, idx_list_enc, load_lstm, hidden_size, num_layers, dropout, bidirectional, dataset_name, train_type 200 | 201 | def eval_spheres_centers(train_loader: DataLoader, encoder_net: torch.nn.Module, ae_net_cehckpoint: str, use_selectors: bool, device:str, debug: bool) -> dict: 202 | """Eval the centers of the hyperspheres at each chosen layer. 203 | 204 | Parameters 205 | ---------- 206 | train_loader : DataLoader 207 | DataLoader for trainin data 208 | encoder_net : torch.nn.Module 209 | Encoder network 210 | ae_net_cehckpoint : str 211 | Checkpoint of the full AutoEncoder 212 | use_selectors : bool 213 | True if we want to use selector models 214 | device : str 215 | Device on which run the computations 216 | debug : bool 217 | Activate debug mode 218 | 219 | Returns 220 | ------- 221 | dict : dictionary 222 | Dictionary with k='layer name'; v='features vector representing hypersphere center' 223 | 224 | """ 225 | logger = logging.getLogger() 226 | 227 | centers_files = ae_net_cehckpoint[:-4]+f'_w_centers_{use_selectors}.pth' 228 | 229 | # If centers are found, then load and return 230 | if os.path.exists(centers_files): 231 | 232 | logger.info("Found hyperspheres centers") 233 | ae_net_ckp = torch.load(centers_files, map_location=lambda storage, loc: storage) 234 | 235 | centers = {k: v.to(device) for k, v in ae_net_ckp['centers'].items()} 236 | else: 237 | 238 | logger.info("Hyperspheres centers not found... evaluating...") 239 | centers_ = init_center_c(train_loader=train_loader, encoder_net=encoder_net, device=device, debug=debug) 240 | 241 | logger.info("Hyperspheres centers evaluated!!!") 242 | new_ckp = ae_net_cehckpoint.split('.pth')[0]+f'_w_centers_{use_selectors}.pth' 243 | 244 | logger.info(f"New AE dict saved at: {new_ckp}!!!") 245 | centers = {k: v for k, v in centers_.items()} 246 | 247 | torch.save({ 248 | 'ae_state_dict': torch.load(ae_net_cehckpoint)['ae_state_dict'], 249 | 'centers': centers 250 | }, new_ckp) 251 | 252 | return centers 253 | 254 | 255 | @torch.no_grad() 256 | def init_center_c(train_loader: DataLoader, encoder_net: torch.nn.Module, device: str, debug: bool, eps: float=0.1) -> dict: 257 | """Initialize hypersphere center as the mean from an initial forward pass on the data. 258 | 259 | Parameters 260 | ---------- 261 | train_loader : 262 | encoder_net : 263 | debug : 264 | eps: 265 | 266 | Returns 267 | ------- 268 | dictionary : dict 269 | Dictionary with k='layer name'; v='center featrues' 270 | 271 | """ 272 | n_samples = 0 273 | 274 | encoder_net.eval().to(device) 275 | 276 | for idx, (data, _) in enumerate(tqdm(train_loader, desc='Init hyperspheres centeres', total=len(train_loader), leave=False)): 277 | if debug and idx == 5: break 278 | 279 | data = data.to(device) 280 | n_samples += data.shape[0] 281 | 282 | zipped = encoder_net(data) 283 | 284 | if idx == 0: 285 | c = {item[0]: torch.zeros_like(item[1][-1], device=device) for item in zipped} 286 | 287 | for item in zipped: 288 | c[item[0]] += torch.sum(item[1], dim=0) 289 | 290 | for k in c.keys(): 291 | c[k] = c[k] / n_samples 292 | 293 | # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights. 294 | c[k][(abs(c[k]) < eps) & (c[k] < 0)] = -eps 295 | c[k][(abs(c[k]) < eps) & (c[k] > 0)] = eps 296 | 297 | return c 298 | -------------------------------------------------------------------------------- /models/shanghaitech_model.py: -------------------------------------------------------------------------------- 1 | from .shanghaitech_base_model import BaseModule, DownsampleBlock, UpsampleBlock, TemporallySharedFullyConnection, MaskedConv3d 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Selector(BaseModule): 7 | def __init__(self, code_length, idx): 8 | super(Selector, self).__init__() 9 | """ 10 | sizes = [[ch, time , h, w], ...] 11 | """ 12 | self.idx = idx 13 | self.sizes = [ 14 | [8, 16, 128, 256], 15 | [16, 16, 64, 128], 16 | [32, 8, 32, 64], 17 | [64, 8, 16, 32], 18 | [64, 4, 8, 16] 19 | ] 20 | mid_features_size = 256 21 | #self.adaptive = nn.AdaptiveMaxPool3d(output_size=(None,16,16)) 22 | self.cv1 = nn.Sequential( 23 | nn.Conv3d(in_channels=self.sizes[idx][0], out_channels=self.sizes[idx][0]*2, kernel_size=3, padding=1,stride=(1,2,2)), 24 | nn.BatchNorm3d(num_features=self.sizes[idx][0]*2), 25 | nn.ReLU(), 26 | nn.Conv3d(in_channels=self.sizes[idx][0]*2, out_channels=self.sizes[idx][0]*4, kernel_size=3, padding=1), 27 | nn.BatchNorm3d(num_features=self.sizes[idx][0]*4), 28 | nn.ReLU(), 29 | nn.Conv3d(in_channels=self.sizes[idx][0]*4, out_channels=self.sizes[idx][0]*4, kernel_size=1) 30 | ) 31 | 32 | self.fc = nn.Sequential( 33 | TemporallySharedFullyConnection(in_features=(self.sizes[self.idx][0]*4 * self.sizes[self.idx][2]//2 * self.sizes[self.idx][3]//2), out_features=mid_features_size, bias=True), 34 | nn.BatchNorm1d((self.sizes[self.idx][1])), 35 | nn.ReLU(), 36 | TemporallySharedFullyConnection(in_features=mid_features_size, out_features=self.sizes[self.idx][0], bias=True) 37 | ) 38 | 39 | 40 | def forward(self, x): 41 | x = self.cv1(x) 42 | _, t, _, _ = self.sizes[self.idx] 43 | x = torch.transpose(x, 1, 2).contiguous() 44 | x = x.view(-1, t, (self.sizes[self.idx][0]*4 * self.sizes[self.idx][2]//2 * self.sizes[self.idx][3]//2)) 45 | x = self.fc(x) 46 | return x 47 | 48 | def build_lstm(input_size, hidden_size, num_layers, dropout, bidirectional): 49 | return nn.LSTM( 50 | input_size=input_size, 51 | hidden_size=hidden_size, 52 | num_layers=num_layers, 53 | bias=True, 54 | batch_first=True, 55 | dropout=dropout, 56 | bidirectional=bidirectional 57 | ) 58 | 59 | 60 | class ShanghaiTechEncoder(BaseModule): 61 | """ 62 | ShanghaiTech model encoder. 63 | """ 64 | def __init__(self, input_shape, code_length, load_lstm, hidden_size, num_layers, dropout, bidirectional, use_selectors): 65 | # type: (Tuple[int, int, int, int], int) -> None 66 | """ 67 | Class constructor: 68 | :param input_shape: the shape of UCSD Ped2 samples. 69 | :param code_length: the dimensionality of latent vectors. 70 | """ 71 | super(ShanghaiTechEncoder, self).__init__() 72 | 73 | self.input_shape = input_shape 74 | self.code_length = code_length 75 | self.load_lstm = load_lstm 76 | self.use_selectors = use_selectors 77 | 78 | c, t, h, w = input_shape 79 | 80 | activation_fn = nn.LeakyReLU() 81 | 82 | # Convolutional network 83 | #self.conv = nn.Sequential( 84 | self.conv_1 = DownsampleBlock(channel_in=c, channel_out=8, activation_fn=activation_fn, stride=(1, 2, 2)) 85 | self.conv_2 = DownsampleBlock(channel_in=8, channel_out=16, activation_fn=activation_fn, stride=(1, 2, 2)) 86 | self.conv_3 = DownsampleBlock(channel_in=16, channel_out=32, activation_fn=activation_fn, stride=(2, 2, 2)) 87 | self.conv_4 = DownsampleBlock(channel_in=32, channel_out=64, activation_fn=activation_fn, stride=(1, 2, 2)) 88 | self.conv_5 = DownsampleBlock(channel_in=64, channel_out=64, activation_fn=activation_fn, stride=(2, 2, 2)) 89 | if load_lstm: 90 | self.lstm_1 = build_lstm(8, hidden_size, num_layers, dropout, bidirectional) 91 | self.lstm_2 = build_lstm(16, hidden_size, num_layers, dropout, bidirectional) 92 | self.lstm_3 = build_lstm(32, hidden_size, num_layers, dropout, bidirectional) 93 | self.lstm_4 = build_lstm(64, hidden_size, num_layers, dropout, bidirectional) 94 | self.lstm_5 = build_lstm(64, hidden_size, num_layers, dropout, bidirectional) 95 | #) 96 | 97 | ## Features selector models (MLPs) 98 | self.sel1 = Selector(self.code_length, 0) 99 | self.sel2 = Selector(self.code_length, 1) 100 | self.sel3 = Selector(self.code_length, 2) 101 | self.sel4 = Selector(self.code_length, 3) 102 | self.sel5 = Selector(self.code_length, 4) 103 | 104 | 105 | self.deepest_shape = (64, t // 4, h // 32, w // 32) 106 | 107 | # FC network 108 | dc, dt, dh, dw = self.deepest_shape 109 | #self.tdl = nn.Sequential( 110 | self.tdl_1 = TemporallySharedFullyConnection(in_features=(dc * dh * dw), out_features=512) 111 | self.tanh = nn.Tanh() 112 | self.tdl_2 = TemporallySharedFullyConnection(in_features=512, out_features=code_length) 113 | self.sigmoid = nn.Sigmoid() 114 | if load_lstm: 115 | self.lstm_tdl_1 = build_lstm(512, hidden_size, num_layers, dropout, bidirectional) 116 | self.lstm_tdl_2 = build_lstm(code_length, hidden_size, num_layers, dropout, bidirectional) 117 | #) 118 | 119 | def forward(self, x): 120 | # types: (torch.Tensor) -> torch.Tensor 121 | """ 122 | Forward propagation. 123 | :param x: the input batch of patches. 124 | :return: the batch of latent vectors. 125 | """ 126 | h = x 127 | #h = self.conv(h) 128 | o1 = self.conv_1(h) 129 | o2 = self.conv_2(o1) 130 | o3 = self.conv_3(o2) 131 | o4 = self.conv_4(o3) 132 | o5 = self.conv_5(o4) 133 | 134 | # Reshape for fully connected sub-network (flatten) 135 | c, t, height, width = self.deepest_shape 136 | h = torch.transpose(o5, 1, 2).contiguous() 137 | h = h.view(-1, t, (c * height * width)) 138 | #o = self.tdl(h) 139 | o_tdl_1 = self.tdl_1(h) 140 | o_tdl_1_t = self.tanh(o_tdl_1) 141 | o_tdl_2 = self.tdl_2(o_tdl_1_t) 142 | o_tdl_2_s = self.sigmoid(o_tdl_2) 143 | 144 | if self.load_lstm: 145 | 146 | def shape_lstm_input(o): 147 | # batch, channel, height, width 148 | o = o.permute(0, 2, 1, 3, 4) 149 | kernel_size = (1, o.shape[-2], o.shape[-1]) 150 | o = F.avg_pool3d(o, kernel_size).squeeze() if o.ndimension() > 3 else o 151 | # batch, time, channel 152 | return o if o.ndim > 2 else o.unsqueeze(0) 153 | if self.use_selectors: 154 | o1_lstm, _ = self.lstm_1(self.sel1(o1)) 155 | o2_lstm, _ = self.lstm_2(self.sel2(o2)) 156 | o3_lstm, _ = self.lstm_3(self.sel3(o3)) 157 | o4_lstm, _ = self.lstm_4(self.sel4(o4)) 158 | o5_lstm, _ = self.lstm_5(self.sel5(o5)) 159 | else: 160 | o1_lstm, _ = self.lstm_1(shape_lstm_input(o1)) 161 | o2_lstm, _ = self.lstm_2(shape_lstm_input(o2)) 162 | o3_lstm, _ = self.lstm_3(shape_lstm_input(o3)) 163 | o4_lstm, _ = self.lstm_4(shape_lstm_input(o4)) 164 | o5_lstm, _ = self.lstm_5(shape_lstm_input(o5)) 165 | 166 | o1_tdl_lstm, _ = self.lstm_tdl_1(o_tdl_1_t) 167 | o2_tdl_lstm, _ = self.lstm_tdl_2(o_tdl_2_s) 168 | 169 | conv_lstms = [o1_lstm[:, -1], o2_lstm[:, -1], o3_lstm[:, -1], o4_lstm[:, -1], o5_lstm[:, -1]] 170 | tdl_lstms = [o1_tdl_lstm[:, -1], o2_tdl_lstm[:, -1]] 171 | 172 | d_lstms = dict(zip([f"conv_lstm_o_{i}" for i in range(len(conv_lstms))], conv_lstms)) 173 | d_lstms.update(dict(zip([f"tdl_lstm_o_{i}" for i in range(len(tdl_lstms))], tdl_lstms))) 174 | return o_tdl_2_s, d_lstms 175 | 176 | else: 177 | return o_tdl_2_s 178 | 179 | 180 | class ShanghaiTechDecoder(BaseModule): 181 | """ 182 | ShanghaiTech model decoder. 183 | """ 184 | def __init__(self, code_length, deepest_shape, output_shape): 185 | # type: (int, Tuple[int, int, int, int], Tuple[int, int, int, int]) -> None 186 | """ 187 | Class constructor. 188 | :param code_length: the dimensionality of latent vectors. 189 | :param deepest_shape: the dimensionality of the encoder's deepest convolutional map. 190 | :param output_shape: the shape of UCSD Ped2 samples. 191 | """ 192 | super(ShanghaiTechDecoder, self).__init__() 193 | 194 | self.code_length = code_length 195 | self.deepest_shape = deepest_shape 196 | self.output_shape = output_shape 197 | 198 | dc, dt, dh, dw = deepest_shape 199 | 200 | activation_fn = nn.LeakyReLU() 201 | 202 | # FC network 203 | self.tdl = nn.Sequential( 204 | TemporallySharedFullyConnection(in_features=code_length, out_features=512), 205 | nn.Tanh(), 206 | TemporallySharedFullyConnection(in_features=512, out_features=(dc * dh * dw)), 207 | activation_fn 208 | ) 209 | 210 | # Convolutional network 211 | self.conv = nn.Sequential( 212 | UpsampleBlock(channel_in=dc, channel_out=64, 213 | activation_fn=activation_fn, stride=(2, 2, 2), output_padding=(1, 1, 1)), 214 | UpsampleBlock(channel_in=64, channel_out=32, 215 | activation_fn=activation_fn, stride=(1, 2, 2), output_padding=(0, 1, 1)), 216 | UpsampleBlock(channel_in=32, channel_out=16, 217 | activation_fn=activation_fn, stride=(2, 2, 2), output_padding=(1, 1, 1)), 218 | UpsampleBlock(channel_in=16, channel_out=8, 219 | activation_fn=activation_fn, stride=(1, 2, 2), output_padding=(0, 1, 1)), 220 | UpsampleBlock(channel_in=8, channel_out=8, 221 | activation_fn=activation_fn, stride=(1, 2, 2), output_padding=(0, 1, 1)), 222 | nn.Conv3d(in_channels=8, out_channels=output_shape[0], kernel_size=1) 223 | ) 224 | 225 | def forward(self, x): 226 | # types: (torch.Tensor) -> torch.Tensor 227 | """ 228 | Forward propagation. 229 | :param x: the batch of latent vectors. 230 | :return: the batch of reconstructions. 231 | """ 232 | h = x 233 | h = self.tdl(h) 234 | 235 | # Reshape to encoder's deepest convolutional shape 236 | h = torch.transpose(h, 1, 2).contiguous() 237 | h = h.view(len(h), *self.deepest_shape) 238 | 239 | h = self.conv(h) 240 | o = h 241 | 242 | return o 243 | 244 | 245 | class ShanghaiTech(BaseModule): 246 | """ 247 | Model for ShanghaiTech video anomaly detection. 248 | """ 249 | def __init__(self, input_shape, code_length, load_lstm=False, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False, use_selectors=False): 250 | # type: (Tuple[int, int, int, int], int, int) -> None 251 | """ 252 | Class constructor. 253 | :param input_shape: the shape of UCSD Ped2 samples. 254 | :param code_length: the dimensionality of latent vectors. 255 | :param cpd_channels: number of bins in which the multinomial works. 256 | """ 257 | super(ShanghaiTech, self).__init__() 258 | 259 | self.input_shape = input_shape 260 | self.code_length = code_length 261 | self.load_lstm = load_lstm 262 | # Build encoder 263 | self.encoder = ShanghaiTechEncoder( 264 | input_shape=input_shape, 265 | code_length=code_length, 266 | load_lstm=load_lstm, 267 | hidden_size=hidden_size, 268 | num_layers=num_layers, 269 | dropout=dropout, 270 | bidirectional=bidirectional, 271 | use_selectors=use_selectors 272 | ) 273 | 274 | # Build decoder 275 | self.decoder = ShanghaiTechDecoder( 276 | code_length=code_length, 277 | deepest_shape=self.encoder.deepest_shape, 278 | output_shape=input_shape 279 | ) 280 | 281 | def forward(self, x): 282 | # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 283 | """ 284 | Forward propagation. 285 | :param x: the input batch of patches. 286 | :return: a tuple of torch.Tensors holding reconstructions, latent vectors and CPD estimates. 287 | """ 288 | h = x 289 | 290 | # Produce representations 291 | if self.load_lstm: 292 | z, d_lstms = self.encoder(h) 293 | else: 294 | z = self.encoder(h) 295 | 296 | # Reconstruct x 297 | x_r = self.decoder(z) 298 | x_r = x_r.view(-1, *self.input_shape) 299 | if self.load_lstm: 300 | return x_r, z, d_lstms 301 | else: 302 | return x_r, z -------------------------------------------------------------------------------- /trainers/train_cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.optim import Adam, SGD 10 | from torch.optim.lr_scheduler import MultiStepLR 11 | from torch.utils.data.dataloader import DataLoader 12 | 13 | from tensorboardX import SummaryWriter 14 | 15 | from sklearn.metrics import roc_auc_score 16 | 17 | 18 | def pretrain(ae_net: torch.nn.Module, train_loader: DataLoader, out_dir: str, tb_writer: SummaryWriter, device: str, ae_learning_rate: float, ae_weight_decay: float, ae_lr_milestones: list, ae_epochs: int) -> str: 19 | """Train the full AutoEncoder network. 20 | 21 | Parameters 22 | ---------- 23 | ae_net : torch.nn.Module 24 | AutoEncoder network 25 | train_loader : DataLoader 26 | Data laoder 27 | out_dir : str 28 | Path to checkpoint dir 29 | tb_writer : SummaryWriter 30 | Writer on tensorboard 31 | device : str 32 | Device 33 | ae_learning_rate : float 34 | AutoEncoder learning rate 35 | ae_weight_decay : float 36 | Weight decay 37 | ae_lr_milestones : list 38 | Epochs at which drop the learning rate 39 | ae_epochs: int 40 | Number of training epochs 41 | 42 | Returns 43 | ------- 44 | ae_net_cehckpoint : str 45 | Path to model checkpoint 46 | 47 | """ 48 | logger = logging.getLogger() 49 | 50 | ae_net = ae_net.train().to(device) 51 | 52 | optimizer = Adam(ae_net.parameters(), lr=ae_learning_rate, weight_decay=ae_weight_decay) 53 | scheduler = MultiStepLR(optimizer, milestones=ae_lr_milestones, gamma=0.1) 54 | 55 | for epoch in range(ae_epochs): 56 | loss_epoch = 0.0 57 | n_batches = 0 58 | 59 | for (data, _, _) in train_loader: 60 | data = data.to(device) 61 | 62 | optimizer.zero_grad() 63 | 64 | outputs = ae_net(data) 65 | 66 | scores = torch.sum((outputs - data) ** 2, dim=tuple(range(1, outputs.dim()))) 67 | 68 | loss = torch.mean(scores) 69 | loss.backward() 70 | 71 | optimizer.step() 72 | 73 | loss_epoch += loss.item() 74 | n_batches += 1 75 | 76 | scheduler.step() 77 | if epoch in ae_lr_milestones: 78 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0])) 79 | 80 | logger.info(f"PreTrain at epoch: {epoch+1} ==> Recon Loss: {loss_epoch/len(train_loader):.4f}") 81 | tb_writer.add_scalar('pretrain/recon_loss', loss_epoch/len(train_loader), epoch+1) 82 | 83 | logger.info('Finished pretraining.') 84 | 85 | ae_net_cehckpoint = os.path.join(out_dir, f'ae_ckp_{time.time()}.pth') 86 | torch.save({'ae_state_dict': ae_net.state_dict()}, ae_net_cehckpoint) 87 | logger.info(f'Saved autoencoder at: {ae_net_cehckpoint}') 88 | 89 | return ae_net_cehckpoint 90 | 91 | 92 | def train(net: torch.nn.Module, train_loader: DataLoader, out_dir: str, tb_writer: SummaryWriter, device: str, ae_net_cehckpoint: str, idx_list_enc: list, learning_rate: float, weight_decay: float, lr_milestones: list, epochs: int, nu: float, boundary: str, debug: bool) -> str: 93 | """Train the Encoder network on the one class task. 94 | 95 | Parameters 96 | ---------- 97 | net : torch.nn.Module 98 | Encoder network 99 | train_loader : DataLoader 100 | Data laoder 101 | out_dir : str 102 | Path to checkpoint dir 103 | tb_writer : SummaryWriter 104 | Writer on tensorboard 105 | device : str 106 | Device 107 | ae_net_cehckpoint : str 108 | Path to autoencoder checkpoint 109 | idx_list_enc : list 110 | List of indexes of layers from which extract features 111 | learning_rate : float 112 | AutoEncoder learning rate 113 | weight_decay : float 114 | Weight decay 115 | lr_milestones : list 116 | Epochs at which drop the learning rate 117 | epochs: int 118 | Number of training epochs 119 | nu : float 120 | Value of the trade-off parameter 121 | boundary : str 122 | Type of boundary 123 | debug: bool 124 | If True, enable debug mode 125 | 126 | Returns 127 | ------- 128 | net_cehckpoint : str 129 | Path to model checkpoint 130 | 131 | """ 132 | logger = logging.getLogger() 133 | 134 | net.train().to(device) 135 | 136 | # Hook model's layers 137 | feat_d = {} 138 | hooks = hook_model(idx_list_enc=idx_list_enc, model=net, dataset_name="cifar10", feat_d=feat_d) 139 | 140 | optimizer = Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay) 141 | scheduler = MultiStepLR(optimizer, milestones=lr_milestones, gamma=0.1) 142 | 143 | # Initialize hypersphere center c 144 | logger.info('Initializing center c...') 145 | c = init_center_c(feat_d=feat_d, train_loader=train_loader, net=net, device=device) 146 | logger.info('Center c initialized.') 147 | 148 | R = {k: torch.tensor(0.0, device=device) for k in c.keys()} 149 | 150 | logger.info('Start training...') 151 | warm_up_n_epochs = 10 152 | 153 | for epoch in range(epochs): 154 | loss_epoch = 0.0 155 | n_batches = 0 156 | d_from_c = {} 157 | 158 | for (data, _, _) in train_loader: 159 | data = data.to(device) 160 | 161 | # Update network parameters via backpropagation: forward + backward + optimize 162 | _ = net(data) 163 | 164 | dist, loss = eval_ad_loss(feat_d=feat_d, c=c, R=R, nu=nu, boundary=boundary) 165 | 166 | for k in dist.keys(): 167 | if k not in d_from_c: 168 | d_from_c[k] = 0 169 | d_from_c[k] += torch.mean(dist[k]).item() 170 | 171 | optimizer.zero_grad() 172 | loss.backward() 173 | optimizer.step() 174 | 175 | # Update hypersphere radius R on mini-batch distances 176 | # only after the warm up epochs 177 | if (boundary == 'soft') and (epoch >= warm_up_n_epochs): 178 | for k in R.keys(): 179 | R[k].data = torch.tensor( 180 | np.quantile(np.sqrt(dist[k].clone().data.cpu().numpy()), 1 - nu), 181 | device=device 182 | ) 183 | 184 | loss_epoch += loss.item() 185 | n_batches += 1 186 | 187 | scheduler.step() 188 | if epoch in lr_milestones: 189 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0])) 190 | 191 | # log epoch statistics 192 | logger.info(f"TRAIN at epoch: {epoch} ==> Objective Loss: {loss_epoch/n_batches:.4f}") 193 | tb_writer.add_scalar('train/objective_loss', loss_epoch/n_batches, epoch) 194 | for en, k in enumerate(d_from_c.keys()): 195 | logger.info( 196 | f"[{k}] -- Radius: {R[k]:.4f} - " 197 | f"Dist from sphere centr: {d_from_c[k]/n_batches:.4f}" 198 | ) 199 | tb_writer.add_scalar(f'train/radius_{idx_list_enc[en]}', R[k], epoch) 200 | tb_writer.add_scalar(f'train/distance_c_sphere_{idx_list_enc[en]}', d_from_c[k]/n_batches, epoch) 201 | 202 | logger.info('Finished training!!') 203 | 204 | [h.remove() for h in hooks] 205 | 206 | time_ = ae_net_cehckpoint.split('_')[-1].split('.p')[0] 207 | net_cehckpoint = os.path.join(out_dir, f'net_ckp_{time_}.pth') 208 | if debug: 209 | net_cehckpoint = './test_net_ckp.pth' 210 | torch.save({ 211 | 'net_state_dict': net.state_dict(), 212 | 'R': R, 213 | 'c': c 214 | }, 215 | net_cehckpoint 216 | ) 217 | logger.info(f'Saved model at: {net_cehckpoint}') 218 | 219 | return net_cehckpoint 220 | 221 | 222 | def test(net: torch.nn.Module, test_loader: DataLoader, R: dict, c: dict, device: str, idx_list_enc: list, boundary: str) -> float: 223 | """Test the Encoder network. 224 | 225 | Parameters 226 | ---------- 227 | net : torch.nn.Module 228 | Encoder network 229 | test_loader : DataLoader 230 | Data laoder 231 | R : dict 232 | Dictionary containing the values of the radiuses for each layer 233 | c : dict 234 | Dictionary containing the values of the hyperspheres' center for each layer 235 | device : str 236 | Device 237 | idx_list_enc : list 238 | List of indexes of layers from which extract features 239 | boundary : str 240 | Type of boundary 241 | debug: bool 242 | If True, enable debug mode 243 | 244 | Returns 245 | ------- 246 | test_auc : float 247 | AUC 248 | 249 | """ 250 | logger = logging.getLogger() 251 | 252 | # Hook model's layers 253 | feat_d = {} 254 | hooks = hook_model(idx_list_enc=idx_list_enc, model=net, dataset_name="cifar10", feat_d=feat_d) 255 | 256 | # Testing 257 | logger.info('Starti testing...') 258 | idx_label_score = [] 259 | net.eval().to(device) 260 | with torch.no_grad(): 261 | for data in test_loader: 262 | inputs, labels, idx = data 263 | inputs = inputs.to(device) 264 | 265 | _ = net(inputs) 266 | 267 | scores = get_scores(feat_d=feat_d, c=c, R=R, device=device, boundary=boundary) 268 | 269 | # Save triples of (idx, label, score) in a list 270 | idx_label_score += list(zip(idx.cpu().data.numpy().tolist(), 271 | labels.cpu().data.numpy().tolist(), 272 | scores.cpu().data.numpy().tolist())) 273 | 274 | [h.remove() for h in hooks] 275 | 276 | # Compute AUC 277 | _, labels, scores = zip(*idx_label_score) 278 | labels = np.array(labels) 279 | scores = np.array(scores) 280 | test_auc = roc_auc_score(labels, scores) 281 | logger.info('Test set AUC: {:.2f}%'.format(100. * test_auc)) 282 | 283 | logger.info('Finished testing!!') 284 | 285 | return 100. * test_auc 286 | 287 | 288 | def hook_model(idx_list_enc: list, model: torch.nn.Module, dataset_name: str, feat_d: dict) -> None: 289 | """Create hooks for model's layers. 290 | 291 | Parameters 292 | ---------- 293 | idx_list_enc : list 294 | List of indexes of layers from which extract features 295 | model : torch.nn.Module 296 | Encoder network 297 | dataset_name : str 298 | Name of the dataset 299 | feat_d : dict 300 | Dictionary containing features 301 | 302 | Returns 303 | ------- 304 | registered hooks 305 | 306 | """ 307 | if dataset_name == 'mnist': 308 | 309 | blocks_ = [model.conv1, model.conv2, model.fc1] 310 | else: 311 | 312 | blocks_ = [model.conv1, model.conv2, model.conv3, model.fc1] 313 | 314 | if isinstance(idx_list_enc, list) and len(idx_list_enc) != 0: 315 | assert len(idx_list_enc) <= len(blocks_), f"Too many indices for decoder: {idx_list_enc} - for {len(blocks_)} blocks" 316 | blocks = [blocks_[idx] for idx in idx_list_enc] 317 | 318 | blocks_idx = dict(zip(blocks, map('{:02d}'.format, range(len(blocks))))) 319 | 320 | def hook_func(module, input, output): 321 | block_num = blocks_idx[module] 322 | extracted = output 323 | if extracted.ndimension() > 2: 324 | extracted = F.avg_pool2d(extracted, extracted.shape[-2:]) 325 | feat_d[block_num] = extracted.squeeze() 326 | 327 | return [b.register_forward_hook(hook_func) for b in blocks_idx] 328 | 329 | 330 | @torch.no_grad() 331 | def init_center_c(feat_d: dict, train_loader: DataLoader, net: torch.nn.Module, device: str, eps: float = 0.1) -> dict: 332 | """Initialize hyperspheres' center c as the mean from an initial forward pass on the data. 333 | 334 | Parameters 335 | ---------- 336 | feat_d : dict 337 | Dictionary containing features 338 | train_loader : DataLoader 339 | Training data loader 340 | net : torch.nn.Module 341 | Encoder network 342 | device : str 343 | Device 344 | eps : float = 0.1 345 | If a center is too close to 0, set to +-eps 346 | Returns 347 | ------- 348 | c : dict 349 | hyperspheres' center 350 | 351 | """ 352 | n_samples = 0 353 | 354 | net.eval() 355 | 356 | for idx, (data, _, _) in enumerate(tqdm(train_loader, desc='init hyperspheres centeres', total=len(train_loader), leave=False)): 357 | data = data.to(device) 358 | 359 | outputs = net(data) 360 | n_samples += outputs.shape[0] 361 | 362 | if idx == 0: 363 | c = {k: torch.zeros_like(feat_d[k][-1], device=device) for k in feat_d.keys()} 364 | 365 | for k in feat_d.keys(): 366 | c[k] += torch.sum(feat_d[k], dim=0) 367 | 368 | for k in c.keys(): 369 | c[k] = c[k] / n_samples 370 | # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights. 371 | c[k][(abs(c[k]) < eps) & (c[k] < 0)] = -eps 372 | c[k][(abs(c[k]) < eps) & (c[k] > 0)] = eps 373 | 374 | return c 375 | 376 | 377 | def eval_ad_loss(feat_d: dict, c: dict, R: dict, nu: float, boundary: str) -> [dict, torch.Tensor]: 378 | """Eval training loss. 379 | 380 | Parameters 381 | ---------- 382 | feat_d : dict 383 | Dictionary containing features 384 | c : dict 385 | Dictionary hyperspheres' center 386 | R : dict 387 | Dictionary hyperspheres' radius 388 | nu : float 389 | Value of the trade-off parameter 390 | boundary : str 391 | Type of boundary 392 | 393 | Returns 394 | ------- 395 | dist : dict 396 | Dictionary containing the average distance of the features vectors from the hypersphere's center for each layer 397 | loss : torch.Tensor 398 | Loss value 399 | """ 400 | dist = {} 401 | loss = 1 402 | 403 | for k in feat_d.keys(): 404 | 405 | dist[k] = torch.sum((feat_d[k] - c[k].unsqueeze(0)) ** 2, dim=1) 406 | if boundary == 'soft': 407 | 408 | scores = dist[k] - R[k] ** 2 409 | loss += R[k] ** 2 + (1 / nu) * torch.mean(torch.max(torch.zeros_like(scores), scores)) 410 | else: 411 | 412 | loss += torch.mean(dist[k]) 413 | 414 | return dist, loss 415 | 416 | 417 | def get_scores(feat_d: dict, c: dict, R: dict, device: str, boundary: str) -> float: 418 | """Eval anomaly score. 419 | 420 | Parameters 421 | ---------- 422 | feat_d : dict 423 | Dictionary containing features 424 | c : dict 425 | Dictionary hyperspheres' center 426 | R : dict 427 | Dictionary hyperspheres' radius 428 | device : str 429 | Device 430 | boundary : str 431 | Type of boundary 432 | 433 | Returns 434 | ------- 435 | scores : float 436 | Anomaly score 437 | 438 | """ 439 | dist, _ = eval_ad_loss(feat_d, c, R, 1, boundary) 440 | shape = dist[list(dist.keys())[0]].shape[0] 441 | scores = torch.zeros((shape,), device=device) 442 | 443 | for k in dist.keys(): 444 | if boundary == 'soft': 445 | 446 | scores += dist[k] - R[k] ** 2 447 | else: 448 | 449 | scores += dist[k] 450 | 451 | return scores/len(list(dist.keys())) 452 | -------------------------------------------------------------------------------- /trainers/trainer_mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import logging 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.optim import Adam, SGD 12 | from torch.utils.data import DataLoader 13 | from torch.optim.lr_scheduler import MultiStepLR 14 | 15 | from tensorboardX import SummaryWriter 16 | from sklearn.metrics import roc_curve, roc_auc_score, auc 17 | 18 | 19 | def pretrain(ae_net: nn.Module, train_loader: DataLoader, out_dir: str, tb_writer: SummaryWriter, device: str, ae_learning_rate: float, ae_weight_decay: float, ae_lr_milestones: list, ae_epochs: int, log_frequency: int, batch_accumulation: int, debug: bool) -> str: 20 | """Train the full AutoEncoder network. 21 | 22 | Parameters 23 | ---------- 24 | ae_net : nn.Module 25 | AutoEncoder network 26 | train_loader : DataLoader 27 | Data loader 28 | out_dir : str 29 | Path to checkpoint dir 30 | tb_writer : SummaryWriter 31 | Writer on tensorboard 32 | device : str 33 | Device 34 | ae_learning_rate : float 35 | AutoEncoder learning rate 36 | ae_weight_decay : float 37 | Weight decay 38 | ae_lr_milestones : list 39 | Epochs at which drop the learning rate 40 | ae_epochs: int 41 | Number of training epochs 42 | log_frequency : int 43 | Number of iteration after which show logs 44 | batch_accumulation : int 45 | Number of iteration among which accumulate gradients 46 | debug : bool 47 | Only use the first 10 batches 48 | 49 | Returns 50 | ------- 51 | ae_net_cehckpoint : str 52 | Path to model checkpoint 53 | 54 | """ 55 | logger = logging.getLogger() 56 | 57 | ae_net = ae_net.train().to(device) 58 | 59 | optimizer = Adam(ae_net.parameters(), lr=ae_learning_rate, weight_decay=ae_weight_decay) 60 | scheduler = MultiStepLR(optimizer, milestones=ae_lr_milestones, gamma=0.1) 61 | 62 | # Independent index to save stats on tensorboard 63 | kk = 1 64 | 65 | # Counter for the batch accumulation steps 66 | j_ba_steps = 0 67 | 68 | for epoch in range(ae_epochs): 69 | loss_epoch = 0.0 70 | n_batches = 0 71 | optimizer.zero_grad() 72 | 73 | for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)): 74 | if debug and idx == 5: break 75 | 76 | data = data.to(device) 77 | 78 | x_r = ae_net(data) 79 | 80 | scores = torch.sum((x_r - data) ** 2, dim=tuple(range(1, x_r.dim()))) 81 | 82 | loss = torch.mean(scores) 83 | loss.backward() 84 | 85 | j_ba_steps += 1 86 | if batch_accumulation != -1: 87 | if j_ba_steps % batch_accumulation == 0: 88 | optimizer.step() 89 | optimizer.zero_grad() 90 | j_ba_steps = 0 91 | else: 92 | optimizer.step() 93 | optimizer.zero_grad() 94 | 95 | # Sanity check 96 | if np.isnan(loss.item()): 97 | logger.info("Found nan values into loss") 98 | sys.exit(0) 99 | 100 | loss_epoch += loss.item() 101 | n_batches += 1 102 | 103 | if idx != 0 and idx % ((len(train_loader)//log_frequency)+1) == 0: 104 | logger.info(f"PreTrain at epoch: {epoch+1} ([{idx}]/[{len(train_loader)}]) ==> Recon Loss: {loss_epoch/idx:.4f}") 105 | tb_writer.add_scalar('pretrain/recon_loss', loss_epoch/idx, kk) 106 | kk += 1 107 | 108 | scheduler.step() 109 | if epoch in ae_lr_milestones: 110 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0])) 111 | 112 | ae_net_cehckpoint = os.path.join(out_dir, f'best_ae_ckp.pth') 113 | torch.save({'ae_state_dict': ae_net.state_dict()}, ae_net_cehckpoint) 114 | logger.info(f'Saved best autoencoder so far at: {ae_net_cehckpoint}') 115 | 116 | logger.info('Finished pretraining.') 117 | 118 | return ae_net_cehckpoint 119 | 120 | 121 | def train(net: torch.nn.Module, train_loader: DataLoader, centers: dict, out_dir: str, tb_writer: SummaryWriter, device: str, learning_rate: float, weight_decay: float, lr_milestones: list, epochs: int, nu: float, boundary: str, batch_accumulation: int, warm_up_n_epochs: int, log_frequency: int, debug: bool) -> str: 122 | """Train the Encoder network on the one class task. 123 | 124 | Parameters 125 | ---------- 126 | net : nn.Module 127 | Encoder network 128 | train_loader : DataLoader 129 | Data laoder 130 | centers : dict 131 | Dictionary containing hyperspheres' center at each layer 132 | out_dir : str 133 | Path to checkpoint dir 134 | tb_writer : SummaryWriter 135 | Writer on tensorboard 136 | device : str 137 | Device 138 | learning_rate : float 139 | AutoEncoder learning rate 140 | weight_decay : float 141 | Weight decay 142 | lr_milestones : list 143 | Epochs at which drop the learning rate 144 | epochs: int 145 | Number of training epochs 146 | nu : float 147 | Value of the trade-off parameter 148 | boundary : str 149 | Type of boundary 150 | batch_accumulation : int 151 | 152 | warm_up_n_epochs : int 153 | 154 | log_frequency: int 155 | 156 | debug : bool 157 | Only use the first 10 batches 158 | 159 | Returns 160 | ------- 161 | net_cehckpoint : str 162 | Path to model checkpoint 163 | 164 | """ 165 | logger = logging.getLogger() 166 | 167 | optimizer = Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay) 168 | scheduler = MultiStepLR(optimizer, milestones=lr_milestones, gamma=0.1) 169 | 170 | # Init spheres' radius 171 | R = {k: torch.tensor(0.0, device=device) for k in centers.keys()} 172 | 173 | # Training 174 | logger.info('Start training...') 175 | kk = 1 176 | net.train().to(device) 177 | best_loss = 1.e6 178 | 179 | for epoch in range(epochs): 180 | j = 0 181 | loss_epoch = 0.0 182 | n_batches = 0 183 | d_from_c = {} 184 | optimizer.zero_grad() 185 | 186 | for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)): 187 | if debug and idx == 5: break 188 | 189 | data = data.to(device) 190 | 191 | zipped = net(data) 192 | 193 | dist, loss = eval_ad_loss(zipped=zipped, c=centers, R=R, nu=nu, boundary=boundary) 194 | 195 | for k in dist.keys(): 196 | if k not in d_from_c: 197 | d_from_c[k] = 0 198 | d_from_c[k] += torch.mean(dist[k]).item() 199 | 200 | loss.backward() 201 | j += 1 202 | if batch_accumulation != -1: 203 | if j == batch_accumulation: 204 | j = 0 205 | optimizer.step() 206 | optimizer.zero_grad() 207 | else: 208 | optimizer.step() 209 | optimizer.zero_grad() 210 | 211 | # Update hypersphere radius R on mini-batch distances 212 | if (boundary == 'soft') and (epoch >= warm_up_n_epochs): 213 | # R.data = torch.tensor(get_radius(dist, nu), device=device) 214 | for k in R.keys(): 215 | R[k].data = torch.tensor( 216 | np.quantile(np.sqrt(dist[k].clone().data.cpu().numpy()), 1 - nu), 217 | device=device 218 | ) 219 | 220 | loss_epoch += loss.item() 221 | n_batches += 1 222 | 223 | if np.isnan(loss.item()): 224 | logger.info("Found nan values into loss") 225 | sys.exit(0) 226 | 227 | if idx != 0 and idx % ((len(train_loader)//log_frequency)+1) == 0: 228 | # log epoch statistics 229 | logger.info(f"TRAIN at epoch: {epoch+1} ([{idx}]/[{len(train_loader)}]) ==> Objective Loss: {loss_epoch/idx:.4f}") 230 | tb_writer.add_scalar('train/objective_loss', loss_epoch/idx, kk) 231 | for _, k in enumerate(d_from_c.keys()): 232 | logger.info( 233 | f"[{k}] -- Radius: {R[k]:.4f} - " 234 | f"Dist from sphere centr: {d_from_c[k]/idx:.4f}" 235 | ) 236 | tb_writer.add_scalar(f'train/radius_{k}', R[k], kk) 237 | tb_writer.add_scalar(f'train/distance_c_sphere_{k}', d_from_c[k]/idx, kk) 238 | kk += 1 239 | 240 | scheduler.step() 241 | if epoch in lr_milestones: 242 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0])) 243 | 244 | if (loss_epoch/len(train_loader)) <= best_loss: 245 | net_cehckpoint = os.path.join(out_dir, f'best_oc_model.pth') 246 | best_loss = (loss_epoch/len(train_loader)) 247 | torch.save({ 248 | 'net_state_dict': net.state_dict(), 249 | 'R': R, 250 | 'c': centers 251 | }, 252 | net_cehckpoint 253 | ) 254 | logger.info(f'Saved best model so far at: {net_cehckpoint}') 255 | 256 | logger.info('Finished training!!') 257 | 258 | return net_cehckpoint 259 | 260 | 261 | def test(normal_class: str, is_texture: bool, net: nn.Module, test_loader: DataLoader, R: dict, c: dict, device: str, boundary: str, debug: bool) -> [float, float]: 262 | """Test the Encoder network. 263 | 264 | Parameters 265 | ---------- 266 | normal_class : str 267 | Name of the class under test 268 | is_texture : bool 269 | True if the input data belong to a texture-type class 270 | net : nn.Module 271 | Encoder network 272 | test_loader : DataLoader 273 | Data laoder 274 | R : dict 275 | Dictionary containing the values of the radiuses for each layer 276 | c : dict 277 | Dictionary containing the values of the hyperspheres' center for each layer 278 | device : str 279 | Device 280 | boundary : str 281 | Type of boundary 282 | debug : bool 283 | Only use the first 10 batches 284 | 285 | Returns 286 | ------- 287 | test_auc : float 288 | AUC 289 | balanced_accuracy : float 290 | Maximum Balanced Accuracy 291 | 292 | """ 293 | logger = logging.getLogger() 294 | 295 | # Testing 296 | logger.info('Start testing...') 297 | 298 | idx_label_score = [] 299 | 300 | net.eval().to(device) 301 | 302 | with torch.no_grad(): 303 | for idx, (data, labels) in enumerate(tqdm(test_loader, total=len(test_loader), desc=f"Testing class: {normal_class}", leave=False)): 304 | if debug and idx == 5: break 305 | 306 | data = data.to(device) 307 | 308 | if is_texture: 309 | ## Get 8 patches from each texture image ==> the anomaly score is max{score(patches)} 310 | _, _, h, w = data.shape 311 | assert h == w, "Height and Width are different!!!" 312 | patch_size = 64 313 | 314 | patches = [ 315 | data[:, :, h_:h_+patch_size, w_:w_+patch_size] 316 | for h_ in range(0, h, patch_size) 317 | for w_ in range(0, w, patch_size) 318 | ] 319 | 320 | patches = torch.stack(patches, dim=1) # shape = (b_size, nb_patches, ch, h, w) 321 | 322 | scores = torch.stack([ 323 | get_scores( 324 | zipped=net(batch), 325 | c=c, R=R, 326 | device=device, 327 | boundary=boundary, 328 | is_texture=is_texture) 329 | for batch in patches 330 | ]) # batch.shape = (nb_patches, ch, h, w) 331 | 332 | else: 333 | scores = get_scores(zipped=net(data), c=c, R=R, device=device, boundary=boundary, is_texture=is_texture) 334 | 335 | idx_label_score += list( 336 | zip( 337 | labels.cpu().data.numpy().tolist(), 338 | scores.cpu().data.numpy().tolist() 339 | ) 340 | ) 341 | 342 | # Compute AUC 343 | labels, scores = zip(*idx_label_score) 344 | labels = np.array(labels) 345 | scores = np.array(scores) 346 | 347 | #test_auc = roc_auc_score(labels, scores) 348 | fpr, tpr, _ = roc_curve(labels, scores) 349 | balanced_accuracy = np.max((tpr + (1 - fpr)) / 2) 350 | auroc = auc(fpr, tpr) 351 | logger.info(f'Test set results ===> AUC: {auroc:.4f} --- maxB: {balanced_accuracy:.4f}') 352 | logger.info('Finished testing!!') 353 | 354 | return auroc, balanced_accuracy 355 | 356 | 357 | def eval_ad_loss(zipped: dict, c: dict, R: dict, nu: float, boundary: str) -> [dict, torch.Tensor]: 358 | """Evaluate ancoder loss in the one class setting. 359 | 360 | Parameters 361 | ---------- 362 | zipped : dict 363 | Dictionary containing output features 364 | c : dict 365 | Dictionary of layers centroids 366 | R : dict 367 | Dictionary of layers radiuses 368 | nu : float 369 | Trade-off parameters 370 | boundary: str 371 | Type of boundary 372 | 373 | Returns 374 | ------- 375 | dist : dict 376 | Dictionary containing the sum of the distances of the features from the hyperspheres center at each layer 377 | loss : torch.Tensor 378 | Trainign loss 379 | 380 | """ 381 | dist = {} 382 | 383 | loss = 1 384 | 385 | for (k, v) in zipped: 386 | dist[k] = torch.sum((v - c[k].unsqueeze(0)) ** 2, dim=1) 387 | 388 | if boundary == 'soft': 389 | scores = dist[k] - R[k] ** 2 390 | loss += R[k] ** 2 + (1 / nu) * torch.mean(torch.max(torch.zeros_like(scores), scores)) 391 | 392 | else: 393 | loss += torch.mean(dist[k]) 394 | 395 | return dist, loss 396 | 397 | 398 | def get_scores(zipped: dict, c: dict, R: dict, device: str, boundary: str, is_texture: bool) -> float: 399 | """Evaluate anomaly score. 400 | 401 | Parameters 402 | ---------- 403 | zipped : dict 404 | Dictionary containing output features 405 | c : dict 406 | Dictionary of layers centroids 407 | R : dict 408 | Dictionary of layers radiuses 409 | device : str 410 | Device on which run the computation 411 | boundary: str 412 | Type of boundary 413 | is_texture : bool 414 | True if images belong to texture-type classes 415 | 416 | Returns 417 | ------- 418 | scores : float 419 | Anomlay score for each image 420 | 421 | """ 422 | dist = {item[0]: torch.norm(item[1] - c[item[0]].unsqueeze(0), dim=1) for item in zipped} 423 | 424 | shape = dist[list(dist.keys())[0]].shape[0] 425 | scores = torch.zeros((shape,), device=device) 426 | 427 | for k in dist.keys(): 428 | 429 | if boundary == 'soft': 430 | scores += dist[k] - R[k] # R[k] is a number not a vector 431 | 432 | else: 433 | scores += dist[k] 434 | 435 | return scores.max()/len(list(dist.keys())) if is_texture else scores/len(list(dist.keys())) 436 | -------------------------------------------------------------------------------- /main_shanghaitech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import logging 5 | import argparse 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from torchvision.utils import make_grid 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | from tensorboardX import SummaryWriter 15 | 16 | from models.shanghaitech_model import ShanghaiTech, ShanghaiTechEncoder, ShanghaiTechDecoder 17 | 18 | from datasets.data_manager import DataManager 19 | from datasets.shanghaitech_test import VideoAnomalyDetectionResultHelper 20 | 21 | from trainers.trainer_shanghaitech import pretrain, train 22 | from utils import set_seeds, get_out_dir, eval_spheres_centers, load_mvtec_model_from_checkpoint, extract_arguments_from_checkpoint 23 | 24 | 25 | def main(args): 26 | # Set seed 27 | set_seeds(args.seed) 28 | 29 | # Get the device 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | 32 | if args.disable_logging: 33 | logging.disable(level=logging.INFO) 34 | 35 | 36 | ## Init logger & print training/warm-up summary 37 | logging.basicConfig( 38 | level=logging.INFO, 39 | format="%(asctime)s | %(message)s", 40 | handlers=[ 41 | logging.FileHandler('./training.log'), 42 | logging.StreamHandler() 43 | ]) 44 | logger = logging.getLogger() 45 | 46 | 47 | if args.train or args.pretrain or args.end_to_end_training: 48 | # If the list of layers from which extract the features is empty, then use the last one (after the sigmoid) 49 | if len(args.idx_list_enc) == 0: args.idx_list_enc = [6] 50 | 51 | logger.info( 52 | "Start run with params:\n" 53 | f"\n\t\t\t\tEnd to end training : {args.end_to_end_training}" 54 | f"\n\t\t\t\tPretrain model : {args.pretrain}" 55 | f"\n\t\t\t\tTrain model : {args.train}" 56 | f"\n\t\t\t\tTest model : {args.test}" 57 | f"\n\t\t\t\tBatch size : {args.batch_size}\n" 58 | f"\n\t\t\t\tAutoEncoder Pretraining" 59 | f"\n\t\t\t\tPretrain epochs : {args.ae_epochs}" 60 | f"\n\t\t\t\tAE-Learning rate : {args.ae_learning_rate}" 61 | f"\n\t\t\t\tAE-milestones : {args.ae_lr_milestones}" 62 | f"\n\t\t\t\tAE-Weight decay : {args.ae_weight_decay}\n" 63 | f"\n\t\t\t\tEncoder Training" 64 | f"\n\t\t\t\tClip length : {args.clip_length}" 65 | f"\n\t\t\t\tBoundary : {args.boundary}" 66 | f"\n\t\t\t\tTrain epochs : {args.epochs}" 67 | f"\n\t\t\t\tLearning rate : {args.learning_rate}" 68 | f"\n\t\t\t\tMilestones : {args.lr_milestones}" 69 | f"\n\t\t\t\tUse selectors : {args.use_selectors}" 70 | f"\n\t\t\t\tWeight decay : {args.weight_decay}" 71 | f"\n\t\t\t\tCode length : {args.code_length}" 72 | f"\n\t\t\t\tNu : {args.nu}" 73 | f"\n\t\t\t\tEncoder list : {args.idx_list_enc}\n" 74 | f"\n\t\t\t\tLSTMs" 75 | f"\n\t\t\t\tLoad LSTMs : {args.load_lstm}" 76 | f"\n\t\t\t\tBidirectional : {args.bidirectional}" 77 | f"\n\t\t\t\tHidden size : {args.hidden_size}" 78 | f"\n\t\t\t\tNumber of layers : {args.num_layers}" 79 | f"\n\t\t\t\tDropout prob : {args.dropout}\n" 80 | ) 81 | else: 82 | if args.model_ckp is None: 83 | logger.info("CANNOT TEST MODEL WITHOUT A VALID CHECKPOINT") 84 | sys.exit(0) 85 | 86 | 87 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 88 | 89 | # Init DataHolder class 90 | data_holder = DataManager( 91 | dataset_name='ShanghaiTech', 92 | data_path=args.data_path, 93 | normal_class=None, 94 | only_test=args.test 95 | ).get_data_holder() 96 | 97 | # Load data 98 | train_loader, _ = data_holder.get_loaders( 99 | batch_size=args.batch_size, 100 | shuffle_train=True, 101 | pin_memory=device=="cuda", 102 | num_workers=args.n_workers 103 | ) 104 | # Print data infos 105 | only_test = args.test and not args.train and not args.pretrain 106 | logger.info("Dataset info:") 107 | logger.info( 108 | "\n" 109 | f"\n\t\t\t\tBatch size : {args.batch_size}" 110 | ) 111 | if not only_test: 112 | logger.info( 113 | f"TRAIN:" 114 | f"\n\t\t\t\tNumber of clips : {len(train_loader.dataset)}" 115 | f"\n\t\t\t\tNumber of batches : {len(train_loader.dataset)//args.batch_size}" 116 | ) 117 | 118 | 119 | ######################################################################################## 120 | ####### Train the AUTOENCODER on the RECONSTRUCTION task and then train only the ####### 121 | ########################## ENCODER on the ONE CLASS OBJECTIVE ########################## 122 | ######################################################################################## 123 | ae_net_checkpoint = None 124 | if args.pretrain and not args.end_to_end_training: 125 | out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='ShanghaiTech') 126 | 127 | tb_writer = SummaryWriter(os.path.join(args.output_path, "ShanghaiTech", 'tb_runs_pretrain', tmp)) 128 | # Init AutoEncoder 129 | ae_net = ShanghaiTech(data_holder.shape, args.code_length,use_selectors=args.use_selectors) 130 | ### PRETRAIN 131 | ae_net_checkpoint = pretrain(ae_net, train_loader, out_dir, tb_writer, device, args) 132 | tb_writer.close() 133 | 134 | net_checkpoint = None 135 | 136 | if args.train and not args.end_to_end_training: 137 | if ae_net_checkpoint is None: 138 | if args.model_ckp is None: 139 | logger.info("CANNOT TRAIN MODEL WITHOUT A VALID CHECKPOINT") 140 | sys.exit(0) 141 | ae_net_checkpoint = args.model_ckp 142 | 143 | aelr = float(ae_net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1]) 144 | 145 | out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr, dset_name='ShanghaiTech') 146 | tb_writer = SummaryWriter(os.path.join(args.output_path, "ShanghaiTech", 'tb_runs_train', tmp)) 147 | 148 | # Init Encoder 149 | net = ShanghaiTechEncoder(data_holder.shape, args.code_length, args.load_lstm, args.hidden_size, args.num_layers, args.dropout, args.bidirectional, args.use_selectors) 150 | 151 | # Load encoder weight from autoencoder 152 | net_dict = net.state_dict() 153 | logger.info(f"Loading encoder from: {ae_net_checkpoint}") 154 | ae_net_dict = torch.load(ae_net_checkpoint, map_location=lambda storage, loc: storage)['ae_state_dict'] 155 | 156 | # Filter out decoder network keys 157 | st_dict = {k: v for k, v in ae_net_dict.items() if k in net_dict} 158 | # Overwrite values in the existing state_dict 159 | net_dict.update(st_dict) 160 | # Load the new state_dict 161 | net.load_state_dict(net_dict) 162 | 163 | ### TRAIN 164 | net_checkpoint = train(net, train_loader, out_dir, tb_writer, device, ae_net_checkpoint, args) 165 | tb_writer.close() 166 | 167 | ######################################################################################## 168 | ######################################################################################## 169 | 170 | ######################################################################################## 171 | ################### Train the AUTOENCODER on the combined objective: ################### 172 | ############################## RECONSTRUCTION + ONE CLASS ############################## 173 | ######################################################################################## 174 | if args.end_to_end_training: 175 | 176 | out_dir, tmp = get_out_dir(args, pretrain=False, aelr=int(args.learning_rate), dset_name='ShanghaiTech') 177 | 178 | 179 | tb_writer = SummaryWriter(os.path.join(args.output_path, "ShanghaiTech", 'tb_runs_train_end_to_end', tmp)) 180 | # Init AutoEncoder 181 | ae_net = ShanghaiTech(data_holder.shape, args.code_length, args.load_lstm, args.hidden_size, args.num_layers, args.dropout, args.bidirectional, args.use_selectors) 182 | ### End to end TRAIN 183 | net_checkpoint = train(ae_net, train_loader, out_dir, tb_writer, device, None, args) 184 | tb_writer.close() 185 | ######################################################################################## 186 | ######################################################################################## 187 | 188 | ######################################################################################## 189 | ###################################### Model test ###################################### 190 | ######################################################################################## 191 | if args.test: 192 | if net_checkpoint is None: 193 | net_checkpoint = args.model_ckp 194 | 195 | code_length, batch_size, boundary, use_selectors, idx_list_enc, \ 196 | load_lstm, hidden_size, num_layers, dropout, bidirectional, \ 197 | dataset_name, train_type = extract_arguments_from_checkpoint(net_checkpoint) 198 | 199 | # Init dataset 200 | dataset = data_holder.get_test_data() 201 | if train_type == "train_end_to_end": 202 | # Init Autoencoder 203 | net = ShanghaiTech(data_holder.shape, args.code_length, load_lstm, hidden_size, num_layers, dropout, bidirectional, use_selectors) 204 | else: 205 | # Init Encoder ONLY 206 | net = ShanghaiTechEncoder(dataset.shape, code_length, load_lstm, hidden_size, num_layers, dropout, bidirectional, use_selectors) 207 | st_dict = torch.load(net_checkpoint) 208 | 209 | net.load_state_dict(st_dict['net_state_dict']) 210 | logger.info(f"Loaded model from: {net_checkpoint}") 211 | logger.info( 212 | f"Start test with params:" 213 | f"\n\t\t\t\tDataset : {dataset_name}" 214 | f"\n\t\t\t\tCode length : {code_length}" 215 | f"\n\t\t\t\tEnc layer list : {idx_list_enc}" 216 | f"\n\t\t\t\tBoundary : {boundary}" 217 | f"\n\t\t\t\tUse Selectors : {use_selectors}" 218 | f"\n\t\t\t\tBatch size : {batch_size}" 219 | f"\n\t\t\t\tN workers : {args.n_workers}" 220 | f"\n\t\t\t\tLoad LSTMs : {load_lstm}" 221 | f"\n\t\t\t\tHidden size : {hidden_size}" 222 | f"\n\t\t\t\tNum layers : {num_layers}" 223 | f"\n\t\t\t\tBidirectional : {bidirectional}" 224 | f"\n\t\t\t\tDropout prob : {dropout}" 225 | ) 226 | 227 | # Initialize test helper for processing each video seperately 228 | # It prints the result to the loaded checkpoint directory 229 | helper = VideoAnomalyDetectionResultHelper( 230 | dataset=dataset, 231 | model=net, 232 | c=st_dict['c'], 233 | R=st_dict['R'], 234 | boundary=boundary, 235 | device=device, 236 | end_to_end_training= True if train_type == "train_end_to_end" else False, 237 | debug=args.debug, 238 | output_file=os.path.join("".join(net_checkpoint.split(os.sep)[:-1]),"shanghaitech_test_results.txt") 239 | ) 240 | ### TEST 241 | helper.test_video_anomaly_detection() 242 | print("Test finished") 243 | ######################################################################################## 244 | ######################################################################################## 245 | 246 | 247 | if __name__ == '__main__': 248 | 249 | parser = argparse.ArgumentParser('AD') 250 | ## General config 251 | parser.add_argument('-s', '--seed', type=int, default=-1, help='Random seed (default: -1)') 252 | parser.add_argument('--n_workers', type=int, default=8, help='Number of workers for data loading. 0 means that the data will be loaded in the main process. (default=8)') 253 | parser.add_argument('--output_path', default='./output') 254 | parser.add_argument('-lf', '--log-frequency', type=int, default=5, help='Log frequency (default: 5)') 255 | parser.add_argument('-dl', '--disable-logging', action="store_true", help='Disabel logging (default: False)') 256 | parser.add_argument('-db', '--debug', action='store_true', help='Debug mode (default: False)') 257 | ## Model config 258 | parser.add_argument('-zl', '--code-length', default=2048, type=int, help='Code length (default: 2048)') 259 | parser.add_argument('-ck', '--model-ckp', help='Model checkpoint') 260 | ## Optimizer config 261 | parser.add_argument('-opt', '--optimizer', choices=('adam', 'sgd'), default='adam', help='Optimizer (default: adam)') 262 | parser.add_argument('-alr', '--ae-learning-rate', type=float, default=1.e-4, help='Warm up learning rate (default: 1.e-4)') 263 | parser.add_argument('-lr', '--learning-rate', type=float, default=1.e-4, help='Learning rate (default: 1.e-4)') 264 | parser.add_argument('-awd', '--ae-weight-decay', type=float, default=0.5e-6, help='Warm up learning rate (default: 1.e-4)') 265 | parser.add_argument('-wd', '--weight-decay', type=float, default=0.5e-6, help='Learning rate (default: 1.e-4)') 266 | parser.add_argument('-aml', '--ae-lr-milestones', type=int, nargs='+', default=[], help='Pretrain milestone') 267 | parser.add_argument('-ml', '--lr-milestones', type=int, nargs='+', default=[], help='Training milestone') 268 | ## Data 269 | parser.add_argument('-dp', '--data-path', default='./ShanghaiTech', help='Dataset main path') 270 | parser.add_argument('-cl', '--clip-length', type=int, default=16, help='Clip length (default: 16)') 271 | ## Training config 272 | # LSTMs 273 | parser.add_argument('-ll', '--load-lstm', action="store_true", help='Load LSTMs (default: False)') 274 | parser.add_argument('-bdl', '--bidirectional', action="store_true", help='Bidirectional LSTMs (default: False)') 275 | parser.add_argument('-hs', '--hidden-size', type=int, default=100, help='Hidden size (default: 100)') 276 | parser.add_argument('-nl', '--num-layers', type=int, default=1, help='Number of LSTMs layers (default: 1)') 277 | parser.add_argument('-drp', '--dropout', type=float, default=0.0, help='Dropout probability (default: 0.0)') 278 | # Autoencoder 279 | parser.add_argument('-ee', '--end-to-end-training', action="store_true", help='End-to-End training of the autoencoder (default: False)') 280 | parser.add_argument('-we', '--warm_up_n_epochs', type=int, default=5, help='Warm up epochs (default: 5)') 281 | parser.add_argument('-use','--use-selectors', action="store_true", help='Use features selector (default: False)') 282 | parser.add_argument('-ba', '--batch-accumulation', type=int, default=-1, help='Batch accumulation (default: -1, i.e., None)') 283 | parser.add_argument('-ptr', '--pretrain', action="store_true", help='Pretrain model (default: False)') 284 | parser.add_argument('-tr', '--train', action="store_true", help='Train model (default: False)') 285 | parser.add_argument('-tt', '--test', action="store_true", help='Test model (default: False)') 286 | parser.add_argument('-tbc', '--train-best-conf', action="store_true", help='Train best configurations (default: False)') 287 | parser.add_argument('-bs', '--batch-size', type=int, default=4, help='Batch size (default: 4)') 288 | parser.add_argument('-bd', '--boundary', choices=("hard", "soft"), default="soft", help='Boundary (default: soft)') 289 | parser.add_argument('-ile', '--idx-list-enc', type=int, nargs='+', default=[], help='List of indices of model encoder') 290 | parser.add_argument('-e', '--epochs', type=int, default=1, help='Training epochs (default: 1)') 291 | parser.add_argument('-ae', '--ae-epochs', type=int, default=1, help='Warmp up epochs (default: 1)') 292 | parser.add_argument('-nu', '--nu', type=float, default=0.1) 293 | 294 | args = parser.parse_args() 295 | main(args) -------------------------------------------------------------------------------- /datasets/shanghaitech_test.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from os.path import basename 3 | from os.path import isdir 4 | from os.path import join 5 | from typing import List 6 | from typing import Tuple 7 | 8 | 9 | 10 | 11 | import cv2 12 | import torch 13 | import numpy as np 14 | import skimage.io as io 15 | from tqdm import tqdm 16 | 17 | from sklearn.metrics import roc_auc_score 18 | from torch.utils.data import DataLoader 19 | from prettytable import PrettyTable 20 | from skimage.transform import resize 21 | from torchvision import transforms 22 | from torch.utils.data.dataloader import default_collate 23 | from .base import VideoAnomalyDetectionDataset, ToFloatTensor3D, ToFloatTensor3DMask 24 | 25 | class ShanghaiTechTestHandler(VideoAnomalyDetectionDataset): 26 | 27 | def __init__(self, path): 28 | # type: (str) -> None 29 | """ 30 | Class constructor. 31 | :param path: The folder in which ShanghaiTech is stored. 32 | """ 33 | super(ShanghaiTechTestHandler, self).__init__() 34 | self.path = path 35 | # Test directory 36 | self.test_dir = join(path, 'testing') 37 | # Transform 38 | self.transform = transforms.Compose([ToFloatTensor3D(normalize=True)]) 39 | # Load all test ids 40 | self.test_ids = self.load_test_ids() 41 | # Other utilities 42 | self.cur_len = 0 43 | self.cur_video_id = None 44 | self.cur_video_frames = None 45 | self.cur_video_gt = None 46 | 47 | def load_test_ids(self): 48 | # type: () -> List[str] 49 | """ 50 | Loads the set of all test video ids. 51 | :return: The list of test ids. 52 | """ 53 | return sorted([basename(d) for d in glob(join(self.test_dir, 'nobackground_frames_resized', '**')) if isdir(d)]) 54 | 55 | def load_test_sequence_frames(self, video_id): 56 | # type: (str) -> np.ndarray 57 | """ 58 | Loads a test video in memory. 59 | :param video_id: the id of the test video to be loaded 60 | :return: the video in a np.ndarray, with shape (n_frames, h, w, c). 61 | """ 62 | c, t, h, w = self.shape 63 | sequence_dir = join(self.test_dir, 'nobackground_frames_resized', video_id) 64 | img_list = sorted(glob(join(sequence_dir, '*.jpg'))) 65 | #print(f"Creating clips for {sequence_dir} dataset with length {t}...") 66 | return np.stack([np.uint8(io.imread(img_path)) for img_path in img_list]) 67 | 68 | def load_test_sequence_gt(self, video_id): 69 | # type: (str) -> np.ndarray 70 | """ 71 | Loads the groundtruth of a test video in memory. 72 | :param video_id: the id of the test video for which the groundtruth has to be loaded. 73 | :return: the groundtruth of the video in a np.ndarray, with shape (n_frames,). 74 | """ 75 | clip_gt = np.load(join(self.test_dir, 'test_frame_mask', f'{video_id}.npy')) 76 | return clip_gt 77 | 78 | def test(self, video_id): 79 | # type: (str) -> None 80 | """ 81 | Sets the dataset in test mode. 82 | :param video_id: the id of the video to test. 83 | """ 84 | c, t, h, w = self.shape 85 | self.cur_video_id = video_id 86 | self.cur_video_frames = self.load_test_sequence_frames(video_id) 87 | self.cur_video_gt = self.load_test_sequence_gt(video_id) 88 | self.cur_len = len(self.cur_video_frames) - t + 1 89 | 90 | @property 91 | def shape(self): 92 | # type: () -> Tuple[int, int, int, int] 93 | """ 94 | Returns the shape of examples being fed to the model. 95 | """ 96 | return 3, 16, 256, 512 97 | 98 | @property 99 | def test_videos(self): 100 | # type: () -> List[str] 101 | """ 102 | Returns all available test videos. 103 | """ 104 | return self.test_ids 105 | 106 | def __len__(self): 107 | # type: () -> int 108 | """ 109 | Returns the number of examples. 110 | """ 111 | return self.cur_len 112 | 113 | def __getitem__(self, i): 114 | # type: (int) -> Tuple[torch.Tensor, torch.Tensor] 115 | """ 116 | Provides the i-th example. 117 | """ 118 | c, t, h, w = self.shape 119 | clip = self.cur_video_frames[i:i+t] 120 | sample = clip 121 | # Apply transform 122 | if self.transform: 123 | sample = self.transform(sample) 124 | return sample 125 | 126 | @property 127 | def collate_fn(self): 128 | """ 129 | Returns a function that decides how to merge a list of examples in a batch. 130 | """ 131 | return default_collate 132 | 133 | def __repr__(self): 134 | return f'ShanghaiTech (video id = {self.cur_video_id})' 135 | 136 | def get_target_label_idx(labels, targets): 137 | """ 138 | Get the indices of labels that are included in targets. 139 | :param labels: array of labels 140 | :param targets: list/tuple of target labels 141 | :return: list with indices of target labels 142 | """ 143 | return np.argwhere(np.isin(labels, targets)).flatten().tolist() 144 | 145 | 146 | def global_contrast_normalization(x: torch.tensor, scale='l2'): 147 | """ 148 | Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale, 149 | which is either the standard deviation, L1- or L2-norm across features (pixels). 150 | Note this is a *per sample* normalization globally across features (and not across the dataset). 151 | """ 152 | 153 | assert scale in ('l1', 'l2') 154 | 155 | n_features = int(np.prod(x.shape)) 156 | 157 | mean = torch.mean(x) # mean over all features (pixels) per sample 158 | x -= mean 159 | 160 | if scale == 'l1': 161 | x_scale = torch.mean(torch.abs(x)) 162 | 163 | if scale == 'l2': 164 | x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features 165 | 166 | x /= x_scale 167 | 168 | return x 169 | 170 | 171 | class ResultsAccumulator: 172 | """ 173 | Accumulates results in a buffer for a sliding window 174 | results computation. Employed to get frame-level scores 175 | from clip-level scores. 176 | ` In order to recover the anomaly score of each 177 | frame, we compute the mean score of all clips in which it 178 | appears` 179 | """ 180 | def __init__(self, nb_frames_per_clip): 181 | # type: (int) -> None 182 | """ 183 | Class constructor. 184 | :param nb_frames_per_clip: the number of frames each clip holds. 185 | """ 186 | 187 | # This buffers rotate. 188 | self._buffer = np.zeros(shape=(nb_frames_per_clip,), dtype=np.float32) 189 | self._counts = np.zeros(shape=(nb_frames_per_clip,)) 190 | 191 | def push(self, score): 192 | # type: (float) -> None 193 | """ 194 | Pushes the score of a clip into the buffer. 195 | :param score: the score of a clip 196 | """ 197 | 198 | # Update buffer and counts 199 | self._buffer += score 200 | self._counts += 1 201 | 202 | def get_next(self): 203 | # type: () -> float 204 | """ 205 | Gets the next frame (the first in the buffer) score, 206 | computed as the mean of the clips in which it appeared, 207 | and rolls the buffers. 208 | :return: the averaged score of the frame exiting the buffer. 209 | """ 210 | 211 | # Return first in buffer 212 | ret = self._buffer[0] / self._counts[0] 213 | 214 | # Roll time backwards 215 | self._buffer = np.roll(self._buffer, shift=-1) 216 | self._counts = np.roll(self._counts, shift=-1) 217 | 218 | # Zero out final frame (next to be filled) 219 | self._buffer[-1] = 0 220 | self._counts[-1] = 0 221 | 222 | return ret 223 | 224 | @property 225 | def results_left(self): 226 | # type: () -> np.int32 227 | """ 228 | Returns the number of frames still in the buffer. 229 | """ 230 | return np.sum(self._counts != 0).astype(np.int32) 231 | 232 | 233 | class VideoAnomalyDetectionResultHelper(object): 234 | """ 235 | Performs tests for video anomaly detection datasets (UCSD Ped2 or Shanghaitech). 236 | """ 237 | 238 | def __init__(self, dataset, model, c, R, boundary, device, end_to_end_training, debug, output_file): 239 | # type: (VideoAnomalyDetectionDataset, BaseModule, str, str) -> None 240 | """ 241 | Class constructor. 242 | :param dataset: dataset class. 243 | :param model: pytorch model to evaluate. 244 | :param output_file: text file where to save results. 245 | """ 246 | self.dataset = dataset 247 | self.model = model 248 | self.hc = c 249 | self.keys = list(c.keys()) 250 | self.R = R 251 | self.boundary = boundary 252 | self.device = device 253 | self.end_to_end_training = end_to_end_training 254 | self.debug = debug 255 | self.output_file = output_file 256 | 257 | def _get_scores(self, d_lstm): 258 | # Eval novelty scores 259 | dist = {k: torch.sum((d_lstm[k] - self.hc[k].unsqueeze(0)) ** 2, dim=1) for k in self.keys} 260 | scores = {k: torch.zeros((dist[k].shape[0],), device=self.device) for k in self.keys} 261 | overall_score = torch.zeros((dist[self.keys[0]].shape[0],), device=self.device) 262 | for k in self.keys: 263 | if self.boundary == 'soft': 264 | scores[k] += dist[k] - self.R[k] ** 2 265 | overall_score += dist[k] - self.R[k] ** 2 266 | else: 267 | scores[k] += dist[k] 268 | overall_score += dist[k] 269 | scores = {k: scores[k]/len(self.keys) for k in self.keys} 270 | return scores, overall_score/len(self.keys) 271 | 272 | @torch.no_grad() 273 | def test_video_anomaly_detection(self): 274 | # type: () -> None 275 | """ 276 | Actually performs tests. 277 | """ 278 | self.model.eval().to(self.device) 279 | 280 | c, t, h, w = self.dataset.raw_shape 281 | 282 | # Prepare a table to show results 283 | vad_table = self.empty_table 284 | 285 | # Set up container for anomaly scores from all test videos 286 | ## oc: one class 287 | ## rc: reconstruction 288 | ## as: overall anomaly score 289 | global_oc = [] 290 | global_rc = [] 291 | global_as = [] 292 | global_as_by_layer = {k: [] for k in self.keys} 293 | global_y = [] 294 | global_y_by_layer = {k: [] for k in self.keys} 295 | 296 | # Get accumulators 297 | results_accumulator_rc = ResultsAccumulator(nb_frames_per_clip=t) 298 | results_accumulator_oc = ResultsAccumulator(nb_frames_per_clip=t) 299 | results_accumulator_oc_by_layer = {k: ResultsAccumulator(nb_frames_per_clip=t) for k in self.keys} 300 | print(self.dataset.test_videos) 301 | # Start iteration over test videos 302 | for cl_idx, video_id in tqdm(enumerate(self.dataset.test_videos), total=len(self.dataset.test_videos), desc="Test on Video"): 303 | # Run the test 304 | self.dataset.test(video_id) 305 | loader = DataLoader(self.dataset, collate_fn=self.dataset.collate_fn) 306 | 307 | # Build score containers 308 | sample_rc = np.zeros(shape=(len(loader) + t - 1,)) 309 | sample_oc = np.zeros(shape=(len(loader) + t - 1,)) 310 | sample_oc_by_layer = {k: np.zeros(shape=(len(loader) + t - 1,)) for k in self.keys} 311 | sample_y = self.dataset.load_test_sequence_gt(video_id) 312 | 313 | for i, x in tqdm(enumerate(loader), total=len(loader), desc=f'Computing scores for {self.dataset}', leave=False): 314 | # x.shape = [1, 3, 16, 256, 512] 315 | x = x.to(self.device) 316 | 317 | if self.end_to_end_training: 318 | x_r, _, d_lstm = self.model(x) 319 | recon_loss = torch.mean(torch.sum((x_r - x) ** 2, dim=tuple(range(1, x_r.dim())))) 320 | else: 321 | _, d_lstm = self.model(x) 322 | recon_loss = torch.tensor([0.0]) 323 | 324 | # Eval one class score for current clip 325 | oc_loss_by_layer, oc_overall_loss = self._get_scores(d_lstm) 326 | 327 | # Feed results accumulators 328 | results_accumulator_rc.push(recon_loss.item()) 329 | sample_rc[i] = results_accumulator_rc.get_next() 330 | results_accumulator_oc.push(oc_overall_loss.item()) 331 | sample_oc[i] = results_accumulator_oc.get_next() 332 | 333 | for k in self.keys: 334 | if k != "tdl_lstm_o_0" and k != "tdl_lstm_o_1": 335 | results_accumulator_oc_by_layer[k].push(oc_loss_by_layer[k].item()) 336 | sample_oc_by_layer[k][i] = results_accumulator_oc_by_layer[k].get_next() 337 | 338 | # Get last results layer by layer 339 | for k in self.keys: 340 | if k != "tdl_lstm_o_0" and k != "tdl_lstm_o_1": 341 | while results_accumulator_oc_by_layer[k].results_left != 0: 342 | index = (- results_accumulator_oc_by_layer[k].results_left) 343 | sample_oc_by_layer[k][index] = results_accumulator_oc_by_layer[k].get_next() 344 | 345 | min_, max_ = sample_oc_by_layer[k].min(), sample_oc_by_layer[k].max() 346 | 347 | # Computes the normalized novelty score given likelihood scores, reconstruction scores 348 | # and normalization coefficients (Eq. 9-10). 349 | sample_ns = (sample_oc_by_layer[k] - min_) / (max_ - min_) 350 | 351 | # Update global scores (used for global metrics) 352 | global_as_by_layer[k].append(sample_ns) 353 | global_y_by_layer[k].append(sample_y) 354 | 355 | try: 356 | # Compute AUROC for this video 357 | this_video_metrics = [ 358 | roc_auc_score(sample_y, sample_ns), # anomaly score == one class metric 359 | 0., 360 | 0. 361 | ] 362 | #vad_table.add_row([k] + [video_id] + this_video_metrics) 363 | except ValueError: 364 | # This happens for sequences in which all frames are abnormal 365 | # Skipping this row in the table (the sequence will still count for global metrics) 366 | continue 367 | 368 | # Get last results 369 | while results_accumulator_oc.results_left != 0: 370 | index = (- results_accumulator_oc.results_left) 371 | sample_oc[index] = results_accumulator_oc.get_next() 372 | sample_rc[index] = results_accumulator_rc.get_next() 373 | 374 | min_oc, max_oc, min_rc, max_rc = sample_oc.min(), sample_oc.max(), sample_rc.min(), sample_rc.max() 375 | 376 | # Computes the normalized novelty score given likelihood scores, reconstruction scores 377 | # and normalization coefficients (Eq. 9-10). 378 | sample_oc = (sample_oc - min_oc) / (max_oc - min_oc) 379 | sample_rc = (sample_rc - min_rc) / (max_rc - min_rc) if (max_rc - min_rc) > 0 else np.zeros_like(sample_rc) 380 | sample_as = sample_oc + sample_rc 381 | 382 | # Update global scores (used for global metrics) 383 | global_oc.append(sample_oc) 384 | global_rc.append(sample_rc) 385 | global_as.append(sample_as) 386 | global_y.append(sample_y) 387 | 388 | try: 389 | # Compute AUROC for this video 390 | this_video_metrics = [ 391 | roc_auc_score(sample_y, sample_oc), # one class metric 392 | roc_auc_score(sample_y, sample_rc), # reconstruction metric 393 | roc_auc_score(sample_y, sample_as) # anomaly score 394 | ] 395 | #vad_table.add_row(['Overall'] + [video_id] + this_video_metrics) 396 | except ValueError: 397 | # This happens for sequences in which all frames are abnormal 398 | # Skipping this row in the table (the sequence will still count for global metrics) 399 | continue 400 | 401 | if self.debug: break 402 | 403 | # Compute global AUROC and print table 404 | for k in self.keys: 405 | if k != "tdl_lstm_o_0" and k != "tdl_lstm_o_1": 406 | global_as_by_layer[k] = np.concatenate(global_as_by_layer[k]) 407 | global_y_by_layer[k] = np.concatenate(global_y_by_layer[k]) 408 | global_metrics = [ 409 | roc_auc_score(global_y_by_layer[k], global_as_by_layer[k]), # anomaly score == one class metric 410 | 0., 411 | 0. 412 | ] 413 | vad_table.add_row([k] + ['avg'] + global_metrics) 414 | 415 | # Compute global AUROC and print table 416 | global_oc = np.concatenate(global_oc) 417 | global_rc = np.concatenate(global_rc) 418 | global_as = np.concatenate(global_as) 419 | global_y = np.concatenate(global_y) 420 | global_metrics = [ 421 | roc_auc_score(global_y, global_oc), # one class metric 422 | roc_auc_score(global_y, global_rc), # reconstruction metric 423 | roc_auc_score(global_y, global_as) # anomaly score 424 | ] 425 | 426 | vad_table.add_row(['Overall'] + ['avg'] + global_metrics) 427 | print(vad_table) 428 | 429 | # Save table 430 | with open(self.output_file, mode='w') as f: 431 | f.write(str(vad_table)) 432 | 433 | @property 434 | def empty_table(self): 435 | # type: () -> PrettyTable 436 | """ 437 | Sets up a nice ascii-art table to hold results. 438 | This table is suitable for the video anomaly detection setting. 439 | :return: table to be filled with auroc metrics. 440 | """ 441 | table = PrettyTable() 442 | table.field_names = ['Layer key', 'VIDEO-ID', 'OC metric', 'Recon metric', 'AUROC-AS'] 443 | table.float_format = '0.3' 444 | return table 445 | -------------------------------------------------------------------------------- /main_mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import random 5 | import logging 6 | import argparse 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from os import makedirs 11 | from os.path import exists 12 | from prettytable import PrettyTable 13 | 14 | import torch 15 | from torch.utils.data import DataLoader 16 | 17 | from tensorboardX import SummaryWriter 18 | 19 | from datasets.data_manager import DataManager 20 | from models.mvtec_model import MVTecNet_AutoEncoder 21 | from trainers.trainer_mvtec import pretrain, train, test 22 | from utils import set_seeds, get_out_dir, eval_spheres_centers, load_mvtec_model_from_checkpoint 23 | 24 | 25 | def test_models(test_loader: DataLoader, net_cehckpoint: str, tables: tuple, out_df: pd.DataFrame, is_texture: bool, input_shape: tuple, idx_list_enc: list, boundary: str, normal_class: str, use_selectors: bool, device: str, debug: bool): 26 | """Test a single model. 27 | 28 | Parameters 29 | ---------- 30 | test_loader : DataLoader 31 | Test data loader 32 | net_cehckpoint : str 33 | Path to model checkpoint 34 | tables : tuple 35 | Tuple containing PrettyTabels for soft and hard boundary 36 | out_df : DataFrame 37 | Output dataframe 38 | is_texture : bool 39 | True if we are dealing with texture-type class 40 | input_shape : tuple 41 | Shape of the input data 42 | idx_list_enc : list 43 | List containing the index of layers from which extract features 44 | boundary : str 45 | Type of boundary 46 | normal_class : str 47 | Name of the normal class 48 | use_selectors : bool 49 | True if we want to use Selector modules 50 | device : str 51 | Device to be used 52 | debug : bool 53 | Activate debug mode 54 | 55 | Returns 56 | ------- 57 | out_df : DataFrame 58 | Dataframe containing the test results 59 | 60 | """ 61 | logger = logging.getLogger() 62 | 63 | if not os.path.exists(net_cehckpoint): 64 | print(f"File not found at: {net_cehckpoint}") 65 | return out_df 66 | 67 | # Get latent code size from checkpoint name 68 | code_length = int(net_cehckpoint.split('/')[-2].split('-')[3].split('_')[-1]) 69 | 70 | if net_cehckpoint.split('/')[-2].split('-')[-1].split('_')[-1].split('.')[0] == '': 71 | idx_list_enc = [7] 72 | 73 | idx_list_enc = [int(i) for i in net_cehckpoint.split('/')[-2].split('-')[-1].split('_')[-1].split('.')] 74 | boundary = net_cehckpoint.split('/')[-2].split('-')[9].split('_')[-1] 75 | normal_class = net_cehckpoint.split('/')[-2].split('-')[2].split('_')[-1] 76 | 77 | logger.info( 78 | f"Start test with params" 79 | f"\n\t\t\t\tCode length : {code_length}" 80 | f"\n\t\t\t\tEnc layer list : {idx_list_enc}" 81 | f"\n\t\t\t\tBoundary : {boundary}" 82 | f"\n\t\t\t\tObject class : {normal_class}" 83 | ) 84 | 85 | # Init Encoder 86 | net = load_mvtec_model_from_checkpoint( 87 | input_shape=input_shape, 88 | code_length=code_length, 89 | idx_list_enc=idx_list_enc, 90 | use_selectors=use_selectors, 91 | net_cehckpoint=net_cehckpoint 92 | ) 93 | 94 | st_dict = torch.load(net_cehckpoint) 95 | net.load_state_dict(st_dict['net_state_dict']) 96 | 97 | ### TEST 98 | test_auc, test_b_acc = test( 99 | normal_class=normal_class, 100 | is_texture=is_texture, 101 | net=net, 102 | test_loader=test_loader, 103 | R=st_dict['R'], 104 | c=st_dict['c'], 105 | device=device, 106 | boundary=boundary, 107 | debug=debug 108 | ) 109 | 110 | table = tables[0] if boundary == 'soft' else tables[1] 111 | table.add_row([ 112 | net_cehckpoint.split('/')[-2], 113 | code_length, 114 | idx_list_enc, 115 | net_cehckpoint.split('/')[-2].split('-')[7].split('_')[-1]+'-'+net_cehckpoint.split('/')[-2].split('-')[8], 116 | normal_class, 117 | boundary, 118 | net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1], 119 | net_cehckpoint.split('/')[-2].split('-')[5].split('_')[-1], 120 | test_auc, 121 | test_b_acc 122 | ]) 123 | 124 | out_df = out_df.append(dict( 125 | path=net_cehckpoint.split('/')[-2], 126 | code_length=code_length, 127 | enc_l_list=idx_list_enc, 128 | weight_decay=net_cehckpoint.split('/')[-2].split('-')[7].split('_')[-1]+'-'+net_cehckpoint.split('/')[-2].split('-')[8], 129 | object_class=normal_class, 130 | boundary=boundary, 131 | batch_size=net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1], 132 | nu=net_cehckpoint.split('/')[-2].split('-')[5].split('_')[-1], 133 | auc=test_auc, 134 | balanced_acc=test_b_acc 135 | ), 136 | ignore_index=True 137 | ) 138 | 139 | return out_df 140 | 141 | 142 | def main(args): 143 | # Set seed 144 | set_seeds(args.seed) 145 | 146 | # Get the device 147 | device = "cuda" if torch.cuda.is_available() else "cpu" 148 | 149 | if args.disable_logging: 150 | logging.disable(level=logging.INFO) 151 | 152 | ## Init logger & print training/warm-up summary 153 | logging.basicConfig( 154 | level=logging.INFO, 155 | format="%(asctime)s | %(message)s", 156 | handlers=[ 157 | logging.FileHandler('./training.log'), 158 | logging.StreamHandler() 159 | ]) 160 | logger = logging.getLogger() 161 | 162 | if args.train or args.pretrain: 163 | # If the list of layers from which extract the features is empty, then use the last one (after the sigmoid) 164 | if len(args.idx_list_enc) == 0: args.idx_list_enc = [7] 165 | 166 | logger.info( 167 | "Start run with params:" 168 | f"\n\t\t\t\tPretrain model : {args.pretrain}" 169 | f"\n\t\t\t\tTrain model : {args.train}" 170 | f"\n\t\t\t\tTest model : {args.test}" 171 | f"\n\t\t\t\tBoundary : {args.boundary}" 172 | f"\n\t\t\t\tPretrain epochs : {args.ae_epochs}" 173 | f"\n\t\t\t\tAE-Learning rate : {args.ae_learning_rate}" 174 | f"\n\t\t\t\tAE-milestones : {args.ae_lr_milestones}" 175 | f"\n\t\t\t\tAE-Weight decay : {args.ae_weight_decay}\n" 176 | f"\n\t\t\t\tTrain epochs : {args.epochs}" 177 | f"\n\t\t\t\tBatch size: : {args.batch_size}" 178 | f"\n\t\t\t\tBatch acc. : {args.batch_accumulation}" 179 | f"\n\t\t\t\tWarm up epochs : {args.warm_up_n_epochs}" 180 | f"\n\t\t\t\tLearning rate : {args.learning_rate}" 181 | f"\n\t\t\t\tMilestones : {args.lr_milestones}" 182 | f"\n\t\t\t\tUse selectors : {args.use_selectors}" 183 | f"\n\t\t\t\tWeight decay : {args.weight_decay}\n" 184 | f"\n\t\t\t\tCode length : {args.code_length}" 185 | f"\n\t\t\t\tNu : {args.nu}" 186 | f"\n\t\t\t\tEncoder list : {args.idx_list_enc}\n" 187 | f"\n\t\t\t\tTest metric : {args.metric}" 188 | ) 189 | else: 190 | if args.model_ckp is None: 191 | logger.info("CANNOT TEST MODEL WITHOUT A VALID CHECKPOINT") 192 | sys.exit(0) 193 | 194 | if args.debug: 195 | args.normal_class = 'carpet' 196 | 197 | else: 198 | 199 | if os.path.isfile(args.model_ckp): 200 | args.normal_class = args.model_ckp.split('/')[-2].split('-')[2].split('_')[-1] 201 | 202 | else: 203 | args.normal_class = args.model_ckp.split('/')[-3] 204 | 205 | # Init DataHolder class 206 | data_holder = DataManager( 207 | dataset_name='MVTec_Anomaly', 208 | data_path=args.data_path, 209 | normal_class=args.normal_class, 210 | only_test=args.test 211 | ).get_data_holder() 212 | 213 | # Load data 214 | train_loader, test_loader = data_holder.get_loaders( 215 | batch_size=args.batch_size, 216 | shuffle_train=True, 217 | pin_memory=device=="cuda", 218 | num_workers=args.n_workers 219 | ) 220 | 221 | # Print data infos 222 | only_test = args.test and not args.train and not args.pretrain 223 | logger.info("Dataset info:") 224 | logger.info( 225 | "\n" 226 | f"\n\t\t\t\tNormal class : {args.normal_class}" 227 | f"\n\t\t\t\tBatch size : {args.batch_size}" 228 | ) 229 | if not only_test: 230 | logger.info( 231 | f"TRAIN:" 232 | f"\n\t\t\t\tNumber of images : {len(train_loader.dataset)}" 233 | f"\n\t\t\t\tNumber of batches : {len(train_loader.dataset)//args.batch_size}" 234 | ) 235 | logger.info( 236 | f"TEST:" 237 | f"\n\t\t\t\tNumber of images : {len(test_loader.dataset)}" 238 | ) 239 | 240 | is_texture = args.normal_class in tuple(["carpet", "grid", "leather", "tile", "wood"]) 241 | input_shape = (3, 64, 64) if is_texture else (3, 128, 128) 242 | 243 | ### PRETRAIN the full AutoEncoder 244 | ae_net_cehckpoint = None 245 | if args.pretrain: 246 | 247 | pretrain_out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='mvtec') 248 | pretrain_tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/pretrain', tmp)) 249 | 250 | # Init AutoEncoder 251 | ae_net = MVTecNet_AutoEncoder(input_shape=input_shape, code_length=args.code_length, use_selectors=args.use_selectors) 252 | 253 | # Start pretraining 254 | ae_net_cehckpoint = pretrain( 255 | ae_net=ae_net, 256 | train_loader=train_loader, 257 | out_dir=pretrain_out_dir, 258 | tb_writer=pretrain_tb_writer, 259 | device=device, 260 | ae_learning_rate=args.ae_learning_rate, 261 | ae_weight_decay=args.ae_weight_decay, 262 | ae_lr_milestones=args.ae_lr_milestones, 263 | ae_epochs=args.ae_epochs, 264 | log_frequency=args.log_frequency, 265 | batch_accumulation=args.batch_accumulation, 266 | debug=args.debug 267 | ) 268 | 269 | pretrain_tb_writer.close() 270 | 271 | ### TRAIN the Encoder 272 | net_cehckpoint = None 273 | if args.train: 274 | if ae_net_cehckpoint is None: 275 | if args.model_ckp is None: 276 | logger.info("CANNOT TRAIN MODEL WITHOUT A VALID CHECKPOINT") 277 | sys.exit(0) 278 | 279 | ae_net_cehckpoint = args.model_ckp 280 | 281 | aelr = float(ae_net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1]) 282 | 283 | train_out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr, dset_name='mvtec') 284 | train_tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/train', tmp)) 285 | 286 | # Init the Encoder network 287 | encoder_net = load_mvtec_model_from_checkpoint( 288 | input_shape=input_shape, 289 | code_length=args.code_length, 290 | idx_list_enc=args.idx_list_enc, 291 | use_selectors=args.use_selectors, 292 | net_cehckpoint=ae_net_cehckpoint, 293 | purge_ae_params=True 294 | ) 295 | 296 | ## Eval/Load hyperspeheres centers 297 | encoder_net.set_idx_list_enc(range(8)) 298 | centers = eval_spheres_centers(train_loader=train_loader, encoder_net=encoder_net, ae_net_cehckpoint=ae_net_cehckpoint, use_selectors=args.use_selectors, device=device, debug=args.debug) 299 | encoder_net.set_idx_list_enc(args.idx_list_enc) 300 | 301 | # Start training 302 | net_cehckpoint = train( 303 | net=encoder_net, 304 | train_loader=train_loader, 305 | centers=centers, 306 | out_dir=train_out_dir, 307 | tb_writer=train_tb_writer, 308 | device=device, 309 | learning_rate=args.learning_rate, 310 | weight_decay=args.weight_decay, 311 | lr_milestones=args.lr_milestones, 312 | epochs=args.epochs, 313 | nu=args.nu, 314 | boundary=args.boundary, 315 | batch_accumulation=args.batch_accumulation, 316 | warm_up_n_epochs=args.warm_up_n_epochs, 317 | log_frequency=args.log_frequency, 318 | debug=args.debug 319 | ) 320 | 321 | train_tb_writer.close() 322 | 323 | ### TEST the Encoder 324 | if args.test: 325 | if net_cehckpoint is None: 326 | net_cehckpoint = args.model_ckp 327 | 328 | # Init table to print resutls on shell 329 | # If we only test one model at a time, on the two tables will be empty 330 | # If all the model checkpoints are in one folder then the two tables will be automatically filled 331 | table_s = PrettyTable() 332 | table_s.field_names = ['Path', 'Code length', 'Enc layer list', 'weight decay', 'Object class', 'Boundary', 'batch size', 'nu', 'AUC', 'Balanced acc'] 333 | table_s.float_format = '0.3' 334 | table_h = PrettyTable() 335 | table_h.field_names = ['Path', 'Code length', 'Enc layer list', 'weight decay', 'Object class', 'Boundary', 'batch size', 'nu', 'AUC', 'Balanced acc'] 336 | table_h.float_format = '0.3' 337 | 338 | # Init dataframe to store results 339 | out_df = pd.DataFrame() 340 | 341 | is_file = os.path.isfile(net_cehckpoint) 342 | if is_file: 343 | out_df = test_models( 344 | test_loader=test_loader, 345 | net_cehckpoint=net_cehckpoint, 346 | tables=(table_s, table_h), 347 | out_df=out_df, 348 | is_texture=is_texture, 349 | input_shape=input_shape, 350 | idx_list_enc=args.idx_list_enc, 351 | boundary=args.boundary, 352 | normal_class=args.normal_class, 353 | use_selectors=args.use_selectors, 354 | device=device, 355 | debug=args.debug 356 | ) 357 | else: 358 | for model_ckp in tqdm(os.listdir(net_cehckpoint), total=len(os.listdir(net_cehckpoint)), desc="Running on models"): 359 | out_df = test_models( 360 | test_loader=test_loader, 361 | net_cehckpoint=os.path.join(net_cehckpoint, model_ckp, 'best_oc_model_model.pth'), 362 | tables=(table_s, table_h), 363 | out_df=out_df, 364 | is_texture=is_texture, 365 | idx_list_enc=args.idx_list_enc, 366 | boundary=args.boundary, 367 | normal_class=args.normal_class, 368 | use_selectors=args.use_selectors, 369 | device=device, 370 | debug=args.debug 371 | ) 372 | 373 | print(table_s) 374 | print(table_h) 375 | 376 | b_path = "./output/mvtec_test_results/test_csv" 377 | if not exists(b_path): 378 | makedirs(b_path) 379 | 380 | normal_class = net_cehckpoint.split('/')[-4] 381 | ff = glob.glob(os.path.join(b_path, f'*{normal_class}*')) 382 | if len(ff) == 0: 383 | csv_out_name = os.path.join(b_path, f"test-results-{normal_class}_0.csv") 384 | 385 | else: 386 | ff.sort() 387 | version = int(ff[-1].split('_')[-1].split('.')[0]) + 1 388 | logger.info(f"Already found csv file for {normal_class} with latest version: {version-1} ==> creaing new csv file with version: {version}") 389 | csv_out_name = os.path.join(b_path, f"test-results-{normal_class}_{version}.csv") 390 | 391 | out_df.to_csv(csv_out_name) 392 | 393 | 394 | if __name__ == '__main__': 395 | parser = argparse.ArgumentParser('AD') 396 | ## General config 397 | parser.add_argument('-s', '--seed', type=int, default=-1, help='Random seed (default: -1)') 398 | parser.add_argument('--n_workers', type=int, default=8, help='Number of workers for data loading. 0 means that the data will be loaded in the main process. (default: 8)') 399 | parser.add_argument('--output_path', default='./output') 400 | parser.add_argument('-lf', '--log-frequency', type=int, default=5, help='Log frequency (default: 5)') 401 | parser.add_argument('-dl', '--disable-logging', action="store_true", help='Disabel logging (default: False)') 402 | parser.add_argument('-db', '--debug', action="store_true", help='Activate debug mode, i.e., only use the first three batches (default: False)') 403 | ## Model config 404 | parser.add_argument('-zl', '--code-length', default=64, type=int, help='Code length (default: 64)') 405 | parser.add_argument('-ck', '--model-ckp', help='Model checkpoint') 406 | ## Optimizer config 407 | parser.add_argument('-alr', '--ae-learning-rate', type=float, default=1.e-4, help='Warm up learning rate (default: 1.e-4)') 408 | parser.add_argument('-lr', '--learning-rate', type=float, default=1.e-4, help='Learning rate (default: 1.e-4)') 409 | parser.add_argument('-awd', '--ae-weight-decay', type=float, default=0.5e-6, help='Warm up learning rate (default: 0.5e-4)') 410 | parser.add_argument('-wd', '--weight-decay', type=float, default=0.5e-6, help='Learning rate (default: 0.5e-6)') 411 | parser.add_argument('-aml', '--ae-lr-milestones', type=int, nargs='+', default=[], help='Pretrain milestone') 412 | parser.add_argument('-ml', '--lr-milestones', type=int, nargs='+', default=[], help='Training milestone') 413 | ## Data 414 | parser.add_argument('-dp', '--data-path', default='./MVTec_Anomaly', help='Dataset main path') 415 | parser.add_argument('-nc', '--normal-class', choices=('bottle', 'capsule', 'grid', 'leather', 'metal_nut', 'screw', 'toothbrush', 'wood', 'cable', 'carpet', 'hazelnut', 'pill', 'tile', 'transistor', 'zipper'), default='cable', help='Category (default: cable)') 416 | ## Training config 417 | parser.add_argument('-we', '--warm_up_n_epochs', type=int, default=5, help='Warm up epochs (default: 5)') 418 | parser.add_argument('--use-selectors', action="store_true", help='Use features selector (default: False)') 419 | parser.add_argument('-ba', '--batch-accumulation', type=int, default=-1, help='Batch accumulation (default: -1, i.e., None)') 420 | parser.add_argument('-ptr', '--pretrain', action="store_true", help='Pretrain model (default: False)') 421 | parser.add_argument('-tr', '--train', action="store_true", help='Train model (default: False)') 422 | parser.add_argument('-tt', '--test', action="store_true", help='Test model (default: False)') 423 | parser.add_argument('-tbc', '--train-best-conf', action="store_true", help='Train best configurations (default: False)') 424 | parser.add_argument('-bs', '--batch-size', type=int, default=128, help='Batch size (default: 128)') 425 | parser.add_argument('-bd', '--boundary', choices=("hard", "soft"), default="soft", help='Boundary (default: soft)') 426 | parser.add_argument('-ile', '--idx-list-enc', type=int, nargs='+', default=[], help='List of indices of model encoder') 427 | parser.add_argument('-e', '--epochs', type=int, default=1, help='Training epochs (default: 1)') 428 | parser.add_argument('-ae', '--ae-epochs', type=int, default=1, help='Warmp up epochs (default: 1)') 429 | parser.add_argument('-nu', '--nu', type=float, default=0.1) 430 | ## Test config 431 | parser.add_argument('-mt', '--metric', choices=(1, 2), type=int, default=2, help="Metric to evaluate norms (default: 2, i.e., L2)") 432 | args = parser.parse_args() 433 | 434 | main(args) 435 | --------------------------------------------------------------------------------