├── README.md ├── __init__.py ├── argument.py ├── configs └── cad.yaml ├── datasets ├── __init__.py ├── joint_mvtec_mtd.py ├── mtd_dataset.py ├── mvtec_dataset.py ├── revdis_mvtec_dataset.py ├── seq_mtd_mvtec.py ├── seq_mvtec.py ├── transforms │ ├── __init__.py │ ├── maskimg.py │ └── trans_cutpaste.py └── utils │ ├── __init__.py │ └── make_mtd_ano.py ├── eval.py ├── main.py ├── methods ├── __init__.py ├── agem.py ├── csflow.py ├── cutpaste.py ├── der.py ├── derpp.py ├── dne.py ├── er.py ├── fdr.py ├── panda.py ├── revdis.py └── utils │ ├── __init__.py │ └── base_method.py ├── models ├── __init__.py ├── csflow_net.py ├── resnet.py ├── revdis_net.py └── vit.py ├── requirements.txt └── utils ├── __init__.py ├── buffer.py ├── de_resnet.py ├── density.py ├── freia_funcs.py ├── optimizer.py ├── rd_resnet.py └── visualization.py /README.md: -------------------------------------------------------------------------------- 1 | # Continual Anomaly Detection 2 | 3 | Official code for ACMMM 2022 paper: 4 | 5 | **Title:** Towards Continual Adaptation in Industrial Anomaly Detection [[pdf]](https://dl.acm.org/doi/pdf/10.1145/3503161.3548232?casa_token=DjLhJL0kQl8AAAAA:AQyuwCMk4m_bNFtyfFi3YJu-lHa7-EIRrdgztanRKsf5f0535ROUoponI9gAZIrx4_PrUDjta64dNg). 6 | 7 | 8 | ## Datasets 9 | To train on the MVTec Anomaly Detection dataset [download](https://www.mvtec.com/company/research/datasets/mvtec-ad) 10 | the data and extract it. For the additional Magnetic Tile Defects dataset, we [download](https://github.com/abin24/Magnetic-tile-defect-datasets.) the data then run **datasets/utils/make_mtd_ano.py** for anomaly detection. 11 | 12 | ## Enviroment setup 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Getting pretrained ViT model 18 | ViT-B/16 model used in this paper can be downloaded at [here](https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz). 19 | 20 | ## Run 21 | We provide the configuration file to run CAD on multiple benchmarks in `configs`. 22 | 23 | ``` 24 | python main.py --config-file ./configs/cad.yaml --data_dir ../datasets/mvtec --mtd_dir ../datasets/mtd_ano_mask 25 | ``` 26 | You can run the method you need by modifying the configuration file. 27 | 28 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vijaylee/Continual_Anomaly_Detection/7be4ea5b4087f292ddc0e332e9e9e0fd6eca0c67/__init__.py -------------------------------------------------------------------------------- /argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import random 5 | import yaml 6 | import re 7 | 8 | class Namespace(object): 9 | def __init__(self, somedict): 10 | for key, value in somedict.items(): 11 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key) 12 | if isinstance(value, dict): 13 | self.__dict__[key] = Namespace(value) 14 | else: 15 | self.__dict__[key] = value 16 | 17 | def __getattr__(self, attribute): 18 | raise AttributeError( 19 | f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!") 20 | 21 | def set_deterministic(seed): 22 | # seed by default is None 23 | if seed is not None: 24 | print(f"Deterministic with seed = {seed}") 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | 32 | def str2bool(v): 33 | return v.lower() in ("yes", "true", "t", "1") 34 | 35 | 36 | def get_args(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--config-file', default='./configs/cad.yaml', type=str, help="xxx.yaml") 39 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 40 | parser.add_argument('--data_dir', type=str, default="../datasets/mvtec") 41 | parser.add_argument('--mtd_dir', type=str, default="../datasets/mtd_ano_mask") 42 | parser.add_argument('--save_checkpoint', type=str2bool, default=False, help='save checkpoint or not.') 43 | parser.add_argument('--save_path', type=str, default="./checkpoints") 44 | parser.add_argument('--noise_ratio', type=float, default=0) 45 | parser.add_argument('--seed', type=int, default=42) 46 | args = parser.parse_args() 47 | 48 | print(args) 49 | with open(args.config_file, 'r') as f: 50 | data = yaml.load(f, Loader=yaml.FullLoader) 51 | print(data) 52 | for key, value in Namespace(data).__dict__.items(): 53 | vars(args)[key] = value 54 | 55 | set_deterministic(args.seed) 56 | 57 | return args -------------------------------------------------------------------------------- /configs/cad.yaml: -------------------------------------------------------------------------------- 1 | name: continual anomaly detection 2 | dataset: 3 | name: seq-mvtec # seq-mvtec or seq-mtd-mvtec or joint-mtd-mvtec 4 | image_size: 224 # 224; revdis: 256; csflow: 768 5 | num_workers: 4 6 | data_incre_setting: mul # mul: 3+3+3+3+3; one: 10+1+1+1+1+1 7 | n_classes_per_task: 3 # mul_class_incre: 3; one_class_incre:1. if joint: 15 8 | n_tasks: 5 # mul_class_incre: 5; seq-mtd-mvtec, one_class_incre:6. if joint: 1 9 | dataset_order: 1 # 1, 2, 3 10 | strong_augmentation: True # strong augmentation: cutpaste, maskimg, etc.; weak augmentation: ColorJitter, RandomRotation, etc. 11 | random_aug: False 12 | 13 | 14 | model: 15 | name: vit # resnet, vit, net_csflow, net_revdis 16 | pretrained: True 17 | method: dne # dne, panda, cutpaste, csflow, revdis, er, der, derpp, fdr, agem, upper 18 | # dne 19 | fix_head: True 20 | with_dne: True 21 | # der 22 | with_embeds: True 23 | buffer_size: 200 24 | # cflow 25 | n_feat: 304 26 | fc_internal: 1024 27 | n_coupling_blocks: 4 28 | clamp: 3 29 | n_scales: 3 30 | 31 | train: 32 | optimizer: 33 | name: adam 34 | weight_decay: 0.00003 # 0.00003; csflow: 0.00001 35 | momentum: 0.9 36 | warmup_epochs: 10 37 | warmup_lr: 0 38 | base_lr: 0.0001 # 0.0001; csflow: 0.0002; revdis:0.005 39 | final_lr: 0 40 | num_epochs: 50 41 | batch_size: 32 # 32, csflow:16 42 | test_epochs: 10 43 | alpha: 0.4 44 | beta: 0.5 45 | num_classes: 2 46 | 47 | eval: 48 | eval_classifier: density # density, head 49 | batch_size: 32 # 32, revdis:1; csflow:16 50 | visualization: False 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .seq_mvtec import get_mvtec_dataloaders 2 | from .seq_mtd_mvtec import get_mtd_mvtec_dataloaders, get_joint_mtd_mvtec_dataloaders 3 | 4 | 5 | def get_dataloaders(args, t, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames): 6 | if args.dataset.name == 'seq-mvtec': 7 | train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, data_train_nums, all_test_filenames = get_mvtec_dataloaders( 8 | args, t, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames) 9 | elif args.dataset.name == 'seq-mtd-mvtec': 10 | train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, data_train_nums, all_test_filenames = get_mtd_mvtec_dataloaders( 11 | args, t, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames) 12 | elif args.dataset.name == 'joint-mtd-mvtec': 13 | train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, data_train_nums, all_test_filenames = get_joint_mtd_mvtec_dataloaders( 14 | args, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames) 15 | return train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, data_train_nums, all_test_filenames -------------------------------------------------------------------------------- /datasets/joint_mvtec_mtd.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from PIL import Image 4 | from joblib import Parallel, delayed 5 | import os 6 | import numpy as np 7 | import pandas as pd 8 | from collections import Iterable 9 | 10 | def flatten(items, ignore_types=(str, bytes)): 11 | for x in items: 12 | if isinstance(x, Iterable): 13 | yield from flatten(x) 14 | else: 15 | yield x 16 | 17 | class Repeat(Dataset): 18 | def __init__(self, org_dataset, new_length): 19 | self.org_dataset = org_dataset 20 | self.org_length = len(self.org_dataset) 21 | self.new_length = new_length 22 | 23 | def __len__(self): 24 | return self.new_length 25 | 26 | def __getitem__(self, idx): 27 | return self.org_dataset[idx % self.org_length] 28 | 29 | 30 | class MVTecMTDjoint(Dataset): 31 | """Face Landmarks dataset.""" 32 | 33 | def __init__(self, mvtec_dir, mtd_dir, size, transform=None, mode="train"): 34 | """ 35 | Args: 36 | root_dir (string): Directory with the MTD dataset. 37 | class_name (string): class to load. 38 | transform: Transform to apply to data 39 | mode: "train" loads training samples "test" test samples default "train" 40 | """ 41 | mvtec_classes = ['leather', 'bottle', 'metal_nut', 42 | 'grid', 'screw', 'zipper', 43 | 'tile', 'hazelnut', 'toothbrush', 44 | 'wood', 'transistor', 'pill', 45 | 'carpet', 'capsule', 'cable'] 46 | 47 | self.mvtec_dir = Path(mvtec_dir) 48 | self.mtd_dir = Path(mtd_dir) 49 | self.task_mvtec_classes = mvtec_classes 50 | self.transform = transform 51 | self.mode = mode 52 | self.size = size 53 | self.all_imgs = [] 54 | self.all_image_names = [] 55 | # find test images 56 | if self.mode == "train": 57 | self.mtd_image_names = list((self.mtd_dir / "train" / "good").glob("*.jpg")) 58 | self.all_image_names.append(self.mtd_image_names) 59 | print("loading MTD images") 60 | # during training we cache the smaller images for performance reasons (not a good coding style) 61 | self.mtd_imgs = (Parallel(n_jobs=10)( 62 | delayed(lambda file: Image.open(file).resize((size, size)).convert("RGB"))(file) for file in 63 | self.mtd_image_names)) 64 | self.all_imgs.append(self.mtd_imgs) 65 | print(f"loaded MTD : {len(self.mtd_imgs)} images") 66 | 67 | for class_name in self.task_mvtec_classes: 68 | self.mvtec_image_names = list((self.mvtec_dir / class_name / "train" / "good").glob("*.png")) 69 | self.all_image_names.append(self.mvtec_image_names) 70 | print("loading MVTec images") 71 | # during training we cache the smaller images for performance reasons (not a good coding style) 72 | self.mvtec_imgs = (Parallel(n_jobs=10)( 73 | delayed(lambda file: Image.open(file).resize((size, size)).convert("RGB"))(file) for file in 74 | self.mvtec_image_names)) 75 | self.all_imgs.append(self.mvtec_imgs) 76 | print(f"loaded {class_name} : {len(self.mvtec_imgs)} images") 77 | else: 78 | # test mode 79 | self.mtd_image_names = list((self.mtd_dir / "test").glob(str(Path("*") / "*.jpg"))) 80 | self.mtd_gt_names = list((self.mtd_dir / "gt").glob(str(Path("*") / "*.png"))) 81 | for class_name in self.task_mvtec_classes: 82 | self.mvtec_image_names = list((self.mvtec_dir / class_name / "test").glob(str(Path("*") / "*.png"))) 83 | self.all_image_names.append(self.mvtec_image_names) 84 | self.all_imgs, self.all_image_names = list(flatten(self.all_imgs)), list(flatten(self.all_image_names)) 85 | 86 | 87 | def __len__(self): 88 | return len(self.all_image_names) 89 | 90 | def __getitem__(self, idx): 91 | if self.mode == "train": 92 | # img = Image.open(self.image_names[idx]) 93 | # img = img.convert("RGB") 94 | img = self.all_imgs[idx].copy() 95 | if self.transform is not None: 96 | img = self.transform(img) 97 | return img 98 | else: 99 | filename = self.all_image_names[idx] 100 | label = filename.parts[-2] 101 | img = Image.open(filename) 102 | img = img.resize((self.size, self.size)).convert("RGB") 103 | if self.transform is not None: 104 | img = self.transform(img) 105 | return img, label != "good" 106 | -------------------------------------------------------------------------------- /datasets/mtd_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from PIL import Image 4 | from joblib import Parallel, delayed 5 | import os 6 | import numpy as np 7 | import pandas as pd 8 | from collections import Iterable 9 | 10 | def flatten(items, ignore_types=(str, bytes)): 11 | for x in items: 12 | if isinstance(x, Iterable): 13 | yield from flatten(x) 14 | else: 15 | yield x 16 | 17 | class Repeat(Dataset): 18 | def __init__(self, org_dataset, new_length): 19 | self.org_dataset = org_dataset 20 | self.org_length = len(self.org_dataset) 21 | self.new_length = new_length 22 | 23 | def __len__(self): 24 | return self.new_length 25 | 26 | def __getitem__(self, idx): 27 | return self.org_dataset[idx % self.org_length] 28 | 29 | 30 | class MTD(Dataset): 31 | """Face Landmarks dataset.""" 32 | 33 | def __init__(self, root_dir, size, require_mask=False, transform=None, mode="train"): 34 | """ 35 | Args: 36 | root_dir (string): Directory with the MTD dataset. 37 | class_name (string): class to load. 38 | transform: Transform to apply to data 39 | mode: "train" loads training samples "test" test samples default "train" 40 | """ 41 | self.root_dir = Path(root_dir) 42 | self.transform = transform 43 | self.require_mask = require_mask 44 | self.mode = mode 45 | self.size = size 46 | 47 | 48 | if self.mode == "train": 49 | self.image_names = list((self.root_dir / "train" / "good").glob("*.jpg")) 50 | print("loading MTD images") 51 | # during training we cache the smaller images for performance reasons (not a good coding style) 52 | self.imgs = (Parallel(n_jobs=10)( 53 | delayed(lambda file: Image.open(file).resize((size, size)).convert("RGB"))(file) for file in 54 | self.image_names)) 55 | print(f"loaded training set : {len(self.imgs)} images") 56 | else: 57 | # test mode 58 | self.image_names = list((self.root_dir / "test").glob(str(Path("*") / "*.jpg"))) 59 | self.gt_names = list((self.root_dir / "gt").glob(str(Path("*") / "*.png"))) 60 | 61 | def __len__(self): 62 | return len(self.image_names) 63 | 64 | def __getitem__(self, idx): 65 | if self.mode == "train": 66 | # img = Image.open(self.image_names[idx]) 67 | # img = img.convert("RGB") 68 | img = self.imgs[idx].copy() 69 | if self.transform is not None: 70 | img = self.transform(img) 71 | return img 72 | else: 73 | filename = self.image_names[idx] 74 | label = filename.parts[-2] 75 | img = Image.open(filename) 76 | img = img.resize((self.size, self.size)).convert("RGB") 77 | if self.require_mask: 78 | gtname = self.gt_names[idx] 79 | gt = Image.open(gtname).convert("RGB") 80 | if self.transform is not None: 81 | img = self.transform(img) 82 | gt = self.transform(gt) 83 | return img, gt, label != "good" 84 | else: 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | return img, label != "good" 88 | -------------------------------------------------------------------------------- /datasets/mvtec_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from PIL import Image 4 | from joblib import Parallel, delayed 5 | import os 6 | import numpy as np 7 | import pandas as pd 8 | from collections import Iterable 9 | 10 | def flatten(items, ignore_types=(str, bytes)): 11 | for x in items: 12 | if isinstance(x, Iterable): 13 | yield from flatten(x) 14 | else: 15 | yield x 16 | 17 | class Repeat(Dataset): 18 | def __init__(self, org_dataset, new_length): 19 | self.org_dataset = org_dataset 20 | self.org_length = len(self.org_dataset) 21 | self.new_length = new_length 22 | 23 | def __len__(self): 24 | return self.new_length 25 | 26 | def __getitem__(self, idx): 27 | return self.org_dataset[idx % self.org_length] 28 | 29 | 30 | class MVTecAD(Dataset): 31 | """Face Landmarks dataset.""" 32 | 33 | def __init__(self, root_dir, task_mvtec_classes, size, transform=None, mode="train"): 34 | """ 35 | Args: 36 | root_dir (string): Directory with the MVTec AD dataset. 37 | class_name (string): class to load. 38 | transform: Transform to apply to data 39 | mode: "train" loads training samples "test" test samples default "train" 40 | """ 41 | self.root_dir = Path(root_dir) 42 | self.task_mvtec_classes = task_mvtec_classes 43 | self.transform = transform 44 | self.mode = mode 45 | self.size = size 46 | self.all_imgs = [] 47 | self.all_image_names = [] 48 | # find test images 49 | for class_name in self.task_mvtec_classes: 50 | if self.mode == "train": 51 | self.image_names = list((self.root_dir / class_name / "train" / "good").glob("*.png")) 52 | self.all_image_names.append(self.image_names) 53 | print("loading images") 54 | # during training we cache the smaller images for performance reasons (not a good coding style) 55 | self.imgs = (Parallel(n_jobs=10)( 56 | delayed(lambda file: Image.open(file).resize((size, size)).convert("RGB"))(file) for file in 57 | self.image_names)) 58 | self.all_imgs.append(self.imgs) 59 | print(f"loaded {class_name} : {len(self.imgs)} images") 60 | else: 61 | # test mode 62 | self.image_names = list((self.root_dir / class_name / "test").glob(str(Path("*") / "*.png"))) 63 | self.all_image_names.append(self.image_names) 64 | self.all_imgs, self.all_image_names = list(flatten(self.all_imgs)), list(flatten(self.all_image_names)) 65 | 66 | def __len__(self): 67 | return len(self.all_image_names) 68 | 69 | def __getitem__(self, idx): 70 | if self.mode == "train": 71 | # img = Image.open(self.image_names[idx]) 72 | # img = img.convert("RGB") 73 | img = self.all_imgs[idx].copy() 74 | if self.transform is not None: 75 | img = self.transform(img) 76 | return img 77 | else: 78 | filename = self.all_image_names[idx] 79 | label = filename.parts[-2] 80 | img = Image.open(filename) 81 | img = img.resize((self.size, self.size)).convert("RGB") 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | return img, label != "good" 85 | -------------------------------------------------------------------------------- /datasets/revdis_mvtec_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | import os 4 | import torch 5 | import glob 6 | from pathlib import Path 7 | from torch.utils.data import Dataset 8 | 9 | class RevDisTestMVTecDataset(Dataset): 10 | def __init__(self, root_dir, task_mvtec_classes, size): 11 | self.root_dir = root_dir 12 | self.task_mvtec_classes = task_mvtec_classes 13 | self.transform, self.gt_transform = self.get_data_transforms(size, size) 14 | # load dataset 15 | self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 16 | 17 | def load_dataset(self): 18 | img_tot_paths = [] 19 | gt_tot_paths = [] 20 | tot_labels = [] 21 | 22 | for class_name in self.task_mvtec_classes: 23 | self.img_path = os.path.join(self.root_dir, class_name, "test") 24 | self.gt_path = os.path.join(self.root_dir, class_name, 'ground_truth') 25 | defect_types = os.listdir(self.img_path) 26 | 27 | for defect_type in defect_types: 28 | if defect_type == 'good': 29 | img_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") 30 | img_tot_paths.extend(img_paths) 31 | gt_tot_paths.extend([0] * len(img_paths)) 32 | tot_labels.extend([0] * len(img_paths)) 33 | else: 34 | img_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") 35 | gt_paths = glob.glob(os.path.join(self.gt_path, defect_type) + "/*.png") 36 | img_paths.sort() 37 | gt_paths.sort() 38 | img_tot_paths.extend(img_paths) 39 | gt_tot_paths.extend(gt_paths) 40 | tot_labels.extend([1] * len(img_paths)) 41 | 42 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 43 | return img_tot_paths, gt_tot_paths, tot_labels 44 | 45 | def get_data_transforms(self, size, isize): 46 | mean_train = [0.485, 0.456, 0.406] 47 | std_train = [0.229, 0.224, 0.225] 48 | data_transforms = transforms.Compose([ 49 | transforms.Resize((size, size)), 50 | transforms.ToTensor(), 51 | transforms.CenterCrop(isize), 52 | # transforms.CenterCrop(args.input_size), 53 | transforms.Normalize(mean=mean_train, 54 | std=std_train)]) 55 | gt_transforms = transforms.Compose([ 56 | transforms.Resize((size, size)), 57 | transforms.CenterCrop(isize), 58 | transforms.ToTensor()]) 59 | return data_transforms, gt_transforms 60 | 61 | def __len__(self): 62 | return len(self.img_paths) 63 | 64 | def __getitem__(self, idx): 65 | img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] 66 | img = Image.open(img_path).convert('RGB') 67 | img = self.transform(img) 68 | if gt == 0: 69 | gt = torch.zeros([1, img.size()[-2], img.size()[-2]]) 70 | else: 71 | gt = Image.open(gt) 72 | gt = self.gt_transform(gt) 73 | 74 | assert img.size()[1:] == gt.size()[1:], "image.size != gt.size !!!" 75 | 76 | return img, gt, label -------------------------------------------------------------------------------- /datasets/seq_mtd_mvtec.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from .transforms import aug_transformation, no_aug_transformation 3 | from .mvtec_dataset import MVTecAD 4 | from .mtd_dataset import MTD 5 | from torch.utils.data import DataLoader 6 | from .revdis_mvtec_dataset import RevDisTestMVTecDataset 7 | from .joint_mvtec_mtd import MVTecMTDjoint 8 | from .utils import get_mvtec_classes 9 | 10 | 11 | def get_mtd_mvtec_dataloaders(args, t, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames): 12 | train_transform = aug_transformation(args) 13 | test_transform = no_aug_transformation(args) 14 | 15 | if t == 0: 16 | learned_tasks.append('Magnetic-Tile-Defect') 17 | 18 | if args.model.method == 'revdis': 19 | train_data = MTD(args.mtd_dir, args.dataset.image_size, require_mask=False, transform=test_transform) 20 | test_data = MTD(args.mtd_dir, args.dataset.image_size, require_mask=True, transform=test_transform, mode="test") 21 | else: 22 | train_data = MTD(args.mtd_dir, args.dataset.image_size, require_mask=False, transform=train_transform) 23 | test_data = MTD(args.mtd_dir, args.dataset.image_size, require_mask=False, transform=test_transform, mode="test") 24 | all_test_filenames.append(test_data.image_names) 25 | 26 | train_dataloader = DataLoader(train_data, batch_size=args.train.batch_size, shuffle=True, num_workers=args.dataset.num_workers) 27 | dataloaders_train.append(train_dataloader) 28 | dataloader_test = DataLoader(test_data, batch_size=args.eval.batch_size, shuffle=False, num_workers=args.dataset.num_workers) 29 | dataloaders_test.append(dataloader_test) 30 | print('class name: MTD', 'number of training sets:', len(train_data), 31 | 'number of testing sets:', len(test_data)) 32 | else: 33 | mvtec_classes = get_mvtec_classes(args) 34 | t -= 1 35 | N_CLASSES_PER_TASK = args.dataset.n_classes_per_task 36 | i = t * N_CLASSES_PER_TASK 37 | task_mvtec_classes = mvtec_classes[i: i + N_CLASSES_PER_TASK] 38 | learned_tasks.append(task_mvtec_classes) 39 | 40 | if args.model.method == 'revdis': 41 | train_data = MVTecAD(args.data_dir, task_mvtec_classes, transform=test_transform, 42 | size=args.dataset.image_size) 43 | test_data = RevDisTestMVTecDataset(args.data_dir, task_mvtec_classes, size=args.dataset.image_size) 44 | else: 45 | train_data = MVTecAD(args.data_dir, task_mvtec_classes, 46 | transform=train_transform, size=args.dataset.image_size) 47 | test_data = MVTecAD(args.data_dir, task_mvtec_classes, args.dataset.image_size, 48 | transform=test_transform, mode="test") 49 | all_test_filenames.append(test_data.all_image_names) 50 | 51 | train_dataloader = DataLoader(train_data, batch_size=args.train.batch_size, shuffle=True, num_workers=args.dataset.num_workers) 52 | dataloaders_train.append(train_dataloader) 53 | dataloader_test = DataLoader(test_data, batch_size=args.eval.batch_size, shuffle=False, num_workers=args.dataset.num_workers) 54 | dataloaders_test.append(dataloader_test) 55 | print('class name:', task_mvtec_classes, 'number of training sets:', len(train_data), 56 | 'number of testing sets:', len(test_data)) 57 | 58 | return train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, len(train_data), all_test_filenames 59 | 60 | 61 | def get_joint_mtd_mvtec_dataloaders(args, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames): 62 | train_transform = aug_transformation(args) 63 | test_transform = no_aug_transformation(args) 64 | 65 | learned_tasks.append('Magnetic-Tile-Defect and MVTec') 66 | train_data = MVTecMTDjoint(args.data_dir, args.mtd_dir, args.dataset.image_size, transform=train_transform) 67 | test_data = MVTecMTDjoint(args.data_dir, args.mtd_dir, args.dataset.image_size, transform=test_transform, mode="test") 68 | all_test_filenames.append(test_data.all_image_names) 69 | 70 | train_dataloader = DataLoader(train_data, batch_size=args.train.batch_size, shuffle=True, 71 | num_workers=args.dataset.num_workers) 72 | dataloaders_train.append(train_dataloader) 73 | dataloader_test = DataLoader(test_data, batch_size=args.eval.batch_size, shuffle=False, 74 | num_workers=args.dataset.num_workers) 75 | dataloaders_test.append(dataloader_test) 76 | 77 | print('class name: MVTec+MTD', 'number of training sets:', len(train_data), 78 | 'number of testing sets:', len(test_data)) 79 | 80 | return train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, len(train_data), all_test_filenames -------------------------------------------------------------------------------- /datasets/seq_mvtec.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from .transforms import aug_transformation, no_aug_transformation 3 | from .mvtec_dataset import MVTecAD 4 | from .revdis_mvtec_dataset import RevDisTestMVTecDataset 5 | from torch.utils.data import DataLoader 6 | from .utils import get_mvtec_classes 7 | 8 | 9 | def get_mvtec_dataloaders(args, t, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames): 10 | mvtec_classes = get_mvtec_classes(args) 11 | 12 | N_CLASSES_PER_TASK = args.dataset.n_classes_per_task 13 | if args.dataset.data_incre_setting == 'one': 14 | # N_CLASSES_PER_TASK = 1 15 | if t == 0: 16 | task_mvtec_classes = mvtec_classes[: 10] 17 | else: 18 | i = 10 + (t - 1) * N_CLASSES_PER_TASK 19 | task_mvtec_classes = mvtec_classes[i: i + N_CLASSES_PER_TASK] 20 | else: 21 | # N_CLASSES_PER_TASK = 3 22 | i = t * N_CLASSES_PER_TASK 23 | task_mvtec_classes = mvtec_classes[i: i + N_CLASSES_PER_TASK] 24 | learned_tasks.append(task_mvtec_classes) 25 | 26 | train_transform = aug_transformation(args) 27 | test_transform = no_aug_transformation(args) 28 | 29 | if args.model.method == 'revdis': 30 | train_data = MVTecAD(args.data_dir, task_mvtec_classes, transform=test_transform, size=args.dataset.image_size) 31 | test_data = RevDisTestMVTecDataset(args.data_dir, task_mvtec_classes, size=args.dataset.image_size) 32 | all_test_filenames.append(test_data.img_paths) 33 | else: 34 | train_data = MVTecAD(args.data_dir, task_mvtec_classes, transform=train_transform, size=args.dataset.image_size) 35 | test_data = MVTecAD(args.data_dir, task_mvtec_classes, args.dataset.image_size, transform=test_transform, mode="test") 36 | all_test_filenames.append(test_data.all_image_names) 37 | 38 | train_dataloader = DataLoader(train_data, batch_size=args.train.batch_size, shuffle=True, num_workers=args.dataset.num_workers) 39 | dataloaders_train.append(train_dataloader) 40 | dataloader_test = DataLoader(test_data, batch_size=args.eval.batch_size, shuffle=False, num_workers=args.dataset.num_workers) 41 | dataloaders_test.append(dataloader_test) 42 | print('class name:', task_mvtec_classes, 'number of training sets:', len(train_data), 43 | 'number of testing sets:', len(test_data)) 44 | 45 | return train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, len(train_data), all_test_filenames -------------------------------------------------------------------------------- /datasets/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from .trans_cutpaste import CutPasteNormal, CutPasteScar, CutPaste3Way 3 | from .maskimg import MaskImg 4 | 5 | 6 | def aug_transformation(args): 7 | data_norm = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]] 8 | if args.dataset.strong_augmentation: 9 | after_cutpaste_transform = transforms.Compose([ 10 | transforms.RandomRotation(90), 11 | transforms.ToTensor(), 12 | transforms.Normalize(*data_norm) 13 | ]) 14 | if args.dataset.random_aug: 15 | aug_transformation = transforms.Compose([ 16 | transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 17 | transforms.Resize((args.dataset.image_size, args.dataset.image_size)), 18 | transforms.RandomChoice([ 19 | CutPasteNormal(transform=after_cutpaste_transform), 20 | CutPasteScar(transform=after_cutpaste_transform), 21 | MaskImg(args.device, args.dataset.image_size, 0.25, [16, 2], colorJitter=0.1, 22 | transform=after_cutpaste_transform)]) 23 | ]) 24 | else: 25 | aug_transformation = transforms.Compose([ 26 | transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 27 | transforms.Resize((args.dataset.image_size, args.dataset.image_size)), 28 | # MaskImg(args.device, args.dataset.image_size, 0.25, [16, 2], colorJitter=0.1, transform=after_cutpaste_transform) 29 | CutPasteNormal(transform=after_cutpaste_transform) 30 | ]) 31 | else: 32 | aug_transformation = transforms.Compose([ 33 | transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 34 | transforms.Resize(args.dataset.image_size), 35 | transforms.RandomRotation(90), 36 | transforms.ToTensor(), 37 | transforms.Normalize(*data_norm) 38 | ]) 39 | return aug_transformation 40 | 41 | 42 | def no_aug_transformation(args): 43 | data_norm = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]] 44 | no_aug_transformation = transforms.Compose([ 45 | transforms.Resize((args.dataset.image_size, args.dataset.image_size)), 46 | transforms.ToTensor(), 47 | transforms.Normalize(*data_norm) 48 | ]) 49 | return no_aug_transformation -------------------------------------------------------------------------------- /datasets/transforms/maskimg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from timm.models.layers import to_2tuple 3 | import torch 4 | from PIL import Image 5 | from torchvision import transforms 6 | from einops import rearrange 7 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, \ 8 | IMAGENET_INCEPTION_STD 9 | from torchvision.transforms import ToPILImage 10 | 11 | 12 | class RandomMaskingGenerator: 13 | def __init__(self, input_size, mask_ratio): 14 | if not isinstance(input_size, tuple): 15 | input_size = (input_size,) * 2 16 | 17 | self.height, self.width = input_size 18 | 19 | self.num_patches = self.height * self.width 20 | self.num_mask = int(mask_ratio * self.num_patches) 21 | 22 | def __repr__(self): 23 | repr_str = "Maks: total patches {}, mask patches {}".format( 24 | self.num_patches, self.num_mask 25 | ) 26 | return repr_str 27 | 28 | def __call__(self): 29 | mask = np.hstack([ 30 | np.zeros(self.num_patches - self.num_mask), 31 | np.ones(self.num_mask), 32 | ]) 33 | np.random.shuffle(mask) 34 | return mask # [196] 35 | 36 | 37 | class MaskImg: 38 | def __init__(self, device, size, mask_ratio, patch_size, colorJitter=0.1, transform=None): 39 | self.device = device 40 | self.input_size = size 41 | self.mask_ratio = mask_ratio 42 | self.patch_size = patch_size 43 | # self.patch_size = to_2tuple(patch_size) 44 | self.window_size = (self.input_size // self.patch_size[0], self.input_size // self.patch_size[1]) 45 | self.masked_position_generator = RandomMaskingGenerator(self.window_size, self.mask_ratio) 46 | 47 | self.transform = transform 48 | if colorJitter is None: 49 | self.colorJitter = None 50 | else: 51 | self.colorJitter = transforms.ColorJitter(brightness=colorJitter, 52 | contrast=colorJitter, 53 | saturation=colorJitter, 54 | hue=colorJitter) 55 | 56 | def __call__(self, img): 57 | oriimg = self.transform(img) 58 | 59 | bool_masked_pos = self.masked_position_generator() 60 | bool_masked_pos = torch.from_numpy(bool_masked_pos) 61 | 62 | # imagenet_default_mean_and_std = True 63 | # mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 64 | # std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 65 | # 66 | # # transform = transforms.Compose([ 67 | # # # transforms.RandomResizedCrop(input_size), 68 | # # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 69 | # # transforms.ToTensor(), 70 | # # transforms.Normalize( 71 | # # mean=torch.tensor(mean), 72 | # # std=torch.tensor(std)) 73 | # # ]) 74 | img = self.colorJitter(img) 75 | img = self.transform(img) 76 | img = img[None, :] 77 | bool_masked_pos = bool_masked_pos[None, :] 78 | # img = img.to(self.device) 79 | bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool) 80 | 81 | # save original img 82 | mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN)[None, :, None, None] 83 | std = torch.as_tensor(IMAGENET_DEFAULT_STD)[None, :, None, None] 84 | ori_img = img * std + mean # in [0, 1] 85 | 86 | img_squeeze = rearrange(ori_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=self.patch_size[0], p2=self.patch_size[1]) 87 | img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / ( 88 | img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) 89 | img_patch = rearrange(img_norm, 'b n p c -> b n (p c)') 90 | 91 | # make mask 92 | mask = torch.ones_like(img_patch) 93 | mask[bool_masked_pos] = 0 94 | mask = rearrange(mask, 'b n (p c) -> b n p c', c=3) 95 | h = int(self.input_size / self.patch_size[0]) 96 | w = int(self.input_size / self.patch_size[1]) 97 | mask = rearrange(mask, 'b (h w) (p1 p2) c -> b c (h p1) (w p2)', p1=self.patch_size[0], p2=self.patch_size[1], h=h, w=w) 98 | 99 | # save reconstruction img 100 | rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3) 101 | # Notice: To visualize the reconstruction image, we add the predict and the original mean and var of each patch. Issue #40 102 | rec_img = rec_img * (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + img_squeeze.mean( 103 | dim=-2, 104 | keepdim=True) 105 | rec_img = rearrange(rec_img, 'b (h w) (p1 p2) c -> b c (h p1) (w p2)', p1=self.patch_size[0], p2=self.patch_size[1], h=h, 106 | w=w) 107 | # img = ToPILImage()(rec_img[0, :].clip(0, 0.996)) 108 | 109 | # save random mask img 110 | img_mask = rec_img * mask 111 | img = img_mask[0, :] 112 | # img = ToPILImage()(img_mask[0, :]) 113 | 114 | return oriimg, img 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /datasets/transforms/trans_cutpaste.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | from torchvision import transforms 4 | import torch 5 | from PIL import ImageFilter, Image 6 | import numpy as np 7 | 8 | 9 | def cut_paste_collate_fn(batch): 10 | # cutPaste return 2 tuples of tuples we convert them into a list of tuples 11 | img_types = list(zip(*batch)) 12 | # print(list(zip(*batch))) 13 | return [torch.stack(imgs) for imgs in img_types] 14 | 15 | 16 | class MyGaussianBlur(ImageFilter.Filter): 17 | name = "GaussianBlur" 18 | 19 | def __init__(self, radius=2, bounds=None): 20 | self.radius = radius 21 | self.bounds = bounds 22 | 23 | def filter(self, image): 24 | if self.bounds: 25 | clips = image.crop(self.bounds).gaussian_blur(self.radius) 26 | image.paste(clips, self.bounds) 27 | return image 28 | else: 29 | return image.gaussian_blur(self.radius) 30 | 31 | class CutPaste1(object): 32 | """Base class for both cutpaste variants with common operations""" 33 | 34 | def __init__(self, colorJitter=0.1, transform=None): 35 | self.transform = transform 36 | if colorJitter is None: 37 | self.colorJitter = None 38 | else: 39 | self.colorJitter = transforms.ColorJitter(brightness=colorJitter, 40 | contrast=colorJitter, 41 | saturation=colorJitter, 42 | hue=colorJitter) 43 | 44 | def __call__(self, img): 45 | # apply transforms to both images 46 | if self.transform: 47 | org_img = self.transform(img) 48 | if self.colorJitter: 49 | img = self.colorJitter(img) 50 | img = self.transform(img) 51 | return org_img, img 52 | 53 | class CutPaste(object): 54 | """Base class for both cutpaste variants with common operations""" 55 | 56 | def __init__(self, colorJitter=0.1, transform=None): 57 | self.transform = transform 58 | 59 | if colorJitter is None: 60 | self.colorJitter = None 61 | else: 62 | self.colorJitter = transforms.ColorJitter(brightness=colorJitter, 63 | contrast=colorJitter, 64 | saturation=colorJitter, 65 | hue=colorJitter) 66 | # self.k = random.randint(0,3) 67 | 68 | def __call__(self, org_img, img): 69 | # apply transforms to both images 70 | if self.transform: 71 | img = self.transform(img) 72 | # org_img = np.rot90(org_img, k=self.k) 73 | # org_img = Image.fromarray(org_img) 74 | org_img = self.transform(org_img) 75 | return org_img, img 76 | 77 | 78 | class CutPasteNormal(CutPaste): 79 | """Randomly copy one patche from the image and paste it somewere else. 80 | Args: 81 | area_ratio (list): list with 2 floats for maximum and minimum area to cut out 82 | aspect_ratio (float): minimum area ration. Ration is sampled between aspect_ratio and 1/aspect_ratio. 83 | """ 84 | 85 | def __init__(self, area_ratio=[0.02, 0.15], aspect_ratio=0.3, **kwags): 86 | super(CutPasteNormal, self).__init__(**kwags) 87 | self.area_ratio = area_ratio 88 | self.aspect_ratio = aspect_ratio 89 | 90 | def __call__(self, img): 91 | # TODO: we might want to use the pytorch implementation to calculate the patches from https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#RandomErasing 92 | h = img.size[0] 93 | w = img.size[1] 94 | 95 | new_aspect_ratio = random.uniform(0.1, 1) 96 | 97 | augmented = img.copy() 98 | # ratio between area_ratio[0] and area_ratio[1] 99 | ratio_area = random.uniform(self.area_ratio[0], self.area_ratio[1]) * w * h 100 | 101 | # sample in log space 102 | log_ratio = torch.log(torch.tensor((new_aspect_ratio, 1 / new_aspect_ratio))) 103 | aspect = torch.exp( 104 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 105 | ).item() 106 | 107 | cut_w = int(round(math.sqrt(ratio_area * aspect))) 108 | cut_h = int(round(math.sqrt(ratio_area / aspect))) 109 | 110 | # one might also want to sample from other images. currently we only sample from the image itself 111 | from_location_h = int(random.uniform(0, h - cut_h)) 112 | from_location_w = int(random.uniform(0, w - cut_w)) 113 | 114 | box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h] 115 | patch = img.crop(box) 116 | 117 | # Before pasting, we apply color jitter.we rotate or jitter pixel values in the patch 118 | if self.colorJitter: 119 | patch = self.colorJitter(patch) 120 | 121 | to_location_h = int(random.uniform(0, h - cut_h)) 122 | to_location_w = int(random.uniform(0, w - cut_w)) 123 | 124 | insert_box = [to_location_w, to_location_h, to_location_w + cut_w, to_location_h + cut_h] 125 | 126 | augmented.paste(patch, insert_box) # The pasted image patch always origins from the same image it is pasted to 127 | 128 | return super().__call__(img, augmented) 129 | 130 | 131 | class BlurCutPasteNormal(CutPaste): 132 | """Randomly copy one patche from the image and paste it somewere else. 133 | Args: 134 | area_ratio (list): list with 2 floats for maximum and minimum area to cut out 135 | aspect_ratio (float): minimum area ration. Ration is sampled between aspect_ratio and 1/aspect_ratio. 136 | """ 137 | 138 | def __init__(self, area_ratio=[0.02, 0.15], aspect_ratio=0.3, **kwags): 139 | super(BlurCutPasteNormal, self).__init__(**kwags) 140 | self.area_ratio = area_ratio 141 | self.aspect_ratio = aspect_ratio 142 | 143 | def __call__(self, img): 144 | # TODO: we might want to use the pytorch implementation to calculate the patches from https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#RandomErasing 145 | h = img.size[0] 146 | w = img.size[1] 147 | 148 | augmented = img.copy() 149 | # ratio between area_ratio[0] and area_ratio[1] 150 | ratio_area = random.uniform(self.area_ratio[0], self.area_ratio[1]) * w * h 151 | 152 | # sample in log space 153 | log_ratio = torch.log(torch.tensor((self.aspect_ratio, 1 / self.aspect_ratio))) 154 | aspect = torch.exp( 155 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 156 | ).item() 157 | 158 | cut_w = int(round(math.sqrt(ratio_area * aspect))) 159 | cut_h = int(round(math.sqrt(ratio_area / aspect))) 160 | 161 | to_location_h = int(random.uniform(0, h - cut_h)) 162 | to_location_w = int(random.uniform(0, w - cut_w)) 163 | 164 | insert_box = [to_location_w, to_location_h, to_location_w + cut_w, to_location_h + cut_h] 165 | 166 | augmented = augmented.filter(MyGaussianBlur(radius=5, bounds=insert_box)) 167 | return super().__call__(img, augmented) 168 | 169 | 170 | class MultiCutPasteNormal(CutPaste): 171 | """Randomly copy one patche from the image and paste it somewere else. 172 | Args: 173 | area_ratio (list): list with 2 floats for maximum and minimum area to cut out 174 | aspect_ratio (float): minimum area ration. Ration is sampled between aspect_ratio and 1/aspect_ratio. 175 | """ 176 | 177 | def __init__(self, area_ratio=[0.002, 0.015], aspect_ratio=0.3, **kwags): 178 | super(MultiCutPasteNormal, self).__init__(**kwags) 179 | self.area_ratio = area_ratio 180 | self.aspect_ratio = aspect_ratio 181 | 182 | def __call__(self, img): 183 | # TODO: we might want to use the pytorch implementation to calculate the patches from https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#RandomErasing 184 | h = img.size[0] 185 | w = img.size[1] 186 | 187 | left, right = random.uniform(0.001, 0.003), random.uniform(0.01, 0.02) 188 | new_area_ratio = [left, right] 189 | new_aspect_ratio = random.uniform(0.1, 1) 190 | 191 | augmented = img.copy() 192 | 193 | num_patch = random.randint(1, 5) 194 | for _ in range(num_patch): 195 | # ratio between area_ratio[0] and area_ratio[1] 196 | ratio_area = random.uniform(new_area_ratio[0], new_area_ratio[1]) * w * h 197 | 198 | # sample in log space 199 | log_ratio = torch.log(torch.tensor((new_aspect_ratio, 1 / new_aspect_ratio))) 200 | aspect = torch.exp( 201 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 202 | ).item() 203 | 204 | cut_w = int(round(math.sqrt(ratio_area * aspect))) 205 | cut_h = int(round(math.sqrt(ratio_area / aspect))) 206 | 207 | # one might also want to sample from other images. currently we only sample from the image itself 208 | from_location_h = int(random.uniform(0, h - cut_h)) 209 | from_location_w = int(random.uniform(0, w - cut_w)) 210 | 211 | box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h] 212 | patch = img.crop(box) 213 | 214 | # Before pasting, we apply color jitter.we rotate or jitter pixel values in the patch 215 | if self.colorJitter: 216 | patch = self.colorJitter(patch) 217 | 218 | to_location_h = int(random.uniform(0, h - cut_h)) 219 | to_location_w = int(random.uniform(0, w - cut_w)) 220 | 221 | insert_box = [to_location_w, to_location_h, to_location_w + cut_w, to_location_h + cut_h] 222 | 223 | augmented.paste(patch, 224 | insert_box) # The pasted image patch always origins from the same image it is pasted to 225 | 226 | return super().__call__(img, augmented) 227 | 228 | 229 | class SwapPatch(CutPaste): 230 | """Randomly copy one patche from the image and paste it somewere else. 231 | Args: 232 | area_ratio (list): list with 2 floats for maximum and minimum area to cut out 233 | aspect_ratio (float): minimum area ration. Ration is sampled between aspect_ratio and 1/aspect_ratio. 234 | """ 235 | 236 | def __init__(self, area_ratio=[0.0002, 0.0015], aspect_ratio=0.3, **kwags): 237 | super(SwapPatch, self).__init__(**kwags) 238 | self.area_ratio = area_ratio 239 | self.aspect_ratio = aspect_ratio 240 | 241 | def __call__(self, img): 242 | # TODO: we might want to use the pytorch implementation to calculate the patches from https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#RandomErasing 243 | h = img.size[0] 244 | w = img.size[1] 245 | 246 | augmented = img.copy() 247 | # ratio between area_ratio[0] and area_ratio[1] 248 | ratio_area = random.uniform(self.area_ratio[0], self.area_ratio[1]) * w * h 249 | 250 | # sample in log space 251 | log_ratio = torch.log(torch.tensor((self.aspect_ratio, 1 / self.aspect_ratio))) 252 | aspect = torch.exp( 253 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 254 | ).item() 255 | 256 | cut_w = int(round(math.sqrt(ratio_area * aspect))) 257 | cut_h = int(round(math.sqrt(ratio_area / aspect))) 258 | 259 | # one might also want to sample from other images. currently we only sample from the image itself 260 | from_location_h = int(random.uniform(0, h - cut_h)) 261 | from_location_w = int(random.uniform(0, w - cut_w)) 262 | 263 | box1 = [0, from_location_h, w, from_location_h + cut_h] 264 | patch1 = img.crop(box1) 265 | 266 | to_location_h = int(random.uniform(0, h - cut_h)) 267 | to_location_w = int(random.uniform(0, w - cut_w)) 268 | 269 | box2 = [0, to_location_h, w, to_location_h + cut_h] 270 | patch2 = img.crop(box2) 271 | 272 | augmented.paste(patch1, box2) # The pasted image patch always origins from the same image it is pasted to 273 | augmented.paste(patch2, box1) 274 | 275 | return super().__call__(img, augmented) 276 | 277 | 278 | class CutPasteScar(CutPaste): 279 | """Randomly copy one patche from the image and paste it somewere else. 280 | Args: 281 | width (list): width to sample from. List of [min, max] 282 | height (list): height to sample from. List of [min, max] 283 | rotation (list): rotation to sample from. List of [min, max] 284 | """ 285 | 286 | def __init__(self, width=[2, 16], height=[10, 25], rotation=[-45, 45], **kwags): 287 | super(CutPasteScar, self).__init__(**kwags) 288 | self.width = width 289 | self.height = height 290 | self.rotation = rotation 291 | 292 | def __call__(self, img): 293 | h = img.size[0] 294 | w = img.size[1] 295 | 296 | # cut region 297 | cut_w = random.uniform(*self.width) 298 | cut_h = random.uniform(*self.height) 299 | 300 | from_location_h = int(random.uniform(0, h - cut_h)) 301 | from_location_w = int(random.uniform(0, w - cut_w)) 302 | 303 | box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h] 304 | patch = img.crop(box) 305 | 306 | if self.colorJitter: 307 | patch = self.colorJitter(patch) 308 | 309 | # rotate 310 | rot_deg = random.uniform(*self.rotation) 311 | patch = patch.convert("RGBA").rotate(rot_deg, expand=True) 312 | 313 | # paste 314 | to_location_h = int(random.uniform(0, h - patch.size[0])) 315 | to_location_w = int(random.uniform(0, w - patch.size[1])) 316 | 317 | mask = patch.split()[-1] 318 | patch = patch.convert("RGB") 319 | 320 | augmented = img.copy() 321 | augmented.paste(patch, (to_location_w, to_location_h), mask=mask) 322 | 323 | return super().__call__(img, augmented) 324 | 325 | 326 | class MultiCutPasteScar(CutPaste): 327 | """Randomly copy one patche from the image and paste it somewere else. 328 | Args: 329 | width (list): width to sample from. List of [min, max] 330 | height (list): height to sample from. List of [min, max] 331 | rotation (list): rotation to sample from. List of [min, max] 332 | """ 333 | 334 | def __init__(self, width=[2, 16], height=[10, 25], rotation=[-45, 45], **kwags): 335 | super(MultiCutPasteScar, self).__init__(**kwags) 336 | self.width = width 337 | self.height = height 338 | self.rotation = rotation 339 | 340 | def __call__(self, img): 341 | h = img.size[0] 342 | w = img.size[1] 343 | augmented = img.copy() 344 | 345 | num_patch = random.randint(1, 5) 346 | for _ in range(num_patch): 347 | # cut region 348 | cut_w = random.uniform(*self.width) 349 | cut_h = random.uniform(*self.height) 350 | 351 | from_location_h = int(random.uniform(0, h - cut_h)) 352 | from_location_w = int(random.uniform(0, w - cut_w)) 353 | 354 | box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h] 355 | patch = img.crop(box) 356 | 357 | if self.colorJitter: 358 | patch = self.colorJitter(patch) 359 | 360 | # rotate 361 | rot_deg = random.uniform(*self.rotation) 362 | patch = patch.convert("RGBA").rotate(rot_deg, expand=True) 363 | 364 | # paste 365 | to_location_h = int(random.uniform(0, h - patch.size[0])) 366 | to_location_w = int(random.uniform(0, w - patch.size[1])) 367 | 368 | mask = patch.split()[-1] 369 | patch = patch.convert("RGB") 370 | 371 | augmented.paste(patch, (to_location_w, to_location_h), mask=mask) 372 | 373 | return super().__call__(img, augmented) 374 | 375 | 376 | class CutPasteUnion(object): 377 | def __init__(self, **kwags): 378 | self.normal = CutPasteNormal(**kwags) 379 | self.scar = CutPasteScar(**kwags) 380 | 381 | def __call__(self, img): 382 | r = random.uniform(0, 1) 383 | if r < 0.5: 384 | return self.normal(img) 385 | else: 386 | return self.scar(img) 387 | 388 | 389 | class CutPaste3Way(object): 390 | def __init__(self, **kwags): 391 | self.normal = CutPasteNormal(**kwags) 392 | self.scar = CutPasteScar(**kwags) 393 | 394 | def __call__(self, img): 395 | org, cutpaste_normal = self.normal(img) 396 | _, cutpaste_scar = self.scar(img) 397 | 398 | return org, cutpaste_normal, cutpaste_scar 399 | 400 | -------------------------------------------------------------------------------- /datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def get_mvtec_classes(args): 4 | if args.dataset.dataset_order == 1: 5 | mvtec_classes = ['leather', 'bottle', 'metal_nut', 6 | 'grid', 'screw', 'zipper', 7 | 'tile', 'hazelnut', 'toothbrush', 8 | 'wood', 'transistor', 'pill', 9 | 'carpet', 'capsule', 'cable'] 10 | elif args.dataset.dataset_order == 2: 11 | mvtec_classes = ['wood', 'transistor', 'pill', 12 | 'tile', 'hazelnut', 'toothbrush', 13 | 'leather', 'bottle', 'metal_nut', 14 | 'carpet', 'capsule', 'cable', 15 | 'grid', 'screw', 'zipper'] 16 | elif args.dataset.dataset_order == 3: 17 | mvtec_classes = ['leather', 'grid', 'tile', 18 | 'bottle', 'toothbrush', 'capsule', 19 | 'screw', 'pill', 'zipper', 20 | 'cable', 'metal_nut', 'hazelnut', 21 | 'wood', 'carpet', 'transistor'] 22 | return mvtec_classes 23 | -------------------------------------------------------------------------------- /datasets/utils/make_mtd_ano.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | 5 | path = '../datasets/Magnetic-tile-defect-datasets./' # original mvtec path 6 | target_path = '../datasets/mtd_ano_mask/' 7 | if not os.path.exists(target_path): 8 | os.makedirs(target_path) 9 | 10 | for ob in os.listdir(path): 11 | if ob == 'MT_Free': 12 | dest_train = os.path.join(target_path, 'train', 'good') 13 | if not os.path.exists(dest_train): 14 | os.makedirs(dest_train) 15 | dest_test = os.path.join(target_path, 'test', 'good') 16 | if not os.path.exists(dest_test): 17 | os.makedirs(dest_test) 18 | dest_gt = os.path.join(target_path, 'gt', 'good') 19 | if not os.path.exists(dest_gt): 20 | os.makedirs(dest_gt) 21 | 22 | ob_path = os.path.join(path, ob, 'Imgs') 23 | for img in os.listdir(ob_path): 24 | a = np.random.rand(1) 25 | if 'jpg' in img: 26 | src = os.path.join(ob_path, img) 27 | if a < 0.75: 28 | dest = os.path.join(dest_train, img) 29 | shutil.copy(src, dest) 30 | 31 | else: 32 | dest = os.path.join(dest_test ,img) 33 | shutil.copy(src, dest) 34 | src_ext = os.path.splitext(src) # 返回文件名和后缀 35 | src1 = src_ext[0] + '.png' 36 | img_ext = os.path.splitext(img) 37 | img1 = img_ext[0] + '.png' 38 | src1 = os.path.join(ob_path, img1) 39 | dest1 = os.path.join(dest_gt, img1) 40 | shutil.copy(src1, dest1) 41 | 42 | else: 43 | now_path = os.path.join(path, ob) 44 | if os.path.isdir(now_path): 45 | if not 'git' in ob: 46 | dest_test = os.path.join(target_path, 'test', ob) 47 | if not os.path.exists(dest_test): 48 | os.makedirs(dest_test) 49 | 50 | dest_gt = os.path.join(target_path, 'gt', ob) 51 | if not os.path.exists(dest_gt): 52 | os.makedirs(dest_gt) 53 | 54 | ob_path = os.path.join(path, ob, 'Imgs') 55 | for img in os.listdir(ob_path): 56 | if 'jpg' in img: 57 | src = os.path.join(ob_path, img) 58 | dest = os.path.join(dest_test, img) 59 | shutil.copy(src, dest) 60 | src_ext = os.path.splitext(src) # 返回文件名和后缀 61 | src1 = src_ext[0] + '.png' 62 | img_ext = os.path.splitext(img) 63 | img1 = img_ext[0] + '.png' 64 | src1 = os.path.join(ob_path, img1) 65 | dest1 = os.path.join(dest_gt, img1) 66 | shutil.copy(src1, dest1) 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_curve, auc, roc_auc_score 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | from scipy.ndimage import gaussian_filter 7 | from argument import get_args 8 | from datasets import get_dataloaders 9 | from utils.visualization import plot_tsne, compare_histogram, cal_anomaly_map 10 | 11 | 12 | 13 | def t2np(tensor): 14 | '''pytorch tensor -> numpy array''' 15 | return tensor.cpu().data.numpy() if tensor is not None else None 16 | 17 | 18 | def csflow_eval(args, epoch, dataloaders_test, learned_tasks, net): 19 | all_roc_auc = [] 20 | eval_task_wise_scores, eval_task_wise_labels = [], [] 21 | task_num = 0 22 | for idx, (dataloader_test, learned_task) in enumerate(zip(dataloaders_test, learned_tasks)): 23 | test_z, test_labels = list(), list() 24 | 25 | with torch.no_grad(): 26 | for i, data in enumerate(dataloader_test): 27 | inputs, labels = data 28 | inputs, labels = inputs.to(args.device), labels.to(args.device) 29 | _, z, jac = net(inputs) 30 | z = t2np(z[..., None]) 31 | score = np.mean(z ** 2, axis=(1, 2)) 32 | test_z.append(score) 33 | test_labels.append(t2np(labels)) 34 | 35 | test_labels = np.concatenate(test_labels) 36 | is_anomaly = np.array([0 if l == 0 else 1 for l in test_labels]) 37 | 38 | anomaly_score = np.concatenate(test_z, axis=0) 39 | roc_auc = roc_auc_score(is_anomaly, anomaly_score) 40 | all_roc_auc.append(roc_auc * len(learned_task)) 41 | task_num += len(learned_task) 42 | print('data_type:', learned_task, 'auc:', roc_auc, '**' * 11) 43 | 44 | eval_task_wise_scores.append(anomaly_score) 45 | eval_task_wise_scores_np = np.concatenate(eval_task_wise_scores) 46 | eval_task_wise_labels.append(is_anomaly) 47 | eval_task_wise_labels_np = np.concatenate(eval_task_wise_labels) 48 | print('mean_auc:', np.sum(all_roc_auc) / task_num, '**' * 11) 49 | 50 | if args.eval.visualization: 51 | name = f'{args.model.method}_task{len(learned_tasks)}_epoch{epoch}' 52 | his_save_path = f'./his_results/{args.model.method}{args.model.name}_{args.train.num_epochs}_epochs_seed{args.seed}' 53 | compare_histogram(np.array(eval_task_wise_scores_np), np.array(eval_task_wise_labels_np), start=0, thresh=5, 54 | interval=1, name=name, save_path=his_save_path) 55 | 56 | 57 | def revdis_eval(args, epoch, dataloaders_test, learned_tasks, net): 58 | all_roc_auc = [] 59 | eval_task_wise_scores, eval_task_wise_labels = [], [] 60 | task_num = 0 61 | for idx, (dataloader_test, learned_task) in enumerate(zip(dataloaders_test, learned_tasks)): 62 | gt_list_sp, pr_list_sp = [], [] 63 | with torch.no_grad(): 64 | for img, gt, label in dataloader_test: 65 | img = img.to(args.device) 66 | inputs = net.encoder(img) 67 | outputs = net.decoder(net.bn(inputs)) 68 | anomaly_map, _ = cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a') 69 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 70 | gt[gt > 0.5] = 1 71 | gt[gt <= 0.5] = 0 72 | gt_list_sp.append(np.max(gt.cpu().numpy().astype(int))) 73 | pr_list_sp.append(np.max(anomaly_map)) 74 | 75 | roc_auc = roc_auc_score(gt_list_sp, pr_list_sp) 76 | all_roc_auc.append(roc_auc * len(learned_task)) 77 | task_num += len(learned_task) 78 | print('data_type:', learned_task, 'auc:', roc_auc, '**' * 11) 79 | 80 | eval_task_wise_scores.append(pr_list_sp) 81 | eval_task_wise_scores_np = np.concatenate(eval_task_wise_scores) 82 | eval_task_wise_labels.append(gt_list_sp) 83 | eval_task_wise_labels_np = np.concatenate(eval_task_wise_labels) 84 | print('mean_auc:', np.sum(all_roc_auc) / task_num, '**' * 11) 85 | 86 | if args.eval.visualization: 87 | name = f'{args.model.method}_task{len(learned_tasks)}_epoch{epoch}' 88 | his_save_path = f'./his_results/{args.model.method}{args.model.name}_{args.train.num_epochs}_epochs_seed{args.seed}' 89 | compare_histogram(np.array(eval_task_wise_scores_np), np.array(eval_task_wise_labels_np), thresh=2, interval=1, 90 | name=name, save_path=his_save_path) 91 | 92 | 93 | 94 | def eval_model(args, epoch, dataloaders_test, learned_tasks, net, density): 95 | if args.model.method == 'csflow': 96 | csflow_eval(args, epoch, dataloaders_test, learned_tasks, net) 97 | elif args.model.method == 'revdis': 98 | revdis_eval(args, epoch, dataloaders_test, learned_tasks, net) 99 | else: 100 | all_roc_auc, all_embeds, all_labels = [], [], [] 101 | task_num = 0 102 | for idx, (dataloader_test, learned_task) in enumerate(zip(dataloaders_test, learned_tasks)): 103 | labels, embeds, logits = [], [], [] 104 | with torch.no_grad(): 105 | for x, label in dataloader_test: 106 | logit, embed = net(x.to(args.device)) 107 | _, logit = torch.max(logit, 1) 108 | logits.append(logit.cpu()) 109 | embeds.append(embed.cpu()) 110 | labels.append(label.cpu()) 111 | labels, embeds, logits = torch.cat(labels), torch.cat(embeds), torch.cat(logits) 112 | # norm embeds 113 | if args.eval.eval_classifier == 'density': 114 | embeds = F.normalize(embeds, p=2, dim=1) # embeds.shape=(2*bs, emd_dim) 115 | distances = density.predict(embeds) # distances.shape=(2*bs) 116 | fpr, tpr, _ = roc_curve(labels, distances) 117 | elif args.eval.eval_classifier == 'head': 118 | fpr, tpr, _ = roc_curve(labels, logits) 119 | roc_auc = auc(fpr, tpr) 120 | all_roc_auc.append(roc_auc * len(learned_task)) 121 | task_num += len(learned_task) 122 | all_embeds.append(embeds) 123 | all_labels.append(labels) 124 | print('data_type:', learned_task[:], 'auc:', roc_auc, '**' * 11) 125 | 126 | if args.eval.visualization: 127 | name = f'{args.model.method}_task{len(learned_tasks)}_{learned_task[0]}_epoch{epoch}' 128 | his_save_path = f'./his_results/{args.model.method}{args.model.name}_{args.train.num_epochs}e_order{args.data_order}_seed{args.seed}' 129 | tnse_save_path = f'./tsne_results/{args.model.method}{args.model.name}_{args.train.num_epochs}e_order{args.data_order}_seed{args.seed}' 130 | plot_tsne(labels, np.array(embeds), defect_name=name, save_path=tnse_save_path) 131 | # These parameters can be modified based on the visualization effect 132 | start, thresh, interval = 0, 120, 1 133 | compare_histogram(np.array(distances), labels, start=start, 134 | thresh=thresh, interval=interval, 135 | name=name, save_path=his_save_path) 136 | 137 | print('mean_auc:', np.sum(all_roc_auc) / task_num, '**' * 11) 138 | 139 | 140 | if __name__ == "__main__": 141 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 142 | args = get_args() 143 | dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames = [], [], [], [] 144 | for t in range(args.dataset.n_tasks): 145 | train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, data_train_nums, all_test_filenames = get_dataloaders(args, t, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames) 146 | 147 | epoch = args.train.num_epochs 148 | net, density = torch.load(f'{args.save_path}/net.pth'), torch.load(f'{args.save_path}/density.pth') 149 | eval_model(args, epoch, dataloaders_test, learned_tasks, net, density) 150 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | from torch import optim 4 | import os 5 | import numpy as np 6 | from eval import eval_model 7 | from argument import get_args 8 | from models import get_net_optimizer_scheduler 9 | from methods import get_model 10 | from datasets import get_dataloaders 11 | from utils.density import GaussianDensityTorch 12 | 13 | 14 | 15 | def get_inputs_labels(data): 16 | if isinstance(data, list): 17 | inputs = [x.to(args.device) for x in data] 18 | labels = torch.arange(len(inputs), device=args.device) 19 | labels = labels.repeat_interleave(inputs[0].size(0)) 20 | inputs = torch.cat(inputs, dim=0) 21 | else: 22 | inputs = data.to(args.device) 23 | labels = torch.zeros(inputs.size(0), device=args.device).long() 24 | return inputs, labels 25 | 26 | 27 | def main(args): 28 | net, optimizer, scheduler = get_net_optimizer_scheduler(args) 29 | density = GaussianDensityTorch() 30 | net.to(args.device) 31 | 32 | model = get_model(args, net, optimizer, scheduler) 33 | 34 | dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames = [], [], [], [] 35 | task_wise_mean, task_wise_cov, task_wise_train_data_nums = [], [], [] 36 | for t in range(args.dataset.n_tasks): 37 | print('---' * 10, f'Task:{t}', '---' * 10) 38 | train_dataloader, dataloaders_train, dataloaders_test, learned_tasks, data_train_nums, all_test_filenames = get_dataloaders(args, t, dataloaders_train, dataloaders_test, learned_tasks, all_test_filenames) 39 | task_wise_train_data_nums.append(data_train_nums) 40 | 41 | extra_para = None 42 | if args.model.method == 'panda': 43 | extra_para = model.get_center(train_dataloader) 44 | 45 | net.train() 46 | for epoch in tqdm(range(args.train.num_epochs)): 47 | one_epoch_embeds = [] 48 | if args.model.method == 'upper': 49 | for dataloader_train in dataloaders_train: 50 | for batch_idx, (data) in enumerate(dataloader_train): 51 | inputs, labels = get_inputs_labels(data) 52 | model(epoch, inputs, labels, one_epoch_embeds, t, extra_para) 53 | else: 54 | for batch_idx, (data) in enumerate(train_dataloader): 55 | inputs, labels = get_inputs_labels(data) 56 | model(epoch, inputs, labels, one_epoch_embeds, t, extra_para) 57 | 58 | if args.train.test_epochs > 0 and (epoch+1) % args.train.test_epochs == 0: 59 | net.eval() 60 | density = model.training_epoch(density, one_epoch_embeds, task_wise_mean, task_wise_cov, task_wise_train_data_nums, t) 61 | eval_model(args, epoch, dataloaders_test, learned_tasks, net, density) 62 | net.train() 63 | 64 | if hasattr(model, 'end_task'): 65 | model.end_task(train_dataloader) 66 | 67 | if args.save_checkpoint: 68 | torch.save(net, f'{args.save_path}/net.pth') 69 | torch.save(density, f'{args.save_path}/density.pth') 70 | 71 | 72 | if __name__ == "__main__": 73 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 74 | args = get_args() 75 | main(args) 76 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .cutpaste import CutPaste 2 | from .dne import DNE 3 | from .csflow import CSFlow 4 | from .panda import PANDA 5 | from .revdis import RevDis 6 | from .der import DER 7 | from .er import ER 8 | from .derpp import DERpp 9 | from .fdr import FDR 10 | from .agem import AGEM 11 | 12 | 13 | def get_model(args, net, optimizer, scheduler): 14 | if args.model.method == 'dne': 15 | model = DNE(args, net, optimizer, scheduler) 16 | args.dataset.strong_augmentation = True 17 | elif args.model.method == 'upper': 18 | model = CutPaste(args, net, optimizer, scheduler) 19 | elif args.model.method == 'cutpaste': 20 | model = CutPaste(args, net, optimizer, scheduler) 21 | args.dataset.strong_augmentation = True 22 | args.dataset.strong_augmentation = True 23 | elif args.model.method == 'csflow': 24 | model = CSFlow(args, net, optimizer, scheduler) 25 | args.dataset.strong_augmentation = False 26 | elif args.model.method == 'panda': 27 | model = PANDA(args, net, optimizer, scheduler) 28 | args.dataset.strong_augmentation = False 29 | elif args.model.method == 'revdis': 30 | model = RevDis(args, net, optimizer, scheduler) 31 | args.dataset.strong_augmentation = False 32 | elif args.model.method == 'er': 33 | model = ER(args, net, optimizer, scheduler) 34 | elif args.model.method == 'der': 35 | model = DER(args, net, optimizer, scheduler) 36 | elif args.model.method == 'derpp': 37 | model = DERpp(args, net, optimizer, scheduler) 38 | elif args.model.method == 'fdr': 39 | model = FDR(args, net, optimizer, scheduler) 40 | elif args.model.method == 'agem': 41 | model = AGEM(args, net, optimizer, scheduler) 42 | return model -------------------------------------------------------------------------------- /methods/agem.py: -------------------------------------------------------------------------------- 1 | from utils.buffer import Buffer 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from .utils.base_method import BaseMethodwDNE 7 | try: 8 | import quadprog 9 | except: 10 | print('Warning: GEM and A-GEM cannot be used on Windows (quadprog required)') 11 | 12 | 13 | 14 | def store_grad(params, grads, grad_dims): 15 | """ 16 | This stores parameter gradients of past tasks. 17 | pp: parameters 18 | grads: gradients 19 | grad_dims: list with number of parameters per layers 20 | """ 21 | # store the gradients 22 | grads.fill_(0.0) 23 | count = 0 24 | for param in params(): 25 | if param.grad is not None: 26 | begin = 0 if count == 0 else sum(grad_dims[:count]) 27 | end = np.sum(grad_dims[:count + 1]) 28 | grads[begin: end].copy_(param.grad.data.view(-1)) 29 | count += 1 30 | 31 | def overwrite_grad(params, newgrad, grad_dims): 32 | """ 33 | This is used to overwrite the gradients with a new gradient 34 | vector, whenever violations occur. 35 | pp: parameters 36 | newgrad: corrected gradient 37 | grad_dims: list storing number of parameters at each layer 38 | """ 39 | count = 0 40 | for param in params(): 41 | if param.grad is not None: 42 | begin = 0 if count == 0 else sum(grad_dims[:count]) 43 | end = sum(grad_dims[:count + 1]) 44 | this_grad = newgrad[begin: end].contiguous().view( 45 | param.grad.data.size()) 46 | param.grad.data.copy_(this_grad) 47 | count += 1 48 | 49 | 50 | def project(gxy: torch.Tensor, ger: torch.Tensor) -> torch.Tensor: 51 | corr = torch.dot(gxy, ger) / torch.dot(ger, ger) 52 | return gxy - corr * ger 53 | 54 | 55 | class AGEM(BaseMethodwDNE): 56 | def __init__(self, args, net, optimizer, scheduler): 57 | super(AGEM, self).__init__(args, net, optimizer, scheduler) 58 | self.buffer = Buffer(self.args.model.buffer_size, self.args.device) 59 | self.cross_entropy = nn.CrossEntropyLoss() 60 | self.current_task = 0 61 | self.grad_dims = [] 62 | for param in self.parameters(): 63 | self.grad_dims.append(param.data.numel()) 64 | self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.args.device) 65 | self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.args.device) 66 | 67 | def end_task(self, train_loader): 68 | data = next(iter(train_loader)) 69 | cur_x = [x for x in data] 70 | cur_y = torch.arange(len(cur_x)) 71 | cur_y = cur_y.repeat_interleave(cur_x[0].size(0)) 72 | cur_x = torch.cat(cur_x, dim=0) 73 | self.buffer.add_data( 74 | examples=cur_x.to(self.args.device), 75 | labels=cur_y.to(self.args.device) 76 | ) 77 | 78 | def forward(self, epoch, inputs, labels, one_epoch_embeds, t, *args): 79 | num = self.pre_forward(inputs, t) 80 | 81 | with torch.no_grad(): 82 | noaug_logits, noaug_embeds = self.net(inputs[:num]) 83 | one_epoch_embeds.append(noaug_embeds.cpu()) 84 | 85 | self.optimizer.zero_grad() 86 | logits, embeds = self.net(inputs) 87 | loss = self.cross_entropy(logits, labels) 88 | loss.backward() 89 | 90 | if not self.buffer.is_empty(): 91 | store_grad(self.parameters, self.grad_xy, self.grad_dims) 92 | 93 | buf_inputs, buf_labels = self.buffer.get_data(self.args.train.batch_size) 94 | self.net.zero_grad() 95 | buf_outputs, _ = self.net(buf_inputs) 96 | penalty = self.cross_entropy(buf_outputs, buf_labels) 97 | penalty.backward() 98 | store_grad(self.parameters, self.grad_er, self.grad_dims) 99 | 100 | dot_prod = torch.dot(self.grad_xy, self.grad_er) 101 | if dot_prod.item() < 0: 102 | g_tilde = project(gxy=self.grad_xy, ger=self.grad_er) 103 | overwrite_grad(self.parameters, g_tilde, self.grad_dims) 104 | else: 105 | overwrite_grad(self.parameters, self.grad_xy, self.grad_dims) 106 | 107 | self.optimizer.step() 108 | -------------------------------------------------------------------------------- /methods/csflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import numpy as np 5 | from .utils.base_method import BaseMethod 6 | 7 | 8 | class CSFlow(BaseMethod): 9 | def __init__(self, args, net, optimizer, scheduler): 10 | super(CSFlow, self).__init__(args, net, optimizer, scheduler) 11 | 12 | 13 | def forward(self, epoch, inputs, labels, one_epoch_embeds, *args): 14 | self.optimizer.zero_grad() 15 | embeds, z, log_jac_det = self.net(inputs) 16 | # yy, rev_y, zz = self.net.revward(inputs) 17 | loss = torch.mean(0.5 * torch.sum(z ** 2, dim=(1,)) - log_jac_det) / z.shape[1] 18 | 19 | loss.backward() 20 | self.optimizer.step() 21 | if self.scheduler: 22 | self.scheduler.step(epoch) 23 | 24 | -------------------------------------------------------------------------------- /methods/cutpaste.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from .utils.base_method import BaseMethod 6 | 7 | 8 | 9 | class CutPaste(BaseMethod): 10 | def __init__(self, args, net, optimizer, scheduler): 11 | super(CutPaste, self).__init__(args, net, optimizer, scheduler) 12 | self.cross_entropy = nn.CrossEntropyLoss() 13 | 14 | def forward(self, epoch, inputs, labels, one_epoch_embeds, *args): 15 | if self.args.dataset.strong_augmentation: 16 | half_num = int(len(inputs) / 2) 17 | no_strongaug_inputs = inputs[:half_num] 18 | else: 19 | no_strongaug_inputs = inputs 20 | 21 | self.optimizer.zero_grad() 22 | with torch.no_grad(): 23 | noaug_embeds = self.net.forward_features(no_strongaug_inputs) 24 | one_epoch_embeds.append(noaug_embeds.cpu()) 25 | out, _ = self.net(inputs) 26 | loss = self.cross_entropy(out, labels) 27 | loss.backward() 28 | self.optimizer.step() 29 | if self.scheduler: 30 | self.scheduler.step(epoch) 31 | 32 | def training_epoch(self, density, one_epoch_embeds, *args): 33 | if self.args.eval.eval_classifier == 'density': 34 | one_epoch_embeds = torch.cat(one_epoch_embeds) 35 | one_epoch_embeds = F.normalize(one_epoch_embeds, p=2, dim=1) 36 | _, _ = density.fit(one_epoch_embeds) 37 | return density 38 | else: 39 | pass 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /methods/der.py: -------------------------------------------------------------------------------- 1 | from utils.buffer import Buffer 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from .utils.base_method import BaseMethodwDNE 7 | 8 | 9 | class DER(BaseMethodwDNE): 10 | def __init__(self, args, net, optimizer, scheduler): 11 | super(DER, self).__init__(args, net, optimizer, scheduler) 12 | self.buffer = Buffer(self.args.model.buffer_size, self.args.device) 13 | self.cross_entropy = nn.CrossEntropyLoss() 14 | 15 | def forward(self, epoch, inputs, labels, one_epoch_embeds, t, *args): 16 | num = self.pre_forward(inputs, t) 17 | 18 | self.optimizer.zero_grad() 19 | with torch.no_grad(): 20 | noaug_logits, noaug_embeds = self.net(inputs[:num]) 21 | one_epoch_embeds.append(noaug_embeds.cpu()) 22 | 23 | logits, embeds = self.net(inputs) 24 | loss = self.cross_entropy(logits, labels) 25 | if not self.buffer.is_empty(): 26 | if self.args.model.with_embeds: 27 | buf_inputs, buf_embeds = self.buffer.get_data(self.args.train.batch_size) 28 | _, past_embeds = self.net(buf_inputs) 29 | loss += self.args.train.alpha * F.mse_loss(past_embeds, buf_embeds) 30 | else: 31 | buf_inputs, buf_logits = self.buffer.get_data(self.args.train.batch_size) 32 | past_logits, _ = self.net(buf_inputs) 33 | loss += self.args.train.alpha * F.mse_loss(past_logits, buf_logits) 34 | 35 | loss.backward() 36 | self.optimizer.step() 37 | if self.scheduler: 38 | self.scheduler.step(epoch) 39 | 40 | if self.args.model.with_embeds: 41 | self.buffer.add_data(examples=inputs[:num], logits=noaug_embeds.data) 42 | else: 43 | self.buffer.add_data(examples=inputs[:num], logits=noaug_logits.data) 44 | -------------------------------------------------------------------------------- /methods/derpp.py: -------------------------------------------------------------------------------- 1 | from utils.buffer import Buffer 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from .utils.base_method import BaseMethodwDNE 7 | 8 | 9 | class DERpp(BaseMethodwDNE): 10 | def __init__(self, args, net, optimizer, scheduler): 11 | super(DERpp, self).__init__(args, net, optimizer, scheduler) 12 | self.buffer = Buffer(self.args.model.buffer_size, self.args.device) 13 | self.cross_entropy = nn.CrossEntropyLoss() 14 | 15 | def forward(self, epoch, inputs, labels, one_epoch_embeds, t, *args): 16 | num = self.pre_forward(inputs, t) 17 | 18 | self.optimizer.zero_grad() 19 | with torch.no_grad(): 20 | noaug_logits, noaug_embeds = self.net(inputs[:num]) 21 | one_epoch_embeds.append(noaug_embeds.cpu()) 22 | 23 | logits, embeds = self.net(inputs) 24 | loss = self.cross_entropy(logits, labels) 25 | if not self.buffer.is_empty(): 26 | buf_inputs, buf_labels, buf_logits = self.buffer.get_data(self.args.train.batch_size) 27 | past_logits, _ = self.net(buf_inputs) 28 | loss += (self.args.train.alpha * F.mse_loss(past_logits, buf_logits) 29 | + self.args.train.beta * self.cross_entropy(past_logits, buf_labels)) 30 | 31 | loss.backward() 32 | self.optimizer.step() 33 | if self.scheduler: 34 | self.scheduler.step(epoch) 35 | 36 | self.buffer.add_data(examples=inputs[:num], labels=labels[:num], logits=noaug_logits.data) 37 | 38 | 39 | -------------------------------------------------------------------------------- /methods/dne.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from .utils.base_method import BaseMethod 6 | 7 | 8 | 9 | class DNE(BaseMethod): 10 | def __init__(self, args, net, optimizer, scheduler): 11 | super(DNE, self).__init__(args, net, optimizer, scheduler) 12 | self.cross_entropy = nn.CrossEntropyLoss() 13 | 14 | 15 | def forward(self, epoch, inputs, labels, one_epoch_embeds, t, *args): 16 | if self.args.dataset.strong_augmentation: 17 | half_num = int(len(inputs) / 2) 18 | no_strongaug_inputs = inputs[:half_num] 19 | else: 20 | no_strongaug_inputs = inputs 21 | 22 | if self.args.model.fix_head: 23 | if t >= 1: 24 | for param in self.net.head.parameters(): 25 | param.requires_grad = False 26 | 27 | self.optimizer.zero_grad() 28 | with torch.no_grad(): 29 | noaug_embeds = self.net.forward_features(no_strongaug_inputs) 30 | one_epoch_embeds.append(noaug_embeds.cpu()) 31 | out, _ = self.net(inputs) 32 | loss = self.cross_entropy(out, labels) 33 | loss.backward() 34 | self.optimizer.step() 35 | if self.scheduler: 36 | self.scheduler.step(epoch) 37 | 38 | 39 | def training_epoch(self, density, one_epoch_embeds, task_wise_mean, task_wise_cov, task_wise_train_data_nums, t): 40 | if self.args.eval.eval_classifier == 'density': 41 | one_epoch_embeds = torch.cat(one_epoch_embeds) 42 | one_epoch_embeds = F.normalize(one_epoch_embeds, p=2, dim=1) 43 | mean, cov = density.fit(one_epoch_embeds) 44 | 45 | if len(task_wise_mean) < t + 1: 46 | task_wise_mean.append(mean) 47 | task_wise_cov.append(cov) 48 | else: 49 | task_wise_mean[-1] = mean 50 | task_wise_cov[-1] = cov 51 | 52 | task_wise_embeds = [] 53 | for i in range(t + 1): 54 | if i < t: 55 | past_mean, past_cov, past_nums = task_wise_mean[i], task_wise_cov[i], task_wise_train_data_nums[i] 56 | past_embeds = np.random.multivariate_normal(past_mean, past_cov, size=int(past_nums * (1 - self.args.noise_ratio))) 57 | task_wise_embeds.append(torch.FloatTensor(past_embeds)) 58 | noise_mean, noise_cov = np.random.rand(past_mean.shape[0]), np.random.rand(past_cov.shape[0], past_cov.shape[1]) 59 | noise = np.random.multivariate_normal(noise_mean, noise_cov, size=int(past_nums * self.args.noise_ratio)) 60 | task_wise_embeds.append(torch.FloatTensor(noise)) 61 | else: 62 | task_wise_embeds.append(one_epoch_embeds) 63 | for_eval_embeds = torch.cat(task_wise_embeds, dim=0) 64 | for_eval_embeds = F.normalize(for_eval_embeds, p=2, dim=1) 65 | _, _ = density.fit(for_eval_embeds) 66 | return density 67 | else: 68 | pass 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /methods/er.py: -------------------------------------------------------------------------------- 1 | from utils.buffer import Buffer 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from .utils.base_method import BaseMethodwDNE 7 | 8 | 9 | class ER(BaseMethodwDNE): 10 | def __init__(self, args, net, optimizer, scheduler): 11 | super(ER, self).__init__(args, net, optimizer, scheduler) 12 | self.buffer = Buffer(self.args.model.buffer_size, self.args.device) 13 | self.cross_entropy = nn.CrossEntropyLoss() 14 | 15 | def forward(self, epoch, inputs, labels, one_epoch_embeds, t, *args): 16 | num = self.pre_forward(inputs, t) 17 | 18 | self.optimizer.zero_grad() 19 | with torch.no_grad(): 20 | noaug_logits, noaug_embeds = self.net(inputs[:num]) 21 | one_epoch_embeds.append(noaug_embeds.cpu()) 22 | 23 | if not self.buffer.is_empty(): 24 | buf_inputs, buf_labels = self.buffer.get_data(self.args.train.batch_size) 25 | inputs = torch.cat((inputs, buf_inputs)) 26 | labels = torch.cat((labels, buf_labels)) 27 | 28 | logits, _ = self.net(inputs) 29 | loss = self.cross_entropy(logits, labels) 30 | 31 | loss.backward() 32 | self.optimizer.step() 33 | if self.scheduler: 34 | self.scheduler.step(epoch) 35 | 36 | self.buffer.add_data(examples=inputs[:num], labels=labels[:num]) 37 | 38 | -------------------------------------------------------------------------------- /methods/fdr.py: -------------------------------------------------------------------------------- 1 | from utils.buffer import Buffer 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from .utils.base_method import BaseMethodwDNE 7 | 8 | 9 | class FDR(BaseMethodwDNE): 10 | def __init__(self, args, net, optimizer, scheduler): 11 | super(FDR, self).__init__(args, net, optimizer, scheduler) 12 | self.buffer = Buffer(self.args.model.buffer_size, self.args.device) 13 | self.cross_entropy = nn.CrossEntropyLoss() 14 | self.current_task = 0 15 | self.i = 0 16 | self.soft = torch.nn.Softmax(dim=1) 17 | self.logsoft = torch.nn.LogSoftmax(dim=1) 18 | 19 | def end_task(self, train_loader): 20 | self.current_task += 1 21 | examples_per_task = self.args.model.buffer_size // self.current_task 22 | 23 | if self.current_task > 1: 24 | buf_x, buf_log, buf_tl = self.buffer.get_all_data() 25 | self.buffer.empty() 26 | 27 | for ttl in buf_tl.unique(): 28 | idx = (buf_tl == ttl) 29 | ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx] 30 | first = min(ex.shape[0], examples_per_task) 31 | self.buffer.add_data( 32 | examples=ex[:first], 33 | logits=log[:first], 34 | task_labels=tasklab[:first]) 35 | counter = 0 36 | with torch.no_grad(): 37 | for i, data in enumerate(train_loader): 38 | if isinstance(data, list): 39 | inputs = [x.to(self.args.device) for x in data] 40 | inputs = torch.cat(inputs, dim=0) 41 | else: 42 | inputs = data.to(self.args.device) 43 | 44 | if self.args.dataset.strong_augmentation: 45 | num = int(len(inputs) / 2) 46 | else: 47 | num = int(len(inputs)) 48 | 49 | not_aug_inputs = inputs[:num] 50 | not_aug_logits, not_aug_embeds = self.net(not_aug_inputs) 51 | if examples_per_task - counter < 0: 52 | break 53 | self.buffer.add_data(examples=not_aug_inputs[:(examples_per_task - counter)], 54 | logits=not_aug_logits.data[:(examples_per_task - counter)], 55 | task_labels=(torch.ones(self.args.train.batch_size) * 56 | (self.current_task - 1))[:(examples_per_task - counter)]) 57 | counter += self.args.train.batch_size 58 | 59 | def forward(self, epoch, inputs, labels, one_epoch_embeds, t, *args): 60 | num = self.pre_forward(inputs, t) 61 | 62 | self.optimizer.zero_grad() 63 | with torch.no_grad(): 64 | noaug_logits, noaug_embeds = self.net(inputs[:num]) 65 | one_epoch_embeds.append(noaug_embeds.cpu()) 66 | 67 | logits, embeds = self.net(inputs) 68 | loss = self.cross_entropy(logits, labels) 69 | loss.backward() 70 | self.optimizer.step() 71 | 72 | if not self.buffer.is_empty(): 73 | self.optimizer.zero_grad() 74 | buf_inputs, buf_logits, _ = self.buffer.get_data(self.args.train.batch_size) 75 | past_logits, _ = self.net(buf_inputs) 76 | loss = torch.norm(self.soft(past_logits) - self.soft(buf_logits), 2, 1).mean() 77 | assert not torch.isnan(loss) 78 | loss.backward() 79 | self.optimizer.step() 80 | 81 | if self.scheduler: 82 | self.scheduler.step(epoch) 83 | -------------------------------------------------------------------------------- /methods/panda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from .utils.base_method import BaseMethod 6 | 7 | 8 | class CompactnessLoss(nn.Module): 9 | def __init__(self, center): 10 | super(CompactnessLoss, self).__init__() 11 | self.center = center # (768, ) 12 | 13 | def forward(self, inputs): 14 | m = inputs.size(1) 15 | variances = (inputs - self.center).norm(dim=1).pow(2) / m # (32, ) 16 | return variances.mean() 17 | 18 | 19 | # contastive svdd 20 | class PANDA(BaseMethod): 21 | def __init__(self, args, net, optimizer, scheduler): 22 | super(PANDA, self).__init__(args, net, optimizer, scheduler) 23 | 24 | def get_center(self, train_loader): 25 | self.net.eval() 26 | train_feature_space = [] 27 | with torch.no_grad(): 28 | for imgs in train_loader: 29 | imgs = imgs.to(self.args.device) 30 | features = self.net.forward_features(imgs) 31 | train_feature_space.append(features) 32 | train_feature_space = torch.cat(train_feature_space, dim=0).contiguous().cpu().numpy() 33 | center = torch.FloatTensor(train_feature_space).mean(dim=0) 34 | return center 35 | 36 | def forward(self, epoch, inputs, labels, one_epoch_embeds, t, center): 37 | self.compactness_loss = CompactnessLoss(center.to(self.args.device)) 38 | 39 | self.optimizer.zero_grad() 40 | embeds = self.net.forward_features(inputs) # (32, 768) 41 | one_epoch_embeds.append(embeds.detach().cpu()) 42 | loss = self.compactness_loss(embeds) 43 | 44 | loss.backward() 45 | self.optimizer.step() 46 | if self.scheduler: 47 | self.scheduler.step(epoch) 48 | 49 | def training_epoch(self, density, one_epoch_embeds, task_wise_mean, task_wise_cov, task_wise_train_data_nums, t): 50 | if self.args.eval.eval_classifier == 'density': 51 | one_epoch_embeds = torch.cat(one_epoch_embeds) 52 | one_epoch_embeds = F.normalize(one_epoch_embeds, p=2, dim=1) 53 | _, _ = density.fit(one_epoch_embeds) 54 | return density 55 | else: 56 | pass 57 | 58 | 59 | -------------------------------------------------------------------------------- /methods/revdis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils.base_method import BaseMethod 4 | 5 | 6 | class RevDis(BaseMethod): 7 | def __init__(self, args, net, optimizer, scheduler): 8 | super(RevDis, self).__init__(args, net, optimizer, scheduler) 9 | 10 | def loss_fucntion(self, a, b): 11 | cos_loss = torch.nn.CosineSimilarity() 12 | loss = 0 13 | for item in range(len(a)): 14 | loss += torch.mean(1 - cos_loss(a[item].view(a[item].shape[0], -1), 15 | b[item].view(b[item].shape[0], -1))) 16 | return loss 17 | 18 | def forward(self, epoch, inputs, labels, one_epoch_embeds, *args): 19 | self.optimizer.zero_grad() 20 | t_outs, outs = self.net(inputs) 21 | loss = self.loss_fucntion(t_outs, outs) 22 | loss.backward() 23 | self.optimizer.step() 24 | -------------------------------------------------------------------------------- /methods/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vijaylee/Continual_Anomaly_Detection/7be4ea5b4087f292ddc0e332e9e9e0fd6eca0c67/methods/utils/__init__.py -------------------------------------------------------------------------------- /methods/utils/base_method.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | class BaseMethod(nn.Module): 9 | def __init__(self, args, net, optimizer, scheduler): 10 | super(BaseMethod, self).__init__() 11 | self.args = args 12 | self.optimizer = optimizer 13 | self.scheduler = scheduler 14 | self.net = net 15 | 16 | def forward(self, *args): 17 | pass 18 | 19 | def training_epoch(self, density, one_epoch_embeds, task_wise_mean, task_wise_cov, task_wise_train_data_nums, t): 20 | pass 21 | 22 | 23 | 24 | class BaseMethodwDNE(nn.Module): 25 | def __init__(self, args, net, optimizer, scheduler): 26 | super(BaseMethodwDNE, self).__init__() 27 | self.args = args 28 | self.optimizer = optimizer 29 | self.scheduler = scheduler 30 | self.net = net 31 | 32 | def pre_forward(self, inputs, t): 33 | if self.args.dataset.strong_augmentation: 34 | num = int(len(inputs) / 2) 35 | else: 36 | num = int(len(inputs)) 37 | 38 | if self.args.model.fix_head: 39 | if t >= 1: 40 | for param in self.net.head.parameters(): 41 | param.requires_grad = False 42 | return num 43 | 44 | def forward(self, *args): 45 | pass 46 | 47 | def training_epoch(self, density, one_epoch_embeds, task_wise_mean, task_wise_cov, task_wise_train_data_nums, t): 48 | if self.args.eval.eval_classifier == 'density': 49 | if self.args.model.with_dne: 50 | one_epoch_embeds = torch.cat(one_epoch_embeds) 51 | one_epoch_embeds = F.normalize(one_epoch_embeds, p=2, dim=1) 52 | mean, cov = density.fit(one_epoch_embeds) 53 | 54 | if len(task_wise_mean) < t + 1: 55 | task_wise_mean.append(mean) 56 | task_wise_cov.append(cov) 57 | else: 58 | task_wise_mean[-1] = mean 59 | task_wise_cov[-1] = cov 60 | 61 | task_wise_embeds = [] 62 | for i in range(t + 1): 63 | if i < t: 64 | past_mean, past_cov, past_nums = task_wise_mean[i], task_wise_cov[i], task_wise_train_data_nums[i] 65 | past_embeds = np.random.multivariate_normal(past_mean, past_cov, size=past_nums) 66 | task_wise_embeds.append(torch.FloatTensor(past_embeds)) 67 | else: 68 | task_wise_embeds.append(one_epoch_embeds) 69 | for_eval_embeds = torch.cat(task_wise_embeds, dim=0) 70 | for_eval_embeds = F.normalize(for_eval_embeds, p=2, dim=1) 71 | _, _ = density.fit(for_eval_embeds) 72 | return density 73 | else: 74 | one_epoch_embeds = torch.cat(one_epoch_embeds) 75 | one_epoch_embeds = F.normalize(one_epoch_embeds, p=2, dim=1) 76 | _, _ = density.fit(one_epoch_embeds) 77 | return density 78 | else: 79 | pass 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNetModel 2 | from .vit import ViT 3 | from .csflow_net import NetCSFlow 4 | from .revdis_net import NetRevDis 5 | 6 | from torch import nn 7 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 8 | from utils.optimizer import get_optimizer 9 | 10 | 11 | 12 | def get_net_optimizer_scheduler(args): 13 | if args.model.name == 'resnet': 14 | net = ResNetModel(pretrained=args.model.pretrained, num_classes=args.train.num_classes) 15 | optimizer = get_optimizer(args, net) 16 | scheduler = CosineAnnealingWarmRestarts(optimizer, args.train.num_epochs) 17 | elif args.model.name == 'vit': 18 | net = ViT(num_classes=args.train.num_classes) 19 | if args.model.pretrained: 20 | checkpoint_path = './checkpoints/ViT-B_16.npz' 21 | net.load_pretrained(checkpoint_path) 22 | optimizer = get_optimizer(args, net) 23 | scheduler = CosineAnnealingWarmRestarts(optimizer, args.train.num_epochs) 24 | elif args.model.name == 'net_csflow': 25 | net = NetCSFlow(args) 26 | optim_modules = nn.ModuleList() 27 | if args.model.pretrained: 28 | names_to_update = ["density_estimator"] 29 | for name, param in net.named_parameters(): 30 | param.requires_grad_(False) 31 | for name_to_update in names_to_update: 32 | optim_modules.append(getattr(net, name_to_update)) 33 | for name, param in net.named_parameters(): 34 | if name_to_update in name: 35 | param.requires_grad_(True) 36 | optimizer = get_optimizer(args, optim_modules) 37 | scheduler = None 38 | elif args.model.name == 'net_revdis': 39 | net = NetRevDis(args) 40 | optim_modules = nn.ModuleList() 41 | if args.model.pretrained: 42 | names_to_update = ["decoder", "bn"] 43 | for name, param in net.named_parameters(): 44 | param.requires_grad_(False) 45 | for name_to_update in names_to_update: 46 | optim_modules.append(getattr(net, name_to_update)) 47 | for name, param in net.named_parameters(): 48 | if name_to_update in name: 49 | param.requires_grad_(True) 50 | optimizer = get_optimizer(args, optim_modules) 51 | scheduler = None 52 | return net, optimizer, scheduler -------------------------------------------------------------------------------- /models/csflow_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from efficientnet_pytorch import EfficientNet 5 | import numpy as np 6 | from utils.freia_funcs import * 7 | import torch.hub 8 | 9 | 10 | 11 | class NetCSFlow(nn.Module): 12 | def __init__(self, args): 13 | super(NetCSFlow, self).__init__() 14 | self.args = args 15 | torch.hub.set_dir('./checkpoints') 16 | self.feature_extractor = EfficientNet.from_pretrained('efficientnet-b5') 17 | for param in self.feature_extractor.parameters(): 18 | param.requires_grad = False 19 | self.map_size = (self.args.dataset.image_size // 12, self.args.dataset.image_size // 12) 20 | self.kernel_sizes = [3] * (self.args.model.n_coupling_blocks - 1) + [5] 21 | self.density_estimator = self.get_cs_flow_model(input_dim=self.args.model.n_feat) 22 | 23 | def get_cs_flow_model(self, input_dim): 24 | nodes = list() 25 | nodes.append(InputNode(input_dim, self.map_size[0], self.map_size[1], name='input')) 26 | nodes.append(InputNode(input_dim, self.map_size[0] // 2, self.map_size[1] // 2, name='input2')) 27 | nodes.append(InputNode(input_dim, self.map_size[0] // 4, self.map_size[1] // 4, name='input3')) 28 | 29 | for k in range(self.args.model.n_coupling_blocks): 30 | if k == 0: 31 | node_to_permute = [nodes[-3].out0, nodes[-2].out0, nodes[-1].out0] 32 | else: 33 | node_to_permute = [nodes[-1].out0, nodes[-1].out1, nodes[-1].out2] 34 | 35 | nodes.append(Node(node_to_permute, ParallelPermute, {'seed': k}, name=F'permute_{k}')) 36 | nodes.append(Node([nodes[-1].out0, nodes[-1].out1, nodes[-1].out2], parallel_glow_coupling_layer, 37 | {'clamp': self.args.model.clamp, 'F_class': CrossConvolutions, 38 | 'F_args': {'channels_hidden': self.args.model.fc_internal, 39 | 'kernel_size': self.kernel_sizes[k], 'block_no': k}}, 40 | name=F'fc1_{k}')) 41 | 42 | nodes.append(OutputNode([nodes[-1].out0], name='output_end0')) 43 | nodes.append(OutputNode([nodes[-2].out1], name='output_end1')) 44 | nodes.append(OutputNode([nodes[-3].out2], name='output_end2')) 45 | nf = ReversibleGraphNet(nodes, n_jac=3) 46 | return nf 47 | 48 | def eff_ext(self, x, use_layer=36): 49 | x = self.feature_extractor._swish(self.feature_extractor._bn0(self.feature_extractor._conv_stem(x))) 50 | # Blocks 51 | for idx, block in enumerate(self.feature_extractor._blocks): 52 | drop_connect_rate = self.feature_extractor._global_params.drop_connect_rate 53 | if drop_connect_rate: 54 | drop_connect_rate *= float(idx) / len(self.feature_extractor._blocks) # scale drop connect_rate 55 | x = block(x, drop_connect_rate=drop_connect_rate) 56 | if idx == use_layer: 57 | return x 58 | 59 | def forward_features(self, x): 60 | y = list() 61 | for s in range(self.args.model.n_scales): 62 | feat_s = F.interpolate(x, size=( 63 | self.args.dataset.image_size // (2 ** s), self.args.dataset.image_size // (2 ** s))) if s > 0 else x 64 | feat_s = self.eff_ext(feat_s) 65 | y.append(feat_s) 66 | return y 67 | 68 | def forward_logits(self, y): 69 | z, log_jac_det = self.density_estimator(y), self.density_estimator.jacobian(run_forward=False) 70 | z = torch.cat([z[i].reshape(z[i].shape[0], -1) for i in range(len(z))], dim=1) 71 | log_jac_det = sum(log_jac_det) 72 | return z, log_jac_det 73 | 74 | def forward(self, x): 75 | y = self.forward_features(x) # y(16, 512, 24/12/6, 24/12/6) 76 | z, log_jac_det = self.forward_logits(y) # z(16, 229824) 77 | return y, z, log_jac_det -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision.models import resnet18 4 | 5 | 6 | class ResNetModel(nn.Module): 7 | def __init__(self, pretrained=True, num_classes=2): 8 | super(ResNetModel, self).__init__() 9 | # self.resnet18 = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=pretrained) 10 | self.backbone = resnet18(pretrained=pretrained) 11 | 12 | # create MPL head as seen in the code in: https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 13 | # TODO: check if this is really the right architecture 14 | last_layer = 512 15 | sequential_layers = [] 16 | head_layers = [512, 512, 128] 17 | for num_neurons in head_layers: 18 | sequential_layers.append(nn.Linear(last_layer, num_neurons)) 19 | sequential_layers.append(nn.BatchNorm1d(num_neurons)) 20 | sequential_layers.append(nn.ReLU(inplace=True)) 21 | last_layer = num_neurons 22 | 23 | # the last layer without activation 24 | head = nn.Sequential( 25 | *sequential_layers 26 | ) 27 | self.backbone.fc = nn.Identity() 28 | self.head = nn.Sequential( 29 | head, 30 | nn.Linear(last_layer, num_classes) 31 | ) 32 | # self.head = head 33 | # self.out = nn.Linear(last_layer, num_classes) 34 | 35 | def forward(self, x): 36 | embeds = self.backbone(x) 37 | # tmp = self.head(embeds) 38 | # logits = self.out(tmp) 39 | logits = self.head(embeds) 40 | return logits, embeds 41 | 42 | def forward_features(self, x): 43 | embeds = self.backbone(x) 44 | return embeds 45 | 46 | def freeze_resnet(self): 47 | # freez full resnet18 48 | for param in self.backbone.parameters(): 49 | param.requires_grad = False 50 | # unfreeze head: 51 | for param in self.backbone.fc.parameters(): 52 | param.requires_grad = True 53 | 54 | def unfreeze(self): 55 | # unfreeze all: 56 | for param in self.parameters(): 57 | param.requires_grad = True -------------------------------------------------------------------------------- /models/revdis_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from utils.rd_resnet import resnet18, resnet34, resnet50, wide_resnet50_2 6 | from utils.de_resnet import de_resnet18, de_resnet34, de_wide_resnet50_2, de_resnet50 7 | import torch.hub 8 | 9 | 10 | 11 | class NetRevDis(nn.Module): 12 | def __init__(self, args): 13 | super(NetRevDis, self).__init__() 14 | torch.hub.set_dir('./checkpoints') 15 | self.args = args 16 | self.encoder, self.bn = wide_resnet50_2(pretrained=True) 17 | for param in self.encoder.parameters(): 18 | param.requires_grad = False 19 | self.decoder = de_wide_resnet50_2(pretrained=False) 20 | 21 | def forward(self, imgs): 22 | inputs = self.encoder(imgs) 23 | outputs = self.decoder(self.bn(inputs)) 24 | return inputs, outputs -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 11 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | class Attention(nn.Module): 16 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 17 | super().__init__() 18 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 19 | self.num_heads = num_heads 20 | head_dim = dim // num_heads 21 | self.scale = head_dim ** -0.5 22 | 23 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 24 | self.attn_drop = nn.Dropout(attn_drop) 25 | self.proj = nn.Linear(dim, dim) 26 | self.proj_drop = nn.Dropout(proj_drop) 27 | 28 | def forward(self, x): 29 | B, N, C = x.shape 30 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 31 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 32 | 33 | attn = (q @ k.transpose(-2, -1)) * self.scale 34 | attn = attn.softmax(dim=-1) 35 | attn = self.attn_drop(attn) 36 | 37 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 38 | x = self.proj(x) 39 | x = self.proj_drop(x) 40 | return x 41 | 42 | 43 | class Block(nn.Module): 44 | 45 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 46 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 47 | super().__init__() 48 | self.norm1 = norm_layer(dim) 49 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 50 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 51 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 52 | self.norm2 = norm_layer(dim) 53 | mlp_hidden_dim = int(dim * mlp_ratio) 54 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 55 | 56 | def forward(self, x): 57 | x = x + self.drop_path(self.attn(self.norm1(x))) 58 | x = x + self.drop_path(self.mlp(self.norm2(x))) 59 | return x 60 | 61 | class prediction_MLP(nn.Module): 62 | def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): # bottleneck structure 63 | super().__init__() 64 | ''' page 3 baseline setting 65 | Prediction MLP. The prediction MLP (h) has BN applied 66 | to its hidden fc layers. Its output fc does not have BN 67 | (ablation in Sec. 4.4) or ReLU. This MLP has 2 layers. 68 | The dimension of h’s input and output (z and p) is d = 2048, 69 | and h’s hidden layer’s dimension is 512, making h a 70 | bottleneck structure (ablation in supplement). 71 | ''' 72 | self.layer1 = nn.Sequential( 73 | nn.Linear(in_dim, hidden_dim), 74 | nn.BatchNorm1d(hidden_dim), 75 | nn.ReLU(inplace=True) 76 | ) 77 | self.layer2 = nn.Linear(hidden_dim, out_dim) 78 | """ 79 | Adding BN to the output of the prediction MLP h does not work 80 | well (Table 3d). We find that this is not about collapsing. 81 | The training is unstable and the loss oscillates. 82 | """ 83 | 84 | def forward(self, x): 85 | x = self.layer1(x) 86 | x = self.layer2(x) 87 | return x 88 | 89 | class ViT(nn.Module): 90 | """ Vision Transformer 91 | 92 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 93 | - https://arxiv.org/abs/2010.11929 94 | 95 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 96 | - https://arxiv.org/abs/2012.12877 97 | """ 98 | 99 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=2, embed_dim=768, depth=12, 100 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 101 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 102 | act_layer=None, weight_init=''): 103 | """ 104 | Args: 105 | img_size (int, tuple): input image size 106 | patch_size (int, tuple): patch size 107 | in_chans (int): number of input channels 108 | num_classes (int): number of classes for classification head 109 | embed_dim (int): embedding dimension 110 | depth (int): depth of transformer 111 | num_heads (int): number of attention heads 112 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 113 | qkv_bias (bool): enable bias for qkv if True 114 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 115 | distilled (bool): model includes a distillation token and head as in DeiT models 116 | drop_rate (float): dropout rate 117 | attn_drop_rate (float): attention dropout rate 118 | drop_path_rate (float): stochastic depth rate 119 | embed_layer (nn.Module): patch embedding layer 120 | norm_layer: (nn.Module): normalization layer 121 | weight_init: (str): weight init scheme 122 | """ 123 | super().__init__() 124 | self.num_classes = num_classes 125 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 126 | self.num_tokens = 2 if distilled else 1 127 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 128 | act_layer = act_layer or nn.GELU 129 | 130 | self.patch_embed = embed_layer( 131 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 132 | for param in self.patch_embed.parameters(): 133 | param.requires_grad = False 134 | num_patches = self.patch_embed.num_patches 135 | 136 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 137 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 138 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 139 | self.pos_drop = nn.Dropout(p=drop_rate) 140 | 141 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 142 | self.blocks = nn.Sequential(*[ 143 | Block( 144 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 145 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 146 | for i in range(depth)]) 147 | self.norm = norm_layer(embed_dim) 148 | 149 | # Representation layer 150 | if representation_size and not distilled: 151 | self.num_features = representation_size 152 | self.pre_logits = nn.Sequential(OrderedDict([ 153 | ('fc', nn.Linear(embed_dim, representation_size)), 154 | ('act', nn.Tanh()) 155 | ])) 156 | else: 157 | self.pre_logits = nn.Identity() 158 | 159 | 160 | # self.z_head = nn.Sequential( 161 | # nn.Linear(self.num_features, 2048), 162 | # nn.BatchNorm1d(2048), 163 | # nn.ReLU(inplace=True), 164 | # nn.Linear(2048, 2048), 165 | # nn.BatchNorm1d(2048), 166 | # nn.ReLU(inplace=True), 167 | # nn.Linear(2048, self.num_features), 168 | # nn.BatchNorm1d(self.num_features), 169 | # nn.ReLU(inplace=True), 170 | # ) 171 | 172 | # Classifier head(s) 173 | # last_layer = self.num_features 174 | # sequential_layers = [] 175 | # head_layers = [2048] * 2 + [768] 176 | # for num_neurons in head_layers: 177 | # sequential_layers.append(nn.Linear(last_layer, num_neurons)) 178 | # sequential_layers.append(nn.BatchNorm1d(num_neurons)) 179 | # sequential_layers.append(nn.ReLU(inplace=True)) 180 | # last_layer = num_neurons 181 | # 182 | # # the last layer without activation 183 | # self.head = nn.Sequential( 184 | # *sequential_layers, 185 | # nn.Linear(last_layer, num_classes) 186 | # ) if num_classes > 0 else nn.Identity() 187 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 188 | # self.head = prediction_MLP(in_dim=self.num_features, out_dim=num_classes) if num_classes > 0 else nn.Identity() 189 | 190 | self.head_dist = None 191 | if distilled: 192 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 193 | 194 | self.init_weights(weight_init) 195 | 196 | def init_weights(self, mode=''): 197 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 198 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 199 | trunc_normal_(self.pos_embed, std=.02) 200 | if self.dist_token is not None: 201 | trunc_normal_(self.dist_token, std=.02) 202 | if mode.startswith('jax'): 203 | # leave cls token as zeros to match jax impl 204 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 205 | else: 206 | trunc_normal_(self.cls_token, std=.02) 207 | self.apply(_init_vit_weights) 208 | 209 | def _init_weights(self, m): 210 | # this fn left here for compat with downstream users 211 | _init_vit_weights(m) 212 | 213 | @torch.jit.ignore() 214 | def load_pretrained(self, checkpoint_path, prefix=''): 215 | _load_weights(self, checkpoint_path, prefix) 216 | 217 | @torch.jit.ignore 218 | def no_weight_decay(self): 219 | return {'pos_embed', 'cls_token', 'dist_token'} 220 | 221 | def get_classifier(self): 222 | if self.dist_token is None: 223 | return self.head 224 | else: 225 | return self.head, self.head_dist 226 | 227 | def reset_classifier(self, num_classes, global_pool=''): 228 | self.num_classes = num_classes 229 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 230 | if self.num_tokens == 2: 231 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 232 | 233 | def forward_features(self, x): 234 | x = self.patch_embed(x) 235 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 236 | if self.dist_token is None: 237 | x = torch.cat((cls_token, x), dim=1) 238 | else: 239 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 240 | x = self.pos_drop(x + self.pos_embed) 241 | x = self.blocks(x) 242 | x = self.norm(x) 243 | if self.dist_token is None: 244 | return self.pre_logits(x[:, 0]) 245 | else: 246 | return x[:, 0], x[:, 1] 247 | 248 | def forward(self, x): 249 | embed = self.forward_features(x) 250 | if self.head_dist is not None: 251 | x, x_dist = self.head(embed[0]), self.head_dist(embed[1]) # x must be a tuple 252 | if self.training and not torch.jit.is_scripting(): 253 | # during inference, return the average of both classifier predictions 254 | return x, x_dist, embed 255 | else: 256 | return (x + x_dist) / 2, embed 257 | else: 258 | # z = self.z_head(embed) 259 | x = self.head(embed) 260 | return x, embed 261 | 262 | 263 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 264 | """ ViT weight initialization 265 | * When called without n, head_bias, jax_impl args it will behave exactly the same 266 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 267 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 268 | """ 269 | if isinstance(module, nn.Linear): 270 | if name.startswith('head'): 271 | nn.init.zeros_(module.weight) 272 | nn.init.constant_(module.bias, head_bias) 273 | elif name.startswith('pre_logits'): 274 | lecun_normal_(module.weight) 275 | nn.init.zeros_(module.bias) 276 | else: 277 | if jax_impl: 278 | nn.init.xavier_uniform_(module.weight) 279 | if module.bias is not None: 280 | if 'mlp' in name: 281 | nn.init.normal_(module.bias, std=1e-6) 282 | else: 283 | nn.init.zeros_(module.bias) 284 | else: 285 | trunc_normal_(module.weight, std=.02) 286 | if module.bias is not None: 287 | nn.init.zeros_(module.bias) 288 | elif jax_impl and isinstance(module, nn.Conv2d): 289 | # NOTE conv was left to pytorch default in my original init 290 | lecun_normal_(module.weight) 291 | if module.bias is not None: 292 | nn.init.zeros_(module.bias) 293 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 294 | nn.init.zeros_(module.bias) 295 | nn.init.ones_(module.weight) 296 | 297 | 298 | @torch.no_grad() 299 | def _load_weights(model: ViT, checkpoint_path: str, prefix: str = ''): 300 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 301 | """ 302 | import numpy as np 303 | 304 | def _n2p(w, t=True): 305 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 306 | w = w.flatten() 307 | if t: 308 | if w.ndim == 4: 309 | w = w.transpose([3, 2, 0, 1]) 310 | elif w.ndim == 3: 311 | w = w.transpose([2, 0, 1]) 312 | elif w.ndim == 2: 313 | w = w.transpose([1, 0]) 314 | return torch.from_numpy(w) 315 | 316 | w = np.load(checkpoint_path) 317 | if not prefix and 'opt/target/embedding/kernel' in w: 318 | prefix = 'opt/target/' 319 | 320 | if hasattr(model.patch_embed, 'backbone'): 321 | # hybrid 322 | backbone = model.patch_embed.backbone 323 | stem_only = not hasattr(backbone, 'stem') 324 | stem = backbone if stem_only else backbone.stem 325 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 326 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 327 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 328 | if not stem_only: 329 | for i, stage in enumerate(backbone.stages): 330 | for j, block in enumerate(stage.blocks): 331 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 332 | for r in range(3): 333 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 334 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 335 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 336 | if block.downsample is not None: 337 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 338 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 339 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 340 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 341 | else: 342 | embed_conv_w = adapt_input_conv( 343 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 344 | model.patch_embed.proj.weight.copy_(embed_conv_w) 345 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 346 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 347 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 348 | if pos_embed_w.shape != model.pos_embed.shape: 349 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 350 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 351 | model.pos_embed.copy_(pos_embed_w) 352 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 353 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 354 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 355 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 356 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 357 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 358 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 359 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 360 | for i, block in enumerate(model.blocks.children()): 361 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 362 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 363 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 364 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 365 | block.attn.qkv.weight.copy_(torch.cat([ 366 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 367 | block.attn.qkv.bias.copy_(torch.cat([ 368 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 369 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 370 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 371 | for r in range(2): 372 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 373 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 374 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 375 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 376 | 377 | 378 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 379 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 380 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 381 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 382 | ntok_new = posemb_new.shape[1] 383 | if num_tokens: 384 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 385 | ntok_new -= num_tokens 386 | else: 387 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 388 | gs_old = int(math.sqrt(len(posemb_grid))) 389 | if not len(gs_new): # backwards compatibility 390 | gs_new = [int(math.sqrt(ntok_new))] * 2 391 | assert len(gs_new) >= 2 392 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 393 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 394 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 395 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 396 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 397 | return posemb 398 | 399 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | efficientnet_pytorch==0.7.1 2 | einops==0.3.0 3 | imgaug==0.4.0 4 | joblib==0.14.1 5 | matplotlib==3.1.3 6 | numpy==1.18.1 7 | opencv_python==4.2.0.32 8 | pandas==1.1.4 9 | Pillow==9.1.0 10 | PyYAML==6.0 11 | scikit_learn==1.0.2 12 | scipy==1.4.1 13 | timm==0.5.4 14 | torch==1.4.0 15 | torchvision==0.5.0 16 | tqdm==4.61.1 17 | 18 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vijaylee/Continual_Anomaly_Detection/7be4ea5b4087f292ddc0e332e9e9e0fd6eca0c67/utils/__init__.py -------------------------------------------------------------------------------- /utils/buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from typing import Tuple 9 | from torchvision import transforms 10 | 11 | 12 | def reservoir(num_seen_examples: int, buffer_size: int) -> int: 13 | """ 14 | Reservoir sampling algorithm. 15 | :param num_seen_examples: the number of seen examples 16 | :param buffer_size: the maximum buffer size 17 | :return: the target index if the current image is sampled, else -1 18 | """ 19 | if num_seen_examples < buffer_size: 20 | return num_seen_examples 21 | 22 | rand = np.random.randint(0, num_seen_examples + 1) 23 | if rand < buffer_size: 24 | return rand 25 | else: 26 | return -1 27 | 28 | 29 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: 30 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size 31 | 32 | 33 | class Buffer: 34 | """ 35 | The memory buffer of rehearsal method. 36 | """ 37 | def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir'): 38 | assert mode in ['ring', 'reservoir'] 39 | self.buffer_size = buffer_size 40 | self.device = device 41 | self.num_seen_examples = 0 42 | self.functional_index = eval(mode) 43 | if mode == 'ring': 44 | assert n_tasks is not None 45 | self.task_number = n_tasks 46 | self.buffer_portion_size = buffer_size // n_tasks 47 | self.attributes = ['examples', 'labels', 'logits', 'task_labels'] 48 | 49 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, 50 | logits: torch.Tensor, task_labels: torch.Tensor) -> None: 51 | """ 52 | Initializes just the required tensors. 53 | :param examples: tensor containing the images 54 | :param labels: tensor containing the labels 55 | :param logits: tensor containing the outputs of the network 56 | :param task_labels: tensor containing the task labels 57 | """ 58 | for attr_str in self.attributes: 59 | attr = eval(attr_str) 60 | if attr is not None and not hasattr(self, attr_str): 61 | typ = torch.int64 if attr_str.endswith('els') else torch.float32 62 | setattr(self, attr_str, torch.zeros((self.buffer_size, 63 | *attr.shape[1:]), dtype=typ, device=self.device)) 64 | 65 | def add_data(self, examples, labels=None, logits=None, task_labels=None): 66 | """ 67 | Adds the data to the memory buffer according to the reservoir strategy. 68 | :param examples: tensor containing the images 69 | :param labels: tensor containing the labels 70 | :param logits: tensor containing the outputs of the network 71 | :param task_labels: tensor containing the task labels 72 | :return: 73 | """ 74 | if not hasattr(self, 'examples'): 75 | self.init_tensors(examples, labels, logits, task_labels) 76 | 77 | for i in range(examples.shape[0]): 78 | index = reservoir(self.num_seen_examples, self.buffer_size) 79 | self.num_seen_examples += 1 80 | if index >= 0: 81 | self.examples[index] = examples[i].to(self.device) 82 | if labels is not None: 83 | self.labels[index] = labels[i].to(self.device) 84 | if logits is not None: 85 | self.logits[index] = logits[i].to(self.device) 86 | if task_labels is not None: 87 | self.task_labels[index] = task_labels[i].to(self.device) 88 | 89 | def get_data(self, size: int, transform: transforms=None) -> Tuple: 90 | """ 91 | Random samples a batch of size items. 92 | :param size: the number of requested items 93 | :param transform: the transformation to be applied (data augmentation) 94 | :return: 95 | """ 96 | if size > min(self.num_seen_examples, self.examples.shape[0]): 97 | size = min(self.num_seen_examples, self.examples.shape[0]) 98 | 99 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]), 100 | size=size, replace=False) 101 | if transform is None: transform = lambda x: x 102 | ret_tuple = (torch.stack([transform(ee.cpu()) 103 | for ee in self.examples[choice]]).to(self.device),) 104 | for attr_str in self.attributes[1:]: 105 | if hasattr(self, attr_str): 106 | attr = getattr(self, attr_str) 107 | ret_tuple += (attr[choice],) 108 | 109 | return ret_tuple 110 | 111 | def is_empty(self) -> bool: 112 | """ 113 | Returns true if the buffer is empty, false otherwise. 114 | """ 115 | if self.num_seen_examples == 0: 116 | return True 117 | else: 118 | return False 119 | 120 | def get_all_data(self, transform: transforms=None) -> Tuple: 121 | """ 122 | Return all the items in the memory buffer. 123 | :param transform: the transformation to be applied (data augmentation) 124 | :return: a tuple with all the items in the memory buffer 125 | """ 126 | if transform is None: transform = lambda x: x 127 | ret_tuple = (torch.stack([transform(ee.cpu()) 128 | for ee in self.examples]).to(self.device),) 129 | for attr_str in self.attributes[1:]: 130 | if hasattr(self, attr_str): 131 | attr = getattr(self, attr_str) 132 | ret_tuple += (attr,) 133 | return ret_tuple 134 | 135 | def empty(self) -> None: 136 | """ 137 | Set all the tensors to None. 138 | """ 139 | for attr_str in self.attributes: 140 | if hasattr(self, attr_str): 141 | delattr(self, attr_str) 142 | self.num_seen_examples = 0 143 | 144 | 145 | -------------------------------------------------------------------------------- /utils/de_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | from typing import Type, Any, Callable, Union, List, Optional 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 13 | 'wide_resnet50_2', 'wide_resnet101_2'] 14 | 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 22 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 23 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 24 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 25 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 26 | } 27 | 28 | 29 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 30 | """3x3 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=dilation, groups=groups, bias=False, dilation=dilation) 33 | 34 | 35 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 36 | """1x1 convolution""" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 38 | 39 | def deconv2x2(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 40 | """1x1 convolution""" 41 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=stride, 42 | groups=groups, bias=False, dilation=dilation) 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | expansion: int = 1 47 | 48 | def __init__( 49 | self, 50 | inplanes: int, 51 | planes: int, 52 | stride: int = 1, 53 | upsample: Optional[nn.Module] = None, 54 | groups: int = 1, 55 | base_width: int = 64, 56 | dilation: int = 1, 57 | norm_layer: Optional[Callable[..., nn.Module]] = None 58 | ) -> None: 59 | super(BasicBlock, self).__init__() 60 | if norm_layer is None: 61 | norm_layer = nn.BatchNorm2d 62 | if groups != 1 or base_width != 64: 63 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 64 | if dilation > 1: 65 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 66 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 67 | if stride == 2: 68 | self.conv1 = deconv2x2(inplanes, planes, stride) 69 | else: 70 | self.conv1 = conv3x3(inplanes, planes, stride) 71 | self.bn1 = norm_layer(planes) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.conv2 = conv3x3(planes, planes) 74 | self.bn2 = norm_layer(planes) 75 | self.upsample = upsample 76 | self.stride = stride 77 | 78 | def forward(self, x: Tensor) -> Tensor: 79 | identity = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | 88 | if self.upsample is not None: 89 | identity = self.upsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class Bottleneck(nn.Module): 98 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 99 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 100 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 101 | # This variant is also known as ResNet V1.5 and improves accuracy according to 102 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 103 | 104 | expansion: int = 4 105 | 106 | def __init__( 107 | self, 108 | inplanes: int, 109 | planes: int, 110 | stride: int = 1, 111 | upsample: Optional[nn.Module] = None, 112 | groups: int = 1, 113 | base_width: int = 64, 114 | dilation: int = 1, 115 | norm_layer: Optional[Callable[..., nn.Module]] = None 116 | ) -> None: 117 | super(Bottleneck, self).__init__() 118 | if norm_layer is None: 119 | norm_layer = nn.BatchNorm2d 120 | width = int(planes * (base_width / 64.)) * groups 121 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 122 | self.conv1 = conv1x1(inplanes, width) 123 | self.bn1 = norm_layer(width) 124 | if stride == 2: 125 | self.conv2 = deconv2x2(width, width, stride, groups, dilation) 126 | else: 127 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 128 | self.bn2 = norm_layer(width) 129 | self.conv3 = conv1x1(width, planes * self.expansion) 130 | self.bn3 = norm_layer(planes * self.expansion) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.upsample = upsample 133 | self.stride = stride 134 | 135 | def forward(self, x: Tensor) -> Tensor: 136 | identity = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | out = self.relu(out) 145 | 146 | out = self.conv3(out) 147 | out = self.bn3(out) 148 | 149 | if self.upsample is not None: 150 | identity = self.upsample(x) 151 | 152 | out += identity 153 | out = self.relu(out) 154 | 155 | return out 156 | 157 | 158 | class ResNet(nn.Module): 159 | 160 | def __init__( 161 | self, 162 | block: Type[Union[BasicBlock, Bottleneck]], 163 | layers: List[int], 164 | num_classes: int = 1000, 165 | zero_init_residual: bool = False, 166 | groups: int = 1, 167 | width_per_group: int = 64, 168 | replace_stride_with_dilation: Optional[List[bool]] = None, 169 | norm_layer: Optional[Callable[..., nn.Module]] = None 170 | ) -> None: 171 | super(ResNet, self).__init__() 172 | if norm_layer is None: 173 | norm_layer = nn.BatchNorm2d 174 | self._norm_layer = norm_layer 175 | 176 | self.inplanes = 512 * block.expansion 177 | self.dilation = 1 178 | if replace_stride_with_dilation is None: 179 | # each element in the tuple indicates if we should replace 180 | # the 2x2 stride with a dilated convolution instead 181 | replace_stride_with_dilation = [False, False, False] 182 | if len(replace_stride_with_dilation) != 3: 183 | raise ValueError("replace_stride_with_dilation should be None " 184 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 185 | self.groups = groups 186 | self.base_width = width_per_group 187 | #self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 188 | # bias=False) 189 | #self.bn1 = norm_layer(self.inplanes) 190 | #self.relu = nn.ReLU(inplace=True) 191 | #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 192 | self.layer1 = self._make_layer(block, 256, layers[0], stride=2) 193 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 194 | dilate=replace_stride_with_dilation[0]) 195 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, 196 | dilate=replace_stride_with_dilation[1]) 197 | #self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 198 | # dilate=replace_stride_with_dilation[2]) 199 | #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 200 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 201 | 202 | for m in self.modules(): 203 | if isinstance(m, nn.Conv2d): 204 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 205 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 206 | nn.init.constant_(m.weight, 1) 207 | nn.init.constant_(m.bias, 0) 208 | 209 | # Zero-initialize the last BN in each residual branch, 210 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 211 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 212 | if zero_init_residual: 213 | for m in self.modules(): 214 | if isinstance(m, Bottleneck): 215 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 216 | elif isinstance(m, BasicBlock): 217 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 218 | 219 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 220 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 221 | norm_layer = self._norm_layer 222 | upsample = None 223 | previous_dilation = self.dilation 224 | if dilate: 225 | self.dilation *= stride 226 | stride = 1 227 | if stride != 1 or self.inplanes != planes * block.expansion: 228 | upsample = nn.Sequential( 229 | deconv2x2(self.inplanes, planes * block.expansion, stride), 230 | norm_layer(planes * block.expansion), 231 | ) 232 | 233 | layers = [] 234 | layers.append(block(self.inplanes, planes, stride, upsample, self.groups, 235 | self.base_width, previous_dilation, norm_layer)) 236 | self.inplanes = planes * block.expansion 237 | for _ in range(1, blocks): 238 | layers.append(block(self.inplanes, planes, groups=self.groups, 239 | base_width=self.base_width, dilation=self.dilation, 240 | norm_layer=norm_layer)) 241 | 242 | return nn.Sequential(*layers) 243 | 244 | def _forward_impl(self, x: Tensor) -> Tensor: 245 | # See note [TorchScript super()] 246 | #x = self.conv1(x) 247 | #x = self.bn1(x) 248 | #x = self.relu(x) 249 | #x = self.maxpool(x) 250 | 251 | feature_a = self.layer1(x) # 512*8*8->256*16*16 252 | feature_b = self.layer2(feature_a) # 256*16*16->128*32*32 253 | feature_c = self.layer3(feature_b) # 128*32*32->64*64*64 254 | #feature_d = self.layer4(feature_c) # 64*64*64->128*32*32 255 | 256 | #x = self.avgpool(feature_d) 257 | #x = torch.flatten(x, 1) 258 | #x = self.fc(x) 259 | 260 | return [feature_c, feature_b, feature_a] 261 | 262 | def forward(self, x: Tensor) -> Tensor: 263 | return self._forward_impl(x) 264 | 265 | 266 | def _resnet( 267 | arch: str, 268 | block: Type[Union[BasicBlock, Bottleneck]], 269 | layers: List[int], 270 | pretrained: bool, 271 | progress: bool, 272 | **kwargs: Any 273 | ) -> ResNet: 274 | model = ResNet(block, layers, **kwargs) 275 | if pretrained: 276 | state_dict = load_state_dict_from_url(model_urls[arch], 277 | progress=progress) 278 | #for k,v in list(state_dict.items()): 279 | # if 'layer4' in k or 'fc' in k: 280 | # state_dict.pop(k) 281 | model.load_state_dict(state_dict) 282 | return model 283 | 284 | 285 | def de_resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 286 | r"""ResNet-18 model from 287 | `"Deep Residual Learning for Image Recognition" `_. 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 293 | **kwargs) 294 | 295 | 296 | def de_resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 297 | r"""ResNet-34 model from 298 | `"Deep Residual Learning for Image Recognition" `_. 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | """ 303 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 304 | **kwargs) 305 | 306 | 307 | def de_resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 308 | r"""ResNet-50 model from 309 | `"Deep Residual Learning for Image Recognition" `_. 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 315 | **kwargs) 316 | 317 | 318 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 319 | r"""ResNet-101 model from 320 | `"Deep Residual Learning for Image Recognition" `_. 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | progress (bool): If True, displays a progress bar of the download to stderr 324 | """ 325 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 326 | **kwargs) 327 | 328 | 329 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 330 | r"""ResNet-152 model from 331 | `"Deep Residual Learning for Image Recognition" `_. 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 337 | **kwargs) 338 | 339 | 340 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 341 | r"""ResNeXt-50 32x4d model from 342 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 343 | Args: 344 | pretrained (bool): If True, returns a model pre-trained on ImageNet 345 | progress (bool): If True, displays a progress bar of the download to stderr 346 | """ 347 | kwargs['groups'] = 32 348 | kwargs['width_per_group'] = 4 349 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 350 | pretrained, progress, **kwargs) 351 | 352 | 353 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 354 | r"""ResNeXt-101 32x8d model from 355 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 356 | Args: 357 | pretrained (bool): If True, returns a model pre-trained on ImageNet 358 | progress (bool): If True, displays a progress bar of the download to stderr 359 | """ 360 | kwargs['groups'] = 32 361 | kwargs['width_per_group'] = 8 362 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 363 | pretrained, progress, **kwargs) 364 | 365 | 366 | def de_wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 367 | r"""Wide ResNet-50-2 model from 368 | `"Wide Residual Networks" `_. 369 | The model is the same as ResNet except for the bottleneck number of channels 370 | which is twice larger in every block. The number of channels in outer 1x1 371 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 372 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 373 | Args: 374 | pretrained (bool): If True, returns a model pre-trained on ImageNet 375 | progress (bool): If True, displays a progress bar of the download to stderr 376 | """ 377 | kwargs['width_per_group'] = 64 * 2 378 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 379 | pretrained, progress, **kwargs) 380 | 381 | 382 | def de_wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 383 | r"""Wide ResNet-101-2 model from 384 | `"Wide Residual Networks" `_. 385 | The model is the same as ResNet except for the bottleneck number of channels 386 | which is twice larger in every block. The number of channels in outer 1x1 387 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 388 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 389 | Args: 390 | pretrained (bool): If True, returns a model pre-trained on ImageNet 391 | progress (bool): If True, displays a progress bar of the download to stderr 392 | """ 393 | kwargs['width_per_group'] = 64 * 2 394 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 395 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /utils/density.py: -------------------------------------------------------------------------------- 1 | from sklearn.covariance import LedoitWolf, empirical_covariance 2 | from sklearn.neighbors import KernelDensity 3 | from sklearn.mixture import GaussianMixture 4 | import torch 5 | 6 | 7 | class Density(object): 8 | def fit(self, embeddings): 9 | raise NotImplementedError 10 | 11 | def predict(self, embeddings): 12 | raise NotImplementedError 13 | 14 | 15 | class GaussianDensityTorch(object): 16 | """Gaussian Density estimation similar to the implementation used by Ripple et al. 17 | The code of Ripple et al. can be found here: https://github.com/ORippler/gaussian-ad-mvtec. 18 | """ 19 | 20 | def fit(self, embeddings): 21 | self.mean = torch.mean(embeddings, dim=0) 22 | # self.cov = torch.Tensor(empirical_covariance(embeddings - self.mean), device="cpu") 23 | self.cov = torch.Tensor(LedoitWolf().fit(embeddings.cpu()).covariance_, device="cpu") 24 | self.inv_cov = torch.Tensor(LedoitWolf().fit(embeddings.cpu()).precision_, device="cpu") 25 | return self.mean, self.cov 26 | 27 | def predict(self, embeddings): 28 | distances = self.mahalanobis_distance(embeddings, self.mean, self.inv_cov) 29 | return distances 30 | 31 | @staticmethod 32 | def mahalanobis_distance( 33 | values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor 34 | ) -> torch.Tensor: 35 | """Compute the batched mahalanobis distance. 36 | values is a batch of feature vectors. 37 | mean is either the mean of the distribution to compare, or a second 38 | batch of feature vectors. 39 | inv_covariance is the inverse covariance of the target distribution. 40 | 41 | from https://github.com/ORippler/gaussian-ad-mvtec/blob/4e85fb5224eee13e8643b684c8ef15ab7d5d016e/src/gaussian/model.py#L308 42 | """ 43 | assert values.dim() == 2 44 | assert 1 <= mean.dim() <= 2 45 | assert len(inv_covariance.shape) == 2 46 | assert values.shape[1] == mean.shape[-1] 47 | assert mean.shape[-1] == inv_covariance.shape[0] 48 | assert inv_covariance.shape[0] == inv_covariance.shape[1] 49 | 50 | if mean.dim() == 1: # Distribution mean. 51 | mean = mean.unsqueeze(0) 52 | x_mu = values - mean # batch x features 53 | # Same as dist = x_mu.t() * inv_covariance * x_mu batch wise 54 | dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu) 55 | return dist.sqrt() 56 | 57 | 58 | class GaussianDensitySklearn(): 59 | """Li et al. use sklearn for density estimation. 60 | This implementation uses sklearn KernelDensity module for fitting and predicting. 61 | """ 62 | 63 | def fit(self, embeddings): 64 | # estimate KDE parameters 65 | # use grid search cross-validation to optimize the bandwidth 66 | self.kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(embeddings) 67 | 68 | def predict(self, embeddings): 69 | scores = self.kde.score_samples(embeddings) 70 | 71 | # invert scores, so they fit to the class labels for the auc calculation 72 | scores = -scores 73 | 74 | return scores 75 | -------------------------------------------------------------------------------- /utils/freia_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | VERBOSE = False 8 | 9 | class dummy_data: 10 | def __init__(self, *dims): 11 | self.dims = dims 12 | 13 | @property 14 | def shape(self): 15 | return self.dims 16 | 17 | class ParallelPermute(nn.Module): 18 | '''permutes input vector in a random but fixed way''' 19 | 20 | def __init__(self, dims_in, seed): 21 | super(ParallelPermute, self).__init__() 22 | # print('dims in', dims_in) 23 | # exit() 24 | self.n_inputs = len(dims_in) 25 | self.in_channels = [dims_in[i][0] for i in range(self.n_inputs)] 26 | 27 | np.random.seed(seed) 28 | perm, perm_inv = self.get_random_perm(0) 29 | self.perm = [perm] 30 | self.perm_inv = [perm_inv] 31 | 32 | for i in range(1, self.n_inputs): 33 | perm, perm_inv = self.get_random_perm(i) 34 | self.perm.append(perm) 35 | self.perm_inv.append(perm_inv) 36 | 37 | def get_random_perm(self, i): 38 | perm = np.random.permutation(self.in_channels[i]) 39 | perm_inv = np.zeros_like(perm) 40 | for i, p in enumerate(perm): 41 | perm_inv[p] = i 42 | 43 | perm = torch.LongTensor(perm) 44 | perm_inv = torch.LongTensor(perm_inv) 45 | return perm, perm_inv 46 | 47 | def forward(self, x, rev=False): 48 | if not rev: 49 | return [x[i][:, self.perm[i]] for i in range(self.n_inputs)] 50 | else: 51 | return [x[i][:, self.perm_inv[i]] for i in range(self.n_inputs)] 52 | 53 | def jacobian(self, x, rev=False): 54 | # TODO: use batch size, set as nn.Parameter so cuda() works 55 | return [0.] * self.n_inputs 56 | 57 | def output_dims(self, input_dims): 58 | return input_dims 59 | 60 | class CrossConvolutions(nn.Module): 61 | '''ResNet transformation, not itself reversible, just used below''' 62 | 63 | def __init__(self, in_channels, channels, channels_hidden=512, 64 | stride=None, kernel_size=3, last_kernel_size=1, leaky_slope=0.1, 65 | batch_norm=False, block_no=0): 66 | super(CrossConvolutions, self).__init__() 67 | if stride: 68 | warnings.warn("Stride doesn't do anything, the argument should be " 69 | "removed", DeprecationWarning) 70 | if not channels_hidden: 71 | channels_hidden = channels 72 | 73 | pad = kernel_size // 2 74 | self.leaky_slope = leaky_slope 75 | pad_mode = 'zeros' 76 | 77 | self.gamma0 = nn.Parameter(torch.zeros(1)) 78 | self.gamma1 = nn.Parameter(torch.zeros(1)) 79 | self.gamma2 = nn.Parameter(torch.zeros(1)) 80 | 81 | self.conv_scale0_0 = nn.Conv2d(in_channels, channels_hidden, 82 | kernel_size=kernel_size, padding=pad, 83 | bias=not batch_norm, padding_mode=pad_mode) 84 | 85 | self.conv_scale1_0 = nn.Conv2d(in_channels, channels_hidden, 86 | kernel_size=kernel_size, padding=pad, 87 | bias=not batch_norm, padding_mode=pad_mode) 88 | self.conv_scale2_0 = nn.Conv2d(in_channels, channels_hidden, 89 | kernel_size=kernel_size, padding=pad, 90 | bias=not batch_norm, padding_mode=pad_mode) 91 | self.conv_scale0_1 = nn.Conv2d(channels_hidden * 1, channels, # 92 | kernel_size=kernel_size, padding=pad, 93 | bias=not batch_norm, padding_mode=pad_mode, dilation=1) 94 | self.conv_scale1_1 = nn.Conv2d(channels_hidden * 1, channels, # 95 | kernel_size=kernel_size, padding=pad * 1, 96 | bias=not batch_norm, padding_mode=pad_mode, dilation=1) 97 | self.conv_scale2_1 = nn.Conv2d(channels_hidden * 1, channels, # 98 | kernel_size=kernel_size, padding=pad, 99 | bias=not batch_norm, padding_mode=pad_mode) 100 | 101 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 102 | 103 | self.up_conv10 = nn.Conv2d(channels_hidden, channels, 104 | kernel_size=kernel_size, padding=pad, bias=True, padding_mode=pad_mode) 105 | 106 | self.up_conv21 = nn.Conv2d(channels_hidden, channels, 107 | kernel_size=kernel_size, padding=pad, bias=True, padding_mode=pad_mode) 108 | 109 | self.down_conv01 = nn.Conv2d(channels_hidden, channels, 110 | kernel_size=kernel_size, padding=pad, 111 | bias=not batch_norm, stride=2, padding_mode=pad_mode, dilation=1) 112 | 113 | self.down_conv12 = nn.Conv2d(channels_hidden, channels, 114 | kernel_size=kernel_size, padding=pad, 115 | bias=not batch_norm, stride=2, padding_mode=pad_mode, dilation=1) 116 | 117 | self.lr = nn.LeakyReLU(self.leaky_slope) 118 | 119 | def forward(self, x0, x1, x2): 120 | out0 = self.conv_scale0_0(x0) 121 | out1 = self.conv_scale1_0(x1) 122 | out2 = self.conv_scale2_0(x2) 123 | 124 | y0 = self.lr(out0) 125 | y1 = self.lr(out1) 126 | y2 = self.lr(out2) 127 | 128 | out0 = self.conv_scale0_1(y0) 129 | out1 = self.conv_scale1_1(y1) 130 | out2 = self.conv_scale2_1(y2) 131 | 132 | y1_up = self.up_conv10(self.upsample(y1)) 133 | y2_up = self.up_conv21(self.upsample(y2)) 134 | 135 | y0_down = self.down_conv01(y0) 136 | y1_down = self.down_conv12(y1) 137 | 138 | out0 = out0 + y1_up 139 | out1 = out1 + y0_down + y2_up 140 | out2 = out2 + y1_down 141 | 142 | out0 = out0 * self.gamma0 143 | out1 = out1 * self.gamma1 144 | out2 = out2 * self.gamma2 145 | return out0, out1, out2 146 | 147 | class parallel_glow_coupling_layer(nn.Module): 148 | def __init__(self, dims_in, F_class=CrossConvolutions, F_args={}, 149 | clamp=5.): 150 | super(parallel_glow_coupling_layer, self).__init__() 151 | channels = dims_in[0][0] 152 | self.ndims = len(dims_in[0]) 153 | 154 | self.split_len1 = channels // 2 155 | self.split_len2 = channels - channels // 2 156 | 157 | self.clamp = clamp 158 | 159 | self.max_s = exp(clamp) 160 | self.min_s = exp(-clamp) 161 | 162 | self.s1 = F_class(self.split_len1, self.split_len2 * 2, **F_args) 163 | self.s2 = F_class(self.split_len2, self.split_len1 * 2, **F_args) 164 | 165 | def e(self, s): 166 | if self.clamp > 0: 167 | return torch.exp(self.log_e(s)) 168 | else: 169 | return torch.exp(s) 170 | 171 | def log_e(self, s): 172 | if self.clamp > 0: 173 | return self.clamp * 0.636 * torch.atan(s / self.clamp) 174 | else: 175 | return s 176 | 177 | def forward(self, x, rev=False): 178 | x01, x02 = (x[0].narrow(1, 0, self.split_len1), 179 | x[0].narrow(1, self.split_len1, self.split_len2)) 180 | x11, x12 = (x[1].narrow(1, 0, self.split_len1), 181 | x[1].narrow(1, self.split_len1, self.split_len2)) 182 | x21, x22 = (x[2].narrow(1, 0, self.split_len1), 183 | x[2].narrow(1, self.split_len1, self.split_len2)) 184 | 185 | if not rev: 186 | r02, r12, r22 = self.s2(x02, x12, x22) 187 | 188 | s02, t02 = r02[:, :self.split_len1], r02[:, self.split_len1:] 189 | s12, t12 = r12[:, :self.split_len1], r12[:, self.split_len1:] 190 | s22, t22 = r22[:, :self.split_len1], r22[:, self.split_len1:] 191 | 192 | y01 = self.e(s02) * x01 + t02 193 | y11 = self.e(s12) * x11 + t12 194 | y21 = self.e(s22) * x21 + t22 195 | 196 | r01, r11, r21 = self.s1(y01, y11, y21) 197 | 198 | s01, t01 = r01[:, :self.split_len2], r01[:, self.split_len2:] 199 | s11, t11 = r11[:, :self.split_len2], r11[:, self.split_len2:] 200 | s21, t21 = r21[:, :self.split_len2], r21[:, self.split_len2:] 201 | y02 = self.e(s01) * x02 + t01 202 | y12 = self.e(s11) * x12 + t11 203 | y22 = self.e(s21) * x22 + t21 204 | 205 | else: # names of x and y are swapped! 206 | r01, r11, r21 = self.s1(x01, x11, x21) 207 | 208 | s01, t01 = r01[:, :self.split_len2], r01[:, self.split_len2:] 209 | s11, t11 = r11[:, :self.split_len2], r11[:, self.split_len2:] 210 | s21, t21 = r21[:, :self.split_len2], r21[:, self.split_len2:] 211 | 212 | y02 = (x02 - t01) / self.e(s01) 213 | y12 = (x12 - t11) / self.e(s11) 214 | y22 = (x22 - t21) / self.e(s21) 215 | 216 | r02, r12, r22 = self.s2(y02, y12, y22) 217 | 218 | s02, t02 = r02[:, :self.split_len2], r01[:, self.split_len2:] 219 | s12, t12 = r12[:, :self.split_len2], r11[:, self.split_len2:] 220 | s22, t22 = r22[:, :self.split_len2], r21[:, self.split_len2:] 221 | 222 | y01 = (x01 - t02) / self.e(s02) 223 | y11 = (x11 - t12) / self.e(s12) 224 | y21 = (x21 - t22) / self.e(s22) 225 | 226 | y0 = torch.cat((y01, y02), 1) 227 | y1 = torch.cat((y11, y12), 1) 228 | y2 = torch.cat((y21, y22), 1) 229 | 230 | y0 = torch.clamp(y0, -1e6, 1e6) 231 | y1 = torch.clamp(y1, -1e6, 1e6) 232 | y2 = torch.clamp(y2, -1e6, 1e6) 233 | 234 | jac0 = torch.sum(self.log_e(s01), dim=(1, 2, 3)) + torch.sum(self.log_e(s02), dim=(1, 2, 3)) 235 | jac1 = torch.sum(self.log_e(s11), dim=(1, 2, 3)) + torch.sum(self.log_e(s12), dim=(1, 2, 3)) 236 | jac2 = torch.sum(self.log_e(s21), dim=(1, 2, 3)) + torch.sum(self.log_e(s22), dim=(1, 2, 3)) 237 | self.last_jac = [jac0, jac1, jac2] 238 | 239 | return [y0, y1, y2] 240 | 241 | def jacobian(self, x, rev=False): 242 | return self.last_jac 243 | 244 | def output_dims(self, input_dims): 245 | return input_dims 246 | 247 | class Node: 248 | '''The Node class represents one transformation in the graph, with an 249 | arbitrary number of in- and outputs.''' 250 | 251 | def __init__(self, inputs, module_type, module_args, name=None): 252 | self.inputs = inputs 253 | self.outputs = [] 254 | self.module_type = module_type 255 | self.module_args = module_args 256 | 257 | self.input_dims, self.module = None, None 258 | self.computed = None 259 | self.computed_rev = None 260 | self.id = None 261 | 262 | if name: 263 | self.name = name 264 | else: 265 | self.name = hex(id(self))[-6:] 266 | for i in range(255): 267 | exec('self.out{0} = (self, {0})'.format(i)) 268 | 269 | def build_modules(self, verbose=VERBOSE): 270 | ''' Returns a list with the dimension of each output of this node, 271 | recursively calling build_modules of the nodes connected to the input. 272 | Use this information to initialize the pytorch nn.Module of this node. 273 | ''' 274 | 275 | if not self.input_dims: # Only do it if this hasn't been computed yet 276 | self.input_dims = [n.build_modules(verbose=verbose)[c] 277 | for n, c in self.inputs] 278 | try: 279 | self.module = self.module_type(self.input_dims, 280 | **self.module_args) 281 | except Exception as e: 282 | print('Error in node %s' % (self.name)) 283 | raise e 284 | 285 | if verbose: 286 | print("Node %s has following input dimensions:" % (self.name)) 287 | for d, (n, c) in zip(self.input_dims, self.inputs): 288 | print("\t Output #%i of node %s:" % (c, n.name), d) 289 | print() 290 | 291 | self.output_dims = self.module.output_dims(self.input_dims) 292 | self.n_outputs = len(self.output_dims) 293 | 294 | return self.output_dims 295 | 296 | def run_forward(self, op_list): 297 | '''Determine the order of operations needed to reach this node. Calls 298 | run_forward of parent nodes recursively. Each operation is appended to 299 | the global list op_list, in the form (node ID, input variable IDs, 300 | output variable IDs)''' 301 | 302 | if not self.computed: 303 | 304 | # Compute all nodes which provide inputs, filter out the 305 | # channels you need 306 | self.input_vars = [] 307 | for i, (n, c) in enumerate(self.inputs): 308 | self.input_vars.append(n.run_forward(op_list)[c]) 309 | # Register youself as an output in the input node 310 | n.outputs.append((self, i)) 311 | 312 | # All outputs could now be computed 313 | self.computed = [(self.id, i) for i in range(self.n_outputs)] 314 | op_list.append((self.id, self.input_vars, self.computed)) 315 | 316 | # Return the variables you have computed (this happens mulitple times 317 | # without recomputing if called repeatedly) 318 | return self.computed 319 | 320 | def run_backward(self, op_list): 321 | '''See run_forward, this is the same, only for the reverse computation. 322 | Need to call run_forward first, otherwise this function will not 323 | work''' 324 | 325 | assert len(self.outputs) > 0, "Call run_forward first" 326 | if not self.computed_rev: 327 | 328 | # These are the input variables that must be computed first 329 | output_vars = [(self.id, i) for i in range(self.n_outputs)] 330 | 331 | # Recursively compute these 332 | for n, c in self.outputs: 333 | n.run_backward(op_list) 334 | 335 | # The variables that this node computes are the input variables 336 | # from the forward pass 337 | self.computed_rev = self.input_vars 338 | op_list.append((self.id, output_vars, self.computed_rev)) 339 | 340 | return self.computed_rev 341 | 342 | class InputNode(Node): 343 | '''Special type of node that represents the input data of the whole net (or 344 | ouput when running reverse)''' 345 | 346 | def __init__(self, *dims, name='node'): 347 | self.name = name 348 | self.data = dummy_data(*dims) 349 | self.outputs = [] 350 | self.module = None 351 | self.computed_rev = None 352 | self.n_outputs = 1 353 | self.input_vars = [] 354 | self.out0 = (self, 0) 355 | 356 | def build_modules(self, verbose=VERBOSE): 357 | return [self.data.shape] 358 | 359 | def run_forward(self, op_list): 360 | return [(self.id, 0)] 361 | 362 | class OutputNode(Node): 363 | '''Special type of node that represents the output of the whole net (of the 364 | input when running in reverse)''' 365 | 366 | class dummy(nn.Module): 367 | 368 | def __init__(self, *args): 369 | super(OutputNode.dummy, self).__init__() 370 | 371 | def __call__(*args): 372 | return args 373 | 374 | def output_dims(*args): 375 | return args 376 | 377 | def __init__(self, inputs, name='node'): 378 | self.module_type, self.module_args = self.dummy, {} 379 | self.output_dims = [] 380 | self.inputs = inputs 381 | self.input_dims, self.module = None, None 382 | self.computed = None 383 | self.id = None 384 | self.name = name 385 | 386 | for c, inp in enumerate(self.inputs): 387 | inp[0].outputs.append((self, c)) 388 | 389 | def run_backward(self, op_list): 390 | return [(self.id, 0)] 391 | 392 | class ReversibleGraphNet(nn.Module): 393 | '''This class represents the invertible net itself. It is a subclass of 394 | torch.nn.Module and supports the same methods. The forward method has an 395 | additional option 'rev', whith which the net can be computed in reverse.''' 396 | 397 | def __init__(self, node_list, ind_in=None, ind_out=None, verbose=False, n_jac=1): 398 | '''node_list should be a list of all nodes involved, and ind_in, 399 | ind_out are the indexes of the special nodes InputNode and OutputNode 400 | in this list.''' 401 | super(ReversibleGraphNet, self).__init__() 402 | 403 | # Gather lists of input and output nodes 404 | if ind_in is not None: 405 | if isinstance(ind_in, int): 406 | self.ind_in = list([ind_in]) 407 | else: 408 | self.ind_in = ind_in 409 | else: 410 | self.ind_in = [i for i in range(len(node_list)) 411 | if isinstance(node_list[i], InputNode)] 412 | assert len(self.ind_in) > 0, "No input nodes specified." 413 | if ind_out is not None: 414 | if isinstance(ind_out, int): 415 | self.ind_out = list([ind_out]) 416 | else: 417 | self.ind_out = ind_out 418 | else: 419 | self.ind_out = [i for i in range(len(node_list)) 420 | if isinstance(node_list[i], OutputNode)] 421 | assert len(self.ind_out) > 0, "No output nodes specified." 422 | 423 | self.return_vars = [] 424 | self.input_vars = [] 425 | 426 | # Assign each node a unique ID 427 | self.node_list = node_list 428 | for i, n in enumerate(node_list): 429 | n.id = i 430 | 431 | # Recursively build the nodes nn.Modules and determine order of 432 | # operations 433 | ops = [] 434 | for i in self.ind_out: 435 | node_list[i].build_modules(verbose=verbose) 436 | node_list[i].run_forward(ops) 437 | 438 | # create list of Pytorch variables that are used 439 | variables = set() 440 | for o in ops: 441 | variables = variables.union(set(o[1] + o[2])) 442 | self.variables_ind = list(variables) 443 | 444 | self.indexed_ops = self.ops_to_indexed(ops) 445 | 446 | self.module_list = nn.ModuleList([n.module for n in node_list]) 447 | self.variable_list = [Variable(requires_grad=True) for v in variables] 448 | 449 | # Find out the order of operations for reverse calculations 450 | ops_rev = [] 451 | for i in self.ind_in: 452 | node_list[i].run_backward(ops_rev) 453 | self.indexed_ops_rev = self.ops_to_indexed(ops_rev) 454 | self.n_jac = n_jac 455 | 456 | def ops_to_indexed(self, ops): 457 | '''Helper function to translate the list of variables (origin ID, channel), 458 | to variable IDs.''' 459 | result = [] 460 | 461 | for o in ops: 462 | try: 463 | vars_in = [self.variables_ind.index(v) for v in o[1]] 464 | except ValueError: 465 | vars_in = -1 466 | 467 | vars_out = [self.variables_ind.index(v) for v in o[2]] 468 | 469 | # Collect input/output nodes in separate lists, but don't add to 470 | # indexed ops 471 | if o[0] in self.ind_out: 472 | self.return_vars.append(self.variables_ind.index(o[1][0])) 473 | continue 474 | if o[0] in self.ind_in: 475 | self.input_vars.append(self.variables_ind.index(o[1][0])) 476 | continue 477 | 478 | result.append((o[0], vars_in, vars_out)) 479 | 480 | # Sort input/output variables so they correspond to initial node list 481 | # order 482 | self.return_vars.sort(key=lambda i: self.variables_ind[i][0]) 483 | self.input_vars.sort(key=lambda i: self.variables_ind[i][0]) 484 | 485 | return result 486 | 487 | def forward(self, x, rev=False): 488 | '''Forward or backward computation of the whole net.''' 489 | if rev: 490 | use_list = self.indexed_ops_rev 491 | input_vars, output_vars = self.return_vars, self.input_vars 492 | else: 493 | use_list = self.indexed_ops 494 | input_vars, output_vars = self.input_vars, self.return_vars 495 | 496 | if isinstance(x, (list, tuple)): 497 | assert len(x) == len(input_vars), ( 498 | f"Got list of {len(x)} input tensors for " 499 | f"{'inverse' if rev else 'forward'} pass, but expected " 500 | f"{len(input_vars)}." 501 | ) 502 | for i in range(len(input_vars)): 503 | self.variable_list[input_vars[i]] = x[i] 504 | else: 505 | assert len(input_vars) == 1, (f"Got single input tensor for " 506 | f"{'inverse' if rev else 'forward'} " 507 | f"pass, but expected list of " 508 | f"{len(input_vars)}.") 509 | self.variable_list[input_vars[0]] = x 510 | 511 | for o in use_list: 512 | try: 513 | results = self.module_list[o[0]]([self.variable_list[i] 514 | for i in o[1]], rev=rev) 515 | except TypeError: 516 | raise RuntimeError("Are you sure all used Nodes are in the " 517 | "Node list?") 518 | for i, r in zip(o[2], results): 519 | self.variable_list[i] = r 520 | # self.variable_list[o[2][0]] = self.variable_list[o[1][0]] 521 | 522 | out = [self.variable_list[output_vars[i]] 523 | for i in range(len(output_vars))] 524 | if len(out) == 1: 525 | return out[0] 526 | else: 527 | return out 528 | 529 | def jacobian(self, x=None, rev=False, run_forward=True): 530 | '''Compute the jacobian determinant of the whole net.''' 531 | jacobian = [0.] * self.n_jac 532 | 533 | if rev: 534 | use_list = self.indexed_ops_rev 535 | else: 536 | use_list = self.indexed_ops 537 | 538 | if run_forward: 539 | if x is None: 540 | raise RuntimeError("You need to provide an input if you want " 541 | "to run a forward pass") 542 | self.forward(x, rev=rev) 543 | 544 | for o in use_list: 545 | try: 546 | node_jac = self.module_list[o[0]].jacobian( 547 | [self.variable_list[i] for i in o[1]], rev=rev 548 | ) 549 | node_jac = [node_jac] if not isinstance(node_jac, list) else node_jac 550 | for i_j, jac in enumerate(node_jac): 551 | jacobian[i_j] += jac 552 | 553 | except TypeError: 554 | raise RuntimeError("Are you sure all used Nodes are in the " 555 | "Node list?") 556 | 557 | return jacobian 558 | 559 | 560 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """ 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | class LARS(Optimizer): 8 | r"""Implements layer-wise adaptive rate scaling for SGD. 9 | 10 | Args: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float): base learning rate (\gamma_0) 14 | momentum (float, optional): momentum factor (default: 0) ("m") 15 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 16 | ("\beta") 17 | eta (float, optional): LARS coefficient 18 | max_epoch: maximum training epoch to determine polynomial LR decay. 19 | 20 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 21 | Large Batch Training of Convolutional Networks: 22 | https://arxiv.org/abs/1708.03888 23 | 24 | Example: 25 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) 26 | >>> optimizer.zero_grad() 27 | >>> loss_fn(model(input), target).backward() 28 | >>> optimizer.step() 29 | """ 30 | def __init__(self, params, lr=required, momentum=.9, 31 | weight_decay=.0005, eta=0.001, max_epoch=200): 32 | if lr is not required and lr < 0.0: 33 | raise ValueError("Invalid learning rate: {}".format(lr)) 34 | if momentum < 0.0: 35 | raise ValueError("Invalid momentum value: {}".format(momentum)) 36 | if weight_decay < 0.0: 37 | raise ValueError("Invalid weight_decay value: {}" 38 | .format(weight_decay)) 39 | if eta < 0.0: 40 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 41 | 42 | self.epoch = 0 43 | defaults = dict(lr=lr, momentum=momentum, 44 | weight_decay=weight_decay, 45 | eta=eta, max_epoch=max_epoch) 46 | super(LARS, self).__init__(params, defaults) 47 | 48 | def step(self, epoch=None, closure=None): 49 | """Performs a single optimization step. 50 | 51 | Arguments: 52 | closure (callable, optional): A closure that reevaluates the model 53 | and returns the loss. 54 | epoch: current epoch to calculate polynomial LR decay schedule. 55 | if None, uses self.epoch and increments it. 56 | """ 57 | loss = None 58 | if closure is not None: 59 | loss = closure() 60 | 61 | if epoch is None: 62 | epoch = self.epoch 63 | self.epoch += 1 64 | 65 | for group in self.param_groups: 66 | weight_decay = group['weight_decay'] 67 | momentum = group['momentum'] 68 | eta = group['eta'] 69 | lr = group['lr'] 70 | max_epoch = group['max_epoch'] 71 | 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | 76 | param_state = self.state[p] 77 | d_p = p.grad.data 78 | 79 | weight_norm = torch.norm(p.data) 80 | grad_norm = torch.norm(d_p) 81 | 82 | # Global LR computed on polynomial decay schedule 83 | decay = (1 - float(epoch) / max_epoch) ** 2 84 | global_lr = lr * decay 85 | 86 | # Compute local learning rate for this layer 87 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm) 88 | 89 | # Update the momentum term 90 | actual_lr = local_lr * global_lr 91 | 92 | if 'momentum_buffer' not in param_state: 93 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 94 | else: 95 | buf = param_state['momentum_buffer'] 96 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr) 97 | p.data.add_(-buf) 98 | 99 | return loss 100 | 101 | class LR_Scheduler(object): 102 | def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, 103 | constant_predictor_lr=False): 104 | self.base_lr = base_lr 105 | self.constant_predictor_lr = constant_predictor_lr 106 | warmup_iter = iter_per_epoch * warmup_epochs 107 | warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) 108 | decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) 109 | cosine_lr_schedule = final_lr + 0.5 * (base_lr - final_lr) * ( 110 | 1 + np.cos(np.pi * np.arange(decay_iter) / decay_iter)) 111 | 112 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) 113 | self.optimizer = optimizer 114 | self.iter = 0 115 | self.current_lr = 0 116 | self.total_iters = num_epochs * iter_per_epoch 117 | 118 | def step(self): 119 | for param_group in self.optimizer.param_groups: 120 | 121 | if self.constant_predictor_lr and param_group['name'] == 'predictor': 122 | param_group['lr'] = self.base_lr 123 | else: 124 | lr = param_group['lr'] = self.lr_schedule[self.iter] 125 | 126 | self.iter += 1 127 | self.current_lr = lr 128 | 129 | if self.iter >= self.total_iters: 130 | self.reset() 131 | return lr 132 | 133 | def reset(self): 134 | self.iter = 0 135 | self.current_lr = 0 136 | 137 | def get_lr(self): 138 | return self.current_lr 139 | 140 | 141 | def get_optimizer(args, model): 142 | params = model.parameters() 143 | 144 | if args.train.optimizer.name == 'lars': 145 | optimizer = LARS(params, lr=args.train.base_lr, momentum=args.train.optimizer.momentum, weight_decay=args.train.optimizer.weight_decay) 146 | elif args.train.optimizer.name == 'sgd': 147 | optimizer = torch.optim.SGD(params, lr=args.train.base_lr, momentum=args.train.optimizer.momentum, weight_decay=args.train.optimizer.weight_decay) 148 | elif args.train.optimizer.name == 'adam': 149 | optimizer = torch.optim.Adam(params, lr=args.train.base_lr, eps=0.0004, weight_decay=args.train.optimizer.weight_decay) 150 | # optimizer = torch.optim.Adam([ 151 | # {'params': params}, 152 | # {'params': z_head_params, 'lr': args.train.base_lr * 3}, 153 | # {'params': head_params, 'lr': args.train.base_lr * 0.1} 154 | # ], lr=args.train.base_lr, eps=0.0004, weight_decay=args.train.optimizer.weight_decay) 155 | else: 156 | raise NotImplementedError 157 | return optimizer 158 | 159 | def get_head_optimizer(args, model): 160 | params = model.head.parameters() 161 | 162 | if args.train.optimizer.name == 'lars': 163 | optimizer = LARS(params, lr=args.train.base_lr, momentum=args.train.optimizer.momentum, weight_decay=args.train.optimizer.weight_decay) 164 | elif args.train.optimizer.name == 'sgd': 165 | optimizer = torch.optim.SGD(params, lr=args.train.base_lr, momentum=args.train.optimizer.momentum, weight_decay=args.train.optimizer.weight_decay) 166 | elif args.train.optimizer.name == 'adam': 167 | optimizer = torch.optim.Adam(params, lr=args.train.base_lr * 3, eps=0.0004, weight_decay=args.train.optimizer.weight_decay) 168 | else: 169 | raise NotImplementedError 170 | return optimizer 171 | 172 | -------------------------------------------------------------------------------- /utils/rd_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | try: 6 | from torch.hub import load_state_dict_from_url 7 | except ImportError: 8 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 9 | from typing import Type, Any, Callable, Union, List, Optional 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 13 | 'wide_resnet50_2', 'wide_resnet101_2'] 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 25 | } 26 | 27 | 28 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 29 | """3x3 convolution with padding""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=dilation, groups=groups, bias=False, dilation=dilation) 32 | 33 | 34 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion: int = 1 41 | 42 | def __init__( 43 | self, 44 | inplanes: int, 45 | planes: int, 46 | stride: int = 1, 47 | downsample: Optional[nn.Module] = None, 48 | groups: int = 1, 49 | base_width: int = 64, 50 | dilation: int = 1, 51 | norm_layer: Optional[Callable[..., nn.Module]] = None 52 | ) -> None: 53 | super(BasicBlock, self).__init__() 54 | if norm_layer is None: 55 | norm_layer = nn.BatchNorm2d 56 | if groups != 1 or base_width != 64: 57 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 58 | if dilation > 1: 59 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 60 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 61 | self.conv1 = conv3x3(inplanes, planes, stride) 62 | self.bn1 = norm_layer(planes) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.conv2 = conv3x3(planes, planes) 65 | self.bn2 = norm_layer(planes) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | identity = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample(x) 81 | 82 | out += identity 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class Bottleneck(nn.Module): 89 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 90 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 91 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 92 | # This variant is also known as ResNet V1.5 and improves accuracy according to 93 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 94 | 95 | expansion: int = 4 96 | 97 | def __init__( 98 | self, 99 | inplanes: int, 100 | planes: int, 101 | stride: int = 1, 102 | downsample: Optional[nn.Module] = None, 103 | groups: int = 1, 104 | base_width: int = 64, 105 | dilation: int = 1, 106 | norm_layer: Optional[Callable[..., nn.Module]] = None 107 | ) -> None: 108 | super(Bottleneck, self).__init__() 109 | if norm_layer is None: 110 | norm_layer = nn.BatchNorm2d 111 | width = int(planes * (base_width / 64.)) * groups 112 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 113 | self.conv1 = conv1x1(inplanes, width) 114 | self.bn1 = norm_layer(width) 115 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 116 | self.bn2 = norm_layer(width) 117 | self.conv3 = conv1x1(width, planes * self.expansion) 118 | self.bn3 = norm_layer(planes * self.expansion) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.downsample = downsample 121 | self.stride = stride 122 | 123 | def forward(self, x: Tensor) -> Tensor: 124 | identity = x 125 | 126 | out = self.conv1(x) 127 | out = self.bn1(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv2(out) 131 | out = self.bn2(out) 132 | out = self.relu(out) 133 | 134 | out = self.conv3(out) 135 | out = self.bn3(out) 136 | 137 | if self.downsample is not None: 138 | identity = self.downsample(x) 139 | 140 | out += identity 141 | out = self.relu(out) 142 | 143 | return out 144 | 145 | 146 | class ResNet(nn.Module): 147 | 148 | def __init__( 149 | self, 150 | block: Type[Union[BasicBlock, Bottleneck]], 151 | layers: List[int], 152 | num_classes: int = 1000, 153 | zero_init_residual: bool = False, 154 | groups: int = 1, 155 | width_per_group: int = 64, 156 | replace_stride_with_dilation: Optional[List[bool]] = None, 157 | norm_layer: Optional[Callable[..., nn.Module]] = None 158 | ) -> None: 159 | super(ResNet, self).__init__() 160 | if norm_layer is None: 161 | norm_layer = nn.BatchNorm2d 162 | self._norm_layer = norm_layer 163 | 164 | self.inplanes = 64 165 | self.dilation = 1 166 | if replace_stride_with_dilation is None: 167 | # each element in the tuple indicates if we should replace 168 | # the 2x2 stride with a dilated convolution instead 169 | replace_stride_with_dilation = [False, False, False] 170 | if len(replace_stride_with_dilation) != 3: 171 | raise ValueError("replace_stride_with_dilation should be None " 172 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 173 | self.groups = groups 174 | self.base_width = width_per_group 175 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 176 | bias=False) 177 | self.bn1 = norm_layer(self.inplanes) 178 | self.relu = nn.ReLU(inplace=True) 179 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 180 | self.layer1 = self._make_layer(block, 64, layers[0]) 181 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 182 | dilate=replace_stride_with_dilation[0]) 183 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 184 | dilate=replace_stride_with_dilation[1]) 185 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 186 | dilate=replace_stride_with_dilation[2]) 187 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 188 | self.fc = nn.Linear(512 * block.expansion, num_classes) 189 | 190 | for m in self.modules(): 191 | if isinstance(m, nn.Conv2d): 192 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 193 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 194 | nn.init.constant_(m.weight, 1) 195 | nn.init.constant_(m.bias, 0) 196 | 197 | # Zero-initialize the last BN in each residual branch, 198 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 199 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 200 | if zero_init_residual: 201 | for m in self.modules(): 202 | if isinstance(m, Bottleneck): 203 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 204 | elif isinstance(m, BasicBlock): 205 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 206 | 207 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 208 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 209 | norm_layer = self._norm_layer 210 | downsample = None 211 | previous_dilation = self.dilation 212 | if dilate: 213 | self.dilation *= stride 214 | stride = 1 215 | if stride != 1 or self.inplanes != planes * block.expansion: 216 | downsample = nn.Sequential( 217 | conv1x1(self.inplanes, planes * block.expansion, stride), 218 | norm_layer(planes * block.expansion), 219 | ) 220 | 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 223 | self.base_width, previous_dilation, norm_layer)) 224 | self.inplanes = planes * block.expansion 225 | for _ in range(1, blocks): 226 | layers.append(block(self.inplanes, planes, groups=self.groups, 227 | base_width=self.base_width, dilation=self.dilation, 228 | norm_layer=norm_layer)) 229 | 230 | return nn.Sequential(*layers) 231 | 232 | def _forward_impl(self, x: Tensor) -> Tensor: 233 | # See note [TorchScript super()] 234 | x = self.conv1(x) 235 | x = self.bn1(x) 236 | x = self.relu(x) 237 | x = self.maxpool(x) 238 | 239 | feature_a = self.layer1(x) 240 | feature_b = self.layer2(feature_a) 241 | feature_c = self.layer3(feature_b) 242 | feature_d = self.layer4(feature_c) 243 | 244 | return [feature_a, feature_b, feature_c] 245 | 246 | def forward(self, x: Tensor) -> Tensor: 247 | return self._forward_impl(x) 248 | 249 | 250 | def _resnet( 251 | arch: str, 252 | block: Type[Union[BasicBlock, Bottleneck]], 253 | layers: List[int], 254 | pretrained: bool, 255 | progress: bool, 256 | **kwargs: Any 257 | ) -> ResNet: 258 | model = ResNet(block, layers, **kwargs) 259 | if pretrained: 260 | state_dict = load_state_dict_from_url(model_urls[arch], 261 | progress=progress) 262 | # for k,v in list(state_dict.items()): 263 | # if 'layer4' in k or 'fc' in k: 264 | # state_dict.pop(k) 265 | model.load_state_dict(state_dict) 266 | return model 267 | 268 | 269 | class AttnBasicBlock(nn.Module): 270 | expansion: int = 1 271 | 272 | def __init__( 273 | self, 274 | inplanes: int, 275 | planes: int, 276 | stride: int = 1, 277 | downsample: Optional[nn.Module] = None, 278 | groups: int = 1, 279 | base_width: int = 64, 280 | dilation: int = 1, 281 | norm_layer: Optional[Callable[..., nn.Module]] = None, 282 | attention: bool = True, 283 | ) -> None: 284 | super(AttnBasicBlock, self).__init__() 285 | self.attention = attention 286 | # print("Attention:", self.attention) 287 | if norm_layer is None: 288 | norm_layer = nn.BatchNorm2d 289 | if groups != 1 or base_width != 64: 290 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 291 | if dilation > 1: 292 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 293 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 294 | self.conv1 = conv3x3(inplanes, planes, stride) 295 | self.bn1 = norm_layer(planes) 296 | self.relu = nn.ReLU(inplace=True) 297 | self.conv2 = conv3x3(planes, planes) 298 | self.bn2 = norm_layer(planes) 299 | # self.cbam = GLEAM(planes, 16) 300 | self.downsample = downsample 301 | self.stride = stride 302 | 303 | def forward(self, x: Tensor) -> Tensor: 304 | # if self.attention: 305 | # x = self.cbam(x) 306 | identity = x 307 | 308 | out = self.conv1(x) 309 | out = self.bn1(out) 310 | out = self.relu(out) 311 | 312 | out = self.conv2(out) 313 | out = self.bn2(out) 314 | 315 | if self.downsample is not None: 316 | identity = self.downsample(x) 317 | 318 | out += identity 319 | out = self.relu(out) 320 | 321 | return out 322 | 323 | 324 | class AttnBottleneck(nn.Module): 325 | expansion: int = 4 326 | 327 | def __init__( 328 | self, 329 | inplanes: int, 330 | planes: int, 331 | stride: int = 1, 332 | downsample: Optional[nn.Module] = None, 333 | groups: int = 1, 334 | base_width: int = 64, 335 | dilation: int = 1, 336 | norm_layer: Optional[Callable[..., nn.Module]] = None, 337 | attention: bool = True, 338 | ) -> None: 339 | super(AttnBottleneck, self).__init__() 340 | self.attention = attention 341 | # print("Attention:",self.attention) 342 | if norm_layer is None: 343 | norm_layer = nn.BatchNorm2d 344 | width = int(planes * (base_width / 64.)) * groups 345 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 346 | self.conv1 = conv1x1(inplanes, width) 347 | self.bn1 = norm_layer(width) 348 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 349 | self.bn2 = norm_layer(width) 350 | self.conv3 = conv1x1(width, planes * self.expansion) 351 | self.bn3 = norm_layer(planes * self.expansion) 352 | self.relu = nn.ReLU(inplace=True) 353 | # self.cbam = GLEAM([int(planes * self.expansion/4), 354 | # int(planes * self.expansion//2), 355 | # planes * self.expansion], 16) 356 | self.downsample = downsample 357 | self.stride = stride 358 | 359 | def forward(self, x: Tensor) -> Tensor: 360 | # if self.attention: 361 | # x = self.cbam(x) 362 | identity = x 363 | 364 | out = self.conv1(x) 365 | out = self.bn1(out) 366 | out = self.relu(out) 367 | 368 | out = self.conv2(out) 369 | out = self.bn2(out) 370 | out = self.relu(out) 371 | 372 | out = self.conv3(out) 373 | out = self.bn3(out) 374 | 375 | if self.downsample is not None: 376 | identity = self.downsample(x) 377 | 378 | out += identity 379 | out = self.relu(out) 380 | 381 | return out 382 | 383 | 384 | class BN_layer(nn.Module): 385 | def __init__(self, 386 | block: Type[Union[BasicBlock, Bottleneck]], 387 | layers: int, 388 | groups: int = 1, 389 | width_per_group: int = 64, 390 | norm_layer: Optional[Callable[..., nn.Module]] = None, 391 | ): 392 | super(BN_layer, self).__init__() 393 | if norm_layer is None: 394 | norm_layer = nn.BatchNorm2d 395 | self._norm_layer = norm_layer 396 | self.groups = groups 397 | self.base_width = width_per_group 398 | self.inplanes = 256 * block.expansion 399 | self.dilation = 1 400 | self.bn_layer = self._make_layer(block, 512, layers, stride=2) 401 | 402 | self.conv1 = conv3x3(64 * block.expansion, 128 * block.expansion, 2) 403 | self.bn1 = norm_layer(128 * block.expansion) 404 | self.relu = nn.ReLU(inplace=True) 405 | self.conv2 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) 406 | self.bn2 = norm_layer(256 * block.expansion) 407 | self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) 408 | self.bn3 = norm_layer(256 * block.expansion) 409 | 410 | self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1) 411 | self.bn4 = norm_layer(512 * block.expansion) 412 | 413 | for m in self.modules(): 414 | if isinstance(m, nn.Conv2d): 415 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 416 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 417 | nn.init.constant_(m.weight, 1) 418 | nn.init.constant_(m.bias, 0) 419 | 420 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 421 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 422 | norm_layer = self._norm_layer 423 | downsample = None 424 | previous_dilation = self.dilation 425 | if dilate: 426 | self.dilation *= stride 427 | stride = 1 428 | if stride != 1 or self.inplanes != planes * block.expansion: 429 | downsample = nn.Sequential( 430 | conv1x1(self.inplanes * 3, planes * block.expansion, stride), 431 | norm_layer(planes * block.expansion), 432 | ) 433 | 434 | layers = [] 435 | layers.append(block(self.inplanes * 3, planes, stride, downsample, self.groups, 436 | self.base_width, previous_dilation, norm_layer)) 437 | self.inplanes = planes * block.expansion 438 | for _ in range(1, blocks): 439 | layers.append(block(self.inplanes, planes, groups=self.groups, 440 | base_width=self.base_width, dilation=self.dilation, 441 | norm_layer=norm_layer)) 442 | 443 | return nn.Sequential(*layers) 444 | 445 | def _forward_impl(self, x: Tensor) -> Tensor: 446 | # See note [TorchScript super()] 447 | # x = self.cbam(x) 448 | l1 = self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x[0])))))) 449 | l2 = self.relu(self.bn3(self.conv3(x[1]))) 450 | feature = torch.cat([l1, l2, x[2]], 1) 451 | output = self.bn_layer(feature) 452 | # x = self.avgpool(feature_d) 453 | # x = torch.flatten(x, 1) 454 | # x = self.fc(x) 455 | 456 | return output.contiguous() 457 | 458 | def forward(self, x: Tensor) -> Tensor: 459 | return self._forward_impl(x) 460 | 461 | 462 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 463 | r"""ResNet-18 model from 464 | `"Deep Residual Learning for Image Recognition" `_. 465 | Args: 466 | pretrained (bool): If True, returns a model pre-trained on ImageNet 467 | progress (bool): If True, displays a progress bar of the download to stderr 468 | """ 469 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 470 | **kwargs), BN_layer(AttnBasicBlock, 2, **kwargs) 471 | 472 | 473 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 474 | r"""ResNet-34 model from 475 | `"Deep Residual Learning for Image Recognition" `_. 476 | Args: 477 | pretrained (bool): If True, returns a model pre-trained on ImageNet 478 | progress (bool): If True, displays a progress bar of the download to stderr 479 | """ 480 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 481 | **kwargs), BN_layer(AttnBasicBlock, 3, **kwargs) 482 | 483 | 484 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 485 | r"""ResNet-50 model from 486 | `"Deep Residual Learning for Image Recognition" `_. 487 | Args: 488 | pretrained (bool): If True, returns a model pre-trained on ImageNet 489 | progress (bool): If True, displays a progress bar of the download to stderr 490 | """ 491 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 492 | **kwargs), BN_layer(AttnBottleneck, 3, **kwargs) 493 | 494 | 495 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 496 | r"""ResNet-101 model from 497 | `"Deep Residual Learning for Image Recognition" `_. 498 | Args: 499 | pretrained (bool): If True, returns a model pre-trained on ImageNet 500 | progress (bool): If True, displays a progress bar of the download to stderr 501 | """ 502 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 503 | **kwargs), BN_layer(AttnBasicBlock, 3, **kwargs) 504 | 505 | 506 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 507 | r"""ResNet-152 model from 508 | `"Deep Residual Learning for Image Recognition" `_. 509 | Args: 510 | pretrained (bool): If True, returns a model pre-trained on ImageNet 511 | progress (bool): If True, displays a progress bar of the download to stderr 512 | """ 513 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 514 | **kwargs), BN_layer(AttnBottleneck, 3, **kwargs) 515 | 516 | 517 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 518 | r"""ResNeXt-50 32x4d model from 519 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 520 | Args: 521 | pretrained (bool): If True, returns a model pre-trained on ImageNet 522 | progress (bool): If True, displays a progress bar of the download to stderr 523 | """ 524 | kwargs['groups'] = 32 525 | kwargs['width_per_group'] = 4 526 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 527 | pretrained, progress, **kwargs) 528 | 529 | 530 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 531 | r"""ResNeXt-101 32x8d model from 532 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 533 | Args: 534 | pretrained (bool): If True, returns a model pre-trained on ImageNet 535 | progress (bool): If True, displays a progress bar of the download to stderr 536 | """ 537 | kwargs['groups'] = 32 538 | kwargs['width_per_group'] = 8 539 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 540 | pretrained, progress, **kwargs) 541 | 542 | 543 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 544 | r"""Wide ResNet-50-2 model from 545 | `"Wide Residual Networks" `_. 546 | The model is the same as ResNet except for the bottleneck number of channels 547 | which is twice larger in every block. The number of channels in outer 1x1 548 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 549 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 550 | Args: 551 | pretrained (bool): If True, returns a model pre-trained on ImageNet 552 | progress (bool): If True, displays a progress bar of the download to stderr 553 | """ 554 | kwargs['width_per_group'] = 64 * 2 555 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 556 | pretrained, progress, **kwargs), BN_layer(AttnBottleneck, 3, **kwargs) 557 | 558 | 559 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 560 | r"""Wide ResNet-101-2 model from 561 | `"Wide Residual Networks" `_. 562 | The model is the same as ResNet except for the bottleneck number of channels 563 | which is twice larger in every block. The number of channels in outer 1x1 564 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 565 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 566 | Args: 567 | pretrained (bool): If True, returns a model pre-trained on ImageNet 568 | progress (bool): If True, displays a progress bar of the download to stderr 569 | """ 570 | kwargs['width_per_group'] = 64 * 2 571 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 572 | pretrained, progress, **kwargs), BN_layer(AttnBottleneck, 3, **kwargs) 573 | 574 | 575 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_curve, auc, roc_auc_score 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from sklearn.manifold import TSNE 7 | from sklearn.utils import shuffle 8 | import matplotlib.pyplot as plt 9 | from typing import Any, Dict, Tuple, Union 10 | import os 11 | from copy import deepcopy 12 | import matplotlib.ticker as ticker 13 | 14 | 15 | 16 | # @staticmethod 17 | def plot_tsne(labels, embeds, defect_name=None, save_path = None, **kwargs: Dict[str, Any]): 18 | """t-SNE visualize 19 | Args: 20 | labels (Tensor): labels of test and train 21 | embeds (Tensor): embeds of test and train 22 | defect_name ([str], optional): same as in roc_auc. Defaults to None. 23 | save_path ([str], optional): same as in roc_auc. Defaults to None. 24 | kwargs (Dict[str, Any]): hyper parameters of t-SNE which will change final result 25 | n_iter (int): > 250, default = 1000 26 | learning_rate (float): (10-1000), default = 100 27 | perplexity (float): (5-50), default = 28 28 | early_exaggeration (float): change it when not converging, default = 12 29 | angle (float): (0.2-0.8), default = 0.3 30 | init (str): "random" or "pca", default = "pca" 31 | """ 32 | tsne = TSNE( 33 | n_components=2, 34 | verbose=1, 35 | n_iter=kwargs.get("n_iter", 1000), 36 | learning_rate=kwargs.get("learning_rate", 100), 37 | perplexity=kwargs.get("perplexity", 28), 38 | early_exaggeration=kwargs.get("early_exaggeration", 12), 39 | angle=kwargs.get("angle", 0.3), 40 | init=kwargs.get("init", "pca"), 41 | ) 42 | embeds, labels = shuffle(embeds, labels) 43 | tsne_results = tsne.fit_transform(embeds) 44 | 45 | cmap = plt.cm.get_cmap("spring") 46 | colors = np.vstack((np.array([[0, 1. ,0, 1.]]), cmap([0, 256//3, (2*256)//3]))) 47 | legends = ["good", "anomaly"] 48 | (_, ax) = plt.subplots(1) 49 | plt.title(f't-SNE: {defect_name}') 50 | for label in torch.unique(labels): 51 | res = tsne_results[torch.where(labels==label)] 52 | ax.plot(*res.T, marker="*", linestyle="", ms=5, label=legends[label], color=colors[label]) 53 | ax.legend(loc="best") 54 | plt.xticks([]) 55 | plt.yticks([]) 56 | 57 | save_images = save_path if save_path else './tnse_results' 58 | os.makedirs(save_images, exist_ok=True) 59 | image_path = os.path.join(save_images, defect_name+'_tsne.pdf') if defect_name else os.path.join(save_images, 'tsne.pdf') 60 | plt.savefig(image_path) 61 | plt.close() 62 | return 63 | 64 | 65 | def compare_histogram(scores, classes, start=0 ,thresh=2, interval=1, n_bins=64, name=None, save_path=None): 66 | classes = deepcopy(classes) 67 | classes[classes > 0] = 1 68 | scores[scores > thresh] = thresh 69 | bins = np.linspace(np.min(scores), np.max(scores), n_bins) 70 | scores_norm = scores[classes == 0] 71 | scores_ano = scores[classes == 1] 72 | 73 | plt.clf() 74 | plt.figure(figsize=(7, 5), dpi=120) 75 | 76 | plt.hist(scores_norm, bins, alpha=0.5, density=True, label='non-defects', color='cyan', edgecolor="black") 77 | plt.hist(scores_ano, bins, alpha=0.5, density=True, label='defects', color='crimson', edgecolor="black") 78 | plt.gca().yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f')) 79 | ticks = np.linspace(start, thresh, interval) 80 | labels = [str(i) for i in ticks[:-1]] + ['>' + str(thresh)] 81 | 82 | save_images = save_path if save_path else './his_results1' 83 | os.makedirs(save_images, exist_ok=True) 84 | image_path = os.path.join(save_images, name + '_his.pdf') if name else os.path.join(save_images, 'his.pdf') 85 | 86 | plt.yticks(rotation=24) 87 | plt.xlabel(r'$-log(p(z))$', fontsize=10) 88 | plt.tick_params(labelsize=10) 89 | plt.autoscale() 90 | plt.xticks([], []) 91 | plt.savefig(image_path, bbox_inches='tight', pad_inches=0) 92 | 93 | 94 | def cal_anomaly_map(fs_list, ft_list, out_size=224, amap_mode='mul'): 95 | if amap_mode == 'mul': 96 | anomaly_map = np.ones([out_size, out_size]) 97 | else: 98 | anomaly_map = np.zeros([out_size, out_size]) 99 | a_map_list = [] 100 | for i in range(len(ft_list)): 101 | fs = fs_list[i] 102 | ft = ft_list[i] 103 | #fs_norm = F.normalize(fs, p=2) 104 | #ft_norm = F.normalize(ft, p=2) 105 | a_map = 1 - F.cosine_similarity(fs, ft) 106 | a_map = torch.unsqueeze(a_map, dim=1) 107 | a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True) 108 | a_map = a_map[0, 0, :, :].to('cpu').detach().numpy() 109 | a_map_list.append(a_map) 110 | if amap_mode == 'mul': 111 | anomaly_map *= a_map 112 | else: 113 | anomaly_map += a_map 114 | return anomaly_map, a_map_list 115 | 116 | --------------------------------------------------------------------------------