├── .gitattributes ├── LICENSE ├── README.md ├── code ├── datasets │ ├── adecuate_BRATS.py │ ├── adecuate_PhysionNet_ICH.py │ ├── datasets.py │ └── utils.py ├── evaluation │ ├── metrics.py │ ├── readme │ └── utils.py ├── main.py ├── methods │ ├── losses │ │ ├── losses.py │ │ └── readme │ ├── train.py │ └── trainers │ │ ├── AMCons.py │ │ ├── ae.py │ │ ├── anoVAEGAN.py │ │ ├── fanoGAN.py │ │ ├── gradCAMCons.py │ │ ├── gradCons.py │ │ ├── histEqualization.py │ │ └── vae.py └── models │ └── models.py ├── data ├── MICCAI_BraTS_2019_Data_Training │ └── README.md └── PhysioNet-ICH │ ├── README.md │ └── readme └── images └── visualizations.png /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Julio Silva-Rodríguez 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Constrained Unsupervised Anomaly Segmentation of Brain Lesions 2 | 3 | This repository contains code for unsupervised anomaly segmentation in brain lesions. Specifically, the implemented methods aim to constrain the optimization process to force a VAE to homogenize the activations produced in normal samples. 4 | 5 | If you find these methods useful for your research, please consider citing: 6 | 7 | **J. Silva-Rodríguez, V. Naranjo and J. Dolz, "Looking at the whole picture: constrained unsupervised anomaly segmentation", in British Machine Vision Conference (BMVC), 2021.** [(paper)](https://www.bmvc2021-virtualconference.com/assets/papers/1011.pdf)[(conference)](https://www.bmvc2021-virtualconference.com/conference/papers/paper_1011.html) 8 | 9 | **J. Silva-Rodríguez, V. Naranjo and J. Dolz, "Constrained unsupervised anomaly segmentation", Medical Image Analysis, vol. 80, p. 102526, 2022.** [(paper)](https://www.sciencedirect.com/science/article/pii/S1361841522001736) 10 | 11 | ## GRADCAMCons: looking at the whole picture via size constraints 12 | 13 | ``` 14 | python main.py --dir_out ../data/results/gradCAMCons/ --method gradCAMCons --learning_rate 0.00001 --wkl 1 --wae 1000 --t 10 15 | ``` 16 | 17 | ## AMCons: entropy maximization on activation maps 18 | 19 | ``` 20 | python main.py --dir_out ../data/results/AMCon/ --method camCons --learning_rate 0.0001 --wkl 10 --wH 0.1 21 | ``` 22 | 23 | ## Visualizations 24 | 25 |

26 | 27 |

28 | 29 | ## Contact 30 | For further questions or details, please directly reach out Julio Silva-Rodríguez (jusiro95@gmail.com) 31 | -------------------------------------------------------------------------------- /code/datasets/adecuate_BRATS.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import os 3 | import numpy as np 4 | import random 5 | import argparse 6 | import cv2 7 | 8 | np.random.seed(42) 9 | random.seed(42) 10 | 11 | 12 | def adecuate_BRATS(args): 13 | 14 | dir_dataset = args.dir_dataset 15 | dir_out = args.dir_out 16 | scan = args.scan 17 | nSlices = args.nSlices 18 | 19 | partitions = ['train', 'val', 'test'] 20 | Ncases = np.array([271, 32, 32]) 21 | 22 | if not os.path.isdir(dir_out): 23 | os.mkdir(dir_out) 24 | if not os.path.isdir(dir_out + '/' + scan + '/'): 25 | os.mkdir(dir_out + '/' + scan + '/') 26 | 27 | cases_LGG = os.listdir(dir_dataset + 'LGG/') 28 | cases_LGG = [dir_dataset + 'LGG/' + iCase for iCase in cases_LGG if iCase != '.DS_Store'] 29 | 30 | cases_HGG = os.listdir(dir_dataset + 'HGG/') 31 | cases_HGG = [dir_dataset + 'HGG/' + iCase for iCase in cases_HGG if iCase != '.DS_Store'] 32 | 33 | cases = cases_LGG + cases_HGG 34 | 35 | random.shuffle(cases) 36 | 37 | for iPartition in np.arange(0, len(partitions)): 38 | 39 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition]): 40 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition]) 41 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/benign'): 42 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/benign') 43 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/malign'): 44 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/malign') 45 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth'): 46 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth') 47 | 48 | cases_partition = cases[np.sum(Ncases[:iPartition]):np.sum(Ncases[:iPartition+1])] 49 | 50 | c = 0 51 | for iCase in cases_partition: 52 | c += 1 53 | print(str(c) + '/' + str(len(cases_partition))) 54 | 55 | img_path = iCase + '/' + iCase.split('/')[-1] + '_' + scan + '.nii.gz' 56 | mask_path = iCase + '/' + iCase.split('/')[-1] + '_seg.nii.gz' 57 | 58 | img = nib.load(img_path) 59 | img = (img.get_fdata())[:, :, :] 60 | img = (img/img.max())*255 61 | img = img.astype(np.uint8) 62 | 63 | mask = nib.load(mask_path) 64 | mask = (mask.get_fdata()) 65 | mask[mask > 0] = 255 66 | mask = mask.astype(np.uint8) 67 | 68 | for iSlice in np.arange(round(img.shape[-1]/2) - nSlices, round(img.shape[-1]/2) + nSlices): 69 | filename = iCase.split('/')[-1] + '_' + str(iSlice) + '.jpg' 70 | 71 | i_image = img[:, :, iSlice] 72 | i_mask = mask[:, :, iSlice] 73 | 74 | if np.any(i_mask == 255): 75 | label = 'malign' 76 | cv2.imwrite(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth/' + filename, i_mask) 77 | else: 78 | label = 'benign' 79 | 80 | cv2.imwrite(dir_out + '/' + scan + '/' + partitions[iPartition] + '/' + label + '/' + filename, i_image) 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument("--dir_dataset", default='../data/MICCAI_BraTS_2019_Data_Training/', type=str) 86 | parser.add_argument("--dir_out", default='../data/BRATS_10slices/', type=str) 87 | parser.add_argument("--scan", default='flair', type=str) 88 | parser.add_argument("--nSlices", default=5, type=int) 89 | 90 | args = parser.parse_args() 91 | adecuate_BRATS(args) 92 | 93 | 94 | -------------------------------------------------------------------------------- /code/datasets/adecuate_PhysionNet_ICH.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import os 3 | import numpy as np 4 | import random 5 | import argparse 6 | import cv2 7 | from scipy import ndimage 8 | from skimage import measure 9 | from matplotlib import pyplot as plt 10 | import SimpleITK as sitk 11 | 12 | 13 | def volume_registration(fixed_image, moving_image, mask=None): 14 | 15 | fixed_image = sitk.GetImageFromArray(fixed_image) 16 | moving_image = sitk.GetImageFromArray(moving_image) 17 | # Initial transformation 18 | ''' 19 | transform_to_displacment_field_filter = sitk.TransformToDisplacementFieldFilter() 20 | transform_to_displacment_field_filter.SetReferenceImage(fixed_image) 21 | initial_transform = sitk.DisplacementFieldTransform( 22 | transform_to_displacment_field_filter.Execute(sitk.Transform(2, sitk.sitkIdentity))) 23 | ''' 24 | initial_transform = sitk.CenteredTransformInitializer(fixed_image, 25 | moving_image, 26 | sitk.Euler3DTransform(), 27 | sitk.CenteredTransformInitializerFilter.GEOMETRY) 28 | 29 | registration_method = sitk.ImageRegistrationMethod() 30 | # Similarity metric settings. 31 | #registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) 32 | registration_method.SetMetricAsCorrelation() 33 | registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) 34 | registration_method.SetMetricSamplingPercentage(0.01) 35 | 36 | registration_method.SetInterpolator(sitk.sitkLinear) 37 | # Optimizer settings. 38 | registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, 39 | convergenceMinimumValue=1e-6, convergenceWindowSize=10) 40 | registration_method.SetOptimizerScalesFromPhysicalShift() 41 | # Setup for the multi-resolution framework. 42 | registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) 43 | registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) 44 | registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() 45 | # Don't optimize in-place, we would possibly like to run this cell multiple times. 46 | registration_method.SetInitialTransform(initial_transform, inPlace=False) 47 | 48 | # Apply transformation 49 | final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 50 | sitk.Cast(moving_image, sitk.sitkFloat32)) 51 | 52 | moving_resampled = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, 53 | moving_image.GetPixelID()) 54 | out = sitk.GetArrayFromImage(moving_resampled) 55 | 56 | if mask is None: 57 | return out 58 | else: # Apply transformation to gt mask 59 | mask_mov = sitk.GetImageFromArray(mask) 60 | moving_resampled = sitk.Resample(mask_mov, fixed_image, final_transform, sitk.sitkNearestNeighbor, 61 | 0, mask_mov.GetPixelID()) 62 | out_mask = sitk.GetArrayFromImage(moving_resampled) 63 | 64 | return out, out_mask 65 | 66 | 67 | def preprocess_vol_ICH(vol, mask): 68 | w_level = 40 69 | w_width = 120 70 | 71 | # Intensity normalization 72 | 73 | vol = (vol - (w_level - (w_width / 2))) * (255 / (w_width)) 74 | vol[vol < 0] = 0 75 | vol[vol > 255] = 255 76 | 77 | # Get tissue mask 78 | 79 | tissue_mask = np.ones(vol.shape) 80 | tissue_mask[vol == 0] = 0 81 | tissue_mask[vol == 255] = 0 82 | tissue_mask = ndimage.binary_opening(tissue_mask, structure=np.ones((10, 10, 1))).astype(tissue_mask.dtype) 83 | tissue_mask = ndimage.binary_erosion(tissue_mask, structure=np.ones((5, 5, 1))).astype(tissue_mask.dtype) 84 | 85 | # Keep larger objetc in CT 86 | for iSlice in np.arange(0, vol.shape[-1]): 87 | tissue_mask_i = tissue_mask[:, :, iSlice] 88 | 89 | if np.max(tissue_mask_i) > 0: 90 | 91 | labels = measure.label(tissue_mask_i) 92 | props = measure.regionprops(labels) 93 | 94 | areas = [i_prop.area for i_prop in props] 95 | labels_given = [i_prop.label for i_prop in props] 96 | idx_areas = np.argsort(areas) 97 | 98 | if np.mean(tissue_mask_i[labels == (labels_given[idx_areas[-1]])]) != 0: 99 | label = labels_given[idx_areas[-1]] 100 | else: 101 | label = labels_given[idx_areas[-2]] 102 | 103 | tissue_mask_i = labels == (label) 104 | tissue_mask[:, :, iSlice] = tissue_mask_i 105 | 106 | vol = vol * tissue_mask 107 | mask = mask * tissue_mask 108 | 109 | return vol, mask 110 | 111 | np.random.seed(42) 112 | random.seed(42) 113 | 114 | dir_dataset = '../data/PhysioNet-ICH/' 115 | dir_out = '../data/PhysioNetICH_5slices_registered_new/' 116 | scan = 'CT' 117 | nSlices = 5 118 | partitions = ['train', 'test'] 119 | Ncases = np.array([50, 25]) 120 | 121 | if not os.path.isdir(dir_out): 122 | os.mkdir(dir_out) 123 | if not os.path.isdir(dir_out + '/' + scan + '/'): 124 | os.mkdir(dir_out + '/' + scan + '/') 125 | 126 | cases = os.listdir(dir_dataset + 'ct_scans/') 127 | cases = [dir_dataset + 'ct_scans/' + iCase for iCase in cases if iCase != '.DS_Store'] 128 | 129 | random.shuffle(cases) 130 | 131 | for iPartition in np.arange(0, len(partitions)): 132 | 133 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition]): 134 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition]) 135 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/benign'): 136 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/benign') 137 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/malign'): 138 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/malign') 139 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth'): 140 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth') 141 | 142 | cases_partition = cases[np.sum(Ncases[:iPartition]):np.sum(Ncases[:iPartition+1])] 143 | 144 | # Load volume reference 145 | vol_ref = nib.load('../data/PhysioNet-ICH/ct_scans/071.nii') 146 | vol_ref = (vol_ref.get_fdata()) 147 | mask_ref = nib.load('../data/PhysioNet-ICH/masks/071.nii') 148 | mask_ref = (mask_ref.get_fdata()) 149 | mask_ref[mask_ref > 0] = 255 150 | vol_ref, mask_ref = preprocess_vol_ICH(vol_ref, mask_ref) 151 | 152 | c = 0 153 | for iCase in cases: 154 | c += 1 155 | 156 | img_path = iCase 157 | mask_path = iCase.replace('ct_scans', 'masks') 158 | 159 | # Load volume and mask 160 | img = nib.load(img_path) 161 | img = (img.get_fdata())[:, :, :] 162 | mask = nib.load(mask_path) 163 | mask = (mask.get_fdata()) 164 | mask[mask > 0] = 255 165 | 166 | print(str(c) + '/' + str(len(cases)) + ' || ' + 'slices: ' + str(img.shape[-1])) 167 | 168 | # Preprocess 169 | img, mask = preprocess_vol_ICH(img, mask) 170 | 171 | img = img[:, :, round(img.shape[-1] / 2) - nSlices+5:round(img.shape[-1] / 2) + nSlices+5] 172 | mask = mask[:, :, round(mask.shape[-1] / 2) - nSlices+5:round(mask.shape[-1] / 2) + nSlices+5] 173 | 174 | if np.max(mask) == 255: 175 | part = 'test' 176 | else: 177 | part = 'train' 178 | 179 | mask = mask.astype(np.uint8) 180 | 181 | for iSlice in np.arange(0, nSlices*2): 182 | filename = iCase.split('/')[-1].split('.')[0] + '_' + str(iSlice) + '.jpg' 183 | 184 | i_image = img[:, :, iSlice] 185 | i_mask = mask[:, :, iSlice] 186 | 187 | if np.any(i_mask == 255): 188 | label = 'malign' 189 | cv2.imwrite(dir_out + '/' + scan + '/' + part + '/ground_truth/' + filename, i_mask) 190 | else: 191 | label = 'benign' 192 | 193 | cv2.imwrite(dir_out + '/' + scan + '/' + part + '/' + label + '/' + filename, i_image) 194 | -------------------------------------------------------------------------------- /code/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from datasets.utils import * 4 | 5 | 6 | # ------------------------------------------ 7 | # BRATS dataset 8 | class TestDataset(object): 9 | 10 | def __init__(self, dir_dataset, item, partition, input_shape=(1, 224, 224), channel_first=True, norm='max', 11 | histogram_matching=True, filter_volumes=False): 12 | # Init properties 13 | self.dir_dataset = dir_dataset 14 | self.dir_datasets = dir_dataset 15 | self.item = item 16 | self.partition = partition 17 | self.input_shape = input_shape 18 | self.channel_first = channel_first 19 | self.norm = norm 20 | self.histogram_matching = histogram_matching 21 | self.nchannels = self.input_shape[0] 22 | self.ref_image = {} 23 | self.filter_volumes = filter_volumes 24 | if 'BRATS' in dir_dataset: 25 | for iModality in np.arange(0, len(self.item)): 26 | x = Image.open( 27 | '../data/BRATS_5slices/' + self.item[iModality] + '/train/benign/' + 'BraTS19_CBICA_AWV_1_77.jpg') 28 | x = np.asarray(x) 29 | self.ref_image[iModality] = x[40:-40, 40:-40] 30 | 31 | # Select all files in partition 32 | self.images = [] 33 | for subdirs in ['benign', 'malign']: 34 | for images in os.listdir(self.dir_dataset + self.item[0] + '/' + self.partition + '/' + subdirs + '/'): 35 | if 'Thumbs.db' not in images: 36 | self.images.append(self.dir_dataset + 'modality' + '/' + self.partition + '/' + subdirs + '/' + images) 37 | 38 | # Get number of patients (volumes), and number of slices 39 | if 'BRATS' in dir_dataset: 40 | patients = [image.split('/')[-1][:-7] for image in self.images] 41 | else: 42 | patients = [image.split('/')[-1][:-6] for image in self.images] 43 | self.unique_patients = np.unique(patients) 44 | slices_per_volume = len(patients) // len(self.unique_patients) 45 | 46 | # Load images and masks 47 | self.X = np.zeros((len(self.unique_patients), len(self.item), slices_per_volume, self.nchannels, self.input_shape[1], self.input_shape[2])) 48 | self.M = np.zeros((len(self.unique_patients), slices_per_volume, self.nchannels, self.input_shape[1], self.input_shape[2])) 49 | self.Y = np.zeros((len(self.unique_patients), slices_per_volume)) 50 | for iPatient in np.arange(0, len(self.unique_patients)): 51 | 52 | slices_patient = list(np.sort([iSlice for iSlice in self.images if self.unique_patients[iPatient] in iSlice])) 53 | 54 | if 'ICH' in dir_dataset: 55 | indexes = np.array([int(id[-5]) for id in slices_patient]) 56 | idx = np.array(np.argsort(indexes)) 57 | slices_patient = [slices_patient[i] for i in idx] 58 | 59 | for iSlice in np.arange(0, slices_per_volume): 60 | 61 | for iModality in np.arange(0, len(self.item)): 62 | 63 | # Load image 64 | x = Image.open(slices_patient[iSlice].replace('modality', self.item[iModality])) 65 | x = np.asarray(x) 66 | 67 | # Normalization 68 | if 'BRATS' in dir_dataset: 69 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nchannels, 70 | histogram_matching=self.histogram_matching, reference_image=self.ref_image[iModality], 71 | mask=False, channel_first=True) 72 | else: 73 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nchannels, 74 | histogram_matching=False, 75 | mask=False, channel_first=True) 76 | self.X[iPatient, iModality, iSlice, :, :, :] = x 77 | 78 | # Load mask 79 | if 'malign' in slices_patient[iSlice]: 80 | mask_id = slices_patient[iSlice].replace('malign', 'ground_truth').replace('modality', self.item[iModality]) 81 | 82 | m = Image.open(mask_id) 83 | m = np.asarray(m) 84 | 85 | # Normalization 86 | m = image_normalization(m, self.input_shape[-1], norm=self.norm, channels=1, 87 | histogram_matching=False, 88 | reference_image=None, 89 | mask=True, channel_first=True) 90 | self.M[iPatient, iSlice, :, :, :] = m 91 | self.Y[iPatient, iSlice] = 1 92 | 93 | if self.filter_volumes: 94 | idx = np.squeeze(np.argwhere(np.sum(self.M, (1, 2, 3, 4)) / (slices_per_volume*self.input_shape[1]*self.input_shape[2]) > 0.001)) 95 | self.X = self.X[idx, :, :, :, :, :] 96 | self.M = self.M[idx, :, :, :, :] 97 | self.Y = self.Y[idx, :] 98 | self.images = list(np.array(self.images)[idx]) 99 | self.unique_patients = list(np.array(self.unique_patients)[idx]) 100 | 101 | if len(self.item) == 1: 102 | self.X = self.X[:, 0, :, :, :, :] 103 | 104 | 105 | class MultiModalityDataset(object): 106 | 107 | def __init__(self, dir_datasets, modalities, input_shape=(3, 512, 512), channel_first=True, norm='max', 108 | hist_match=True, weak_supervision=False): 109 | 110 | 'Internal states initialization' 111 | self.dir_datasets = dir_datasets 112 | self.modalities = modalities 113 | self.input_shape = input_shape 114 | self.channel_first = channel_first 115 | self.norm = norm 116 | self.nChannels = input_shape[0] 117 | self.hist_match = hist_match 118 | self.weak_supervision = weak_supervision 119 | self.ref_image = {} 120 | if 'BRATS' in dir_datasets: 121 | for iModality in np.arange(0, len(modalities)): 122 | x = Image.open('../data/BRATS_5slices/' + modalities[iModality] + '/train/benign/' + 'BraTS19_CBICA_AWV_1_77.jpg') 123 | x = np.asarray(x) 124 | self.ref_image[iModality] = x 125 | self.train_images = [] 126 | self.test_images = [] 127 | 128 | # Paths for training data 129 | name_normal = '/train/benign/' 130 | name_anomaly = '/test/malign/' 131 | 132 | # Get train images 133 | train_images = os.listdir(dir_datasets + modalities[0] + name_normal) 134 | # Remove other files 135 | train_images = [train_images[i] for i in range(train_images.__len__()) if train_images[i] != 'Thumbs.db'] 136 | for iImage in train_images: 137 | self.train_images.append(dir_datasets + 'modality' + name_normal + iImage) 138 | 139 | # Get train images 140 | test_images = os.listdir(dir_datasets + modalities[0] + name_anomaly) 141 | # Remove other files 142 | test_images = [test_images[i] for i in range(test_images.__len__()) if test_images[i] != 'Thumbs.db'] 143 | for iImage in test_images: 144 | self.test_images.append(dir_datasets + 'modality' + name_anomaly + iImage) 145 | 146 | self.train_indexes = np.arange(0, len(self.train_images)) 147 | self.test_indexes = np.arange(0, len(self.test_images)) + len(self.train_images) 148 | self.images = self.train_images + self.test_images 149 | 150 | # Pre-allocate images 151 | self.X = np.zeros((len(self.images), len(self.modalities), input_shape[0], input_shape[1], input_shape[2]), dtype=np.float32) 152 | self.M = np.zeros((len(self.images), 1, input_shape[1], input_shape[2]), dtype=np.float32) 153 | self.Y = np.zeros((len(self.images), 2), dtype=np.float32) 154 | 155 | # Load, and normalize images 156 | print('[INFO]: Loading training images...') 157 | for i in np.arange(len(self.images)): 158 | for iModality in np.arange(0, len(self.modalities)): 159 | print(str(i) + '/' + str(len(self.images)), end='\r') 160 | 161 | # Load image 162 | x = Image.open(self.images[i].replace('modality', modalities[iModality])) 163 | x = np.asarray(x) 164 | 165 | # Normalization 166 | if 'BRATS' in dir_datasets: 167 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nChannels, 168 | histogram_matching=self.hist_match, reference_image=self.ref_image[iModality], 169 | mask=False, channel_first=True) 170 | else: 171 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nChannels, 172 | histogram_matching=False, 173 | mask=False, channel_first=True) 174 | self.X[i, iModality, :, :, :] = x 175 | 176 | if 'benign' in self.images[i]: 177 | self.Y[i, :] = np.array([1, 0]) 178 | else: 179 | self.Y[i, :] = np.array([0, 1]) 180 | 181 | mask_id = self.images[i].replace('malign', 'ground_truth').replace('modality', modalities[iModality]) 182 | 183 | y = Image.open(mask_id) 184 | y = np.asarray(y) 185 | if len(y.shape) == 3: 186 | y = y[:, :, 0] 187 | # Normalization 188 | y = image_normalization(y, self.input_shape[-1], norm=self.norm, channels=1, 189 | histogram_matching=False, 190 | reference_image=None, 191 | mask=True, channel_first=True) 192 | self.M[i, :, :, :] = y 193 | 194 | print('[INFO]: Images loaded') 195 | 196 | def __len__(self): 197 | 'Denotes the total number of samples' 198 | return len(self.train_indexes) 199 | 200 | def __getitem__(self, index): 201 | 'Generates one sample of data' 202 | 203 | x = self.X[index, :, :, :, :] 204 | y = self.Y[index, :] 205 | 206 | if len(self.modalities) == 1: 207 | x = x[0, :, :, :] 208 | 209 | return x, y 210 | 211 | # ------------------------------------------ 212 | # MVTEC dataset 213 | 214 | 215 | class MVTECDataset(object): 216 | 217 | def __init__(self, dir_datasets, modalities, input_shape=(3, 512, 512), channel_first=True, norm='max', 218 | weak_supervision=False, partition='train'): 219 | 220 | 'Internal states initialization' 221 | self.dir_datasets = dir_datasets 222 | self.modalities = modalities 223 | self.input_shape = input_shape 224 | self.channel_first = channel_first 225 | self.norm = norm 226 | self.nChannels = input_shape[0] 227 | self.weak_supervision = weak_supervision 228 | self.images = [] 229 | 230 | # Get images 231 | categories = os.listdir(dir_datasets + modalities[0] + '/' + partition + '/') 232 | for i_category in categories: 233 | for iFile in os.listdir(dir_datasets + modalities[0] + '/' + partition + '/' + i_category + '/'): 234 | self.images.append(dir_datasets + modalities[0] + '/' + partition + '/' + i_category + '/' + iFile) 235 | 236 | # Remove other files 237 | self.images = [self.images[i] for i in range(self.images.__len__()) if 'Thumbs.db' not in self.images[i]] 238 | 239 | # Pre-allocate images 240 | self.X = np.zeros((len(self.images), input_shape[0], input_shape[1], input_shape[2]), dtype=np.float32) 241 | self.M = np.zeros((len(self.images), 1, input_shape[1], input_shape[2]), dtype=np.float32) 242 | self.Y = np.zeros((len(self.images), 1), dtype=np.float32) 243 | 244 | # Load, and normalize images 245 | print('[INFO]: Loading training images...') 246 | for i in np.arange(len(self.images)): 247 | print(str(i) + '/' + str(len(self.images)), end='\r') 248 | 249 | # Load image 250 | x = Image.open(self.images[i]) 251 | x = np.asarray(x) 252 | 253 | # Normalization 254 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nChannels, 255 | histogram_matching=False, 256 | mask=False, channel_first=True) 257 | self.X[i, :, :, :] = x 258 | 259 | if 'good' in self.images[i]: 260 | self.Y[i, :] = np.array([0]) 261 | else: 262 | self.Y[i, :] = np.array([1]) 263 | 264 | mask_id = self.images[i].replace(partition, 'ground_truth').replace('.png', '_mask.png') 265 | 266 | y = Image.open(mask_id) 267 | y = np.asarray(y) 268 | if len(y.shape) == 3: 269 | y = y[:, :, 0] 270 | # Normalization 271 | y = image_normalization(y, self.input_shape[-1], norm=self.norm, channels=1, 272 | histogram_matching=False, 273 | reference_image=None, 274 | mask=True, channel_first=True) 275 | self.M[i, :, :, :] = y 276 | 277 | print('[INFO]: Images loaded') 278 | if partition == 'train': 279 | self.train_indexes = np.arange(0, len(self.images)) 280 | 281 | def __len__(self): 282 | 'Denotes the total number of samples' 283 | return len(self.images) 284 | 285 | def __getitem__(self, index): 286 | 'Generates one sample of data' 287 | 288 | x = self.X[index, :, :, :] 289 | y = self.Y[index, :] 290 | 291 | return x, y 292 | 293 | 294 | # ------------------------------------------ 295 | # Data generator 296 | 297 | class WSALDataGenerator(object): 298 | 299 | def __init__(self, dataset, partition, batch_size=16, shuffle=False): 300 | 301 | 'Internal states initialization' 302 | self.dataset = dataset 303 | self.batch_size = batch_size 304 | self.shuffle = shuffle 305 | self.partition = partition 306 | 307 | if self.partition == 'train': 308 | self.indexes = self.dataset.train_indexes.copy() 309 | elif self.partition == 'test': 310 | self.indexes = self.dataset.test_indexes.copy() 311 | 312 | if self.dataset.weak_supervision: 313 | self.indexes_abnormal = self.dataset.test_indexes.copy() 314 | self._idx_abnormal = 0 315 | self.batch_size_anomaly = np.min(np.array((self.batch_size, len(self.dataset.test_indexes)))) 316 | 317 | self._idx = 0 318 | self._reset() 319 | 320 | def __len__(self): 321 | 322 | N = len(self.indexes) 323 | b = self.batch_size 324 | return N // b 325 | 326 | def __iter__(self): 327 | 328 | return self 329 | 330 | def __next__(self): 331 | 332 | # If dataset is completed, stop iterator 333 | if self._idx + self.batch_size >= len(self.indexes): 334 | self._reset() 335 | raise StopIteration() 336 | 337 | if self.dataset.weak_supervision: 338 | if self._idx_abnormal + self.batch_size_anomaly >= len(self.indexes_abnormal): 339 | self._idx_abnormal = 0 340 | 341 | # Load images and include into the batch 342 | X, Y = [], [] 343 | for i in range(self._idx, self._idx + self.batch_size): 344 | x, y = self.dataset.__getitem__(self.indexes[i]) 345 | X.append(x) 346 | Y.append(y) 347 | # Update index iterator 348 | self._idx += self.batch_size 349 | 350 | if self.dataset.weak_supervision: 351 | Xa, Ya = [], [] 352 | for i in range(self._idx_abnormal, self._idx_abnormal + self.batch_size_anomaly): 353 | xa, ya = self.dataset.__getitem__(self.indexes_abnormal[i]) 354 | Xa.append(xa) 355 | Ya.append(ya) 356 | # Update index iterator 357 | self._idx_abnormal += self.batch_size_anomaly 358 | 359 | return np.array(X).astype('float32'), np.array(Y).astype('float32'),\ 360 | np.array(Xa).astype('float32'), np.array(Ya).astype('float32') 361 | else: 362 | return np.array(X).astype('float32'), np.array(Y).astype('float32'),\ 363 | None, None 364 | 365 | def _reset(self): 366 | 367 | if self.shuffle: 368 | random.shuffle(self.indexes) 369 | self._idx = 0 370 | 371 | 372 | -------------------------------------------------------------------------------- /code/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | 5 | from skimage import exposure 6 | from matplotlib import pyplot as plt 7 | import imutils 8 | 9 | 10 | def image_normalization(x, shape, norm='max', channels=3, histogram_matching=False, reference_image=None, 11 | mask=False, channel_first=True): 12 | 13 | # Histogram matching to reference image 14 | if histogram_matching: 15 | x_norm = exposure.match_histograms(x, reference_image) 16 | x_norm[x == 0] = 0 17 | x = x_norm 18 | 19 | # image resize 20 | x = imutils.resize(x, height=shape) 21 | #x = resize_image_canvas(x, shape) 22 | 23 | # Grayscale image -- add channel dimension 24 | if len(x.shape) < 3: 25 | x = np.expand_dims(x, -1) 26 | 27 | if mask: 28 | x = (x > 200) 29 | 30 | # channel first 31 | if channel_first: 32 | x = np.transpose(x, (2, 0, 1)) 33 | if not mask: 34 | if norm == 'max': 35 | x = x / 255.0 36 | elif norm == 'zscore': 37 | x = (x - 127.5) / 127.5 38 | 39 | # numeric type 40 | x.astype('float32') 41 | return x 42 | 43 | 44 | def plot_image(x, y=None, denorm_intensity=False, channel_first=True): 45 | if len(x.shape) < 3: 46 | x = np.expand_dims(x, 0) 47 | # channel first 48 | if channel_first: 49 | x = np.transpose(x, (1, 2, 0)) 50 | if denorm_intensity: 51 | if self.norm == 'zscore': 52 | x = (x*127.5) + 127.5 53 | x = x.astype(int) 54 | 55 | plt.imshow(x) 56 | 57 | if y is not None: 58 | y = np.expand_dims(y[0, :, :], -1) 59 | plt.imshow(y, cmap='jet', alpha=0.1) 60 | 61 | plt.axis('off') 62 | plt.show() 63 | 64 | 65 | def augment_input_batch(batch): 66 | masks = np.zeros(batch.shape) 67 | 68 | for i in np.arange(0, batch.shape[0]): 69 | (batch[i, :, :, :], masks[i, :, :, :]) = augment_input_context(batch[i, :, :, :]) 70 | 71 | return batch, masks 72 | 73 | 74 | def augment_input_context(x): 75 | im = x.copy() 76 | mask = np.zeros(im.shape) 77 | 78 | # Randomize anomaly size 79 | w = random.randint(0, im.shape[2] // 10) 80 | 81 | # Random center-cropping 82 | xx = random.randint(0, im.shape[2] - w) 83 | yy = random.randint(0, im.shape[1] - w) 84 | 85 | # Get intensity 86 | i = np.percentile(im, 99) + np.std(im) 87 | 88 | # Inset anomaly 89 | im[:, xx-w:xx+w, yy-w:yy+w] = 0 90 | mask[:, xx-w:xx+w, yy-w:yy+w] = 1 91 | 92 | # keep only skull 93 | im[x==0] = 0 94 | mask[x == 0] = 0 95 | 96 | return im, mask 97 | 98 | 99 | def resize_image_canvas(image, input_shape, pcn_eat_hz=0.): 100 | 101 | # cut a bit on the sides 102 | if pcn_eat_hz > 0: 103 | h, w = image.shape 104 | px_eat_hz = int(pcn_eat_hz * w) 105 | image = image[:, px_eat_hz:w - px_eat_hz] 106 | 107 | h, w = image.shape 108 | ratio_h = input_shape[1] / h 109 | ratio_w = input_shape[2] / w 110 | 111 | img_res = np.zeros((input_shape[1], input_shape[2]), dtype=image.dtype) 112 | 113 | if ratio_w > ratio_h: 114 | img_res_h = imutils.resize(image, height=input_shape[1]) 115 | left_margin = (input_shape[2] - img_res_h.shape[1]) // 2 116 | img_res[:, left_margin:left_margin + img_res_h.shape[1]] = img_res_h 117 | else: 118 | img_res_w = imutils.resize(image, width=input_shape[2]) 119 | top_margin = (input_shape[1] - img_res_w.shape[0]) // 2 120 | img_res[top_margin:top_margin + img_res_w.shape[0], :] = img_res_w 121 | 122 | return img_res 123 | 124 | -------------------------------------------------------------------------------- /code/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import precision_recall_curve 2 | from sklearn.metrics import f1_score 3 | from sklearn.metrics import auc 4 | import numpy as np 5 | 6 | 7 | def dice(true_mask, pred_mask, non_seg_score=1.0): 8 | 9 | assert true_mask.shape == pred_mask.shape 10 | 11 | true_mask = np.asarray(true_mask).astype(np.bool) 12 | pred_mask = np.asarray(pred_mask).astype(np.bool) 13 | 14 | # If both segmentations are all zero, the dice will be 1. (Developer decision) 15 | im_sum = true_mask.sum() + pred_mask.sum() 16 | if im_sum == 0: 17 | return non_seg_score 18 | 19 | # Compute Dice coefficient 20 | intersection = np.logical_and(true_mask, pred_mask) 21 | return 2. * intersection.sum() / im_sum 22 | 23 | 24 | def au_prc(true_mask, pred_mask): 25 | 26 | # Calculate pr curve and its area 27 | precision, recall, threshold = precision_recall_curve(true_mask, pred_mask) 28 | au_prc = auc(recall, precision) 29 | 30 | # Search the optimum point and obtain threshold via f1 score 31 | f1 = 2 * (precision * recall) / (precision + recall) 32 | f1[np.isnan(f1)] = 0 33 | 34 | th = threshold[np.argmax(f1)] 35 | 36 | return au_prc, th -------------------------------------------------------------------------------- /code/evaluation/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import sklearn.metrics 2 | import imutils 3 | import cv2 4 | import os 5 | import matplotlib.pyplot as plt 6 | 7 | from evaluation.metrics import * 8 | 9 | 10 | def inference_dataset(method, dataset): 11 | 12 | # Take references and inputs from testing dataset 13 | X = dataset.X 14 | M = dataset.M 15 | Y = dataset.Y 16 | 17 | if 'BRATS' in dataset.dir_datasets or 'PhysioNet' in dataset.dir_datasets: # Inference is volume-wise 18 | 19 | if len(X.shape) > 5: 20 | X = X[:, 0, :, :, :, :] 21 | 22 | if len(M.shape) < 5: 23 | M = np.expand_dims(M, 1) 24 | 25 | # Init variables 26 | (p, s, c, h, w) = X.shape # maps dimensions 27 | Mhat = np.zeros(M.shape) # Predicted segmentation maps 28 | Xhat = np.zeros(X.shape) # Reconstructed images 29 | Scores = np.zeros((p, s)) 30 | 31 | for iVolume in np.arange(0, p): 32 | for iSlice in np.arange(0, s): 33 | # Take image 34 | x = X[iVolume, iSlice, :, :, :] 35 | # Predict anomaly map and score 36 | score, mhat, xhat = method.predict_score(x) 37 | 38 | Mhat[iVolume, iSlice, :, :, :] = mhat 39 | Xhat[iVolume, iSlice, :, :, :] = xhat 40 | Scores[iVolume, iSlice] = score 41 | 42 | elif 'MVTEC' in dataset.dir_datasets: # Inference is image-wise 43 | 44 | # Init variables 45 | (cases, c, h, w) = X.shape # maps dimensions 46 | Mhat = np.zeros(M.shape) # Predicted segmentation maps 47 | Xhat = np.zeros(X.shape) # Reconstructed images 48 | Scores = np.zeros((cases, 1)) 49 | 50 | for iCase in np.arange(0, cases): 51 | # Take image 52 | x = X[iCase, :, :, :] 53 | # Predict anomaly map and score 54 | score, mhat, xhat = method.predict_score(x) 55 | 56 | Mhat[iCase, :, :, :] = mhat 57 | Xhat[iCase, :, :, :] = xhat 58 | Scores[iCase, :] = score 59 | 60 | return Y, Scores, M, Mhat, X, Xhat 61 | 62 | 63 | def evaluate_anomaly_detection(y, scores, dir_out='', range=[-1, 1], tit='cosine similarity', bins=50, th=None): 64 | 65 | scores = np.ravel(scores) 66 | y = np.ravel(y) 67 | 68 | auroc = sklearn.metrics.roc_auc_score(y, scores) # au_roc 69 | auprc, th_op = au_prc(y, scores) # au_prc 70 | 71 | if th is None: 72 | th = th_op 73 | 74 | if dir_out != '': 75 | plt.hist(np.ravel(scores)[np.ravel(y) == 1], bins=bins, range=range, fc=[0.7, 0, 0, 0.5]) 76 | plt.hist(np.ravel(scores)[np.ravel(y) == 0], bins=bins, range=range, fc=[0, 0, 0.7, 0.5]) 77 | plt.legend(['Anomaly', 'Normal']) 78 | plt.xlabel(tit) 79 | 80 | plt.savefig(dir_out + 'anomaly_detection.png') 81 | plt.close('all') 82 | 83 | return auroc, auprc, th 84 | 85 | 86 | def evaluate_anomaly_localization(dataset, save_maps=False, dir_out='', filter_volumes=True, th=None): 87 | 88 | print('[INFO]: Testing...') 89 | if save_maps: 90 | if not os.path.isdir(dir_out + 'masks_predicted/'): 91 | os.mkdir(dir_out + 'masks_predicted/') 92 | if not os.path.isdir(dir_out + 'masks_reference'): 93 | os.mkdir(dir_out + 'masks_reference/') 94 | if not os.path.isdir(dir_out + 'xhat_predicted'): 95 | os.mkdir(dir_out + 'xhat_predicted/') 96 | 97 | # Get references and predictions from dataset 98 | M = dataset.M 99 | Mhat = dataset.Mhat 100 | X = dataset.X 101 | Xhat = dataset.Xhat 102 | 103 | if 'BRATS' in dataset.dir_datasets or 'PhysioNet' in dataset.dir_datasets: # Inference is volume-wise 104 | 105 | if len(X.shape) > 5: 106 | X = X[:, 0, :, :, :, :] 107 | 108 | if filter_volumes: # Filver volumes withouth annotations 109 | idx = np.squeeze(np.argwhere(np.sum(M, (1, 2, 3, 4)) / (M.shape[1] * M.shape[-1] * M.shape[-1]) > 0.001)) 110 | X = X[idx, :, :, :, :] 111 | Xhat = Xhat[idx, :, :, :, :] 112 | M = M[idx, :, :, :, :] 113 | Mhat = Mhat[idx, :, :, :, :] 114 | unique_patients = list(np.array(dataset.unique_patients)[idx]) 115 | 116 | # Obtain overall metrics and optimum point threshold 117 | AU_ROC = sklearn.metrics.roc_auc_score(M.flatten() == 1, Mhat.flatten()) # au_roc 118 | AU_PRC, th_op = au_prc(M.flatten() == 1, Mhat.flatten()) # au_prc 119 | 120 | if th is None: 121 | th = th_op 122 | 123 | # Apply threshold 124 | DICE = dice(M.flatten() == 1, (Mhat > th).flatten()) # Dice 125 | IoU = sklearn.metrics.jaccard_score(M.flatten() == 1, (Mhat > th).flatten()) # IoU 126 | 127 | if 'BRATS' in dataset.dir_datasets or 'PhysioNet' in dataset.dir_datasets: # Inference is volume-wise 128 | 129 | # Once the threshold is obtained calculate volume-level metrics and plot results 130 | patient_dice = [] 131 | (p, s, c, h, w) = X.shape # maps dimensions 132 | for iVolume in np.arange(0, p): 133 | patient_dice.append(dice(M[iVolume, :, :, :, :].flatten(), (Mhat[iVolume, :, :, :, :] > th).flatten())) 134 | 135 | if save_maps and dir_out!= '': # Save slices' masks 136 | for iSlice in np.arange(0, s): 137 | id = unique_patients[iVolume] + '_' + str(iSlice) + '.jpg' 138 | 139 | # Obtain heatmaps for predicted and reference 140 | m_i = imutils.rotate_bound(np.uint8(M[iVolume, iSlice, 0, :, :] * 255), 270) 141 | heatmap_m = cv2.applyColorMap(m_i, cv2.COLORMAP_JET) 142 | mhat_i = imutils.rotate_bound(np.uint8((Mhat[iVolume, iSlice, 0, :, :] > th) * 255), 270) 143 | heatmap_mhat = cv2.applyColorMap(mhat_i, cv2.COLORMAP_JET) 144 | heatmap_mhat = heatmap_mhat * ( 145 | np.expand_dims(imutils.rotate_bound(Mhat[iVolume, iSlice, 0, :, :], 270), -1) > 0) 146 | 147 | # Move grayscale image to three channels 148 | xh = cv2.cvtColor(np.uint8(np.squeeze(X[iVolume, iSlice, :, :, :]) * 255), cv2.COLOR_GRAY2RGB) 149 | xh = imutils.rotate_bound(xh, 270) 150 | 151 | # Combine original image and masks 152 | fin_mask = cv2.addWeighted(xh, 0.7, heatmap_m, 0.3, 0) 153 | fin_predicted = cv2.addWeighted(xh, 0.7, heatmap_mhat, 0.3, 0) 154 | 155 | fin_predicted = mhat_i 156 | fin_mask = m_i 157 | 158 | cv2.imwrite(dir_out + 'masks_predicted/' + id, fin_predicted) 159 | cv2.imwrite(dir_out + 'masks_reference/' + id, fin_mask) 160 | cv2.imwrite(dir_out + 'xhat_predicted/' + id, np.uint8(Xhat[iVolume, iSlice, 0, :, :] * 255)) 161 | 162 | DICE_mu = np.mean(patient_dice) 163 | DICE_std = np.std(patient_dice) 164 | 165 | if 'MVTEC' in dataset.dir_datasets: # Inference is image-wise 166 | 167 | # Once the threshold is obtained calculate volume-level metrics and plot results 168 | case_dice = [] 169 | (cases, c, h, w) = X.shape # maps dimensions 170 | for iCase in np.arange(0, cases): 171 | if 'good' not in dataset.images[iCase]: 172 | case_dice.append(dice(M[iCase, :, :, :].flatten(), (Mhat[iCase, :, :, :] > th).flatten())) 173 | 174 | if save_maps and dir_out != '': # Save slices' masks 175 | id = dataset.images[iCase].replace('.png', '.jpg').split('/')[-2] + '_' + \ 176 | dataset.images[iCase].replace('.png', '.jpg').split('/')[-1] 177 | 178 | # Obtain heatmaps for predicted and reference 179 | m_i = np.uint8(M[iCase, 0, :, :] * 255) 180 | heatmap_m = cv2.applyColorMap(m_i, cv2.COLORMAP_JET) 181 | mhat_i = np.uint8((Mhat[iCase, 0, :, :] > th) * 255) 182 | heatmap_mhat = cv2.applyColorMap(mhat_i, cv2.COLORMAP_JET) 183 | 184 | # Move grayscale image to three channels 185 | xh = np.uint8(np.squeeze(X[iCase, :, :, :]) * 255) 186 | xh = np.transpose(xh, (1, 2, 0)) 187 | 188 | # Combine original image and masks 189 | fin_mask = cv2.addWeighted(xh, 0.7, heatmap_m, 0.3, 0) 190 | fin_predicted = cv2.addWeighted(xh, 0.7, heatmap_mhat, 0.3, 0) 191 | 192 | cv2.imwrite(dir_out + 'masks_predicted/' + id, fin_predicted) 193 | cv2.imwrite(dir_out + 'masks_reference/' + id, fin_mask) 194 | cv2.imwrite(dir_out + 'xhat_predicted/' + id, np.uint8(Xhat[iCase, 0, :, :] * 255)) 195 | 196 | DICE_mu = np.mean(case_dice) 197 | DICE_std = np.std(case_dice) 198 | 199 | metrics = {'AU_ROC': AU_ROC, 'AU_PRC': AU_PRC, 'DICE': DICE, 'IoU': IoU, 200 | 'DICE_mu': DICE_mu, 'DICE_std': DICE_std} 201 | 202 | return metrics, th -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | from datasets.datasets import TestDataset, MultiModalityDataset, WSALDataGenerator, MVTECDataset 7 | from methods.train import AnomalyDetectorTrainer 8 | from evaluation.utils import * 9 | 10 | 11 | def main(args): 12 | 13 | exp = {"dir_datasets": args.dir_datasets, "dir_out": args.dir_out, "load_weigths": args.load_weigths, 14 | "epochs": args.epochs, "item": args.item, "method": args.method, "input_shape": args.input_shape, 15 | "batch_size": args.batch_size, "lr": args.learning_rate, "zdim": args.zdim, "images_on_ram": True, 16 | "wkl": args.wkl, "wr": 1, "wadv": args.wadv, "wae": args.wae, "epochs_to_test": 50, "dense": args.dense, 17 | "channel_first": True, "normalization_cam": args.normalization_cam, "avg_grads": True, "t": args.t, 18 | "context": args.context, "level_cams": args.level_cams, "p_activation_cam": args.p_activation_cam, 19 | "bayesian": args.bayesian, "loss_reconstruction": "bce", 20 | "expansion_loss_penalty": args.expansion_loss_penalty, "restoration": args.restoration, 21 | "n_blocks": args.n_blocks, "alpha_entropy": args.wH} 22 | 23 | if not os.path.isdir('../data/results/'): 24 | os.mkdir('../data/results/') 25 | 26 | metrics = [] 27 | for iteration in [0, 1, 2]: 28 | 29 | if 'BRATS' in exp['dir_datasets'] or 'PhysioNet' in exp['dir_datasets']: 30 | 31 | # Set test 32 | test_dataset = TestDataset(exp['dir_datasets'], item=exp['item'], partition='val', 33 | input_shape=exp['input_shape'], 34 | channel_first=True, norm='max', histogram_matching=True) 35 | 36 | # Set train data loader 37 | dataset = MultiModalityDataset(exp['dir_datasets'], exp['item'], input_shape=exp['input_shape'], 38 | channel_first=exp['channel_first'], norm='max', hist_match=True) 39 | 40 | train_generator = WSALDataGenerator(dataset, partition='train', batch_size=exp['batch_size'], shuffle=True) 41 | 42 | elif 'MVTEC' in exp['dir_datasets']: 43 | 44 | # Set test 45 | test_dataset = MVTECDataset(exp['dir_datasets'], exp['item'], input_shape=exp['input_shape'], 46 | channel_first=exp['channel_first'], norm='max', 47 | partition='test') 48 | 49 | # Set train data loader 50 | dataset = MVTECDataset(exp['dir_datasets'], exp['item'], input_shape=exp['input_shape'], 51 | channel_first=exp['channel_first'], norm='max', 52 | partition='train') 53 | 54 | train_generator = WSALDataGenerator(dataset, partition='train', batch_size=exp['batch_size'], shuffle=True) 55 | 56 | # Set trainer and train model 57 | trainer = AnomalyDetectorTrainer(exp['dir_out'], exp['method'], item=exp['item'], zdim=exp['zdim'], 58 | dense=exp['dense'], n_blocks=exp['n_blocks'], 59 | lr=exp['lr'], input_shape=exp['input_shape'], load_weigths=exp['load_weigths'], 60 | epochs_to_test=exp['epochs_to_test'], context=exp['context'], 61 | bayesian=exp['bayesian'], restoration=exp['restoration'], 62 | loss_reconstruction=exp['loss_reconstruction'], 63 | level_cams=exp['level_cams'], 64 | expansion_loss_penalty=exp['expansion_loss_penalty'], 65 | iteration=iteration, 66 | alpha_kl=exp['wkl'], alpha_entropy=exp["alpha_entropy"], 67 | p_activation_cam=exp["p_activation_cam"], 68 | alpha_ae=exp["wae"], t=exp["t"]) 69 | 70 | # Save experiment setup 71 | with open(exp['dir_out'] + 'setup.json', 'w') as fp: 72 | json.dump(exp, fp) 73 | 74 | if not args.only_test: 75 | # Train 76 | trainer.train(train_generator, exp['epochs'], test_dataset) 77 | else: 78 | 79 | trainer.method.train_generator = train_generator 80 | trainer.method.dataset_test = test_dataset 81 | 82 | thresholod_with_percentile = False 83 | thresholod_with_valsubset = False 84 | if thresholod_with_percentile: 85 | # Predictions on normal dataset 86 | Y_t, Scores_t, M_t, Mhat_t, X_t, Xhat_t = inference_dataset(trainer.method, dataset) 87 | th = np.percentile(np.ravel(Mhat_t), 99) 88 | elif thresholod_with_valsubset: 89 | val_dataset = TestDataset(exp['dir_datasets'], item=exp['item'], partition='val', 90 | input_shape=exp['input_shape'], 91 | channel_first=True, norm='max', histogram_matching=True) 92 | # Make predictions 93 | _, Scores, _, Mhat, _, Xhat = inference_dataset(trainer.method, val_dataset) 94 | # Input to dataset 95 | val_dataset.Scores = Scores 96 | val_dataset.Mhat = Mhat 97 | val_dataset.Xhat = Xhat 98 | # Get threshold 99 | _, th = evaluate_anomaly_localization(val_dataset, save_maps=False, 100 | dir_out=trainer.method.dir_results, 101 | th=None) 102 | else: 103 | th = None 104 | 105 | # Make predictions 106 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(trainer.method, test_dataset) 107 | 108 | # Input to dataset 109 | test_dataset.Scores = Scores 110 | test_dataset.Mhat = Mhat 111 | test_dataset.Xhat = Xhat 112 | 113 | metrics_i, th_i = evaluate_anomaly_localization(test_dataset, save_maps=False, 114 | dir_out=trainer.method.dir_results, 115 | th=th) 116 | print(metrics_i) 117 | trainer.method.metrics = metrics_i 118 | 119 | # Save overall metrics 120 | metrics.append(list(trainer.method.metrics.values())) 121 | 122 | # Compute average performance and save performance in dictionary 123 | metrics = np.array(metrics) 124 | metrics_mu = np.mean(metrics, 0) 125 | metrics_std = np.std(metrics, 0) 126 | 127 | labels = list(trainer.method.metrics.keys()) 128 | metrics_mu = {labels[i]: metrics_mu[i] for i in range(0, len(labels))} 129 | metrics_std = {labels[i]: metrics_std[i] for i in range(0, len(labels))} 130 | 131 | with open(exp['dir_out'] + exp['item'][0] + '/' + 'metrics_avg_val.json', 'w') as fp: 132 | json.dump(metrics_mu, fp) 133 | with open(exp['dir_out'] + exp['item'][0] + '/' + 'metrics_std_val.json', 'w') as fp: 134 | json.dump(metrics_std, fp) 135 | 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser() 139 | # Settings 140 | parser.add_argument("--dir_datasets", default="../data/BRATS_5slices/", type=str) 141 | parser.add_argument("--dir_out", default="../data/gradCAMCons/tests/", type=str) 142 | parser.add_argument("--method", default="gradCAMCons", type=str) 143 | parser.add_argument("--item", default=["flair"], type=list, nargs="+") 144 | parser.add_argument("--load_weigths", default=False, type=bool) 145 | parser.add_argument("--only_test", default=False, type=bool) 146 | # Hyper-params training 147 | parser.add_argument("--input_shape", default=[1, 224, 224], type=list) 148 | parser.add_argument("--learning_rate", default=1e-5, type=float) 149 | parser.add_argument("--epochs", default=300, type=int) 150 | parser.add_argument("--batch_size", default=8, type=int) 151 | # Auto-encoder architecture 152 | parser.add_argument("--zdim", default=32, type=int) 153 | parser.add_argument("--dense", default=True, type=bool) 154 | parser.add_argument("--n_blocks", default=4, type=int) 155 | # Residual-based inference options 156 | parser.add_argument("--restoration", default=False, type=bool) 157 | parser.add_argument("--bayesian", default=False, type=bool) 158 | parser.add_argument("--context", default=False, type=bool) 159 | # Settings with variational AE 160 | parser.add_argument("--wkl", default=1, type=float) 161 | # Settings with discriminator 162 | parser.add_argument("--wadv", default=0., type=float) 163 | # GradCAMCons 164 | parser.add_argument("--wae", default=1e4, type=float) 165 | parser.add_argument("--p_activation_cam", default=1e-2, type=float) 166 | parser.add_argument("--expansion_loss_penalty", default="log_barrier", type=str) 167 | parser.add_argument("--t", default=10, type=int) 168 | parser.add_argument("--normalization_cam", default="sigm", type=str) 169 | # AMCons 170 | parser.add_argument("--wH", default=0., type=float) 171 | parser.add_argument("--level_cams", default=-4, type=float) 172 | 173 | args = parser.parse_args() 174 | main(args) 175 | -------------------------------------------------------------------------------- /code/methods/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def kl_loss(mu, logvar): 6 | 7 | kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 8 | 9 | return kl_divergence 10 | 11 | 12 | def log_barrier(z, t=5): 13 | 14 | # Only one value 15 | if z.shape[0] == 1: 16 | 17 | if z <= - 1 / t ** 2: 18 | log_barrier_loss = - torch.log(-z) / t 19 | else: 20 | log_barrier_loss = t * z + -np.log(1 / (t ** 2)) / t + 1 / t 21 | 22 | # Constrain over multiple values 23 | else: 24 | log_barrier_loss = torch.tensor(0).cuda().float() 25 | for i in np.arange(0, z.shape[0]): 26 | zi = z[i, 0] 27 | if zi <= - 1 / t ** 2: 28 | log_barrier_loss += - torch.log(-zi) / t 29 | else: 30 | log_barrier_loss += t * zi + -np.log(1 / (t ** 2)) / t + 1 / t 31 | 32 | return log_barrier_loss 33 | -------------------------------------------------------------------------------- /code/methods/losses/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/methods/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | # Import different methods 5 | from methods.trainers.ae import * 6 | from methods.trainers.vae import * 7 | from methods.trainers.anoVAEGAN import * 8 | from methods.trainers.fanoGAN import * 9 | from methods.trainers.AMCons import * 10 | from methods.trainers.gradCAMCons import * 11 | from methods.trainers.histEqualization import * 12 | 13 | torch.autograd.set_detect_anomaly(True) 14 | 15 | 16 | class AnomalyDetectorTrainer: 17 | def __init__(self, dir_out, method, item=['flair'], zdim=32, dense=True, variational=False, n_blocks=5, lr=1*1e-4, 18 | input_shape=(1, 224, 224), load_weigths=False, epochs_to_test=10, context=False, bayesian=False, 19 | restoration=False, loss_reconstruction='bce', iteration=0, level_cams=-4, 20 | alpha_kl=10, alpha_entropy=0., expansion_loss_penalty='log_barrier', p_activation_cam=0.2, t=25, 21 | alpha_ae=10): 22 | 23 | # Init input variables 24 | self.dir_out = dir_out 25 | self.method = method 26 | self.item = item 27 | self.zdim = zdim 28 | self.dense = dense 29 | self.variational = variational 30 | self.n_blocks = n_blocks 31 | self.input_shape = input_shape 32 | self.load_weights = load_weigths 33 | self.epochs_to_test = epochs_to_test 34 | self.context = context 35 | self.bayesian = bayesian 36 | self.restoration = restoration 37 | self.loss_reconstruction = loss_reconstruction 38 | self.lr = lr 39 | self.level_cams = level_cams 40 | self.expansion_loss_penalty = expansion_loss_penalty 41 | self.alpha_kl = alpha_kl 42 | self.alpha_entropy = alpha_entropy 43 | self.p_activation_cam = p_activation_cam 44 | self.t = t 45 | self.alpha_ae = alpha_ae 46 | 47 | # Prepare results folders 48 | self.dir_results = dir_out + item[0] + '/iteration_' + str(iteration) + str('/') 49 | if not os.path.isdir(dir_out): 50 | os.mkdir(dir_out) 51 | if not os.path.isdir(dir_out + item[0] + '/'): 52 | os.mkdir(dir_out + item[0] + '/') 53 | if not os.path.isdir(self.dir_results): 54 | os.mkdir(self.dir_results) 55 | 56 | # Create trainer 57 | if self.method == 'ae': 58 | self.method = AnomalyDetectorAE(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr, 59 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test, 60 | load_weigths=self.load_weights, n_blocks=self.n_blocks, dense=self.dense, 61 | context=self.context, bayesian=self.bayesian, 62 | loss_reconstruction=self.loss_reconstruction, restoration=self.restoration) 63 | elif self.method == 'vae': 64 | self.method = AnomalyDetectorVAE(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr, 65 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test, 66 | load_weigths=self.load_weights, n_blocks=self.n_blocks, 67 | dense=self.dense, 68 | context=self.context, bayesian=self.bayesian, 69 | loss_reconstruction=self.loss_reconstruction, 70 | restoration=self.restoration, 71 | alpha_kl=self.alpha_kl) 72 | elif self.method == 'anoVAEGAN': 73 | self.method = AnomalyDetectorAnoVAEGAN(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr, 74 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test, 75 | load_weigths=self.load_weights, n_blocks=self.n_blocks, 76 | dense=self.dense, 77 | context=self.context, bayesian=self.bayesian, 78 | loss_reconstruction=self.loss_reconstruction, 79 | restoration=self.restoration, 80 | alpha_kl=self.alpha_kl) 81 | elif self.method == 'fanoGAN': 82 | self.method = AnomalyDetectorFanoGAN(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr, 83 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test, 84 | load_weigths=self.load_weights, n_blocks=self.n_blocks, 85 | dense=self.dense, 86 | context=self.context, bayesian=self.bayesian, 87 | loss_reconstruction=self.loss_reconstruction, 88 | restoration=self.restoration) 89 | elif self.method == 'gradCAMCons': 90 | self.method = AnomalyDetectorGradCamCons(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr, 91 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test, 92 | load_weigths=self.load_weights, n_blocks=self.n_blocks, 93 | dense=self.dense, loss_reconstruction=self.loss_reconstruction, 94 | pre_training_epochs=50, level_cams=self.level_cams, t=self.t, 95 | p_activation_cam=self.p_activation_cam, 96 | expansion_loss_penalty='log_barrier', alpha_ae=self.alpha_ae, 97 | alpha_kl=self.alpha_kl) 98 | elif self.method == 'camCons': 99 | self.method = AnomalyDetectorAMCons(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr, 100 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test, 101 | load_weigths=self.load_weights, n_blocks=self.n_blocks, 102 | dense=self.dense, loss_reconstruction=self.loss_reconstruction, 103 | pre_training_epochs=0, level_cams=self.level_cams, 104 | alpha_entropy=self.alpha_entropy, 105 | alpha_kl=self.alpha_kl) 106 | elif self.method == 'histEqualization': 107 | self.method = AnomalyDetectorHistEqualization(self.dir_results, item=self.item) 108 | 109 | else: 110 | print('Uncorrect specified method... ', end='\n') 111 | 112 | def train(self, train_generator, epochs, dataset_test): 113 | self.method.train(train_generator, epochs, dataset_test) -------------------------------------------------------------------------------- /code/methods/trainers/AMCons.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import datetime 5 | import kornia 6 | import json 7 | import torch 8 | 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | 12 | from scipy import ndimage 13 | from timeit import default_timer as timer 14 | from models.models import Encoder, Decoder 15 | from evaluation.utils import * 16 | from methods.losses.losses import kl_loss 17 | from methods.losses.losses import log_barrier 18 | from sklearn.metrics import accuracy_score, f1_score 19 | 20 | 21 | class AnomalyDetectorAMCons: 22 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25, 23 | load_weigths=False, n_blocks=5, dense=True, loss_reconstruction='bce', alpha_kl=1, 24 | pre_training_epochs=0, level_cams=-4, alpha_entropy=1, gap=False): 25 | 26 | # Init input variables 27 | self.dir_results = dir_results 28 | self.item = item 29 | self.zdim = zdim 30 | self.lr = lr 31 | self.input_shape = input_shape 32 | self.epochs_to_test = epochs_to_test 33 | self.load_weigths = load_weigths 34 | self.n_blocks = n_blocks 35 | self.dense = dense 36 | self.loss_reconstruction = loss_reconstruction 37 | self.alpha_kl = alpha_kl 38 | self.pre_training_epochs = pre_training_epochs 39 | self.level_cams = level_cams 40 | self.alpha_entropy = alpha_entropy 41 | self.gap = gap 42 | 43 | # Init network 44 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks, 45 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=True, gap=gap) 46 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0], 47 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks, 48 | gap=gap) 49 | 50 | if torch.cuda.is_available(): 51 | self.E.cuda() 52 | self.Dec.cuda() 53 | 54 | if self.load_weigths: 55 | self.E.load_state_dict(torch.load(self.dir_results + 'encoder_weights.pth')) 56 | self.Dec.load_state_dict(torch.load(self.dir_results + 'decoder_weights.pth')) 57 | 58 | # Set parameters 59 | self.params = list(self.E.parameters()) + list(self.Dec.parameters()) 60 | 61 | # Set losses 62 | if self.loss_reconstruction == 'l2': 63 | self.Lr = torch.nn.MSELoss(reduction='sum') 64 | elif self.loss_reconstruction == 'bce': 65 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum') 66 | 67 | self.Lkl = kl_loss 68 | 69 | # Set optimizers 70 | self.opt = torch.optim.Adam(self.params, lr=self.lr) 71 | 72 | # Init additional variables and objects 73 | self.epochs = 0. 74 | self.iterations = 0. 75 | self.init_time = 0. 76 | self.lr_iteration = 0. 77 | self.lr_epoch = 0. 78 | self.kl_iteration = 0. 79 | self.kl_epoch = 0. 80 | self.H_iteration = 0. 81 | self.H_epoch = 0. 82 | self.i_epoch = 0. 83 | self.train_generator = [] 84 | self.dataset_test = [] 85 | self.metrics = {} 86 | self.aucroc_lc = [] 87 | self.auprc_lc = [] 88 | self.auroc_det = [] 89 | self.lr_lc = [] 90 | self.lkl_lc = [] 91 | self.lae_lc = [] 92 | self.H_lc = [] 93 | self.auroc_det_lc = [] 94 | self.refCam = 0. 95 | 96 | def train(self, train_generator, epochs, test_dataset): 97 | self.epochs = epochs 98 | self.init_time = timer() 99 | self.train_generator = train_generator 100 | self.dataset_test = test_dataset 101 | self.iterations = len(self.train_generator) 102 | 103 | # Loop over epochs 104 | for self.i_epoch in range(self.epochs): 105 | # init epoch losses 106 | self.lr_epoch = 0 107 | self.kl_epoch = 0. 108 | self.H_epoch = 0. 109 | 110 | # Loop over training dataset 111 | for self.i_iteration, (x_n, y_n, x_a, y_a) in enumerate(self.train_generator): 112 | #p = q 113 | 114 | # brain mask 115 | if 'BRATS' in train_generator.dataset.dir_datasets or\ 116 | 'PhysioNet' in train_generator.dataset.dir_datasets: 117 | x_mask = 1 - np.mean((x_n == 0).astype(np.int), 0) 118 | if 'BRATS' in train_generator.dataset.dir_datasets: 119 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 6, 6))).astype(x_mask.dtype) 120 | elif 'MVTEC' in train_generator.dataset.dir_datasets: 121 | x_mask = np.zeros((1, 224, 224)) 122 | x_mask[:, 14:-14, 14:-14] = 1 123 | 124 | # Move tensors to gpu 125 | x_n = torch.tensor(x_n).cuda().float() 126 | 127 | # Obtain latent space from normal sample via encoder 128 | z, z_mu, z_logvar, allF = self.E(x_n) 129 | 130 | # Obtain reconstructed images through decoder 131 | xhat, _ = self.Dec(z) 132 | if self.loss_reconstruction == 'l2': 133 | xhat = torch.sigmoid(xhat) 134 | 135 | # Calculate criterion 136 | self.lr_iteration = self.Lr(xhat, x_n) / (self.train_generator.batch_size) # Reconstruction loss 137 | self.kl_iteration = self.Lkl(mu=z_mu, logvar=z_logvar) / (self.train_generator.batch_size) # kl loss 138 | 139 | # Init overall losses 140 | L = self.lr_iteration + self.alpha_kl * self.kl_iteration 141 | 142 | # ---- Compute Attention Homogeneization loss via Entropy 143 | 144 | am = torch.mean(allF[self.level_cams], 1) 145 | 146 | # Restore original shape 147 | am = torch.nn.functional.interpolate(am.unsqueeze(1), 148 | size=(self.input_shape[-1], self.input_shape[-1]), 149 | mode='bilinear', 150 | align_corners=True) 151 | am = am.view((am.shape[0], -1)) 152 | 153 | # Prepare mask with brain 154 | if 'BRATS' in train_generator.dataset.dir_datasets or\ 155 | 'MVTEC' in train_generator.dataset.dir_datasets or\ 156 | 'PhysioNet' in train_generator.dataset.dir_datasets: 157 | 158 | x_mask = np.ravel(x_mask) 159 | x_mask = torch.tensor(np.array(np.argwhere(x_mask > 0.5))).cuda().squeeze() 160 | am = torch.index_select(am, dim=1, index=x_mask) 161 | 162 | # Probabilities 163 | p = torch.nn.functional.softmax(am.view((am.shape[0], -1)), dim=-1) 164 | # Mean entropy 165 | self.H_iteration = torch.mean(-torch.sum(p * torch.log(p + 1e-12), dim=(-1))) 166 | 167 | if self.i_epoch > self.pre_training_epochs: 168 | 169 | if self.alpha_entropy > 0: 170 | # Entropy Maximization 171 | L += - self.alpha_entropy * self.H_iteration 172 | 173 | # Update weights 174 | L.backward() # Backward 175 | self.opt.step() # Update weights 176 | self.opt.zero_grad() # Clear gradients 177 | 178 | """ 179 | ON ITERATION/EPOCH END PROCESS 180 | """ 181 | 182 | # Display losses per iteration 183 | self.display_losses(on_epoch_end=False) 184 | 185 | # Update epoch's losses 186 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator) 187 | self.kl_epoch += self.kl_iteration.cpu().detach().numpy() / len(self.train_generator) 188 | self.H_epoch += self.H_iteration.cpu().detach().numpy() / len(self.train_generator) 189 | 190 | # Epoch-end processes 191 | self.on_epoch_end() 192 | 193 | def on_epoch_end(self): 194 | 195 | # Display losses 196 | self.display_losses(on_epoch_end=True) 197 | 198 | # Update learning curves 199 | self.lr_lc.append(self.lr_epoch) 200 | self.lkl_lc.append(self.kl_epoch) 201 | self.H_lc.append(self.H_epoch) 202 | 203 | # Each x epochs, test models and plot learning curves 204 | if (self.i_epoch + 1) % self.epochs_to_test == 0: 205 | # Save weights 206 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth') 207 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth') 208 | 209 | # Evaluate 210 | if self.i_epoch > (self.pre_training_epochs - 50): 211 | # Make predictions 212 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test) 213 | 214 | # Input to dataset 215 | self.dataset_test.Scores = Scores 216 | self.dataset_test.Mhat = Mhat 217 | self.dataset_test.Xhat = Xhat 218 | 219 | # Evaluate anomaly detection 220 | auroc_det, auprc_det, th_det = evaluate_anomaly_detection(self.dataset_test.Y, self.dataset_test.Scores, 221 | dir_out=self.dir_results, 222 | range=[np.min(Scores)-np.std(Scores), np.max(Scores)+np.std(Scores)], 223 | tit='kl') 224 | acc = accuracy_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int')) 225 | fs = f1_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int')) 226 | 227 | metrics_detection = {'auroc_det': auroc_det, 'auprc_det': auprc_det, 'th_det': th_det, 'acc_det': acc, 228 | 'fs_det': fs} 229 | print(metrics_detection) 230 | 231 | # Evaluate anomaly localization 232 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results) 233 | self.metrics = metrics 234 | 235 | # Save metrics as dict 236 | with open(self.dir_results + 'metrics.json', 'w') as fp: 237 | json.dump(metrics, fp) 238 | print(metrics) 239 | 240 | # Plot learning curve 241 | self.plot_learning_curves() 242 | 243 | # Save learning curves as dataframe 244 | self.aucroc_lc.append(metrics['AU_ROC']) 245 | self.auprc_lc.append(metrics['AU_PRC']) 246 | self.auroc_det_lc.append(auroc_det) 247 | history = pd.DataFrame(list(zip(self.lr_lc, self.lkl_lc, self.H_lc, self.aucroc_lc, self.auprc_lc, self.auroc_det_lc)), 248 | columns=['Lrec', 'Lkl', 'H', 'AUCROC', 'AUPRC', 'AUROC_det']) 249 | history.to_csv(self.dir_results + 'lc_on_direct.csv') 250 | 251 | else: 252 | self.aucroc_lc.append(0) 253 | self.auprc_lc.append(0) 254 | self.auroc_det_lc.append(0) 255 | 256 | def predict_score(self, x): 257 | self.E.eval() 258 | self.Dec.eval() 259 | 260 | # brain mask 261 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets: 262 | x_mask = 1 - (x == 0).astype(np.int) 263 | if 'BRATS' in self.train_generator.dataset.dir_datasets: 264 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 6, 6))).astype(x_mask.dtype) 265 | else: 266 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype) 267 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets: 268 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1])) 269 | x_mask[:, 14:-14, 14:-14] = 1 270 | 271 | # Get reconstruction error map 272 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 273 | xhat = torch.sigmoid(self.Dec(z)[0]).squeeze().detach().cpu().numpy() 274 | 275 | am = torch.mean(f[self.level_cams], 1) 276 | # Restore original shape 277 | mhat = torch.nn.functional.interpolate(am.unsqueeze(0), size=(self.input_shape[-1], self.input_shape[-1]), 278 | mode='bilinear', align_corners=True).squeeze().detach().cpu().numpy() 279 | 280 | # brain mask - Keep only brain region 281 | if 'BRATS' in self.train_generator.dataset.dir_datasets or \ 282 | 'PhysioNet' in self.train_generator.dataset.dir_datasets or \ 283 | 'MVTEC' in self.train_generator.dataset.dir_datasets: 284 | mhat[x_mask[0, :, :] == 0] = 0 285 | 286 | # Get outputs 287 | anomaly_map = mhat 288 | # brain mask - Keep only brain region 289 | if 'BRATS' in self.train_generator.dataset.dir_datasets or \ 290 | 'PhysioNet' in self.train_generator.dataset.dir_datasets or \ 291 | 'MVTEC' in self.train_generator.dataset.dir_datasets: 292 | score = np.std(anomaly_map[x_mask[0, :, :] == 1]) 293 | else: 294 | score = np.std(anomaly_map) 295 | 296 | self.E.train() 297 | self.Dec.train() 298 | return score, anomaly_map, xhat 299 | 300 | def display_losses(self, on_epoch_end=False): 301 | 302 | # Init info display 303 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs, 304 | self.i_iteration + 1, self.iterations) 305 | # Prepare values to show 306 | if on_epoch_end: 307 | lr = self.lr_epoch 308 | lkl = self.kl_epoch 309 | lH = self.H_epoch 310 | 311 | end = '\n' 312 | else: 313 | lr = self.lr_iteration 314 | lkl = self.kl_iteration 315 | lH = self.H_iteration 316 | 317 | end = '\r' 318 | 319 | # Init losses display 320 | info += "Reconstruction={:.4f} || KL={:.4f} || H={:.4f}".format(lr, lkl, lH) 321 | if self.train_generator.dataset.weak_supervision: 322 | info += " || H_a={:.4f}".format(lH_a) 323 | # Print losses 324 | et = str(datetime.timedelta(seconds=timer() - self.init_time)) 325 | print(info + ', ET=' + et, end=end) 326 | 327 | def plot_learning_curves(self): 328 | def plot_subplot(axes, x, y, y_axis): 329 | axes.grid() 330 | axes.plot(x, y, 'o-') 331 | axes.set_ylabel(y_axis) 332 | 333 | fig, axes = plt.subplots(2, 2, figsize=(20, 15)) 334 | plot_subplot(axes[0, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss") 335 | plot_subplot(axes[0, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.lkl_lc), "KL loss") 336 | plot_subplot(axes[1, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.H_lc), "H") 337 | plt.savefig(self.dir_results + 'learning_curve.png') 338 | plt.close() 339 | 340 | -------------------------------------------------------------------------------- /code/methods/trainers/ae.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import kornia 3 | import json 4 | import torch 5 | 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | from scipy import ndimage 10 | from timeit import default_timer as timer 11 | from datasets.utils import augment_input_batch 12 | from models.models import Encoder, Decoder 13 | from evaluation.utils import * 14 | 15 | 16 | class AnomalyDetectorAE: 17 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25, 18 | load_weigths=False, n_blocks=5, dense=True, context=False, bayesian=False, 19 | loss_reconstruction='bce', restoration=False): 20 | 21 | # Init input variables 22 | self.dir_results = dir_results 23 | self.item = item 24 | self.zdim = zdim 25 | self.lr = lr 26 | self.input_shape = input_shape 27 | self.epochs_to_test = epochs_to_test 28 | self.load_weigths = load_weigths 29 | self.n_blocks = n_blocks 30 | self.dense = dense 31 | self.context = context 32 | self.bayesian = bayesian 33 | self.loss_reconstruction = loss_reconstruction 34 | self.restoration = restoration 35 | 36 | # Init network 37 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks, 38 | spatial_dim=self.input_shape[1]//2**self.n_blocks) 39 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0], 40 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks) 41 | 42 | if torch.cuda.is_available(): 43 | self.E.cuda() 44 | self.Dec.cuda() 45 | 46 | if self.load_weigths: 47 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth')) 48 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth')) 49 | 50 | # Set parameters 51 | self.params = list(self.E.parameters()) + list(self.Dec.parameters()) 52 | 53 | # Set losses 54 | if self.loss_reconstruction == 'l2': 55 | self.Lr = torch.nn.MSELoss(reduction='sum') 56 | elif self.loss_reconstruction == 'bce': 57 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum') 58 | 59 | # Set optimizers 60 | self.opt = torch.optim.Adam(self.params, lr=self.lr) 61 | 62 | # Init additional variables and objects 63 | self.epochs = 0. 64 | self.i_iteration = 0. 65 | self.iterations = 0. 66 | self.init_time = 0. 67 | self.lr_iteration = 0. 68 | self.lr_epoch = 0. 69 | self.i_epoch = 0. 70 | self.train_generator = [] 71 | self.dataset_test = [] 72 | self.metrics = {} 73 | self.aucroc_lc = [] 74 | self.auprc_lc = [] 75 | self.lr_lc = [] 76 | 77 | def train(self, train_generator, epochs, dataset_test): 78 | self.epochs = epochs 79 | self.init_time = timer() 80 | self.train_generator = train_generator 81 | self.dataset_test = dataset_test 82 | self.iterations = len(self.train_generator) 83 | 84 | # Loop over epochs 85 | for self.i_epoch in range(self.epochs): 86 | # init epoch losses 87 | self.lr_epoch = 0 88 | 89 | # Loop over training dataset 90 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator): 91 | 92 | if self.context: # if context option, data augmentation to apply context 93 | (x_n_context, _) = augment_input_batch(x_n.copy()) 94 | x_n_context = torch.tensor(x_n_context).cuda().float() 95 | 96 | # Move tensors to gpu 97 | x_n = torch.tensor(x_n).cuda().float() 98 | 99 | # Obtain latent space from normal sample via encoder 100 | if not self.context: 101 | z, _, _, _ = self.E(x_n) 102 | else: 103 | z, _, _, _ = self.E(x_n_context) 104 | 105 | # Obtain reconstructed images through decoder 106 | xhat, _ = self.Dec(z) 107 | if self.loss_reconstruction == 'l2': 108 | xhat = torch.sigmoid(xhat) 109 | 110 | # Calculate criterion 111 | self.lr_iteration = self.Lr(xhat, x_n) / (self.train_generator.batch_size * self.input_shape[1] * 112 | self.input_shape[2]) # Reconstruction loss 113 | 114 | # Init overall losses 115 | L = self.lr_iteration 116 | 117 | # Update weights 118 | L.backward() # Backward 119 | self.opt.step() # Update weights 120 | self.opt.zero_grad() # Clear gradients 121 | 122 | """ 123 | ON ITERATION/EPOCH END PROCESS 124 | """ 125 | 126 | # Display losses per iteration 127 | self.display_losses(on_epoch_end=False) 128 | 129 | # Update epoch's losses 130 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator) 131 | 132 | # Epoch-end processes 133 | self.on_epoch_end() 134 | 135 | def on_epoch_end(self): 136 | 137 | # Display losses 138 | self.display_losses(on_epoch_end=True) 139 | 140 | # Update learning curves 141 | self.lr_lc.append(self.lr_epoch) 142 | 143 | # Each x epochs, test models and plot learning curves 144 | if (self.i_epoch + 1) % self.epochs_to_test == 0: 145 | # Save weights 146 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth') 147 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth') 148 | 149 | # Make predictions 150 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test) 151 | 152 | # Input to dataset 153 | self.dataset_test.Scores = Scores 154 | self.dataset_test.Mhat = Mhat 155 | self.dataset_test.Xhat = Xhat 156 | 157 | # Evaluate 158 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results) 159 | self.metrics = metrics 160 | 161 | # Save metrics as dict 162 | with open(self.dir_results + 'metrics.json', 'w') as fp: 163 | json.dump(metrics, fp) 164 | print(metrics) 165 | 166 | # Plot learning curve 167 | self.plot_learning_curves() 168 | 169 | # Save learning curves as dataframe 170 | self.aucroc_lc.append(metrics['AU_ROC']) 171 | self.auprc_lc.append(metrics['AU_PRC']) 172 | history = pd.DataFrame(list(zip(self.lr_lc, self.aucroc_lc, self.auprc_lc)), 173 | columns=['Lrec', 'AUCROC', 'AUPRC']) 174 | history.to_csv(self.dir_results + 'lc_on_direct.csv') 175 | 176 | else: 177 | self.aucroc_lc.append(0) 178 | self.auprc_lc.append(0) 179 | 180 | def predict_score(self, x): 181 | self.E.eval() 182 | self.Dec.eval() 183 | 184 | # brain mask 185 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets: 186 | x_mask = 1 - (x == 0).astype(np.int) 187 | if 'BRATS' in self.train_generator.dataset.dir_datasets: 188 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype) 189 | else: 190 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype) 191 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets: 192 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1])) 193 | x_mask[:, 14:-14, 14:-14] = 1 194 | 195 | # Get reconstruction error map 196 | if self.restoration: # restoration reconstruction 197 | mhat, xhat = self.restoration_reconstruction(x) 198 | elif self.bayesian: # bayesian reconstruction 199 | mhat, xhat = self.bayesian_reconstruction(x) 200 | else: 201 | # Network forward 202 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 203 | xhat = np.squeeze(torch.sigmoid(self.Dec(z)[0]).cpu().detach().numpy()) 204 | # Compute anomaly map 205 | #mhat = np.squeeze(np.abs(x - xhat)) 206 | mhat = np.squeeze(x - xhat) 207 | 208 | # Keep only brain region 209 | mhat[x_mask[0, :, :] == 0] = 0 210 | 211 | # Get outputs 212 | anomaly_map = mhat 213 | score = np.mean(anomaly_map) 214 | 215 | self.E.train() 216 | self.Dec.train() 217 | return score, anomaly_map, xhat 218 | 219 | def bayesian_reconstruction(self, x): 220 | 221 | N = 100 222 | p_dropout = 0.20 223 | mhat = np.zeros((self.input_shape[1], self.input_shape[2])) 224 | 225 | # Network forward 226 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 227 | xhat = self.Dec(torch.nn.Dropout(p_dropout)(z))[0].cpu().detach().numpy() 228 | 229 | for i in np.arange(N): 230 | if z_mu is None: # apply dropout to z 231 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid( 232 | self.Dec(torch.nn.Dropout(p_dropout)(z))[0]).cpu().detach().numpy()) - x)) / N 233 | else: # sample z 234 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid( 235 | self.Dec(self.E.reparameterize(z_mu, z_logvar))[0]).cpu().detach().numpy()) - x)) / N 236 | return mhat, xhat 237 | 238 | def restoration_reconstruction(self, x): 239 | N = 300 240 | step = 1 * 1e-3 241 | x_rest = torch.tensor(x).cuda().float().unsqueeze(0) 242 | 243 | for i in np.arange(N): 244 | # Forward 245 | x_rest.requires_grad = True 246 | z, z_mu, z_logvar, f = self.E(x_rest) 247 | xhat = self.Dec(z)[0] 248 | 249 | # Compute loss 250 | lr = kornia.losses.total_variation(torch.tensor(x).cuda().float().unsqueeze(0) - torch.sigmoid(xhat)) 251 | L = lr / (self.input[1] * self.input[2]) 252 | 253 | # Get gradients 254 | gradients = torch.autograd.grad(L, x_rest, grad_outputs=None, retain_graph=True, 255 | create_graph=True, 256 | only_inputs=True, allow_unused=True)[0] 257 | 258 | x_rest = x_rest - gradients * step 259 | x_rest = x_rest.clone().detach() 260 | xhat = np.squeeze(x_rest.cpu().numpy()) 261 | 262 | # Compute difference 263 | mhat = np.squeeze(np.abs(x - xhat)) 264 | 265 | return mhat, xhat 266 | 267 | def display_losses(self, on_epoch_end=False): 268 | 269 | # Init info display 270 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs, 271 | self.i_iteration + 1, self.iterations) 272 | # Prepare values to show 273 | if on_epoch_end: 274 | lr = self.lr_epoch 275 | end = '\n' 276 | else: 277 | lr = self.lr_iteration 278 | end = '\r' 279 | # Init losses display 280 | info += "Reconstruction={:.4f}".format(lr) 281 | # Print losses 282 | et = str(datetime.timedelta(seconds=timer() - self.init_time)) 283 | print(info + ',ET=' + et, end=end) 284 | 285 | def plot_learning_curves(self): 286 | def plot_subplot(axes, x, y, y_axis): 287 | axes.grid() 288 | axes.plot(x, y, 'o-') 289 | axes.set_ylabel(y_axis) 290 | 291 | fig, axes = plt.subplots(1, 1, figsize=(20, 15)) 292 | plot_subplot(axes, np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss") 293 | plt.savefig(self.dir_results + 'learning_curve.png') 294 | plt.close() -------------------------------------------------------------------------------- /code/methods/trainers/anoVAEGAN.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import kornia 3 | import json 4 | import torch 5 | 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | import torch.nn.functional as F 9 | 10 | from scipy import ndimage 11 | from timeit import default_timer as timer 12 | from models.models import Encoder, Decoder, Discriminator 13 | from evaluation.utils import * 14 | from methods.losses.losses import kl_loss 15 | from datasets.utils import augment_input_batch 16 | 17 | 18 | class AnomalyDetectorAnoVAEGAN: 19 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25, 20 | load_weigths=False, n_blocks=5, dense=True, context=False, bayesian=False, 21 | loss_reconstruction='bce', restoration=False, alpha_kl=1, alpha_adversial=1): 22 | 23 | # Init input variables 24 | self.dir_results = dir_results 25 | self.item = item 26 | self.zdim = zdim 27 | self.lr = lr 28 | self.input_shape = input_shape 29 | self.epochs_to_test = epochs_to_test 30 | self.load_weigths = load_weigths 31 | self.n_blocks = n_blocks 32 | self.dense = dense 33 | self.context = context 34 | self.bayesian = bayesian 35 | self.loss_reconstruction = loss_reconstruction 36 | self.restoration = restoration 37 | self.alpha_kl = alpha_kl 38 | self.alpha_adversial = alpha_adversial 39 | self.normal_label = 1 40 | self.generated_label = 0 41 | 42 | # Init network 43 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks, 44 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=True, gap=False) 45 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0], 46 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks, 47 | gap=False) 48 | self.Disc = Discriminator(fin=self.E.backbone.nfeats//2**self.n_blocks, n_channels=input_shape[0], 49 | n_blocks=self.n_blocks) 50 | 51 | if torch.cuda.is_available(): 52 | self.E.cuda() 53 | self.Dec.cuda() 54 | self.Disc.cuda() 55 | 56 | if self.load_weigths: 57 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth')) 58 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth')) 59 | 60 | # Set parameters 61 | self.params = list(self.E.parameters()) + list(self.Dec.parameters()) 62 | 63 | # Set losses 64 | if self.loss_reconstruction == 'l2': 65 | self.Lr = torch.nn.MSELoss(reduction='sum') 66 | elif self.loss_reconstruction == 'bce': 67 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum') 68 | self.DLoss = torch.nn.BCEWithLogitsLoss() 69 | self.Lkl = kl_loss 70 | 71 | # Set optimizers 72 | self.opt = torch.optim.Adam(self.params, lr=self.lr) 73 | self.opt_Disc = torch.optim.Adam(self.Disc.parameters(), lr=self.lr * 10) 74 | 75 | # Init additional variables and objects 76 | self.epochs = 0. 77 | self.iterations = 0. 78 | self.init_time = 0. 79 | self.lr_iteration = 0. 80 | self.lr_epoch = 0. 81 | self.kl_iteration = 0. 82 | self.kl_epoch = 0. 83 | self.i_epoch = 0. 84 | self.train_generator = [] 85 | self.dataset_test = [] 86 | self.metrics = {} 87 | self.aucroc_lc = [] 88 | self.auprc_lc = [] 89 | self.lr_lc = [] 90 | self.lkl_lc = [] 91 | self.ldisc_epoch = 0. 92 | self.ldisc_iteration = 0. 93 | self.ladv_epoch = 0. 94 | self.ladv_iteration = 0. 95 | self.ladv_lc = [] 96 | self.ldisc_lc = [] 97 | 98 | def train(self, train_generator, epochs, dataset_test): 99 | self.epochs = epochs 100 | self.init_time = timer() 101 | self.train_generator = train_generator 102 | self.dataset_test = dataset_test 103 | self.iterations = len(self.train_generator) 104 | 105 | # Loop over epochs 106 | for self.i_epoch in range(self.epochs): 107 | # init epoch losses 108 | self.lr_epoch = 0 109 | self.kl_epoch = 0. 110 | 111 | # Loop over training dataset 112 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator): 113 | 114 | if self.context: # if context option, data augmentation to apply context 115 | (x_n_context, _) = augment_input_batch(x_n.copy()) 116 | x_n_context = torch.tensor(x_n_context).cuda().float() 117 | 118 | # Move tensors to gpu 119 | x_n = torch.tensor(x_n).cuda().float() 120 | 121 | # Obtain latent space from normal sample via encoder 122 | if not self.context: 123 | z, z_mu, z_logvar, _ = self.E(x_n) 124 | else: 125 | z, z_mu, z_logvar, _ = self.E(x_n_context) 126 | 127 | # Obtain reconstructed images through decoder 128 | xhat, _ = self.Dec(z) 129 | if self.loss_reconstruction == 'l2': 130 | xhat = torch.sigmoid(xhat) 131 | 132 | # Forward discriminator 133 | d_x, _ = self.Disc(x_n) 134 | d_xhat, _ = self.Disc(torch.sigmoid(xhat)) 135 | 136 | # Discriminator labels 137 | d_x_true = torch.tensor(self.normal_label * np.ones((self.train_generator.batch_size, 1))).cuda().float() 138 | d_f_false = torch.tensor(self.normal_label * np.ones((self.train_generator.batch_size, 1))).cuda().float() 139 | d_xhat_true = torch.tensor(self.generated_label * np.ones((self.train_generator.batch_size, 1))).cuda().float() 140 | 141 | # ------------D training------------------ 142 | self.ldisc_iteration = 0.5 * F.binary_cross_entropy(d_x, d_x_true) + 0.5 * F.binary_cross_entropy(d_xhat, d_xhat_true) 143 | 144 | self.opt_Disc.zero_grad() 145 | self.ldisc_iteration.backward(retain_graph=True) 146 | self.opt_Disc.step() 147 | 148 | # ------------Encoder and Decoder training------------------ 149 | # Discriminator prediction 150 | d_xhat, _ = self.Disc(torch.sigmoid(xhat)) 151 | 152 | # Calculate criterion 153 | self.lr_iteration = self.Lr(xhat, x_n) / self.train_generator.batch_size # Reconstruction loss 154 | self.kl_iteration = self.Lkl(mu=z_mu, logvar=z_logvar) # kl loss (averaged per spatial feature) 155 | self.ladv_iteration = F.binary_cross_entropy(d_xhat, d_f_false) # Adversial loss 156 | 157 | # Init overall losses 158 | L = self.lr_iteration + self.alpha_kl * self.kl_iteration + self.alpha_adversial * self.ladv_iteration 159 | 160 | # Update weights 161 | L.backward() # Backward 162 | self.opt.step() # Update weights 163 | self.opt.zero_grad() # Clear gradients 164 | 165 | """ 166 | ON ITERATION/EPOCH END PROCESS 167 | """ 168 | 169 | # Display losses per iteration 170 | self.display_losses(on_epoch_end=False) 171 | 172 | # Update epoch's losses 173 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator) 174 | self.kl_epoch += self.kl_iteration.cpu().detach().numpy() / len(self.train_generator) 175 | self.ladv_epoch += self.ladv_iteration.cpu().detach().numpy() / len(self.train_generator) 176 | self.ldisc_epoch += self.ldisc_iteration.cpu().detach().numpy() / len(self.train_generator) 177 | 178 | # Epoch-end processes 179 | self.on_epoch_end() 180 | 181 | def on_epoch_end(self): 182 | 183 | # Display losses 184 | self.display_losses(on_epoch_end=True) 185 | 186 | # Update learning curves 187 | self.lr_lc.append(self.lr_epoch) 188 | self.lkl_lc.append(self.kl_epoch) 189 | self.ladv_lc.append(self.ladv_epoch) 190 | self.ldisc_lc.append(self.ldisc_epoch) 191 | 192 | # Each x epochs, test models and plot learning curves 193 | if (self.i_epoch + 1) % self.epochs_to_test == 0: 194 | # Save weights 195 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth') 196 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth') 197 | 198 | # Make predictions 199 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test) 200 | 201 | # Input to dataset 202 | self.dataset_test.Scores = Scores 203 | self.dataset_test.Mhat = Mhat 204 | self.dataset_test.Xhat = Xhat 205 | 206 | # Evaluate 207 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results) 208 | self.metrics = metrics 209 | 210 | # Save metrics as dict 211 | with open(self.dir_results + 'metrics.json', 'w') as fp: 212 | json.dump(metrics, fp) 213 | print(metrics) 214 | 215 | # Plot learning curve 216 | self.plot_learning_curves() 217 | 218 | # Save learning curves as dataframe 219 | self.aucroc_lc.append(metrics['AU_ROC']) 220 | self.auprc_lc.append(metrics['AU_PRC']) 221 | history = pd.DataFrame(list(zip(self.lr_lc, self.lr_lc, self.aucroc_lc, self.auprc_lc)), 222 | columns=['Lrec', 'Lkl', 'AUCROC', 'AUPRC']) 223 | history.to_csv(self.dir_results + 'lc_on_direct.csv') 224 | 225 | else: 226 | self.aucroc_lc.append(0) 227 | self.auprc_lc.append(0) 228 | 229 | def predict_score(self, x): 230 | self.E.eval() 231 | self.Dec.eval() 232 | 233 | # Prepare brain eroded mask 234 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets: 235 | x_mask = 1 - (x == 0).astype(np.int) 236 | if 'BRATS' in self.train_generator.dataset.dir_datasets: 237 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype) 238 | else: 239 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype) 240 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets: 241 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1])) 242 | x_mask[:, 14:-14, 14:-14] = 1 243 | 244 | # Get reconstruction error map 245 | if self.restoration: # restoration reconstruction 246 | mhat, xhat = self.restoration_reconstruction(x) 247 | elif self.bayesian: # bayesian reconstruction 248 | mhat, xhat = self.bayesian_reconstruction(x) 249 | else: 250 | # Network forward 251 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 252 | xhat = np.squeeze(torch.sigmoid(self.Dec(z)[0]).cpu().detach().numpy()) 253 | # Compute anomaly map 254 | mhat = np.squeeze(np.abs(x - xhat)) 255 | 256 | # Keep only brain region 257 | mhat[x_mask[0, :, :] == 0] = 0 258 | 259 | # Get outputs 260 | anomaly_map = mhat 261 | score = np.mean(anomaly_map) 262 | 263 | self.E.train() 264 | self.Dec.train() 265 | return score, anomaly_map, xhat 266 | 267 | def bayesian_reconstruction(self, x): 268 | 269 | N = 100 270 | p_dropout = 0.20 271 | mhat = np.zeros((self.input_shape[1], self.input_shape[2])) 272 | 273 | # Network forward 274 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 275 | xhat = self.Dec(torch.nn.Dropout(p_dropout)(z))[0].cpu().detach().numpy() 276 | 277 | for i in np.arange(N): 278 | if z_mu is None: # apply dropout to z 279 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid( 280 | self.Dec(torch.nn.Dropout(p_dropout)(z))[0]).cpu().detach().numpy()) - x)) / N 281 | else: # sample z 282 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid( 283 | self.Dec(self.E.reparameterize(z_mu, z_logvar))[0]).cpu().detach().numpy()) - x)) / N 284 | return mhat, xhat 285 | 286 | def restoration_reconstruction(self, x): 287 | N = 300 288 | step = 1 * 1e-3 289 | x_rest = torch.tensor(x).cuda().float().unsqueeze(0) 290 | 291 | for i in np.arange(N): 292 | # Forward 293 | x_rest.requires_grad = True 294 | z, z_mu, z_logvar, f = self.E(x_rest) 295 | xhat = self.Dec(z)[0] 296 | 297 | # Compute loss 298 | lr = kornia.losses.total_variation(torch.tensor(x).cuda().float().unsqueeze(0) - torch.sigmoid(xhat)) 299 | L = lr / (self.input_shape[1] * self.input_shape[2]) 300 | 301 | # Get gradients 302 | gradients = torch.autograd.grad(L, x_rest, grad_outputs=None, retain_graph=True, 303 | create_graph=True, 304 | only_inputs=True, allow_unused=True)[0] 305 | 306 | x_rest = x_rest - gradients * step 307 | x_rest = x_rest.clone().detach() 308 | xhat = np.squeeze(x_rest.cpu().numpy()) 309 | 310 | # Compute difference 311 | mhat = np.squeeze(np.abs(x - xhat)) 312 | 313 | return mhat, xhat 314 | 315 | def display_losses(self, on_epoch_end=False): 316 | 317 | # Init info display 318 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs, 319 | self.i_iteration + 1, self.iterations) 320 | # Prepare values to show 321 | if on_epoch_end: 322 | lr = self.lr_epoch 323 | ladv = self.ladv_epoch 324 | ldisc = self.ldisc_epoch 325 | lkl = self.kl_epoch 326 | end = '\n' 327 | else: 328 | lr = self.lr_iteration 329 | ladv = self.ladv_iteration 330 | ldisc = self.ldisc_iteration 331 | lkl = self.kl_iteration 332 | end = '\r' 333 | # Init losses display 334 | info += "Reconstruction={:.4f} || KL={:.4f} || ladv={:.4f} || ldisc={:.4f}".format(lr, lkl, ladv, ldisc) 335 | # Print losses 336 | et = str(datetime.timedelta(seconds=timer() - self.init_time)) 337 | print(info + ', ET=' + et, end=end) 338 | 339 | def plot_learning_curves(self): 340 | def plot_subplot(axes, x, y, y_axis): 341 | axes.grid() 342 | axes.plot(x, y, 'o-') 343 | axes.set_ylabel(y_axis) 344 | 345 | fig, axes = plt.subplots(2, 2, figsize=(20, 15)) 346 | plot_subplot(axes[0, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss") 347 | plot_subplot(axes[0, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.lkl_lc), "KL loss") 348 | plot_subplot(axes[1, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.ladv_lc), "Generator loss") 349 | plot_subplot(axes[1, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.ldisc_lc), "Discriminator loss") 350 | plt.savefig(self.dir_results + 'learning_curve.png') 351 | plt.close() -------------------------------------------------------------------------------- /code/methods/trainers/fanoGAN.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import kornia 3 | import json 4 | import torch 5 | 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | import torch.nn.functional as F 9 | 10 | from scipy import ndimage 11 | from timeit import default_timer as timer 12 | from models.models import Encoder, Decoder, Discriminator 13 | from evaluation.utils import * 14 | from methods.losses.losses import kl_loss 15 | from datasets.utils import augment_input_batch 16 | 17 | 18 | class AnomalyDetectorFanoGAN: 19 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25, 20 | load_weigths=False, n_blocks=4, dense=True, context=False, bayesian=False, 21 | loss_reconstruction='bce', restoration=False, alpha_adversial=1): 22 | 23 | # Init input variables 24 | self.dir_results = dir_results 25 | self.item = item 26 | self.zdim = zdim 27 | self.lr = lr 28 | self.input_shape = input_shape 29 | self.epochs_to_test = epochs_to_test 30 | self.load_weigths = load_weigths 31 | self.n_blocks = n_blocks 32 | self.dense = dense 33 | self.context = context 34 | self.bayesian = bayesian 35 | self.loss_reconstruction = loss_reconstruction 36 | self.restoration = restoration 37 | self.alpha_adversial = alpha_adversial 38 | self.normal_label = 1 39 | self.generated_label = 0 40 | 41 | # Init network 42 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks, 43 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=False, gap=False) 44 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0], 45 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks, 46 | gap=False) 47 | self.Disc = Discriminator(fin=self.E.backbone.nfeats//2**self.n_blocks, n_channels=input_shape[0], 48 | n_blocks=self.n_blocks) 49 | 50 | if torch.cuda.is_available(): 51 | self.E.cuda() 52 | self.Dec.cuda() 53 | self.Disc.cuda() 54 | 55 | if self.load_weigths: 56 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth')) 57 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth')) 58 | 59 | # Set parameters 60 | self.params = list(self.E.parameters()) + list(self.Dec.parameters()) 61 | 62 | # Set losses 63 | if self.loss_reconstruction == 'l2': 64 | self.Lr = torch.nn.MSELoss(reduction='sum') 65 | elif self.loss_reconstruction == 'bce': 66 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum') 67 | self.DLoss = torch.nn.BCEWithLogitsLoss() 68 | 69 | # Set optimizers 70 | self.opt = torch.optim.Adam(self.params, lr=self.lr) 71 | self.opt_Disc = torch.optim.Adam(self.Disc.parameters(), lr=self.lr * 10) 72 | 73 | # Init additional variables and objects 74 | self.epochs = 0. 75 | self.iterations = 0. 76 | self.init_time = 0. 77 | self.lr_iteration = 0. 78 | self.lr_epoch = 0. 79 | self.i_epoch = 0. 80 | self.train_generator = [] 81 | self.dataset_test = [] 82 | self.metrics = {} 83 | self.aucroc_lc = [] 84 | self.auprc_lc = [] 85 | self.lr_lc = [] 86 | self.ldisc_epoch = 0. 87 | self.ldisc_iteration = 0. 88 | self.ladv_epoch = 0. 89 | self.ladv_iteration = 0. 90 | self.ladv_lc = [] 91 | self.ldisc_lc = [] 92 | 93 | def train(self, train_generator, epochs, dataset_test): 94 | self.epochs = epochs 95 | self.init_time = timer() 96 | self.train_generator = train_generator 97 | self.dataset_test = dataset_test 98 | self.iterations = len(self.train_generator) 99 | 100 | # Loop over epochs 101 | for self.i_epoch in range(self.epochs): 102 | # init epoch losses 103 | self.lr_epoch = 0. 104 | self.ldisc_epoch = 0. 105 | self.ladv_epoch = 0. 106 | 107 | # Loop over training dataset 108 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator): 109 | 110 | if self.context: # if context option, data augmentation to apply context 111 | (x_n_context, _) = augment_input_batch(x_n.copy()) 112 | x_n_context = torch.tensor(x_n_context).cuda().float() 113 | 114 | # Move tensors to gpu 115 | x_n = torch.tensor(x_n).cuda().float() 116 | 117 | # Obtain latent space from normal sample via encoder 118 | if not self.context: 119 | z, _, _, _ = self.E(x_n) 120 | else: 121 | z, _, _, _ = self.E(x_n_context) 122 | 123 | # Obtain reconstructed images through decoder 124 | xhat, _ = self.Dec(z) 125 | if self.loss_reconstruction == 'l2': 126 | xhat = torch.sigmoid(xhat) 127 | 128 | # Forward discriminator 129 | d_x, _ = self.Disc(x_n) 130 | d_xhat, _ = self.Disc(torch.sigmoid(xhat)) 131 | 132 | # Discriminator labels 133 | d_x_true = torch.tensor(self.normal_label * np.ones((self.train_generator.batch_size, 1))).cuda().float() 134 | d_f_false = torch.tensor(self.normal_label * np.ones((self.train_generator.batch_size, 1))).cuda().float() 135 | d_xhat_true = torch.tensor(self.generated_label * np.ones((self.train_generator.batch_size, 1))).cuda().float() 136 | 137 | # ------------D training------------------ 138 | self.ldisc_iteration = 0.5 * F.binary_cross_entropy(d_x, d_x_true) + 0.5 * F.binary_cross_entropy(d_xhat, d_xhat_true) 139 | 140 | self.opt_Disc.zero_grad() 141 | self.ldisc_iteration.backward(retain_graph=True) 142 | self.opt_Disc.step() 143 | 144 | # ------------Encoder and Decoder training------------------ 145 | # Discriminator prediction 146 | _, f_x = self.Disc(x_n) 147 | d_xhat, f_xhat = self.Disc(torch.sigmoid(xhat)) 148 | 149 | # Calculate criterion 150 | self.lr_iteration = self.Lr(xhat, x_n) / self.train_generator.batch_size # Reconstruction loss 151 | self.ladv_iteration = torch.mean(torch.pow(f_x[-1] - f_xhat[-1], 2)) # Feature matching loss 152 | 153 | # Init overall losses 154 | L = self.lr_iteration + self.alpha_adversial * self.ladv_iteration 155 | 156 | # Update weights 157 | L.backward() # Backward 158 | self.opt.step() # Update weights 159 | self.opt.zero_grad() # Clear gradients 160 | 161 | """ 162 | ON ITERATION/EPOCH END PROCESS 163 | """ 164 | 165 | # Display losses per iteration 166 | self.display_losses(on_epoch_end=False) 167 | 168 | # Update epoch's losses 169 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator) 170 | self.ladv_epoch += self.ladv_iteration.cpu().detach().numpy() / len(self.train_generator) 171 | self.ldisc_epoch += self.ldisc_iteration.cpu().detach().numpy() / len(self.train_generator) 172 | 173 | # Epoch-end processes 174 | self.on_epoch_end() 175 | 176 | def on_epoch_end(self): 177 | 178 | # Display losses 179 | self.display_losses(on_epoch_end=True) 180 | 181 | # Update learning curves 182 | self.lr_lc.append(self.lr_epoch) 183 | self.ladv_lc.append(self.ladv_epoch) 184 | self.ldisc_lc.append(self.ldisc_epoch) 185 | 186 | # Each x epochs, test models and plot learning curves 187 | if (self.i_epoch + 1) % self.epochs_to_test == 0: 188 | # Save weights 189 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth') 190 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth') 191 | 192 | # Make predictions 193 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test) 194 | 195 | # Input to dataset 196 | self.dataset_test.Scores = Scores 197 | self.dataset_test.Mhat = Mhat 198 | self.dataset_test.Xhat = Xhat 199 | 200 | # Evaluate 201 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results) 202 | self.metrics = metrics 203 | 204 | # Save metrics as dict 205 | with open(self.dir_results + 'metrics.json', 'w') as fp: 206 | json.dump(metrics, fp) 207 | print(metrics) 208 | 209 | # Plot learning curve 210 | self.plot_learning_curves() 211 | 212 | # Save learning curves as dataframe 213 | self.aucroc_lc.append(metrics['AU_ROC']) 214 | self.auprc_lc.append(metrics['AU_PRC']) 215 | history = pd.DataFrame(list(zip(self.lr_lc, self.lr_lc, self.aucroc_lc, self.auprc_lc)), 216 | columns=['Lrec', 'Lkl', 'AUCROC', 'AUPRC']) 217 | history.to_csv(self.dir_results + 'lc_on_direct.csv') 218 | 219 | else: 220 | self.aucroc_lc.append(0) 221 | self.auprc_lc.append(0) 222 | 223 | def predict_score(self, x): 224 | self.E.eval() 225 | self.Dec.eval() 226 | 227 | # Prepare brain eroded mask 228 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets: 229 | x_mask = 1 - (x == 0).astype(np.int) 230 | if 'BRATS' in self.train_generator.dataset.dir_datasets: 231 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype) 232 | else: 233 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype) 234 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets: 235 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1])) 236 | x_mask[:, 14:-14, 14:-14] = 1 237 | 238 | # Get reconstruction error map 239 | if self.restoration: # restoration reconstruction 240 | mhat, xhat = self.restoration_reconstruction(x) 241 | elif self.bayesian: # bayesian reconstruction 242 | mhat, xhat = self.bayesian_reconstruction(x) 243 | else: 244 | # Network forward 245 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 246 | xhat = np.squeeze(torch.sigmoid(self.Dec(z)[0]).cpu().detach().numpy()) 247 | # Compute anomaly map 248 | mhat = np.squeeze(np.abs(x - xhat)) 249 | 250 | # Keep only brain region 251 | mhat[x_mask[0, :, :] == 0] = 0 252 | 253 | # Get outputs 254 | anomaly_map = mhat 255 | score = np.mean(anomaly_map) 256 | 257 | self.E.train() 258 | self.Dec.train() 259 | return score, anomaly_map, xhat 260 | 261 | def bayesian_reconstruction(self, x): 262 | 263 | N = 100 264 | p_dropout = 0.20 265 | mhat = np.zeros((self.input_shape[1], self.input_shape[2])) 266 | 267 | # Network forward 268 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 269 | xhat = self.Dec(torch.nn.Dropout(p_dropout)(z))[0].cpu().detach().numpy() 270 | 271 | for i in np.arange(N): 272 | if z_mu is None: # apply dropout to z 273 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid( 274 | self.Dec(torch.nn.Dropout(p_dropout)(z))[0]).cpu().detach().numpy()) - x)) / N 275 | else: # sample z 276 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid( 277 | self.Dec(self.E.reparameterize(z_mu, z_logvar))[0]).cpu().detach().numpy()) - x)) / N 278 | return mhat, xhat 279 | 280 | def restoration_reconstruction(self, x): 281 | N = 300 282 | step = 1 * 1e-3 283 | x_rest = torch.tensor(x).cuda().float().unsqueeze(0) 284 | 285 | for i in np.arange(N): 286 | # Forward 287 | x_rest.requires_grad = True 288 | z, z_mu, z_logvar, f = self.E(x_rest) 289 | xhat = self.Dec(z)[0] 290 | 291 | # Compute loss 292 | lr = kornia.losses.total_variation(torch.tensor(x).cuda().float().unsqueeze(0) - torch.sigmoid(xhat)) 293 | L = lr / (self.input_shape[1] * self.input_shape[2]) 294 | 295 | # Get gradients 296 | gradients = torch.autograd.grad(L, x_rest, grad_outputs=None, retain_graph=True, 297 | create_graph=True, 298 | only_inputs=True, allow_unused=True)[0] 299 | 300 | x_rest = x_rest - gradients * step 301 | x_rest = x_rest.clone().detach() 302 | xhat = np.squeeze(x_rest.cpu().numpy()) 303 | 304 | # Compute difference 305 | mhat = np.squeeze(np.abs(x - xhat)) 306 | 307 | return mhat, xhat 308 | 309 | def display_losses(self, on_epoch_end=False): 310 | 311 | # Init info display 312 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs, 313 | self.i_iteration + 1, self.iterations) 314 | # Prepare values to show 315 | if on_epoch_end: 316 | lr = self.lr_epoch 317 | ladv = self.ladv_epoch 318 | ldisc = self.ldisc_epoch 319 | end = '\n' 320 | else: 321 | lr = self.lr_iteration 322 | ladv = self.ladv_iteration 323 | ldisc = self.ldisc_iteration 324 | end = '\r' 325 | # Init losses display 326 | info += "Reconstruction={:.4f} || ladv={:.4f} || ldisc={:.4f}".format(lr, ladv, ldisc) 327 | # Print losses 328 | et = str(datetime.timedelta(seconds=timer() - self.init_time)) 329 | print(info + ', ET=' + et, end=end) 330 | 331 | def plot_learning_curves(self): 332 | def plot_subplot(axes, x, y, y_axis): 333 | axes.grid() 334 | axes.plot(x, y, 'o-') 335 | axes.set_ylabel(y_axis) 336 | 337 | fig, axes = plt.subplots(2, 2, figsize=(20, 15)) 338 | plot_subplot(axes[0, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss") 339 | plot_subplot(axes[1, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.ladv_lc), "Generator loss") 340 | plot_subplot(axes[0, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.ldisc_lc), "Discriminator loss") 341 | plt.savefig(self.dir_results + 'learning_curve.png') 342 | plt.close() -------------------------------------------------------------------------------- /code/methods/trainers/gradCAMCons.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import datetime 5 | import kornia 6 | import json 7 | import torch 8 | 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | 12 | from scipy import ndimage 13 | from timeit import default_timer as timer 14 | from models.models import Encoder, Decoder 15 | from evaluation.utils import * 16 | from methods.losses.losses import kl_loss 17 | from methods.losses.losses import log_barrier 18 | from sklearn.metrics import accuracy_score, f1_score 19 | 20 | 21 | class AnomalyDetectorGradCamCons: 22 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1 * 1e-4, input_shape=(1, 224, 224), epochs_to_test=25, 23 | load_weigths=False, n_blocks=5, dense=True, loss_reconstruction='bce', alpha_ae=0, alpha_kl=1, 24 | pre_training_epochs=0, level_cams=-4, t=25, p_activation_cam=0.2, 25 | expansion_loss_penalty='l2', avg_grads=True, gap=False): 26 | 27 | # Init input variables 28 | self.dir_results = dir_results 29 | self.item = item 30 | self.zdim = zdim 31 | self.lr = lr 32 | self.input_shape = input_shape 33 | self.epochs_to_test = epochs_to_test 34 | self.load_weigths = load_weigths 35 | self.n_blocks = n_blocks 36 | self.dense = dense 37 | self.loss_reconstruction = loss_reconstruction 38 | self.alpha_kl = alpha_kl 39 | self.pre_training_epochs = pre_training_epochs 40 | self.level_cams = level_cams 41 | self.t = t 42 | self.p_activation_cam = p_activation_cam 43 | self.expansion_loss_penalty = expansion_loss_penalty 44 | self.alpha_ae = alpha_ae 45 | self.avg_grads = avg_grads 46 | self.gap = gap 47 | self.scheduler = False 48 | 49 | # Init network 50 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks, 51 | spatial_dim=self.input_shape[1] // 2 ** self.n_blocks, variational=True, gap=gap) 52 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats // 2, n_channels=self.input_shape[0], 53 | dense=self.dense, n_blocks=self.n_blocks, 54 | spatial_dim=self.input_shape[1] // 2 ** self.n_blocks, 55 | gap=gap) 56 | 57 | if torch.cuda.is_available(): 58 | self.E.cuda() 59 | self.Dec.cuda() 60 | 61 | if self.load_weigths: 62 | self.E.load_state_dict(torch.load(self.dir_results + 'encoder_weights.pth')) 63 | self.Dec.load_state_dict(torch.load(self.dir_results + 'decoder_weights.pth')) 64 | 65 | # Set parameters 66 | self.params = list(self.E.parameters()) + list(self.Dec.parameters()) 67 | 68 | # Set losses 69 | if self.loss_reconstruction == 'l2': 70 | self.Lr = torch.nn.MSELoss(reduction='sum') 71 | elif self.loss_reconstruction == 'bce': 72 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum') 73 | 74 | self.Lkl = kl_loss 75 | 76 | # Set optimizers 77 | self.opt = torch.optim.Adam(self.params, lr=self.lr) 78 | 79 | # Init additional variables and objects 80 | self.epochs = 0. 81 | self.iterations = 0. 82 | self.init_time = 0. 83 | self.lr_iteration = 0. 84 | self.lr_epoch = 0. 85 | self.kl_iteration = 0. 86 | self.kl_epoch = 0. 87 | self.i_epoch = 0. 88 | self.lae_iteration = 0. 89 | self.lae_epoch = 0. 90 | self.train_generator = [] 91 | self.dataset_test = [] 92 | self.metrics = {} 93 | self.aucroc_lc = [] 94 | self.auprc_lc = [] 95 | self.auroc_det = [] 96 | self.lr_lc = [] 97 | self.lkl_lc = [] 98 | self.lae_lc = [] 99 | self.auroc_det_lc = [] 100 | self.refCam = 0. 101 | 102 | def train(self, train_generator, epochs, test_dataset): 103 | self.epochs = epochs 104 | self.init_time = timer() 105 | self.train_generator = train_generator 106 | self.dataset_test = test_dataset 107 | self.iterations = len(self.train_generator) 108 | 109 | # Loop over epochs 110 | for self.i_epoch in range(self.epochs): 111 | # init epoch losses 112 | self.lr_epoch = 0 113 | self.kl_epoch = 0. 114 | self.lae_epoch = 0. 115 | 116 | if self.scheduler and self.expansion_loss_penalty == 'log_barrier' and self.i_epoch > self.pre_training_epochs: 117 | self.t = self.t * 1.01 118 | print(self.t, end='\n') 119 | 120 | # Loop over training dataset 121 | for self.i_iteration, (x_n, y_n, x_a, y_a) in enumerate(self.train_generator): 122 | # p = q 123 | 124 | # Move tensors to gpu 125 | x_n = torch.tensor(x_n).cuda().float() 126 | 127 | # Obtain latent space from normal sample via encoder 128 | z, z_mu, z_logvar, allF = self.E(x_n) 129 | 130 | # Obtain reconstructed images through decoder 131 | xhat, _ = self.Dec(z) 132 | if self.loss_reconstruction == 'l2': 133 | xhat = torch.sigmoid(xhat) 134 | 135 | # Calculate criterion 136 | self.lr_iteration = self.Lr(xhat, x_n) / (self.train_generator.batch_size) # Reconstruction loss 137 | self.kl_iteration = self.Lkl(mu=z_mu, logvar=z_logvar) / (self.train_generator.batch_size) # kl loss 138 | 139 | # Init overall losses 140 | L = self.lr_iteration + self.alpha_kl * self.kl_iteration 141 | 142 | # ---- Compute Attention expansion loss 143 | 144 | # Compute grad-cams 145 | gcam = grad_cam(allF[self.level_cams], torch.sum(z_mu), normalization='sigm', 146 | avg_grads=True) 147 | # Restore original shape 148 | gcam = torch.nn.functional.interpolate(gcam.unsqueeze(1), 149 | size=(self.input_shape[-1], self.input_shape[-1]), 150 | mode='bilinear', 151 | align_corners=True).squeeze() 152 | self.lae_iteration = torch.mean(gcam) 153 | 154 | if self.i_epoch > self.pre_training_epochs: 155 | # Compute attention expansion loss 156 | if self.expansion_loss_penalty == 'l1': # L1 157 | lae = torch.mean(torch.abs(-torch.mean(gcam, (-1)) + 1 - self.p_activation_cam)) 158 | elif self.expansion_loss_penalty == 'l2': # L2 159 | lae = torch.mean(torch.sqrt(torch.pow(-torch.mean(gcam, (-1)) + 1 - self.p_activation_cam, 2))) 160 | elif self.expansion_loss_penalty == 'log_barrier': 161 | z = -torch.mean(gcam, (1, 2)).unsqueeze(-1) + 1 162 | lae = log_barrier(z - self.p_activation_cam, t=self.t) / self.train_generator.batch_size 163 | 164 | # Update overall losses 165 | L += self.alpha_ae * lae.squeeze() 166 | 167 | # Update weights 168 | L.backward() # Backward 169 | self.opt.step() # Update weights 170 | self.opt.zero_grad() # Clear gradients 171 | 172 | """ 173 | ON ITERATION/EPOCH END PROCESS 174 | """ 175 | 176 | # Display losses per iteration 177 | self.display_losses(on_epoch_end=False) 178 | 179 | # Update epoch's losses 180 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator) 181 | self.kl_epoch += self.kl_iteration.cpu().detach().numpy() / len(self.train_generator) 182 | self.lae_epoch += self.lae_iteration.cpu().detach().numpy() / len(self.train_generator) 183 | 184 | # Epoch-end processes 185 | self.on_epoch_end() 186 | 187 | def on_epoch_end(self): 188 | 189 | # Display losses 190 | self.display_losses(on_epoch_end=True) 191 | 192 | # Update learning curves 193 | self.lr_lc.append(self.lr_epoch) 194 | self.lkl_lc.append(self.kl_epoch) 195 | self.lae_lc.append(self.lae_epoch) 196 | 197 | # Each x epochs, test models and plot learning curves 198 | if (self.i_epoch + 1) % self.epochs_to_test == 0: 199 | # Save weights 200 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth') 201 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth') 202 | 203 | # Evaluate 204 | if self.i_epoch > (0): 205 | # Make predictions 206 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test) 207 | 208 | # Input to dataset 209 | self.dataset_test.Scores = Scores 210 | self.dataset_test.Mhat = Mhat 211 | self.dataset_test.Xhat = Xhat 212 | 213 | # Evaluate anomaly detection 214 | auroc_det, auprc_det, th_det = evaluate_anomaly_detection(self.dataset_test.Y, self.dataset_test.Scores, 215 | dir_out=self.dir_results, 216 | range=[np.min(Scores) - np.std(Scores), 217 | np.max(Scores) + np.std(Scores)], 218 | tit='kl') 219 | acc = accuracy_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int')) 220 | fs = f1_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int')) 221 | 222 | metrics_detection = {'auroc_det': auroc_det, 'auprc_det': auprc_det, 'th_det': th_det, 'acc_det': acc, 223 | 'fs_det': fs} 224 | print(metrics_detection) 225 | 226 | # Evaluate anomaly localization 227 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results) 228 | self.metrics = metrics 229 | 230 | # Save metrics as dict 231 | with open(self.dir_results + 'metrics.json', 'w') as fp: 232 | json.dump(metrics, fp) 233 | print(metrics) 234 | 235 | # Plot learning curve 236 | self.plot_learning_curves() 237 | 238 | # Save learning curves as dataframe 239 | self.aucroc_lc.append(metrics['AU_ROC']) 240 | self.auprc_lc.append(metrics['AU_PRC']) 241 | self.auroc_det_lc.append(auroc_det) 242 | history = pd.DataFrame( 243 | list(zip(self.lr_lc, self.lkl_lc, self.lae_lc, self.aucroc_lc, self.auprc_lc, self.auroc_det_lc)), 244 | columns=['Lrec', 'Lkl', 'lae', 'AUCROC', 'AUPRC', 'AUROC_det']) 245 | history.to_csv(self.dir_results + 'lc_on_direct.csv') 246 | 247 | else: 248 | self.aucroc_lc.append(0) 249 | self.auprc_lc.append(0) 250 | self.auroc_det_lc.append(0) 251 | 252 | def predict_score(self, x): 253 | self.E.eval() 254 | self.Dec.eval() 255 | 256 | # brain mask 257 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets: 258 | x_mask = 1 - (x == 0).astype(np.int) 259 | if 'BRATS' in self.train_generator.dataset.dir_datasets: 260 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 6, 6))).astype(x_mask.dtype) 261 | else: 262 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype) 263 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets: 264 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1])) 265 | x_mask[:, 14:-14, 14:-14] = 1 266 | 267 | # Get reconstruction error map 268 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 269 | xhat = torch.sigmoid(self.Dec(z)[0]).squeeze().detach().cpu().numpy() 270 | 271 | # Compute gradients-cams 272 | gcam = grad_cam(f[self.level_cams], torch.sum(z_mu), normalization='min_max', 273 | avg_grads=self.avg_grads) 274 | 275 | # Restore original shape 276 | mhat = torch.nn.functional.interpolate(gcam.unsqueeze(0), size=(self.input_shape[-1], self.input_shape[-1]), 277 | mode='bilinear', align_corners=True).squeeze().detach().cpu().numpy() 278 | 279 | # brain mask - Keep only brain region 280 | if 'BRATS' in self.train_generator.dataset.dir_datasets or \ 281 | 'PhysioNet' in self.train_generator.dataset.dir_datasets or \ 282 | 'MVTEC' in self.train_generator.dataset.dir_datasets: 283 | mhat[x_mask[0, :, :] == 0] = 0 284 | 285 | # Get outputs 286 | anomaly_map = mhat 287 | 288 | # brain mask - Keep only brain region 289 | if 'BRATS' in self.train_generator.dataset.dir_datasets or \ 290 | 'PhysioNet' in self.train_generator.dataset.dir_datasets or \ 291 | 'MVTEC' in self.train_generator.dataset.dir_datasets: 292 | score = np.std(anomaly_map[x_mask[0, :, :] == 1]) 293 | else: 294 | score = np.std(anomaly_map) 295 | 296 | self.E.train() 297 | self.Dec.train() 298 | return score, anomaly_map, xhat 299 | 300 | def display_losses(self, on_epoch_end=False): 301 | 302 | # Init info display 303 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs, 304 | self.i_iteration + 1, self.iterations) 305 | # Prepare values to show 306 | if on_epoch_end: 307 | lr = self.lr_epoch 308 | lkl = self.kl_epoch 309 | lae = self.lae_epoch 310 | end = '\n' 311 | else: 312 | lr = self.lr_iteration 313 | lkl = self.kl_iteration 314 | lae = self.lae_iteration 315 | end = '\r' 316 | 317 | # Init losses display 318 | info += "Reconstruction={:.4f} || KL={:.4f} || Lae={:.8f} ".format(lr, lkl, lae) 319 | 320 | # Print losses 321 | et = str(datetime.timedelta(seconds=timer() - self.init_time)) 322 | print(info + ', ET=' + et, end=end) 323 | 324 | def plot_learning_curves(self): 325 | def plot_subplot(axes, x, y, y_axis): 326 | axes.grid() 327 | axes.plot(x, y, 'o-') 328 | axes.set_ylabel(y_axis) 329 | 330 | fig, axes = plt.subplots(2, 2, figsize=(20, 15)) 331 | plot_subplot(axes[0, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss") 332 | plot_subplot(axes[0, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.lkl_lc), "KL loss") 333 | plot_subplot(axes[1, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lae_lc), "AE loss") 334 | plt.savefig(self.dir_results + 'learning_curve.png') 335 | plt.close() 336 | 337 | 338 | def grad_cam(activations, output, normalization='relu_min_max', avg_grads=False, norm_grads=False): 339 | def normalize(grads): 340 | l2_norm = torch.sqrt(torch.mean(torch.pow(grads, 2))) + 1e-5 341 | return grads * torch.pow(l2_norm, -1) 342 | 343 | # Obtain gradients 344 | gradients = torch.autograd.grad(output, activations, grad_outputs=None, retain_graph=True, create_graph=True, 345 | only_inputs=True, allow_unused=True)[0] 346 | 347 | # Normalize gradients 348 | if norm_grads: 349 | gradients = normalize(gradients) 350 | 351 | # pool the gradients across the channels 352 | if avg_grads: 353 | gradients = torch.mean(gradients, dim=[2, 3]) 354 | # gradients = torch.nn.functional.softmax(gradients) 355 | gradients = gradients.unsqueeze(-1).unsqueeze(-1) 356 | 357 | # weight activation maps 358 | ''' 359 | if 'relu' in normalization: 360 | GCAM = torch.sum(torch.relu(gradients * activations), 1) 361 | else: 362 | GCAM = gradients * activations 363 | if 'abs' in normalization: 364 | GCAM = torch.abs(GCAM) 365 | GCAM = torch.sum(GCAM, 1) 366 | ''' 367 | GCAM = torch.mean(activations, 1) 368 | 369 | # Normalize CAM 370 | if 'sigm' in normalization: 371 | GCAM = torch.sigmoid(GCAM) 372 | if 'min' in normalization: 373 | norm_value = torch.min(torch.max(GCAM, -1)[0], -1)[0].unsqueeze(-1).unsqueeze(-1) + 1e-3 374 | GCAM = GCAM - norm_value 375 | if 'max' in normalization: 376 | norm_value = torch.max(torch.max(GCAM, -1)[0], -1)[0].unsqueeze(-1).unsqueeze(-1) + 1e-3 377 | GCAM = GCAM * norm_value.pow(-1) 378 | if 'tanh' in normalization: 379 | GCAM = torch.tanh(GCAM) 380 | if 'clamp' in normalization: 381 | GCAM = GCAM.clamp(max=1) 382 | 383 | return GCAM 384 | 385 | -------------------------------------------------------------------------------- /code/methods/trainers/gradCons.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import kornia 3 | import json 4 | import torch 5 | 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | from scipy import ndimage 10 | from timeit import default_timer as timer 11 | from datasets.utils import augment_input_batch 12 | from models.models import Encoder, Decoder, GradConCAEEncoder, GradConCAEDecoder 13 | from evaluation.utils import * 14 | from sklearn.metrics import accuracy_score, f1_score 15 | from methods.losses.losses import kl_loss 16 | 17 | 18 | class AnomalyDetectorGradCons: 19 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25, 20 | load_weigths=False, n_blocks=5, dense=True, loss_reconstruction='bce', alpha_gradloss=1, 21 | n_target_filters=1, alpha_kl=10, variational=True): 22 | 23 | # Init input variables 24 | self.dir_results = dir_results 25 | self.item = item 26 | self.zdim = zdim 27 | self.lr = lr 28 | self.input_shape = input_shape 29 | self.epochs_to_test = epochs_to_test 30 | self.load_weigths = load_weigths 31 | self.n_blocks = n_blocks 32 | self.dense = dense 33 | self.loss_reconstruction = loss_reconstruction 34 | self.alpha_gradloss = alpha_gradloss 35 | self.n_target_filters = n_target_filters 36 | self.alpha_kl = alpha_kl 37 | self.variational = variational 38 | 39 | # Init network 40 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks, 41 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=self.variational) 42 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0], 43 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks) 44 | ''' 45 | # Init network 46 | self.E = GradConCAEEncoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks, 47 | spatial_dim=self.input_shape[1]//2**self.n_blocks) 48 | self.Dec = GradConCAEDecoder(fin=self.zdim, n_channels=self.input_shape[0], 49 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks) 50 | ''' 51 | 52 | if torch.cuda.is_available(): 53 | self.E.cuda() 54 | self.Dec.cuda() 55 | 56 | if self.load_weigths: 57 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth')) 58 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth')) 59 | 60 | # Set parameters 61 | self.params = list(self.E.parameters()) + list(self.Dec.parameters()) 62 | 63 | # Set losses 64 | if self.loss_reconstruction == 'l2': 65 | self.Lr = torch.nn.MSELoss(reduction='sum') 66 | elif self.loss_reconstruction == 'bce': 67 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum') 68 | self.Lkl = kl_loss 69 | 70 | # Set optimizers 71 | self.opt = torch.optim.Adam(self.params, lr=self.lr) 72 | 73 | # Init additional variables and objects 74 | self.epochs = 0. 75 | self.i_iteration = 0. 76 | self.iterations = 0. 77 | self.init_time = 0. 78 | self.lr_iteration = 0. 79 | self.lr_epoch = 0. 80 | self.lgrad_iteration = 0. 81 | self.lgrad_epoch = 0. 82 | self.i_epoch = 0. 83 | self.kl_iteration = 0. 84 | self.kl_epoch = 0. 85 | self.train_generator = [] 86 | self.dataset_test = [] 87 | self.metrics = {} 88 | self.aucroc_lc = [] 89 | self.auprc_lc = [] 90 | self.lr_lc = [] 91 | self.lgrad_lc = [] 92 | 93 | def train(self, train_generator, epochs, dataset_test): 94 | self.epochs = epochs 95 | self.init_time = timer() 96 | self.train_generator = train_generator 97 | self.dataset_test = dataset_test 98 | self.iterations = len(self.train_generator) 99 | # Init gradients module 100 | self.ref_grad = self.initialise() 101 | self.k = 0 102 | 103 | # Loop over epochs 104 | for self.i_epoch in range(self.epochs): 105 | # init epoch losses 106 | self.lr_epoch = 0 107 | self.lgrad_epoch = 0. 108 | self.kl_epoch = 0. 109 | 110 | # Loop over training dataset 111 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator): 112 | 113 | # Move tensors to gpu 114 | x_n = torch.tensor(x_n).cuda().float() 115 | 116 | # Obtain latent space from normal sample via encoder 117 | z, z_mu, z_logvar, _ = self.E(x_n) 118 | 119 | # Obtain reconstructed images through decoder 120 | xhat, _ = self.Dec(z) 121 | 122 | # Calculate criterion 123 | self.lr_iteration = self.Lr(xhat, x_n) / (self.train_generator.batch_size) # Reconstruction loss 124 | self.kl_iteration = self.alpha_kl * self.Lkl(mu=z_mu, logvar=z_logvar) / (self.train_generator.batch_size) # kl loss 125 | 126 | # Calculate gradient loss 127 | # self.kl_iteration.backward(create_graph=True, retain_graph=True) 128 | # self.lr_iteration.backward(create_graph=True, retain_graph=True) 129 | 130 | grad_loss = 0. 131 | i = 0 132 | for module in self.iterlist(): 133 | if isinstance(module, torch.nn.Conv2d): 134 | wrt = module.weight 135 | #target_grad = wrt.grad 136 | target_grad = torch.autograd.grad(self.kl_iteration, wrt, create_graph=True, retain_graph=True)[0] 137 | if self.k > 0: 138 | grad_loss += -1 * torch.nn.functional.cosine_similarity(target_grad.view(-1, 1), self.ref_grad[i].view(-1, 1) / self.k, dim=0).squeeze() 139 | self.ref_grad[i] += target_grad.detach() 140 | i += 1 141 | if self.k == 0: 142 | self.lgrad_iteration = torch.tensor(1.).cuda().float() 143 | else: 144 | self.lgrad_iteration = grad_loss / i # Average over layers 145 | 146 | # Get overall losses - we already computed gradients from Lr, so it is not neccesary again 147 | # L = self.lr_iteration + self.lgrad_iteration * self.alpha_gradloss 148 | L = self.lr_iteration + self.lgrad_iteration * self.alpha_gradloss 149 | 150 | L.backward() # Backward 151 | 152 | # Update the reference gradient 153 | i = 0 154 | for module in self.iterlist(): 155 | if isinstance(module, torch.nn.Conv2d): 156 | self.ref_grad[i] += module.weight.grad 157 | i += 1 158 | 159 | # Update weights 160 | self.opt.step() # Update weights 161 | self.opt.zero_grad() # Clear gradients 162 | 163 | # Update k counter 164 | self.k += 1 165 | 166 | """ 167 | ON ITERATION/EPOCH END PROCESS 168 | """ 169 | 170 | # Display losses per iteration 171 | self.display_losses(on_epoch_end=False) 172 | 173 | # Update epoch's losses 174 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator) 175 | self.lgrad_epoch += self.lgrad_iteration.cpu().detach().numpy() / len(self.train_generator) 176 | 177 | # Epoch-end processes 178 | self.on_epoch_end() 179 | 180 | def on_epoch_end(self): 181 | 182 | # Display losses 183 | self.display_losses(on_epoch_end=True) 184 | 185 | # Update learning curves 186 | self.lr_lc.append(self.lr_epoch) 187 | self.lgrad_lc.append(self.lgrad_epoch) 188 | 189 | # Each x epochs, test models and plot learning curves 190 | if (self.i_epoch + 1) % self.epochs_to_test == 0: 191 | # Save weights 192 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth') 193 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth') 194 | 195 | # Make predictions 196 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test) 197 | 198 | # Input to dataset 199 | self.dataset_test.Scores = Scores 200 | self.dataset_test.Mhat = Mhat 201 | self.dataset_test.Xhat = Xhat 202 | 203 | # Evaluate anomaly detection 204 | auroc_det, auprc_det, th_det = evaluate_anomaly_detection(self.dataset_test.Y, self.dataset_test.Scores, 205 | dir_out=self.dir_results) 206 | acc = accuracy_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int')) 207 | fs = f1_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int')) 208 | 209 | metrics_detection = {'auroc_det': auroc_det, 'auprc_det': auprc_det, 'th_det': th_det, 'acc_det': acc, 210 | 'fs_det': fs} 211 | print(metrics_detection) 212 | 213 | # Evaluate anomaly localization 214 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results) 215 | self.metrics = metrics 216 | 217 | # Save metrics as dict 218 | with open(self.dir_results + 'metrics.json', 'w') as fp: 219 | json.dump(metrics, fp) 220 | print(metrics) 221 | 222 | # Plot learning curve 223 | self.plot_learning_curves() 224 | 225 | # Save learning curves as dataframe 226 | self.aucroc_lc.append(metrics['AU_ROC']) 227 | self.auprc_lc.append(metrics['AU_PRC']) 228 | history = pd.DataFrame(list(zip(self.lr_lc, self.lgrad_lc, self.aucroc_lc, self.auprc_lc)), 229 | columns=['Lrec', 'LgradCons', 'AUCROC', 'AUPRC']) 230 | history.to_csv(self.dir_results + 'lc_on_direct.csv') 231 | 232 | else: 233 | self.aucroc_lc.append(0) 234 | self.auprc_lc.append(0) 235 | 236 | def predict_score(self, x): 237 | self.E.eval() 238 | self.Dec.eval() 239 | 240 | # Prepare brain eroded mask 241 | x_mask = 1 - (x == 0).astype(np.int) 242 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 6, 6))).astype(x_mask.dtype) 243 | 244 | # Get reconstruction error map 245 | x_n = torch.tensor(x).cuda().float().unsqueeze(0) 246 | z, z_mu, z_logvar, _ = self.E(x_n) 247 | xhat = self.Dec(z)[0] 248 | 249 | # Calculate criterion 250 | lr_sample = self.Lr(xhat, x_n) / (self.input_shape[1] * self.input_shape[2]) # Reconstruction loss 251 | kl_sample = self.alpha_kl * self.Lkl(mu=z_mu, logvar=z_logvar) # kl loss 252 | 253 | # Calculate gradient loss for anomaly score 254 | #lr_sample.backward(create_graph=True, retain_graph=True) 255 | 256 | score = 0. 257 | i = 0 258 | for module in self.iterlist(): 259 | if isinstance(module, torch.nn.Conv2d): 260 | wrt = module.weight 261 | #target_grad = wrt.grad 262 | target_grad = torch.autograd.grad(kl_sample, wrt, create_graph=True, retain_graph=True)[0] 263 | if self.k > 0: 264 | score += 1 * torch.nn.functional.cosine_similarity(target_grad.view(-1, 1), 265 | self.ref_grad[i].view(-1, 1) / self.k, 266 | dim=0).squeeze() 267 | i += 1 268 | 269 | score = - score.cpu().detach().numpy() / i # Average over layers 270 | 271 | # Compute anomaly map 272 | xhat = torch.sigmoid(xhat).cpu().detach().numpy() 273 | mhat = np.squeeze(np.abs(x - xhat)) 274 | 275 | # Keep only brain region 276 | mhat[x_mask[0, :, :] == 0] = 0 277 | 278 | # Get outputs 279 | anomaly_map = mhat 280 | 281 | self.opt.zero_grad() # Clear gradients 282 | self.E.train() 283 | self.Dec.train() 284 | return score, anomaly_map, xhat 285 | 286 | def display_losses(self, on_epoch_end=False): 287 | 288 | # Init info display 289 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs, 290 | self.i_iteration + 1, self.iterations) 291 | # Prepare values to show 292 | if on_epoch_end: 293 | lr = self.lr_epoch 294 | lgradCons = self.lgrad_epoch 295 | end = '\n' 296 | else: 297 | lr = self.lr_iteration 298 | lgradCons = self.lgrad_iteration 299 | end = '\r' 300 | # Init losses display 301 | info += "Reconstruction={:.4f} || gradCons={:.4f}".format(lr, lgradCons) 302 | # Print losses 303 | et = str(datetime.timedelta(seconds=timer() - self.init_time)) 304 | print(info + ',ET=' + et, end=end) 305 | 306 | def plot_learning_curves(self): 307 | def plot_subplot(axes, x, y, y_axis): 308 | axes.grid() 309 | axes.plot(x, y, 'o-') 310 | axes.set_ylabel(y_axis) 311 | 312 | fig, axes = plt.subplots(1, 2, figsize=(20, 15)) 313 | plot_subplot(axes[0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss") 314 | plot_subplot(axes[1], np.arange(self.i_epoch + 1) + 1, np.array(self.lgrad_lc), "gradCons loss") 315 | plt.savefig(self.dir_results + 'learning_curve.png') 316 | plt.close() 317 | 318 | def initialise(self): 319 | 320 | ref_grad = [] 321 | n = 0 322 | for i in np.arange(0, len(list(self.E.children()))): # For each block 323 | for module in list(list(self.E.children())[i].modules()): # For each layer in the block 324 | if isinstance(module, torch.nn.Conv2d): 325 | # Get weight 326 | wrt = module.weight 327 | # Init 0s tensor 328 | grad_init_module = torch.zeros(wrt.shape).cuda().float() 329 | # Incorportate to list 330 | ref_grad.append(grad_init_module) 331 | 332 | n += 1 333 | if n == self.n_target_filters: 334 | break 335 | 336 | return ref_grad 337 | 338 | def iterlist(self): 339 | ''' 340 | layers = [] 341 | n = 0 342 | for i in np.arange(0, len(list(self.E.children()))): # For each block 343 | for module in list(list(self.E.children())[i].modules()): # For each layer in the block 344 | if isinstance(module, torch.nn.Conv2d): 345 | # Get layer 346 | layers.append(module) 347 | 348 | n += 1 349 | if n == self.n_target_filters: 350 | break 351 | if n == self.n_target_filters: 352 | break 353 | 354 | return layers 355 | ''' 356 | 357 | layers = [] 358 | n = 0 359 | for i in np.arange(0, len(list(self.E.children()))): # For each block 360 | 361 | for module in list(list(self.E.children())[i].modules()): # For each layer in the block 362 | if isinstance(module, torch.nn.Conv2d): 363 | # Get layer 364 | layers.append(module) 365 | 366 | n += 1 367 | if n == self.n_target_filters: 368 | break 369 | if n == self.n_target_filters: 370 | break 371 | 372 | return layers -------------------------------------------------------------------------------- /code/methods/trainers/histEqualization.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import kornia 3 | import json 4 | import torch 5 | 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | from scipy import ndimage 10 | from timeit import default_timer as timer 11 | from models.models import Encoder, Decoder 12 | from evaluation.utils import * 13 | from methods.losses.losses import kl_loss 14 | from datasets.utils import augment_input_batch 15 | from skimage.exposure import equalize_hist 16 | 17 | 18 | class AnomalyDetectorHistEqualization: 19 | def __init__(self, dir_results, item=['flair']): 20 | 21 | # Init input variables 22 | self.dir_results = dir_results 23 | self.item = item 24 | self.train_generator = [] 25 | self.dataset_test = [] 26 | 27 | def train(self, train_generator, epochs, dataset_test): 28 | self.train_generator = train_generator 29 | self.dataset_test = dataset_test 30 | print('No training for Histogram matching method.', end='\n') 31 | 32 | # Make predictions 33 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test) 34 | 35 | # Input to dataset 36 | self.dataset_test.Scores = Scores 37 | self.dataset_test.Mhat = Mhat 38 | self.dataset_test.Xhat = Xhat 39 | 40 | # Evaluate 41 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results) 42 | self.metrics = metrics 43 | 44 | # Save metrics as dict 45 | with open(self.dir_results + 'metrics.json', 'w') as fp: 46 | json.dump(metrics, fp) 47 | print(metrics) 48 | 49 | def predict_score(self, x): 50 | def equalize_img(img): 51 | """ 52 | Perform histogram equalization on the given image. 53 | """ 54 | # Create equalization mask 55 | mask = np.zeros_like(img) 56 | mask[img > 0] = 1 57 | 58 | # Equalize 59 | img = img*255 60 | img = equalize_hist(img.astype(np.int64), nbins=256, mask=mask) 61 | 62 | # Assure that background still is 0 63 | img *= mask 64 | img *= (1/255) 65 | 66 | return img 67 | 68 | # Prepare brain eroded mask 69 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets: 70 | x_mask = 1 - (x == 0).astype(np.int) 71 | if 'BRATS' in self.train_generator.dataset.dir_datasets: 72 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype) 73 | else: 74 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype) 75 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets: 76 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1])) 77 | x_mask[:, 14:-14, 14:-14] = 1 78 | 79 | # Get reconstruction error map 80 | mhat = equalize_img(x) 81 | mhat = mhat[0, :, :] 82 | 83 | # Keep only brain region 84 | mhat[x_mask[0, :, :] == 0] = 0 85 | 86 | # Get outputs 87 | anomaly_map = mhat 88 | score = np.mean(anomaly_map) 89 | 90 | return score, anomaly_map, x 91 | 92 | 93 | -------------------------------------------------------------------------------- /code/methods/trainers/vae.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import kornia 3 | import json 4 | import torch 5 | 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | from scipy import ndimage 10 | from timeit import default_timer as timer 11 | from models.models import Encoder, Decoder 12 | from evaluation.utils import * 13 | from methods.losses.losses import kl_loss 14 | from datasets.utils import augment_input_batch 15 | 16 | 17 | class AnomalyDetectorVAE: 18 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25, 19 | load_weigths=False, n_blocks=5, dense=True, context=False, bayesian=False, 20 | loss_reconstruction='bce', restoration=False, alpha_kl=0.1): 21 | 22 | # Init input variables 23 | self.dir_results = dir_results 24 | self.item = item 25 | self.zdim = zdim 26 | self.lr = lr 27 | self.input_shape = input_shape 28 | self.epochs_to_test = epochs_to_test 29 | self.load_weigths = load_weigths 30 | self.n_blocks = n_blocks 31 | self.dense = dense 32 | self.context = context 33 | self.bayesian = bayesian 34 | self.loss_reconstruction = loss_reconstruction 35 | self.restoration = restoration 36 | self.alpha_kl = alpha_kl 37 | 38 | # Init network 39 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks, 40 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=True, gap=False) 41 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0], 42 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks, 43 | gap=False) 44 | 45 | if torch.cuda.is_available(): 46 | self.E.cuda() 47 | self.Dec.cuda() 48 | 49 | if self.load_weigths: 50 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth')) 51 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth')) 52 | 53 | # Set parameters 54 | self.params = list(self.E.parameters()) + list(self.Dec.parameters()) 55 | 56 | # Set losses 57 | if self.loss_reconstruction == 'l2': 58 | self.Lr = torch.nn.MSELoss(reduction='sum') 59 | elif self.loss_reconstruction == 'bce': 60 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum') 61 | 62 | self.Lkl = kl_loss 63 | 64 | # Set optimizers 65 | self.opt = torch.optim.Adam(self.params, lr=self.lr) 66 | 67 | # Init additional variables and objects 68 | self.epochs = 0. 69 | self.iterations = 0. 70 | self.init_time = 0. 71 | self.lr_iteration = 0. 72 | self.lr_epoch = 0. 73 | self.kl_iteration = 0. 74 | self.kl_epoch = 0. 75 | self.i_epoch = 0. 76 | self.train_generator = [] 77 | self.dataset_test = [] 78 | self.metrics = {} 79 | self.aucroc_lc = [] 80 | self.auprc_lc = [] 81 | self.lr_lc = [] 82 | self.lkl_lc = [] 83 | 84 | def train(self, train_generator, epochs, dataset_test): 85 | self.epochs = epochs 86 | self.init_time = timer() 87 | self.train_generator = train_generator 88 | self.dataset_test = dataset_test 89 | self.iterations = len(self.train_generator) 90 | 91 | # Loop over epochs 92 | for self.i_epoch in range(self.epochs): 93 | # init epoch losses 94 | self.lr_epoch = 0 95 | self.kl_epoch = 0. 96 | 97 | # Loop over training dataset 98 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator): 99 | 100 | if self.context: # if context option, data augmentation to apply context 101 | (x_n_context, _) = augment_input_batch(x_n.copy()) 102 | x_n_context = torch.tensor(x_n_context).cuda().float() 103 | 104 | # Move tensors to gpu 105 | x_n = torch.tensor(x_n).cuda().float() 106 | 107 | # Obtain latent space from normal sample via encoder 108 | if not self.context: 109 | z, z_mu, z_logvar, _ = self.E(x_n) 110 | else: 111 | z, z_mu, z_logvar, _ = self.E(x_n_context) 112 | 113 | # Obtain reconstructed images through decoder 114 | xhat, _ = self.Dec(z) 115 | if self.loss_reconstruction == 'l2': 116 | xhat = torch.sigmoid(xhat) 117 | 118 | # Calculate criterion 119 | self.lr_iteration = self.Lr(xhat, x_n) / self.train_generator.batch_size # Reconstruction loss 120 | self.kl_iteration = self.Lkl(mu=z_mu, logvar=z_logvar) # kl loss (averaged per spatial feature) 121 | 122 | # Init overall losses 123 | L = self.lr_iteration + self.alpha_kl * self.kl_iteration 124 | 125 | # Update weights 126 | L.backward() # Backward 127 | self.opt.step() # Update weights 128 | self.opt.zero_grad() # Clear gradients 129 | 130 | """ 131 | ON ITERATION/EPOCH END PROCESS 132 | """ 133 | 134 | # Display losses per iteration 135 | self.display_losses(on_epoch_end=False) 136 | 137 | # Update epoch's losses 138 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator) 139 | self.kl_epoch += self.kl_iteration.cpu().detach().numpy() / len(self.train_generator) 140 | 141 | # Epoch-end processes 142 | self.on_epoch_end() 143 | 144 | def on_epoch_end(self): 145 | 146 | # Display losses 147 | self.display_losses(on_epoch_end=True) 148 | 149 | # Update learning curves 150 | self.lr_lc.append(self.lr_epoch) 151 | self.lkl_lc.append(self.kl_epoch) 152 | 153 | # Each x epochs, test models and plot learning curves 154 | if (self.i_epoch + 1) % self.epochs_to_test == 0: 155 | # Save weights 156 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth') 157 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth') 158 | 159 | # Make predictions 160 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test) 161 | 162 | # Input to dataset 163 | self.dataset_test.Scores = Scores 164 | self.dataset_test.Mhat = Mhat 165 | self.dataset_test.Xhat = Xhat 166 | 167 | # Evaluate 168 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results) 169 | self.metrics = metrics 170 | 171 | # Save metrics as dict 172 | with open(self.dir_results + 'metrics.json', 'w') as fp: 173 | json.dump(metrics, fp) 174 | print(metrics) 175 | 176 | # Plot learning curve 177 | self.plot_learning_curves() 178 | 179 | # Save learning curves as dataframe 180 | self.aucroc_lc.append(metrics['AU_ROC']) 181 | self.auprc_lc.append(metrics['AU_PRC']) 182 | history = pd.DataFrame(list(zip(self.lr_lc, self.lr_lc, self.aucroc_lc, self.auprc_lc)), 183 | columns=['Lrec', 'Lkl', 'AUCROC', 'AUPRC']) 184 | history.to_csv(self.dir_results + 'lc_on_direct.csv') 185 | 186 | else: 187 | self.aucroc_lc.append(0) 188 | self.auprc_lc.append(0) 189 | 190 | def predict_score(self, x): 191 | self.E.eval() 192 | self.Dec.eval() 193 | 194 | # Prepare brain eroded mask 195 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets: 196 | x_mask = 1 - (x == 0).astype(np.int) 197 | if 'BRATS' in self.train_generator.dataset.dir_datasets: 198 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype) 199 | else: 200 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype) 201 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets: 202 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1])) 203 | x_mask[:, 14:-14, 14:-14] = 1 204 | 205 | # Get reconstruction error map 206 | if self.restoration: # restoration reconstruction 207 | mhat, xhat = self.restoration_reconstruction(x) 208 | elif self.bayesian: # bayesian reconstruction 209 | mhat, xhat = self.bayesian_reconstruction(x) 210 | else: 211 | # Network forward 212 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 213 | xhat = np.squeeze(torch.sigmoid(self.Dec(z)[0]).cpu().detach().numpy()) 214 | # Compute anomaly map 215 | mhat = np.squeeze(np.abs(x - xhat)) 216 | # mhat = np.squeeze((x - xhat)) 217 | 218 | # Keep only brain region 219 | mhat[x_mask[0, :, :] == 0] = 0 220 | 221 | # Get outputs 222 | anomaly_map = mhat 223 | score = np.mean(anomaly_map) 224 | 225 | self.E.train() 226 | self.Dec.train() 227 | return score, anomaly_map, xhat 228 | 229 | def bayesian_reconstruction(self, x): 230 | 231 | N = 100 232 | p_dropout = 0.20 233 | mhat = np.zeros((self.input_shape[1], self.input_shape[2])) 234 | 235 | # Network forward 236 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0)) 237 | xhat = self.Dec(torch.nn.Dropout(p_dropout)(z))[0].cpu().detach().numpy() 238 | 239 | for i in np.arange(N): 240 | if z_mu is None: # apply dropout to z 241 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid( 242 | self.Dec(torch.nn.Dropout(p_dropout)(z))[0]).cpu().detach().numpy()) - x)) / N 243 | else: # sample z 244 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid( 245 | self.Dec(self.E.reparameterize(z_mu, z_logvar))[0]).cpu().detach().numpy()) - x)) / N 246 | return mhat, xhat 247 | 248 | def restoration_reconstruction(self, x): 249 | N = 300 250 | step = 1 * 1e-3 251 | x_rest = torch.tensor(x).cuda().float().unsqueeze(0) 252 | 253 | for i in np.arange(N): 254 | # Forward 255 | x_rest.requires_grad = True 256 | z, z_mu, z_logvar, f = self.E(x_rest) 257 | xhat = self.Dec(z)[0] 258 | 259 | # Compute loss 260 | lr = kornia.losses.total_variation(torch.tensor(x).cuda().float().unsqueeze(0) - torch.sigmoid(xhat)) 261 | L = lr / (self.input_shape[1] * self.input_shape[2]) 262 | 263 | # Get gradients 264 | gradients = torch.autograd.grad(L, x_rest, grad_outputs=None, retain_graph=True, 265 | create_graph=True, 266 | only_inputs=True, allow_unused=True)[0] 267 | 268 | x_rest = x_rest - gradients * step 269 | x_rest = x_rest.clone().detach() 270 | xhat = np.squeeze(x_rest.cpu().numpy()) 271 | 272 | # Compute difference 273 | mhat = np.squeeze(np.abs(x - xhat)) 274 | 275 | return mhat, xhat 276 | 277 | def display_losses(self, on_epoch_end=False): 278 | 279 | # Init info display 280 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs, 281 | self.i_iteration + 1, self.iterations) 282 | # Prepare values to show 283 | if on_epoch_end: 284 | lr = self.lr_epoch 285 | lkl = self.kl_epoch 286 | end = '\n' 287 | else: 288 | lr = self.lr_iteration 289 | lkl = self.kl_iteration 290 | end = '\r' 291 | # Init losses display 292 | info += "Reconstruction={:.4f} || KL={:.4f}".format(lr, lkl) 293 | # Print losses 294 | et = str(datetime.timedelta(seconds=timer() - self.init_time)) 295 | print(info + ', ET=' + et, end=end) 296 | 297 | def plot_learning_curves(self): 298 | def plot_subplot(axes, x, y, y_axis): 299 | axes.grid() 300 | axes.plot(x, y, 'o-') 301 | axes.set_ylabel(y_axis) 302 | 303 | fig, axes = plt.subplots(1, 2, figsize=(20, 15)) 304 | plot_subplot(axes[0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss") 305 | plot_subplot(axes[1], np.arange(self.i_epoch + 1) + 1, np.array(self.lkl_lc), "KL loss") 306 | plt.savefig(self.dir_results + 'learning_curve.png') 307 | plt.close() -------------------------------------------------------------------------------- /code/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | 5 | # ------------------ 6 | # RESIDUAL MODELS 7 | 8 | 9 | class Resnet(torch.nn.Module): 10 | def __init__(self, in_channels, n_blocks=4): 11 | super(Resnet, self).__init__() 12 | self.n_blocks = n_blocks 13 | self.nfeats = 512 // (2**(4-n_blocks)) 14 | 15 | self.input = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=(2, 2), 16 | padding=(3, 3), bias=False) 17 | resnet18_model = torchvision.models.resnet18(pretrained=False) 18 | self.resnet = torch.nn.Sequential(*(list(resnet18_model.children())[i+4] for i in range(0, self.n_blocks))) 19 | 20 | # placeholder for the gradients 21 | self.gradients = None 22 | 23 | def forward(self, x): 24 | x = self.input(x) 25 | F = [] 26 | for iBlock in range(0, self.n_blocks): 27 | x = list(self.resnet.children())[iBlock](x) 28 | F.append(x) 29 | 30 | return x, F 31 | 32 | 33 | class Encoder(torch.nn.Module): 34 | def __init__(self, fin=1, zdim=128, dense=False, variational=False, n_blocks=4, spatial_dim=7, 35 | gap=False): 36 | super(Encoder, self).__init__() 37 | self.fin = fin 38 | self.zdim = zdim 39 | self.dense = dense 40 | self.n_blocks = n_blocks 41 | self.gap = gap 42 | self.variational = variational 43 | 44 | # 1) Feature extraction 45 | self.backbone = Resnet(in_channels=self.fin, n_blocks=self.n_blocks) 46 | # 2) Latent space (dense or spatial) 47 | if self.dense: # dense 48 | if gap: 49 | if self.variational: 50 | self.mu = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1)) 51 | self.log_var = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1)) 52 | else: 53 | self.z = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1)) 54 | else: 55 | if self.variational: 56 | self.mu = torch.nn.Linear(self.backbone.nfeats*spatial_dim**2, zdim) 57 | self.log_var = torch.nn.Linear(self.backbone.nfeats*spatial_dim**2, zdim) 58 | else: 59 | self.z = torch.nn.Linear(self.backbone.nfeats * spatial_dim ** 2, zdim) 60 | else: # spatial 61 | if self.variational: 62 | self.mu = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1)) 63 | self.log_var = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1)) 64 | else: 65 | self.z = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1)) 66 | 67 | def reparameterize(self, mu, log_var): 68 | std = torch.exp(0.5 * log_var) # standard deviation 69 | eps = torch.randn_like(std) 70 | 71 | sample = mu + (eps * std) # sampling 72 | return sample 73 | 74 | def forward(self, x): 75 | 76 | # 1) Feature extraction 77 | x, allF = self.backbone(x) 78 | 79 | if self.dense and not self.gap: 80 | x = torch.nn.Flatten()(x) 81 | 82 | if self.dense and self.gap: 83 | x = torch.nn.functional.adaptive_avg_pool2d(x, 1) 84 | x = x.view(x.size(0), self.backbone.nfeats, 1, 1) 85 | 86 | # 2) Latent space 87 | if self.variational: 88 | # get `mu` and `log_var` 89 | z_mu = self.mu(x) 90 | z_logvar = self.log_var(x) 91 | # get the latent vector through reparameterization 92 | z = self.reparameterize(z_mu, z_logvar) 93 | else: 94 | z = self.z(x) 95 | z_mu, z_logvar = None, None 96 | 97 | return z, z_mu, z_logvar, allF 98 | 99 | 100 | class Decoder(torch.nn.Module): 101 | 102 | def __init__(self, fin=256, nf0=128, n_channels=1, dense=False, n_blocks=4, spatial_dim=7, 103 | gap=False): 104 | super(Decoder, self).__init__() 105 | self.n_blocks = n_blocks 106 | self.dense = dense 107 | self.spatial_dim = spatial_dim 108 | self.fin = fin 109 | self.gap = gap 110 | self.nf0 = nf0 111 | 112 | if self.dense and not self.gap: 113 | self.dense_layer = torch.nn.Sequential(torch.nn.Linear(fin, nf0*spatial_dim**2)) 114 | if not dense: 115 | self.dense_layer = torch.nn.Sequential(torch.nn.Conv2d(fin, nf0, (1, 1))) 116 | 117 | # Set number of input and output channels 118 | n_filters_in = [nf0//2**(i) for i in range(0, self.n_blocks + 1)] 119 | n_filters_out = [nf0//2**(i+1) for i in range(0, self.n_blocks)] + [n_channels] 120 | 121 | self.blocks = torch.nn.ModuleList() 122 | for i in np.arange(0, self.n_blocks): 123 | self.blocks.append(torch.nn.Sequential(BasicBlock(n_filters_in[i], n_filters_out[i], downsample=True), 124 | BasicBlock(n_filters_out[i], n_filters_out[i]))) 125 | self.out = torch.nn.Conv2d(n_filters_in[-1], n_filters_out[-1], kernel_size=(3, 3), padding=(1, 1)) 126 | 127 | self.n_filters_in = n_filters_in 128 | self.n_filters_out = n_filters_out 129 | 130 | def forward(self, x): 131 | 132 | if self.dense and not self.gap: 133 | x = self.dense_layer(x) 134 | x = torch.nn.Unflatten(-1, (self.nf0, self.spatial_dim, self.spatial_dim))(x) 135 | 136 | if self.dense and self.gap: 137 | x = torch.nn.functional.interpolate(x, scale_factor=self.spatial_dim) 138 | 139 | if not self.dense: 140 | x = self.dense_layer(x) 141 | 142 | for i in np.arange(0, self.n_blocks): 143 | x = self.blocks[i](x) 144 | f = x 145 | out = self.out(f) 146 | 147 | return out, f 148 | 149 | 150 | class ResBlock(torch.nn.Module): 151 | 152 | def __init__(self, fin, fout): 153 | super(ResBlock, self).__init__() 154 | self.conv_straight_1 = torch.nn.Conv2d(fin, fout, kernel_size=(3, 3), padding=(1, 1)) 155 | self.bn_1 = torch.nn.BatchNorm2d(fout) 156 | self.conv_straight_2 = torch.nn.Conv2d(fout, fout, kernel_size=(3, 3), padding=(1, 1)) 157 | self.bn_2 = torch.nn.BatchNorm2d(fout) 158 | self.conv_skip = torch.nn.Conv2d(fin, fout, kernel_size=(3, 3), padding=(1, 1)) 159 | self.upsampling = torch.nn.Upsample(scale_factor=(2, 2)) 160 | self.relu = torch.nn.ReLU() 161 | 162 | def forward(self, x): 163 | 164 | x_st = self.upsampling(x) 165 | x_st = self.conv_straight_1(x_st) 166 | x_st = self.relu(x_st) 167 | x_st = self.bn_1(x_st) 168 | x_st = self.conv_straight_2(x_st) 169 | x_st = self.relu(x_st) 170 | x_st = self.bn_2(x_st) 171 | 172 | x_sk = self.upsampling(x) 173 | x_sk = self.conv_skip(x_sk) 174 | 175 | out = x_sk + x_st 176 | 177 | return out 178 | 179 | 180 | class BasicBlock(torch.nn.Module): 181 | 182 | def __init__(self, inplanes=32, planes=64, stride=1, downsample=False, bn=True): 183 | super().__init__() 184 | norm_layer = torch.nn.BatchNorm2d 185 | self.conv1 = torch.nn.Conv2d(inplanes, planes, kernel_size=(3, 3), padding=(1, 1)) 186 | self.bn1 = norm_layer(planes) 187 | self.relu = torch.nn.ReLU(inplace=True) 188 | self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=(3, 3), padding=(1, 1)) 189 | self.bn2 = norm_layer(planes) 190 | self.downsample = downsample 191 | if self.downsample: 192 | self.downsample_layer_conv = torch.nn.Sequential(torch.nn.Conv2d(inplanes, planes, kernel_size=(1, 1))) 193 | self.downsample_layer = torch.nn.Upsample(scale_factor=(2, 2)) 194 | self.downsample_layer_bn = norm_layer(planes) 195 | self.stride = stride 196 | self.bn = bn 197 | 198 | def forward(self, x): 199 | identity = x 200 | 201 | out = self.conv1(x) 202 | if self.downsample: 203 | out = self.downsample_layer(out) 204 | if self.bn: 205 | out = self.bn1(out) 206 | out = self.relu(out) 207 | 208 | out = self.conv2(out) 209 | if self.bn: 210 | out = self.bn2(out) 211 | 212 | if self.downsample: 213 | identity = self.downsample_layer_conv(identity) 214 | identity = self.downsample_layer(identity) 215 | if self.bn: 216 | identity = self.downsample_layer_bn(identity) 217 | 218 | out += identity 219 | out = self.relu(out) 220 | 221 | return out 222 | 223 | 224 | class ConvBlock(torch.nn.Module): 225 | 226 | def __init__(self, fin, fout): 227 | super(ConvBlock, self).__init__() 228 | 229 | self.conv = torch.nn.Conv2d(fin, fout, kernel_size=4, stride=2, padding=1, bias=False) 230 | self.bn = torch.nn.BatchNorm2d(fout) 231 | self.act = torch.nn.LeakyReLU(0.2, inplace=False) 232 | 233 | def forward(self, x): 234 | 235 | out = self.act(self.bn(self.conv(x))) 236 | 237 | return out 238 | 239 | 240 | class DeconvBlock(torch.nn.Module): 241 | 242 | def __init__(self, fin, fout): 243 | super(DeconvBlock, self).__init__() 244 | 245 | self.conv = torch.nn.Conv2d(fin, fout, kernel_size=4, stride=2, padding=1, bias=False) 246 | self.bn = torch.nn.BatchNorm2d(fout) 247 | self.act = torch.nn.LeakyReLU(0.2, inplace=False) 248 | 249 | def forward(self, x): 250 | 251 | out = self.act(self.bn(self.conv(x))) 252 | 253 | return out 254 | 255 | 256 | class Discriminator(torch.nn.Module): 257 | def __init__(self, fin=32, n_channels=1, n_blocks=4, type='DCGAN'): 258 | super(Discriminator, self).__init__() 259 | 260 | # Number of feature extractor blocks 261 | self.n_blocks = n_blocks 262 | # Set number of input and output channels 263 | n_filters_in = [n_channels] + [fin*(2**i) for i in range(0, self.n_blocks)] 264 | n_filters_out = [fin*(2**i) for i in range(0, self.n_blocks+1)] 265 | # Number of output features 266 | self.nFeats = n_filters_out[-1] 267 | # Prepare blocks: 268 | self.blocks = torch.nn.ModuleList() 269 | for i in np.arange(0, self.n_blocks + 1): 270 | if type == 'DCGAN': 271 | self.blocks.append(ConvBlock(n_filters_in[i], n_filters_out[i])) 272 | # Output for binary clasification 273 | self.pool_feats = torch.nn.AdaptiveAvgPool2d((7, 7)) 274 | self.out = torch.nn.Sequential(torch.nn.AdaptiveAvgPool2d(1), 275 | torch.nn.Flatten(), 276 | torch.nn.Linear(self.nFeats, 1), 277 | torch.nn.Sigmoid()) 278 | 279 | def forward(self, x): 280 | 281 | for i in np.arange(0, self.n_blocks+1): 282 | x = self.blocks[i](x) 283 | f = self.pool_feats(x) 284 | 285 | out = self.out(f) 286 | 287 | return out, f 288 | 289 | 290 | # ------------------ 291 | # SEQUENTIAL MODELS 292 | 293 | class GradConCAEEncoder(torch.nn.Module): 294 | def __init__(self, fin=1, zdim=128, dense=False, variational=False, n_blocks=4, spatial_dim=7, 295 | gap=False): 296 | super(GradConCAEEncoder, self).__init__() 297 | self.fin = fin 298 | self.zdim = zdim 299 | self.dense = dense 300 | self.n_blocks = n_blocks 301 | self.gap = gap 302 | self.variational = variational 303 | 304 | # 1) Feature extraction 305 | self.backbone = torch.nn.Sequential(GradConCAEDownBlock(1, 32), 306 | GradConCAEDownBlock(32, 32), 307 | GradConCAEDownBlock(32, 64), 308 | GradConCAEDownBlock(64, 64)) 309 | # 2) Latent space (dense or spatial) 310 | if self.dense: # dense 311 | if gap: 312 | if self.variational: 313 | self.mu = torch.nn.Conv2d(64, zdim, (1, 1)) 314 | self.log_var = torch.nn.Conv2d(64, zdim, (1, 1)) 315 | else: 316 | self.z = torch.nn.Conv2d(64, zdim, (1, 1)) 317 | else: 318 | if self.variational: 319 | self.mu = torch.nn.Linear(64*spatial_dim**2, zdim) 320 | self.log_var = torch.nn.Linear(64*spatial_dim**2, zdim) 321 | else: 322 | self.z = torch.nn.Linear(64 * spatial_dim ** 2, zdim) 323 | else: # spatial 324 | if self.variational: 325 | self.mu = torch.nn.Conv2d(64, zdim, (1, 1)) 326 | self.log_var = torch.nn.Conv2d(64, zdim, (1, 1)) 327 | else: 328 | self.z = torch.nn.Conv2d(64, zdim, (1, 1)) 329 | 330 | def reparameterize(self, mu, log_var): 331 | std = torch.exp(0.5 * log_var) # standard deviation 332 | eps = torch.randn_like(std) 333 | 334 | sample = mu + (eps * std) # sampling 335 | return sample 336 | 337 | def forward(self, x): 338 | 339 | # 1) Feature extraction 340 | x = self.backbone(x) 341 | 342 | if self.dense and not self.gap: 343 | x = torch.nn.Flatten()(x) 344 | 345 | if self.dense and self.gap: 346 | x = torch.nn.functional.adaptive_avg_pool2d(x, 1) 347 | x = x.view(x.size(0), self.backbone.nfeats, 1, 1) 348 | 349 | # 2) Latent space 350 | if self.variational: 351 | # get `mu` and `log_var` 352 | z_mu = self.mu(x) 353 | z_logvar = self.log_var(x) 354 | # get the latent vector through reparameterization 355 | z = self.reparameterize(z_mu, z_logvar) 356 | else: 357 | z = self.z(x) 358 | z_mu, z_logvar = None, None 359 | 360 | return z, z_mu, z_logvar, None 361 | 362 | 363 | class GradConCAEDecoder(torch.nn.Module): 364 | 365 | def __init__(self, fin=256, nf0=128, n_channels=1, dense=False, n_blocks=4, spatial_dim=7, 366 | gap=False): 367 | super(GradConCAEDecoder, self).__init__() 368 | self.n_blocks = n_blocks 369 | self.dense = dense 370 | self.spatial_dim = spatial_dim 371 | self.fin = fin 372 | self.gap = gap 373 | 374 | if self.dense and not self.gap: 375 | self.dense_layer = torch.nn.Sequential(torch.nn.Linear(fin, 64*spatial_dim**2)) 376 | if not dense: 377 | self.dense_layer = torch.nn.Sequential(torch.nn.Conv2d(fin, 64, (1, 1))) 378 | 379 | # Set number of input and output channels 380 | n_filters_in = [64, 64, 32, 32] 381 | n_filters_out = [64, 32, 32] + [n_channels] 382 | 383 | self.blocks = torch.nn.ModuleList() 384 | for i in np.arange(0, 4): 385 | if i == 0: 386 | self.blocks.append(torch.nn.Sequential(GradConCAEUpBlock(n_filters_in[i], n_filters_out[i], padding=2))) 387 | else: 388 | self.blocks.append(torch.nn.Sequential(GradConCAEUpBlock(n_filters_in[i], n_filters_out[i], padding=1))) 389 | 390 | self.n_filters_in = n_filters_in 391 | self.n_filters_out = n_filters_out 392 | 393 | def forward(self, x): 394 | 395 | if self.dense and not self.gap: 396 | x = self.dense_layer(x) 397 | x = torch.nn.Unflatten(-1, (self.fin, self.spatial_dim, self.spatial_dim))(x) 398 | 399 | if self.dense and self.gap: 400 | x = torch.nn.functional.interpolate(x, scale_factor=self.spatial_dim) 401 | 402 | if not self.dense: 403 | x = self.dense_layer(x) 404 | 405 | for i in np.arange(0, 4): 406 | x = self.blocks[i](x) 407 | f = x 408 | out = x 409 | 410 | return out, f 411 | 412 | 413 | class GradConCAEDownBlock(torch.nn.Module): 414 | def __init__(self, in_channel, out_channel, stride=2, padding=2): 415 | super(GradConCAEDownBlock, self).__init__() 416 | 417 | self.conv1 = torch.nn.Conv2d(in_channel, out_channel, (4, 4), stride=stride, padding=padding) 418 | self.act1 = torch.nn.ReLU() 419 | 420 | def forward(self, x): 421 | 422 | x = self.conv1(x) 423 | x = self.act1(x) 424 | 425 | return x 426 | 427 | 428 | class GradConCAEUpBlock(torch.nn.Module): 429 | def __init__(self, in_channel, out_channel, stride=2, padding=2): 430 | super(GradConCAEUpBlock, self).__init__() 431 | 432 | self.conv1 = torch.nn.ConvTranspose2d(in_channel, out_channel, (4, 4), stride=stride, padding=padding) 433 | self.act1 = torch.nn.ReLU() 434 | 435 | def forward(self, x): 436 | x = self.conv1(x) 437 | x = self.act1(x) 438 | 439 | return x -------------------------------------------------------------------------------- /data/MICCAI_BraTS_2019_Data_Training/README.md: -------------------------------------------------------------------------------- 1 | ### Experiments on BRATS dataset 2 | 3 | To reproduce the experiments carried out in BRATS dataset, you should download the data from the following link: [BRATS](https://drive.google.com/file/d/1NgHMcIcfVGcoAYWd0ABI6AEZCkpFpvJ8/view?usp=sharing). Then, you can adecuate the MRI volumes and produce data slipts using the following code: 4 | 5 | ``` 6 | python adecuate_BRATS.py --dir_datasets ../data/MICCAI_BraTS_2019_Data_Training/ --dir_out ../data/BRATS_5slices/ --scan flair --nSlices 5 7 | ``` 8 | 9 | For more information on data description and usage requirements please reach out the original authors in the following [LINK](https://www.med.upenn.edu/cbica/brats2020/data.html). 10 | -------------------------------------------------------------------------------- /data/PhysioNet-ICH/README.md: -------------------------------------------------------------------------------- 1 | ### Experiments on PhysioNet-ICH dataset 2 | 3 | To reproduce the experiments carried out in BRATS dataset, you should download the data from the following link: [PhysioNet-ICH]([https://drive.google.com/file/d/1NgHMcIcfVGcoAYWd0ABI6AEZCkpFpvJ8/view?usp=sharing](https://physionet.org/content/ct-ich/1.3.1/)). Then, you can adecuate the CT scans and produce data slipts using the following code: 4 | 5 | ``` 6 | python adecuate_PhysionNet_ICH.py 7 | ``` 8 | 9 | For more information on data description and usage requirements please reach out the original authors in the following [LINK]([https://www.med.upenn.edu/cbica/brats2020/data.html](https://physionet.org/content/ct-ich/1.3.1/)). 10 | -------------------------------------------------------------------------------- /data/PhysioNet-ICH/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/visualizations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jusiro/constrained_anomaly_segmentation/c5c963c47ffc924b8611824c09fdfabfc92a48e1/images/visualizations.png --------------------------------------------------------------------------------