├── .gitattributes
├── LICENSE
├── README.md
├── code
├── datasets
│ ├── adecuate_BRATS.py
│ ├── adecuate_PhysionNet_ICH.py
│ ├── datasets.py
│ └── utils.py
├── evaluation
│ ├── metrics.py
│ ├── readme
│ └── utils.py
├── main.py
├── methods
│ ├── losses
│ │ ├── losses.py
│ │ └── readme
│ ├── train.py
│ └── trainers
│ │ ├── AMCons.py
│ │ ├── ae.py
│ │ ├── anoVAEGAN.py
│ │ ├── fanoGAN.py
│ │ ├── gradCAMCons.py
│ │ ├── gradCons.py
│ │ ├── histEqualization.py
│ │ └── vae.py
└── models
│ └── models.py
├── data
├── MICCAI_BraTS_2019_Data_Training
│ └── README.md
└── PhysioNet-ICH
│ ├── README.md
│ └── readme
└── images
└── visualizations.png
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Julio Silva-Rodríguez
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Constrained Unsupervised Anomaly Segmentation of Brain Lesions
2 |
3 | This repository contains code for unsupervised anomaly segmentation in brain lesions. Specifically, the implemented methods aim to constrain the optimization process to force a VAE to homogenize the activations produced in normal samples.
4 |
5 | If you find these methods useful for your research, please consider citing:
6 |
7 | **J. Silva-Rodríguez, V. Naranjo and J. Dolz, "Looking at the whole picture: constrained unsupervised anomaly segmentation", in British Machine Vision Conference (BMVC), 2021.** [(paper)](https://www.bmvc2021-virtualconference.com/assets/papers/1011.pdf)[(conference)](https://www.bmvc2021-virtualconference.com/conference/papers/paper_1011.html)
8 |
9 | **J. Silva-Rodríguez, V. Naranjo and J. Dolz, "Constrained unsupervised anomaly segmentation", Medical Image Analysis, vol. 80, p. 102526, 2022.** [(paper)](https://www.sciencedirect.com/science/article/pii/S1361841522001736)
10 |
11 | ## GRADCAMCons: looking at the whole picture via size constraints
12 |
13 | ```
14 | python main.py --dir_out ../data/results/gradCAMCons/ --method gradCAMCons --learning_rate 0.00001 --wkl 1 --wae 1000 --t 10
15 | ```
16 |
17 | ## AMCons: entropy maximization on activation maps
18 |
19 | ```
20 | python main.py --dir_out ../data/results/AMCon/ --method camCons --learning_rate 0.0001 --wkl 10 --wH 0.1
21 | ```
22 |
23 | ## Visualizations
24 |
25 |
26 |
27 |
28 |
29 | ## Contact
30 | For further questions or details, please directly reach out Julio Silva-Rodríguez (jusiro95@gmail.com)
31 |
--------------------------------------------------------------------------------
/code/datasets/adecuate_BRATS.py:
--------------------------------------------------------------------------------
1 | import nibabel as nib
2 | import os
3 | import numpy as np
4 | import random
5 | import argparse
6 | import cv2
7 |
8 | np.random.seed(42)
9 | random.seed(42)
10 |
11 |
12 | def adecuate_BRATS(args):
13 |
14 | dir_dataset = args.dir_dataset
15 | dir_out = args.dir_out
16 | scan = args.scan
17 | nSlices = args.nSlices
18 |
19 | partitions = ['train', 'val', 'test']
20 | Ncases = np.array([271, 32, 32])
21 |
22 | if not os.path.isdir(dir_out):
23 | os.mkdir(dir_out)
24 | if not os.path.isdir(dir_out + '/' + scan + '/'):
25 | os.mkdir(dir_out + '/' + scan + '/')
26 |
27 | cases_LGG = os.listdir(dir_dataset + 'LGG/')
28 | cases_LGG = [dir_dataset + 'LGG/' + iCase for iCase in cases_LGG if iCase != '.DS_Store']
29 |
30 | cases_HGG = os.listdir(dir_dataset + 'HGG/')
31 | cases_HGG = [dir_dataset + 'HGG/' + iCase for iCase in cases_HGG if iCase != '.DS_Store']
32 |
33 | cases = cases_LGG + cases_HGG
34 |
35 | random.shuffle(cases)
36 |
37 | for iPartition in np.arange(0, len(partitions)):
38 |
39 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition]):
40 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition])
41 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/benign'):
42 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/benign')
43 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/malign'):
44 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/malign')
45 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth'):
46 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth')
47 |
48 | cases_partition = cases[np.sum(Ncases[:iPartition]):np.sum(Ncases[:iPartition+1])]
49 |
50 | c = 0
51 | for iCase in cases_partition:
52 | c += 1
53 | print(str(c) + '/' + str(len(cases_partition)))
54 |
55 | img_path = iCase + '/' + iCase.split('/')[-1] + '_' + scan + '.nii.gz'
56 | mask_path = iCase + '/' + iCase.split('/')[-1] + '_seg.nii.gz'
57 |
58 | img = nib.load(img_path)
59 | img = (img.get_fdata())[:, :, :]
60 | img = (img/img.max())*255
61 | img = img.astype(np.uint8)
62 |
63 | mask = nib.load(mask_path)
64 | mask = (mask.get_fdata())
65 | mask[mask > 0] = 255
66 | mask = mask.astype(np.uint8)
67 |
68 | for iSlice in np.arange(round(img.shape[-1]/2) - nSlices, round(img.shape[-1]/2) + nSlices):
69 | filename = iCase.split('/')[-1] + '_' + str(iSlice) + '.jpg'
70 |
71 | i_image = img[:, :, iSlice]
72 | i_mask = mask[:, :, iSlice]
73 |
74 | if np.any(i_mask == 255):
75 | label = 'malign'
76 | cv2.imwrite(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth/' + filename, i_mask)
77 | else:
78 | label = 'benign'
79 |
80 | cv2.imwrite(dir_out + '/' + scan + '/' + partitions[iPartition] + '/' + label + '/' + filename, i_image)
81 |
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument("--dir_dataset", default='../data/MICCAI_BraTS_2019_Data_Training/', type=str)
86 | parser.add_argument("--dir_out", default='../data/BRATS_10slices/', type=str)
87 | parser.add_argument("--scan", default='flair', type=str)
88 | parser.add_argument("--nSlices", default=5, type=int)
89 |
90 | args = parser.parse_args()
91 | adecuate_BRATS(args)
92 |
93 |
94 |
--------------------------------------------------------------------------------
/code/datasets/adecuate_PhysionNet_ICH.py:
--------------------------------------------------------------------------------
1 | import nibabel as nib
2 | import os
3 | import numpy as np
4 | import random
5 | import argparse
6 | import cv2
7 | from scipy import ndimage
8 | from skimage import measure
9 | from matplotlib import pyplot as plt
10 | import SimpleITK as sitk
11 |
12 |
13 | def volume_registration(fixed_image, moving_image, mask=None):
14 |
15 | fixed_image = sitk.GetImageFromArray(fixed_image)
16 | moving_image = sitk.GetImageFromArray(moving_image)
17 | # Initial transformation
18 | '''
19 | transform_to_displacment_field_filter = sitk.TransformToDisplacementFieldFilter()
20 | transform_to_displacment_field_filter.SetReferenceImage(fixed_image)
21 | initial_transform = sitk.DisplacementFieldTransform(
22 | transform_to_displacment_field_filter.Execute(sitk.Transform(2, sitk.sitkIdentity)))
23 | '''
24 | initial_transform = sitk.CenteredTransformInitializer(fixed_image,
25 | moving_image,
26 | sitk.Euler3DTransform(),
27 | sitk.CenteredTransformInitializerFilter.GEOMETRY)
28 |
29 | registration_method = sitk.ImageRegistrationMethod()
30 | # Similarity metric settings.
31 | #registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
32 | registration_method.SetMetricAsCorrelation()
33 | registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
34 | registration_method.SetMetricSamplingPercentage(0.01)
35 |
36 | registration_method.SetInterpolator(sitk.sitkLinear)
37 | # Optimizer settings.
38 | registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100,
39 | convergenceMinimumValue=1e-6, convergenceWindowSize=10)
40 | registration_method.SetOptimizerScalesFromPhysicalShift()
41 | # Setup for the multi-resolution framework.
42 | registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
43 | registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
44 | registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
45 | # Don't optimize in-place, we would possibly like to run this cell multiple times.
46 | registration_method.SetInitialTransform(initial_transform, inPlace=False)
47 |
48 | # Apply transformation
49 | final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32),
50 | sitk.Cast(moving_image, sitk.sitkFloat32))
51 |
52 | moving_resampled = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0,
53 | moving_image.GetPixelID())
54 | out = sitk.GetArrayFromImage(moving_resampled)
55 |
56 | if mask is None:
57 | return out
58 | else: # Apply transformation to gt mask
59 | mask_mov = sitk.GetImageFromArray(mask)
60 | moving_resampled = sitk.Resample(mask_mov, fixed_image, final_transform, sitk.sitkNearestNeighbor,
61 | 0, mask_mov.GetPixelID())
62 | out_mask = sitk.GetArrayFromImage(moving_resampled)
63 |
64 | return out, out_mask
65 |
66 |
67 | def preprocess_vol_ICH(vol, mask):
68 | w_level = 40
69 | w_width = 120
70 |
71 | # Intensity normalization
72 |
73 | vol = (vol - (w_level - (w_width / 2))) * (255 / (w_width))
74 | vol[vol < 0] = 0
75 | vol[vol > 255] = 255
76 |
77 | # Get tissue mask
78 |
79 | tissue_mask = np.ones(vol.shape)
80 | tissue_mask[vol == 0] = 0
81 | tissue_mask[vol == 255] = 0
82 | tissue_mask = ndimage.binary_opening(tissue_mask, structure=np.ones((10, 10, 1))).astype(tissue_mask.dtype)
83 | tissue_mask = ndimage.binary_erosion(tissue_mask, structure=np.ones((5, 5, 1))).astype(tissue_mask.dtype)
84 |
85 | # Keep larger objetc in CT
86 | for iSlice in np.arange(0, vol.shape[-1]):
87 | tissue_mask_i = tissue_mask[:, :, iSlice]
88 |
89 | if np.max(tissue_mask_i) > 0:
90 |
91 | labels = measure.label(tissue_mask_i)
92 | props = measure.regionprops(labels)
93 |
94 | areas = [i_prop.area for i_prop in props]
95 | labels_given = [i_prop.label for i_prop in props]
96 | idx_areas = np.argsort(areas)
97 |
98 | if np.mean(tissue_mask_i[labels == (labels_given[idx_areas[-1]])]) != 0:
99 | label = labels_given[idx_areas[-1]]
100 | else:
101 | label = labels_given[idx_areas[-2]]
102 |
103 | tissue_mask_i = labels == (label)
104 | tissue_mask[:, :, iSlice] = tissue_mask_i
105 |
106 | vol = vol * tissue_mask
107 | mask = mask * tissue_mask
108 |
109 | return vol, mask
110 |
111 | np.random.seed(42)
112 | random.seed(42)
113 |
114 | dir_dataset = '../data/PhysioNet-ICH/'
115 | dir_out = '../data/PhysioNetICH_5slices_registered_new/'
116 | scan = 'CT'
117 | nSlices = 5
118 | partitions = ['train', 'test']
119 | Ncases = np.array([50, 25])
120 |
121 | if not os.path.isdir(dir_out):
122 | os.mkdir(dir_out)
123 | if not os.path.isdir(dir_out + '/' + scan + '/'):
124 | os.mkdir(dir_out + '/' + scan + '/')
125 |
126 | cases = os.listdir(dir_dataset + 'ct_scans/')
127 | cases = [dir_dataset + 'ct_scans/' + iCase for iCase in cases if iCase != '.DS_Store']
128 |
129 | random.shuffle(cases)
130 |
131 | for iPartition in np.arange(0, len(partitions)):
132 |
133 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition]):
134 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition])
135 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/benign'):
136 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/benign')
137 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/malign'):
138 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/malign')
139 | if not os.path.isdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth'):
140 | os.mkdir(dir_out + '/' + scan + '/' + partitions[iPartition] + '/ground_truth')
141 |
142 | cases_partition = cases[np.sum(Ncases[:iPartition]):np.sum(Ncases[:iPartition+1])]
143 |
144 | # Load volume reference
145 | vol_ref = nib.load('../data/PhysioNet-ICH/ct_scans/071.nii')
146 | vol_ref = (vol_ref.get_fdata())
147 | mask_ref = nib.load('../data/PhysioNet-ICH/masks/071.nii')
148 | mask_ref = (mask_ref.get_fdata())
149 | mask_ref[mask_ref > 0] = 255
150 | vol_ref, mask_ref = preprocess_vol_ICH(vol_ref, mask_ref)
151 |
152 | c = 0
153 | for iCase in cases:
154 | c += 1
155 |
156 | img_path = iCase
157 | mask_path = iCase.replace('ct_scans', 'masks')
158 |
159 | # Load volume and mask
160 | img = nib.load(img_path)
161 | img = (img.get_fdata())[:, :, :]
162 | mask = nib.load(mask_path)
163 | mask = (mask.get_fdata())
164 | mask[mask > 0] = 255
165 |
166 | print(str(c) + '/' + str(len(cases)) + ' || ' + 'slices: ' + str(img.shape[-1]))
167 |
168 | # Preprocess
169 | img, mask = preprocess_vol_ICH(img, mask)
170 |
171 | img = img[:, :, round(img.shape[-1] / 2) - nSlices+5:round(img.shape[-1] / 2) + nSlices+5]
172 | mask = mask[:, :, round(mask.shape[-1] / 2) - nSlices+5:round(mask.shape[-1] / 2) + nSlices+5]
173 |
174 | if np.max(mask) == 255:
175 | part = 'test'
176 | else:
177 | part = 'train'
178 |
179 | mask = mask.astype(np.uint8)
180 |
181 | for iSlice in np.arange(0, nSlices*2):
182 | filename = iCase.split('/')[-1].split('.')[0] + '_' + str(iSlice) + '.jpg'
183 |
184 | i_image = img[:, :, iSlice]
185 | i_mask = mask[:, :, iSlice]
186 |
187 | if np.any(i_mask == 255):
188 | label = 'malign'
189 | cv2.imwrite(dir_out + '/' + scan + '/' + part + '/ground_truth/' + filename, i_mask)
190 | else:
191 | label = 'benign'
192 |
193 | cv2.imwrite(dir_out + '/' + scan + '/' + part + '/' + label + '/' + filename, i_image)
194 |
--------------------------------------------------------------------------------
/code/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from datasets.utils import *
4 |
5 |
6 | # ------------------------------------------
7 | # BRATS dataset
8 | class TestDataset(object):
9 |
10 | def __init__(self, dir_dataset, item, partition, input_shape=(1, 224, 224), channel_first=True, norm='max',
11 | histogram_matching=True, filter_volumes=False):
12 | # Init properties
13 | self.dir_dataset = dir_dataset
14 | self.dir_datasets = dir_dataset
15 | self.item = item
16 | self.partition = partition
17 | self.input_shape = input_shape
18 | self.channel_first = channel_first
19 | self.norm = norm
20 | self.histogram_matching = histogram_matching
21 | self.nchannels = self.input_shape[0]
22 | self.ref_image = {}
23 | self.filter_volumes = filter_volumes
24 | if 'BRATS' in dir_dataset:
25 | for iModality in np.arange(0, len(self.item)):
26 | x = Image.open(
27 | '../data/BRATS_5slices/' + self.item[iModality] + '/train/benign/' + 'BraTS19_CBICA_AWV_1_77.jpg')
28 | x = np.asarray(x)
29 | self.ref_image[iModality] = x[40:-40, 40:-40]
30 |
31 | # Select all files in partition
32 | self.images = []
33 | for subdirs in ['benign', 'malign']:
34 | for images in os.listdir(self.dir_dataset + self.item[0] + '/' + self.partition + '/' + subdirs + '/'):
35 | if 'Thumbs.db' not in images:
36 | self.images.append(self.dir_dataset + 'modality' + '/' + self.partition + '/' + subdirs + '/' + images)
37 |
38 | # Get number of patients (volumes), and number of slices
39 | if 'BRATS' in dir_dataset:
40 | patients = [image.split('/')[-1][:-7] for image in self.images]
41 | else:
42 | patients = [image.split('/')[-1][:-6] for image in self.images]
43 | self.unique_patients = np.unique(patients)
44 | slices_per_volume = len(patients) // len(self.unique_patients)
45 |
46 | # Load images and masks
47 | self.X = np.zeros((len(self.unique_patients), len(self.item), slices_per_volume, self.nchannels, self.input_shape[1], self.input_shape[2]))
48 | self.M = np.zeros((len(self.unique_patients), slices_per_volume, self.nchannels, self.input_shape[1], self.input_shape[2]))
49 | self.Y = np.zeros((len(self.unique_patients), slices_per_volume))
50 | for iPatient in np.arange(0, len(self.unique_patients)):
51 |
52 | slices_patient = list(np.sort([iSlice for iSlice in self.images if self.unique_patients[iPatient] in iSlice]))
53 |
54 | if 'ICH' in dir_dataset:
55 | indexes = np.array([int(id[-5]) for id in slices_patient])
56 | idx = np.array(np.argsort(indexes))
57 | slices_patient = [slices_patient[i] for i in idx]
58 |
59 | for iSlice in np.arange(0, slices_per_volume):
60 |
61 | for iModality in np.arange(0, len(self.item)):
62 |
63 | # Load image
64 | x = Image.open(slices_patient[iSlice].replace('modality', self.item[iModality]))
65 | x = np.asarray(x)
66 |
67 | # Normalization
68 | if 'BRATS' in dir_dataset:
69 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nchannels,
70 | histogram_matching=self.histogram_matching, reference_image=self.ref_image[iModality],
71 | mask=False, channel_first=True)
72 | else:
73 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nchannels,
74 | histogram_matching=False,
75 | mask=False, channel_first=True)
76 | self.X[iPatient, iModality, iSlice, :, :, :] = x
77 |
78 | # Load mask
79 | if 'malign' in slices_patient[iSlice]:
80 | mask_id = slices_patient[iSlice].replace('malign', 'ground_truth').replace('modality', self.item[iModality])
81 |
82 | m = Image.open(mask_id)
83 | m = np.asarray(m)
84 |
85 | # Normalization
86 | m = image_normalization(m, self.input_shape[-1], norm=self.norm, channels=1,
87 | histogram_matching=False,
88 | reference_image=None,
89 | mask=True, channel_first=True)
90 | self.M[iPatient, iSlice, :, :, :] = m
91 | self.Y[iPatient, iSlice] = 1
92 |
93 | if self.filter_volumes:
94 | idx = np.squeeze(np.argwhere(np.sum(self.M, (1, 2, 3, 4)) / (slices_per_volume*self.input_shape[1]*self.input_shape[2]) > 0.001))
95 | self.X = self.X[idx, :, :, :, :, :]
96 | self.M = self.M[idx, :, :, :, :]
97 | self.Y = self.Y[idx, :]
98 | self.images = list(np.array(self.images)[idx])
99 | self.unique_patients = list(np.array(self.unique_patients)[idx])
100 |
101 | if len(self.item) == 1:
102 | self.X = self.X[:, 0, :, :, :, :]
103 |
104 |
105 | class MultiModalityDataset(object):
106 |
107 | def __init__(self, dir_datasets, modalities, input_shape=(3, 512, 512), channel_first=True, norm='max',
108 | hist_match=True, weak_supervision=False):
109 |
110 | 'Internal states initialization'
111 | self.dir_datasets = dir_datasets
112 | self.modalities = modalities
113 | self.input_shape = input_shape
114 | self.channel_first = channel_first
115 | self.norm = norm
116 | self.nChannels = input_shape[0]
117 | self.hist_match = hist_match
118 | self.weak_supervision = weak_supervision
119 | self.ref_image = {}
120 | if 'BRATS' in dir_datasets:
121 | for iModality in np.arange(0, len(modalities)):
122 | x = Image.open('../data/BRATS_5slices/' + modalities[iModality] + '/train/benign/' + 'BraTS19_CBICA_AWV_1_77.jpg')
123 | x = np.asarray(x)
124 | self.ref_image[iModality] = x
125 | self.train_images = []
126 | self.test_images = []
127 |
128 | # Paths for training data
129 | name_normal = '/train/benign/'
130 | name_anomaly = '/test/malign/'
131 |
132 | # Get train images
133 | train_images = os.listdir(dir_datasets + modalities[0] + name_normal)
134 | # Remove other files
135 | train_images = [train_images[i] for i in range(train_images.__len__()) if train_images[i] != 'Thumbs.db']
136 | for iImage in train_images:
137 | self.train_images.append(dir_datasets + 'modality' + name_normal + iImage)
138 |
139 | # Get train images
140 | test_images = os.listdir(dir_datasets + modalities[0] + name_anomaly)
141 | # Remove other files
142 | test_images = [test_images[i] for i in range(test_images.__len__()) if test_images[i] != 'Thumbs.db']
143 | for iImage in test_images:
144 | self.test_images.append(dir_datasets + 'modality' + name_anomaly + iImage)
145 |
146 | self.train_indexes = np.arange(0, len(self.train_images))
147 | self.test_indexes = np.arange(0, len(self.test_images)) + len(self.train_images)
148 | self.images = self.train_images + self.test_images
149 |
150 | # Pre-allocate images
151 | self.X = np.zeros((len(self.images), len(self.modalities), input_shape[0], input_shape[1], input_shape[2]), dtype=np.float32)
152 | self.M = np.zeros((len(self.images), 1, input_shape[1], input_shape[2]), dtype=np.float32)
153 | self.Y = np.zeros((len(self.images), 2), dtype=np.float32)
154 |
155 | # Load, and normalize images
156 | print('[INFO]: Loading training images...')
157 | for i in np.arange(len(self.images)):
158 | for iModality in np.arange(0, len(self.modalities)):
159 | print(str(i) + '/' + str(len(self.images)), end='\r')
160 |
161 | # Load image
162 | x = Image.open(self.images[i].replace('modality', modalities[iModality]))
163 | x = np.asarray(x)
164 |
165 | # Normalization
166 | if 'BRATS' in dir_datasets:
167 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nChannels,
168 | histogram_matching=self.hist_match, reference_image=self.ref_image[iModality],
169 | mask=False, channel_first=True)
170 | else:
171 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nChannels,
172 | histogram_matching=False,
173 | mask=False, channel_first=True)
174 | self.X[i, iModality, :, :, :] = x
175 |
176 | if 'benign' in self.images[i]:
177 | self.Y[i, :] = np.array([1, 0])
178 | else:
179 | self.Y[i, :] = np.array([0, 1])
180 |
181 | mask_id = self.images[i].replace('malign', 'ground_truth').replace('modality', modalities[iModality])
182 |
183 | y = Image.open(mask_id)
184 | y = np.asarray(y)
185 | if len(y.shape) == 3:
186 | y = y[:, :, 0]
187 | # Normalization
188 | y = image_normalization(y, self.input_shape[-1], norm=self.norm, channels=1,
189 | histogram_matching=False,
190 | reference_image=None,
191 | mask=True, channel_first=True)
192 | self.M[i, :, :, :] = y
193 |
194 | print('[INFO]: Images loaded')
195 |
196 | def __len__(self):
197 | 'Denotes the total number of samples'
198 | return len(self.train_indexes)
199 |
200 | def __getitem__(self, index):
201 | 'Generates one sample of data'
202 |
203 | x = self.X[index, :, :, :, :]
204 | y = self.Y[index, :]
205 |
206 | if len(self.modalities) == 1:
207 | x = x[0, :, :, :]
208 |
209 | return x, y
210 |
211 | # ------------------------------------------
212 | # MVTEC dataset
213 |
214 |
215 | class MVTECDataset(object):
216 |
217 | def __init__(self, dir_datasets, modalities, input_shape=(3, 512, 512), channel_first=True, norm='max',
218 | weak_supervision=False, partition='train'):
219 |
220 | 'Internal states initialization'
221 | self.dir_datasets = dir_datasets
222 | self.modalities = modalities
223 | self.input_shape = input_shape
224 | self.channel_first = channel_first
225 | self.norm = norm
226 | self.nChannels = input_shape[0]
227 | self.weak_supervision = weak_supervision
228 | self.images = []
229 |
230 | # Get images
231 | categories = os.listdir(dir_datasets + modalities[0] + '/' + partition + '/')
232 | for i_category in categories:
233 | for iFile in os.listdir(dir_datasets + modalities[0] + '/' + partition + '/' + i_category + '/'):
234 | self.images.append(dir_datasets + modalities[0] + '/' + partition + '/' + i_category + '/' + iFile)
235 |
236 | # Remove other files
237 | self.images = [self.images[i] for i in range(self.images.__len__()) if 'Thumbs.db' not in self.images[i]]
238 |
239 | # Pre-allocate images
240 | self.X = np.zeros((len(self.images), input_shape[0], input_shape[1], input_shape[2]), dtype=np.float32)
241 | self.M = np.zeros((len(self.images), 1, input_shape[1], input_shape[2]), dtype=np.float32)
242 | self.Y = np.zeros((len(self.images), 1), dtype=np.float32)
243 |
244 | # Load, and normalize images
245 | print('[INFO]: Loading training images...')
246 | for i in np.arange(len(self.images)):
247 | print(str(i) + '/' + str(len(self.images)), end='\r')
248 |
249 | # Load image
250 | x = Image.open(self.images[i])
251 | x = np.asarray(x)
252 |
253 | # Normalization
254 | x = image_normalization(x, self.input_shape[-1], norm=self.norm, channels=self.nChannels,
255 | histogram_matching=False,
256 | mask=False, channel_first=True)
257 | self.X[i, :, :, :] = x
258 |
259 | if 'good' in self.images[i]:
260 | self.Y[i, :] = np.array([0])
261 | else:
262 | self.Y[i, :] = np.array([1])
263 |
264 | mask_id = self.images[i].replace(partition, 'ground_truth').replace('.png', '_mask.png')
265 |
266 | y = Image.open(mask_id)
267 | y = np.asarray(y)
268 | if len(y.shape) == 3:
269 | y = y[:, :, 0]
270 | # Normalization
271 | y = image_normalization(y, self.input_shape[-1], norm=self.norm, channels=1,
272 | histogram_matching=False,
273 | reference_image=None,
274 | mask=True, channel_first=True)
275 | self.M[i, :, :, :] = y
276 |
277 | print('[INFO]: Images loaded')
278 | if partition == 'train':
279 | self.train_indexes = np.arange(0, len(self.images))
280 |
281 | def __len__(self):
282 | 'Denotes the total number of samples'
283 | return len(self.images)
284 |
285 | def __getitem__(self, index):
286 | 'Generates one sample of data'
287 |
288 | x = self.X[index, :, :, :]
289 | y = self.Y[index, :]
290 |
291 | return x, y
292 |
293 |
294 | # ------------------------------------------
295 | # Data generator
296 |
297 | class WSALDataGenerator(object):
298 |
299 | def __init__(self, dataset, partition, batch_size=16, shuffle=False):
300 |
301 | 'Internal states initialization'
302 | self.dataset = dataset
303 | self.batch_size = batch_size
304 | self.shuffle = shuffle
305 | self.partition = partition
306 |
307 | if self.partition == 'train':
308 | self.indexes = self.dataset.train_indexes.copy()
309 | elif self.partition == 'test':
310 | self.indexes = self.dataset.test_indexes.copy()
311 |
312 | if self.dataset.weak_supervision:
313 | self.indexes_abnormal = self.dataset.test_indexes.copy()
314 | self._idx_abnormal = 0
315 | self.batch_size_anomaly = np.min(np.array((self.batch_size, len(self.dataset.test_indexes))))
316 |
317 | self._idx = 0
318 | self._reset()
319 |
320 | def __len__(self):
321 |
322 | N = len(self.indexes)
323 | b = self.batch_size
324 | return N // b
325 |
326 | def __iter__(self):
327 |
328 | return self
329 |
330 | def __next__(self):
331 |
332 | # If dataset is completed, stop iterator
333 | if self._idx + self.batch_size >= len(self.indexes):
334 | self._reset()
335 | raise StopIteration()
336 |
337 | if self.dataset.weak_supervision:
338 | if self._idx_abnormal + self.batch_size_anomaly >= len(self.indexes_abnormal):
339 | self._idx_abnormal = 0
340 |
341 | # Load images and include into the batch
342 | X, Y = [], []
343 | for i in range(self._idx, self._idx + self.batch_size):
344 | x, y = self.dataset.__getitem__(self.indexes[i])
345 | X.append(x)
346 | Y.append(y)
347 | # Update index iterator
348 | self._idx += self.batch_size
349 |
350 | if self.dataset.weak_supervision:
351 | Xa, Ya = [], []
352 | for i in range(self._idx_abnormal, self._idx_abnormal + self.batch_size_anomaly):
353 | xa, ya = self.dataset.__getitem__(self.indexes_abnormal[i])
354 | Xa.append(xa)
355 | Ya.append(ya)
356 | # Update index iterator
357 | self._idx_abnormal += self.batch_size_anomaly
358 |
359 | return np.array(X).astype('float32'), np.array(Y).astype('float32'),\
360 | np.array(Xa).astype('float32'), np.array(Ya).astype('float32')
361 | else:
362 | return np.array(X).astype('float32'), np.array(Y).astype('float32'),\
363 | None, None
364 |
365 | def _reset(self):
366 |
367 | if self.shuffle:
368 | random.shuffle(self.indexes)
369 | self._idx = 0
370 |
371 |
372 |
--------------------------------------------------------------------------------
/code/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 |
5 | from skimage import exposure
6 | from matplotlib import pyplot as plt
7 | import imutils
8 |
9 |
10 | def image_normalization(x, shape, norm='max', channels=3, histogram_matching=False, reference_image=None,
11 | mask=False, channel_first=True):
12 |
13 | # Histogram matching to reference image
14 | if histogram_matching:
15 | x_norm = exposure.match_histograms(x, reference_image)
16 | x_norm[x == 0] = 0
17 | x = x_norm
18 |
19 | # image resize
20 | x = imutils.resize(x, height=shape)
21 | #x = resize_image_canvas(x, shape)
22 |
23 | # Grayscale image -- add channel dimension
24 | if len(x.shape) < 3:
25 | x = np.expand_dims(x, -1)
26 |
27 | if mask:
28 | x = (x > 200)
29 |
30 | # channel first
31 | if channel_first:
32 | x = np.transpose(x, (2, 0, 1))
33 | if not mask:
34 | if norm == 'max':
35 | x = x / 255.0
36 | elif norm == 'zscore':
37 | x = (x - 127.5) / 127.5
38 |
39 | # numeric type
40 | x.astype('float32')
41 | return x
42 |
43 |
44 | def plot_image(x, y=None, denorm_intensity=False, channel_first=True):
45 | if len(x.shape) < 3:
46 | x = np.expand_dims(x, 0)
47 | # channel first
48 | if channel_first:
49 | x = np.transpose(x, (1, 2, 0))
50 | if denorm_intensity:
51 | if self.norm == 'zscore':
52 | x = (x*127.5) + 127.5
53 | x = x.astype(int)
54 |
55 | plt.imshow(x)
56 |
57 | if y is not None:
58 | y = np.expand_dims(y[0, :, :], -1)
59 | plt.imshow(y, cmap='jet', alpha=0.1)
60 |
61 | plt.axis('off')
62 | plt.show()
63 |
64 |
65 | def augment_input_batch(batch):
66 | masks = np.zeros(batch.shape)
67 |
68 | for i in np.arange(0, batch.shape[0]):
69 | (batch[i, :, :, :], masks[i, :, :, :]) = augment_input_context(batch[i, :, :, :])
70 |
71 | return batch, masks
72 |
73 |
74 | def augment_input_context(x):
75 | im = x.copy()
76 | mask = np.zeros(im.shape)
77 |
78 | # Randomize anomaly size
79 | w = random.randint(0, im.shape[2] // 10)
80 |
81 | # Random center-cropping
82 | xx = random.randint(0, im.shape[2] - w)
83 | yy = random.randint(0, im.shape[1] - w)
84 |
85 | # Get intensity
86 | i = np.percentile(im, 99) + np.std(im)
87 |
88 | # Inset anomaly
89 | im[:, xx-w:xx+w, yy-w:yy+w] = 0
90 | mask[:, xx-w:xx+w, yy-w:yy+w] = 1
91 |
92 | # keep only skull
93 | im[x==0] = 0
94 | mask[x == 0] = 0
95 |
96 | return im, mask
97 |
98 |
99 | def resize_image_canvas(image, input_shape, pcn_eat_hz=0.):
100 |
101 | # cut a bit on the sides
102 | if pcn_eat_hz > 0:
103 | h, w = image.shape
104 | px_eat_hz = int(pcn_eat_hz * w)
105 | image = image[:, px_eat_hz:w - px_eat_hz]
106 |
107 | h, w = image.shape
108 | ratio_h = input_shape[1] / h
109 | ratio_w = input_shape[2] / w
110 |
111 | img_res = np.zeros((input_shape[1], input_shape[2]), dtype=image.dtype)
112 |
113 | if ratio_w > ratio_h:
114 | img_res_h = imutils.resize(image, height=input_shape[1])
115 | left_margin = (input_shape[2] - img_res_h.shape[1]) // 2
116 | img_res[:, left_margin:left_margin + img_res_h.shape[1]] = img_res_h
117 | else:
118 | img_res_w = imutils.resize(image, width=input_shape[2])
119 | top_margin = (input_shape[1] - img_res_w.shape[0]) // 2
120 | img_res[top_margin:top_margin + img_res_w.shape[0], :] = img_res_w
121 |
122 | return img_res
123 |
124 |
--------------------------------------------------------------------------------
/code/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import precision_recall_curve
2 | from sklearn.metrics import f1_score
3 | from sklearn.metrics import auc
4 | import numpy as np
5 |
6 |
7 | def dice(true_mask, pred_mask, non_seg_score=1.0):
8 |
9 | assert true_mask.shape == pred_mask.shape
10 |
11 | true_mask = np.asarray(true_mask).astype(np.bool)
12 | pred_mask = np.asarray(pred_mask).astype(np.bool)
13 |
14 | # If both segmentations are all zero, the dice will be 1. (Developer decision)
15 | im_sum = true_mask.sum() + pred_mask.sum()
16 | if im_sum == 0:
17 | return non_seg_score
18 |
19 | # Compute Dice coefficient
20 | intersection = np.logical_and(true_mask, pred_mask)
21 | return 2. * intersection.sum() / im_sum
22 |
23 |
24 | def au_prc(true_mask, pred_mask):
25 |
26 | # Calculate pr curve and its area
27 | precision, recall, threshold = precision_recall_curve(true_mask, pred_mask)
28 | au_prc = auc(recall, precision)
29 |
30 | # Search the optimum point and obtain threshold via f1 score
31 | f1 = 2 * (precision * recall) / (precision + recall)
32 | f1[np.isnan(f1)] = 0
33 |
34 | th = threshold[np.argmax(f1)]
35 |
36 | return au_prc, th
--------------------------------------------------------------------------------
/code/evaluation/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | import sklearn.metrics
2 | import imutils
3 | import cv2
4 | import os
5 | import matplotlib.pyplot as plt
6 |
7 | from evaluation.metrics import *
8 |
9 |
10 | def inference_dataset(method, dataset):
11 |
12 | # Take references and inputs from testing dataset
13 | X = dataset.X
14 | M = dataset.M
15 | Y = dataset.Y
16 |
17 | if 'BRATS' in dataset.dir_datasets or 'PhysioNet' in dataset.dir_datasets: # Inference is volume-wise
18 |
19 | if len(X.shape) > 5:
20 | X = X[:, 0, :, :, :, :]
21 |
22 | if len(M.shape) < 5:
23 | M = np.expand_dims(M, 1)
24 |
25 | # Init variables
26 | (p, s, c, h, w) = X.shape # maps dimensions
27 | Mhat = np.zeros(M.shape) # Predicted segmentation maps
28 | Xhat = np.zeros(X.shape) # Reconstructed images
29 | Scores = np.zeros((p, s))
30 |
31 | for iVolume in np.arange(0, p):
32 | for iSlice in np.arange(0, s):
33 | # Take image
34 | x = X[iVolume, iSlice, :, :, :]
35 | # Predict anomaly map and score
36 | score, mhat, xhat = method.predict_score(x)
37 |
38 | Mhat[iVolume, iSlice, :, :, :] = mhat
39 | Xhat[iVolume, iSlice, :, :, :] = xhat
40 | Scores[iVolume, iSlice] = score
41 |
42 | elif 'MVTEC' in dataset.dir_datasets: # Inference is image-wise
43 |
44 | # Init variables
45 | (cases, c, h, w) = X.shape # maps dimensions
46 | Mhat = np.zeros(M.shape) # Predicted segmentation maps
47 | Xhat = np.zeros(X.shape) # Reconstructed images
48 | Scores = np.zeros((cases, 1))
49 |
50 | for iCase in np.arange(0, cases):
51 | # Take image
52 | x = X[iCase, :, :, :]
53 | # Predict anomaly map and score
54 | score, mhat, xhat = method.predict_score(x)
55 |
56 | Mhat[iCase, :, :, :] = mhat
57 | Xhat[iCase, :, :, :] = xhat
58 | Scores[iCase, :] = score
59 |
60 | return Y, Scores, M, Mhat, X, Xhat
61 |
62 |
63 | def evaluate_anomaly_detection(y, scores, dir_out='', range=[-1, 1], tit='cosine similarity', bins=50, th=None):
64 |
65 | scores = np.ravel(scores)
66 | y = np.ravel(y)
67 |
68 | auroc = sklearn.metrics.roc_auc_score(y, scores) # au_roc
69 | auprc, th_op = au_prc(y, scores) # au_prc
70 |
71 | if th is None:
72 | th = th_op
73 |
74 | if dir_out != '':
75 | plt.hist(np.ravel(scores)[np.ravel(y) == 1], bins=bins, range=range, fc=[0.7, 0, 0, 0.5])
76 | plt.hist(np.ravel(scores)[np.ravel(y) == 0], bins=bins, range=range, fc=[0, 0, 0.7, 0.5])
77 | plt.legend(['Anomaly', 'Normal'])
78 | plt.xlabel(tit)
79 |
80 | plt.savefig(dir_out + 'anomaly_detection.png')
81 | plt.close('all')
82 |
83 | return auroc, auprc, th
84 |
85 |
86 | def evaluate_anomaly_localization(dataset, save_maps=False, dir_out='', filter_volumes=True, th=None):
87 |
88 | print('[INFO]: Testing...')
89 | if save_maps:
90 | if not os.path.isdir(dir_out + 'masks_predicted/'):
91 | os.mkdir(dir_out + 'masks_predicted/')
92 | if not os.path.isdir(dir_out + 'masks_reference'):
93 | os.mkdir(dir_out + 'masks_reference/')
94 | if not os.path.isdir(dir_out + 'xhat_predicted'):
95 | os.mkdir(dir_out + 'xhat_predicted/')
96 |
97 | # Get references and predictions from dataset
98 | M = dataset.M
99 | Mhat = dataset.Mhat
100 | X = dataset.X
101 | Xhat = dataset.Xhat
102 |
103 | if 'BRATS' in dataset.dir_datasets or 'PhysioNet' in dataset.dir_datasets: # Inference is volume-wise
104 |
105 | if len(X.shape) > 5:
106 | X = X[:, 0, :, :, :, :]
107 |
108 | if filter_volumes: # Filver volumes withouth annotations
109 | idx = np.squeeze(np.argwhere(np.sum(M, (1, 2, 3, 4)) / (M.shape[1] * M.shape[-1] * M.shape[-1]) > 0.001))
110 | X = X[idx, :, :, :, :]
111 | Xhat = Xhat[idx, :, :, :, :]
112 | M = M[idx, :, :, :, :]
113 | Mhat = Mhat[idx, :, :, :, :]
114 | unique_patients = list(np.array(dataset.unique_patients)[idx])
115 |
116 | # Obtain overall metrics and optimum point threshold
117 | AU_ROC = sklearn.metrics.roc_auc_score(M.flatten() == 1, Mhat.flatten()) # au_roc
118 | AU_PRC, th_op = au_prc(M.flatten() == 1, Mhat.flatten()) # au_prc
119 |
120 | if th is None:
121 | th = th_op
122 |
123 | # Apply threshold
124 | DICE = dice(M.flatten() == 1, (Mhat > th).flatten()) # Dice
125 | IoU = sklearn.metrics.jaccard_score(M.flatten() == 1, (Mhat > th).flatten()) # IoU
126 |
127 | if 'BRATS' in dataset.dir_datasets or 'PhysioNet' in dataset.dir_datasets: # Inference is volume-wise
128 |
129 | # Once the threshold is obtained calculate volume-level metrics and plot results
130 | patient_dice = []
131 | (p, s, c, h, w) = X.shape # maps dimensions
132 | for iVolume in np.arange(0, p):
133 | patient_dice.append(dice(M[iVolume, :, :, :, :].flatten(), (Mhat[iVolume, :, :, :, :] > th).flatten()))
134 |
135 | if save_maps and dir_out!= '': # Save slices' masks
136 | for iSlice in np.arange(0, s):
137 | id = unique_patients[iVolume] + '_' + str(iSlice) + '.jpg'
138 |
139 | # Obtain heatmaps for predicted and reference
140 | m_i = imutils.rotate_bound(np.uint8(M[iVolume, iSlice, 0, :, :] * 255), 270)
141 | heatmap_m = cv2.applyColorMap(m_i, cv2.COLORMAP_JET)
142 | mhat_i = imutils.rotate_bound(np.uint8((Mhat[iVolume, iSlice, 0, :, :] > th) * 255), 270)
143 | heatmap_mhat = cv2.applyColorMap(mhat_i, cv2.COLORMAP_JET)
144 | heatmap_mhat = heatmap_mhat * (
145 | np.expand_dims(imutils.rotate_bound(Mhat[iVolume, iSlice, 0, :, :], 270), -1) > 0)
146 |
147 | # Move grayscale image to three channels
148 | xh = cv2.cvtColor(np.uint8(np.squeeze(X[iVolume, iSlice, :, :, :]) * 255), cv2.COLOR_GRAY2RGB)
149 | xh = imutils.rotate_bound(xh, 270)
150 |
151 | # Combine original image and masks
152 | fin_mask = cv2.addWeighted(xh, 0.7, heatmap_m, 0.3, 0)
153 | fin_predicted = cv2.addWeighted(xh, 0.7, heatmap_mhat, 0.3, 0)
154 |
155 | fin_predicted = mhat_i
156 | fin_mask = m_i
157 |
158 | cv2.imwrite(dir_out + 'masks_predicted/' + id, fin_predicted)
159 | cv2.imwrite(dir_out + 'masks_reference/' + id, fin_mask)
160 | cv2.imwrite(dir_out + 'xhat_predicted/' + id, np.uint8(Xhat[iVolume, iSlice, 0, :, :] * 255))
161 |
162 | DICE_mu = np.mean(patient_dice)
163 | DICE_std = np.std(patient_dice)
164 |
165 | if 'MVTEC' in dataset.dir_datasets: # Inference is image-wise
166 |
167 | # Once the threshold is obtained calculate volume-level metrics and plot results
168 | case_dice = []
169 | (cases, c, h, w) = X.shape # maps dimensions
170 | for iCase in np.arange(0, cases):
171 | if 'good' not in dataset.images[iCase]:
172 | case_dice.append(dice(M[iCase, :, :, :].flatten(), (Mhat[iCase, :, :, :] > th).flatten()))
173 |
174 | if save_maps and dir_out != '': # Save slices' masks
175 | id = dataset.images[iCase].replace('.png', '.jpg').split('/')[-2] + '_' + \
176 | dataset.images[iCase].replace('.png', '.jpg').split('/')[-1]
177 |
178 | # Obtain heatmaps for predicted and reference
179 | m_i = np.uint8(M[iCase, 0, :, :] * 255)
180 | heatmap_m = cv2.applyColorMap(m_i, cv2.COLORMAP_JET)
181 | mhat_i = np.uint8((Mhat[iCase, 0, :, :] > th) * 255)
182 | heatmap_mhat = cv2.applyColorMap(mhat_i, cv2.COLORMAP_JET)
183 |
184 | # Move grayscale image to three channels
185 | xh = np.uint8(np.squeeze(X[iCase, :, :, :]) * 255)
186 | xh = np.transpose(xh, (1, 2, 0))
187 |
188 | # Combine original image and masks
189 | fin_mask = cv2.addWeighted(xh, 0.7, heatmap_m, 0.3, 0)
190 | fin_predicted = cv2.addWeighted(xh, 0.7, heatmap_mhat, 0.3, 0)
191 |
192 | cv2.imwrite(dir_out + 'masks_predicted/' + id, fin_predicted)
193 | cv2.imwrite(dir_out + 'masks_reference/' + id, fin_mask)
194 | cv2.imwrite(dir_out + 'xhat_predicted/' + id, np.uint8(Xhat[iCase, 0, :, :] * 255))
195 |
196 | DICE_mu = np.mean(case_dice)
197 | DICE_std = np.std(case_dice)
198 |
199 | metrics = {'AU_ROC': AU_ROC, 'AU_PRC': AU_PRC, 'DICE': DICE, 'IoU': IoU,
200 | 'DICE_mu': DICE_mu, 'DICE_std': DICE_std}
201 |
202 | return metrics, th
--------------------------------------------------------------------------------
/code/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 |
6 | from datasets.datasets import TestDataset, MultiModalityDataset, WSALDataGenerator, MVTECDataset
7 | from methods.train import AnomalyDetectorTrainer
8 | from evaluation.utils import *
9 |
10 |
11 | def main(args):
12 |
13 | exp = {"dir_datasets": args.dir_datasets, "dir_out": args.dir_out, "load_weigths": args.load_weigths,
14 | "epochs": args.epochs, "item": args.item, "method": args.method, "input_shape": args.input_shape,
15 | "batch_size": args.batch_size, "lr": args.learning_rate, "zdim": args.zdim, "images_on_ram": True,
16 | "wkl": args.wkl, "wr": 1, "wadv": args.wadv, "wae": args.wae, "epochs_to_test": 50, "dense": args.dense,
17 | "channel_first": True, "normalization_cam": args.normalization_cam, "avg_grads": True, "t": args.t,
18 | "context": args.context, "level_cams": args.level_cams, "p_activation_cam": args.p_activation_cam,
19 | "bayesian": args.bayesian, "loss_reconstruction": "bce",
20 | "expansion_loss_penalty": args.expansion_loss_penalty, "restoration": args.restoration,
21 | "n_blocks": args.n_blocks, "alpha_entropy": args.wH}
22 |
23 | if not os.path.isdir('../data/results/'):
24 | os.mkdir('../data/results/')
25 |
26 | metrics = []
27 | for iteration in [0, 1, 2]:
28 |
29 | if 'BRATS' in exp['dir_datasets'] or 'PhysioNet' in exp['dir_datasets']:
30 |
31 | # Set test
32 | test_dataset = TestDataset(exp['dir_datasets'], item=exp['item'], partition='val',
33 | input_shape=exp['input_shape'],
34 | channel_first=True, norm='max', histogram_matching=True)
35 |
36 | # Set train data loader
37 | dataset = MultiModalityDataset(exp['dir_datasets'], exp['item'], input_shape=exp['input_shape'],
38 | channel_first=exp['channel_first'], norm='max', hist_match=True)
39 |
40 | train_generator = WSALDataGenerator(dataset, partition='train', batch_size=exp['batch_size'], shuffle=True)
41 |
42 | elif 'MVTEC' in exp['dir_datasets']:
43 |
44 | # Set test
45 | test_dataset = MVTECDataset(exp['dir_datasets'], exp['item'], input_shape=exp['input_shape'],
46 | channel_first=exp['channel_first'], norm='max',
47 | partition='test')
48 |
49 | # Set train data loader
50 | dataset = MVTECDataset(exp['dir_datasets'], exp['item'], input_shape=exp['input_shape'],
51 | channel_first=exp['channel_first'], norm='max',
52 | partition='train')
53 |
54 | train_generator = WSALDataGenerator(dataset, partition='train', batch_size=exp['batch_size'], shuffle=True)
55 |
56 | # Set trainer and train model
57 | trainer = AnomalyDetectorTrainer(exp['dir_out'], exp['method'], item=exp['item'], zdim=exp['zdim'],
58 | dense=exp['dense'], n_blocks=exp['n_blocks'],
59 | lr=exp['lr'], input_shape=exp['input_shape'], load_weigths=exp['load_weigths'],
60 | epochs_to_test=exp['epochs_to_test'], context=exp['context'],
61 | bayesian=exp['bayesian'], restoration=exp['restoration'],
62 | loss_reconstruction=exp['loss_reconstruction'],
63 | level_cams=exp['level_cams'],
64 | expansion_loss_penalty=exp['expansion_loss_penalty'],
65 | iteration=iteration,
66 | alpha_kl=exp['wkl'], alpha_entropy=exp["alpha_entropy"],
67 | p_activation_cam=exp["p_activation_cam"],
68 | alpha_ae=exp["wae"], t=exp["t"])
69 |
70 | # Save experiment setup
71 | with open(exp['dir_out'] + 'setup.json', 'w') as fp:
72 | json.dump(exp, fp)
73 |
74 | if not args.only_test:
75 | # Train
76 | trainer.train(train_generator, exp['epochs'], test_dataset)
77 | else:
78 |
79 | trainer.method.train_generator = train_generator
80 | trainer.method.dataset_test = test_dataset
81 |
82 | thresholod_with_percentile = False
83 | thresholod_with_valsubset = False
84 | if thresholod_with_percentile:
85 | # Predictions on normal dataset
86 | Y_t, Scores_t, M_t, Mhat_t, X_t, Xhat_t = inference_dataset(trainer.method, dataset)
87 | th = np.percentile(np.ravel(Mhat_t), 99)
88 | elif thresholod_with_valsubset:
89 | val_dataset = TestDataset(exp['dir_datasets'], item=exp['item'], partition='val',
90 | input_shape=exp['input_shape'],
91 | channel_first=True, norm='max', histogram_matching=True)
92 | # Make predictions
93 | _, Scores, _, Mhat, _, Xhat = inference_dataset(trainer.method, val_dataset)
94 | # Input to dataset
95 | val_dataset.Scores = Scores
96 | val_dataset.Mhat = Mhat
97 | val_dataset.Xhat = Xhat
98 | # Get threshold
99 | _, th = evaluate_anomaly_localization(val_dataset, save_maps=False,
100 | dir_out=trainer.method.dir_results,
101 | th=None)
102 | else:
103 | th = None
104 |
105 | # Make predictions
106 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(trainer.method, test_dataset)
107 |
108 | # Input to dataset
109 | test_dataset.Scores = Scores
110 | test_dataset.Mhat = Mhat
111 | test_dataset.Xhat = Xhat
112 |
113 | metrics_i, th_i = evaluate_anomaly_localization(test_dataset, save_maps=False,
114 | dir_out=trainer.method.dir_results,
115 | th=th)
116 | print(metrics_i)
117 | trainer.method.metrics = metrics_i
118 |
119 | # Save overall metrics
120 | metrics.append(list(trainer.method.metrics.values()))
121 |
122 | # Compute average performance and save performance in dictionary
123 | metrics = np.array(metrics)
124 | metrics_mu = np.mean(metrics, 0)
125 | metrics_std = np.std(metrics, 0)
126 |
127 | labels = list(trainer.method.metrics.keys())
128 | metrics_mu = {labels[i]: metrics_mu[i] for i in range(0, len(labels))}
129 | metrics_std = {labels[i]: metrics_std[i] for i in range(0, len(labels))}
130 |
131 | with open(exp['dir_out'] + exp['item'][0] + '/' + 'metrics_avg_val.json', 'w') as fp:
132 | json.dump(metrics_mu, fp)
133 | with open(exp['dir_out'] + exp['item'][0] + '/' + 'metrics_std_val.json', 'w') as fp:
134 | json.dump(metrics_std, fp)
135 |
136 |
137 | if __name__ == '__main__':
138 | parser = argparse.ArgumentParser()
139 | # Settings
140 | parser.add_argument("--dir_datasets", default="../data/BRATS_5slices/", type=str)
141 | parser.add_argument("--dir_out", default="../data/gradCAMCons/tests/", type=str)
142 | parser.add_argument("--method", default="gradCAMCons", type=str)
143 | parser.add_argument("--item", default=["flair"], type=list, nargs="+")
144 | parser.add_argument("--load_weigths", default=False, type=bool)
145 | parser.add_argument("--only_test", default=False, type=bool)
146 | # Hyper-params training
147 | parser.add_argument("--input_shape", default=[1, 224, 224], type=list)
148 | parser.add_argument("--learning_rate", default=1e-5, type=float)
149 | parser.add_argument("--epochs", default=300, type=int)
150 | parser.add_argument("--batch_size", default=8, type=int)
151 | # Auto-encoder architecture
152 | parser.add_argument("--zdim", default=32, type=int)
153 | parser.add_argument("--dense", default=True, type=bool)
154 | parser.add_argument("--n_blocks", default=4, type=int)
155 | # Residual-based inference options
156 | parser.add_argument("--restoration", default=False, type=bool)
157 | parser.add_argument("--bayesian", default=False, type=bool)
158 | parser.add_argument("--context", default=False, type=bool)
159 | # Settings with variational AE
160 | parser.add_argument("--wkl", default=1, type=float)
161 | # Settings with discriminator
162 | parser.add_argument("--wadv", default=0., type=float)
163 | # GradCAMCons
164 | parser.add_argument("--wae", default=1e4, type=float)
165 | parser.add_argument("--p_activation_cam", default=1e-2, type=float)
166 | parser.add_argument("--expansion_loss_penalty", default="log_barrier", type=str)
167 | parser.add_argument("--t", default=10, type=int)
168 | parser.add_argument("--normalization_cam", default="sigm", type=str)
169 | # AMCons
170 | parser.add_argument("--wH", default=0., type=float)
171 | parser.add_argument("--level_cams", default=-4, type=float)
172 |
173 | args = parser.parse_args()
174 | main(args)
175 |
--------------------------------------------------------------------------------
/code/methods/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def kl_loss(mu, logvar):
6 |
7 | kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
8 |
9 | return kl_divergence
10 |
11 |
12 | def log_barrier(z, t=5):
13 |
14 | # Only one value
15 | if z.shape[0] == 1:
16 |
17 | if z <= - 1 / t ** 2:
18 | log_barrier_loss = - torch.log(-z) / t
19 | else:
20 | log_barrier_loss = t * z + -np.log(1 / (t ** 2)) / t + 1 / t
21 |
22 | # Constrain over multiple values
23 | else:
24 | log_barrier_loss = torch.tensor(0).cuda().float()
25 | for i in np.arange(0, z.shape[0]):
26 | zi = z[i, 0]
27 | if zi <= - 1 / t ** 2:
28 | log_barrier_loss += - torch.log(-zi) / t
29 | else:
30 | log_barrier_loss += t * zi + -np.log(1 / (t ** 2)) / t + 1 / t
31 |
32 | return log_barrier_loss
33 |
--------------------------------------------------------------------------------
/code/methods/losses/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/methods/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | # Import different methods
5 | from methods.trainers.ae import *
6 | from methods.trainers.vae import *
7 | from methods.trainers.anoVAEGAN import *
8 | from methods.trainers.fanoGAN import *
9 | from methods.trainers.AMCons import *
10 | from methods.trainers.gradCAMCons import *
11 | from methods.trainers.histEqualization import *
12 |
13 | torch.autograd.set_detect_anomaly(True)
14 |
15 |
16 | class AnomalyDetectorTrainer:
17 | def __init__(self, dir_out, method, item=['flair'], zdim=32, dense=True, variational=False, n_blocks=5, lr=1*1e-4,
18 | input_shape=(1, 224, 224), load_weigths=False, epochs_to_test=10, context=False, bayesian=False,
19 | restoration=False, loss_reconstruction='bce', iteration=0, level_cams=-4,
20 | alpha_kl=10, alpha_entropy=0., expansion_loss_penalty='log_barrier', p_activation_cam=0.2, t=25,
21 | alpha_ae=10):
22 |
23 | # Init input variables
24 | self.dir_out = dir_out
25 | self.method = method
26 | self.item = item
27 | self.zdim = zdim
28 | self.dense = dense
29 | self.variational = variational
30 | self.n_blocks = n_blocks
31 | self.input_shape = input_shape
32 | self.load_weights = load_weigths
33 | self.epochs_to_test = epochs_to_test
34 | self.context = context
35 | self.bayesian = bayesian
36 | self.restoration = restoration
37 | self.loss_reconstruction = loss_reconstruction
38 | self.lr = lr
39 | self.level_cams = level_cams
40 | self.expansion_loss_penalty = expansion_loss_penalty
41 | self.alpha_kl = alpha_kl
42 | self.alpha_entropy = alpha_entropy
43 | self.p_activation_cam = p_activation_cam
44 | self.t = t
45 | self.alpha_ae = alpha_ae
46 |
47 | # Prepare results folders
48 | self.dir_results = dir_out + item[0] + '/iteration_' + str(iteration) + str('/')
49 | if not os.path.isdir(dir_out):
50 | os.mkdir(dir_out)
51 | if not os.path.isdir(dir_out + item[0] + '/'):
52 | os.mkdir(dir_out + item[0] + '/')
53 | if not os.path.isdir(self.dir_results):
54 | os.mkdir(self.dir_results)
55 |
56 | # Create trainer
57 | if self.method == 'ae':
58 | self.method = AnomalyDetectorAE(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr,
59 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test,
60 | load_weigths=self.load_weights, n_blocks=self.n_blocks, dense=self.dense,
61 | context=self.context, bayesian=self.bayesian,
62 | loss_reconstruction=self.loss_reconstruction, restoration=self.restoration)
63 | elif self.method == 'vae':
64 | self.method = AnomalyDetectorVAE(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr,
65 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test,
66 | load_weigths=self.load_weights, n_blocks=self.n_blocks,
67 | dense=self.dense,
68 | context=self.context, bayesian=self.bayesian,
69 | loss_reconstruction=self.loss_reconstruction,
70 | restoration=self.restoration,
71 | alpha_kl=self.alpha_kl)
72 | elif self.method == 'anoVAEGAN':
73 | self.method = AnomalyDetectorAnoVAEGAN(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr,
74 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test,
75 | load_weigths=self.load_weights, n_blocks=self.n_blocks,
76 | dense=self.dense,
77 | context=self.context, bayesian=self.bayesian,
78 | loss_reconstruction=self.loss_reconstruction,
79 | restoration=self.restoration,
80 | alpha_kl=self.alpha_kl)
81 | elif self.method == 'fanoGAN':
82 | self.method = AnomalyDetectorFanoGAN(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr,
83 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test,
84 | load_weigths=self.load_weights, n_blocks=self.n_blocks,
85 | dense=self.dense,
86 | context=self.context, bayesian=self.bayesian,
87 | loss_reconstruction=self.loss_reconstruction,
88 | restoration=self.restoration)
89 | elif self.method == 'gradCAMCons':
90 | self.method = AnomalyDetectorGradCamCons(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr,
91 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test,
92 | load_weigths=self.load_weights, n_blocks=self.n_blocks,
93 | dense=self.dense, loss_reconstruction=self.loss_reconstruction,
94 | pre_training_epochs=50, level_cams=self.level_cams, t=self.t,
95 | p_activation_cam=self.p_activation_cam,
96 | expansion_loss_penalty='log_barrier', alpha_ae=self.alpha_ae,
97 | alpha_kl=self.alpha_kl)
98 | elif self.method == 'camCons':
99 | self.method = AnomalyDetectorAMCons(self.dir_results, item=self.item, zdim=self.zdim, lr=self.lr,
100 | input_shape=self.input_shape, epochs_to_test=self.epochs_to_test,
101 | load_weigths=self.load_weights, n_blocks=self.n_blocks,
102 | dense=self.dense, loss_reconstruction=self.loss_reconstruction,
103 | pre_training_epochs=0, level_cams=self.level_cams,
104 | alpha_entropy=self.alpha_entropy,
105 | alpha_kl=self.alpha_kl)
106 | elif self.method == 'histEqualization':
107 | self.method = AnomalyDetectorHistEqualization(self.dir_results, item=self.item)
108 |
109 | else:
110 | print('Uncorrect specified method... ', end='\n')
111 |
112 | def train(self, train_generator, epochs, dataset_test):
113 | self.method.train(train_generator, epochs, dataset_test)
--------------------------------------------------------------------------------
/code/methods/trainers/AMCons.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | import datetime
5 | import kornia
6 | import json
7 | import torch
8 |
9 | import pandas as pd
10 | import matplotlib.pyplot as plt
11 |
12 | from scipy import ndimage
13 | from timeit import default_timer as timer
14 | from models.models import Encoder, Decoder
15 | from evaluation.utils import *
16 | from methods.losses.losses import kl_loss
17 | from methods.losses.losses import log_barrier
18 | from sklearn.metrics import accuracy_score, f1_score
19 |
20 |
21 | class AnomalyDetectorAMCons:
22 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25,
23 | load_weigths=False, n_blocks=5, dense=True, loss_reconstruction='bce', alpha_kl=1,
24 | pre_training_epochs=0, level_cams=-4, alpha_entropy=1, gap=False):
25 |
26 | # Init input variables
27 | self.dir_results = dir_results
28 | self.item = item
29 | self.zdim = zdim
30 | self.lr = lr
31 | self.input_shape = input_shape
32 | self.epochs_to_test = epochs_to_test
33 | self.load_weigths = load_weigths
34 | self.n_blocks = n_blocks
35 | self.dense = dense
36 | self.loss_reconstruction = loss_reconstruction
37 | self.alpha_kl = alpha_kl
38 | self.pre_training_epochs = pre_training_epochs
39 | self.level_cams = level_cams
40 | self.alpha_entropy = alpha_entropy
41 | self.gap = gap
42 |
43 | # Init network
44 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks,
45 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=True, gap=gap)
46 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0],
47 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks,
48 | gap=gap)
49 |
50 | if torch.cuda.is_available():
51 | self.E.cuda()
52 | self.Dec.cuda()
53 |
54 | if self.load_weigths:
55 | self.E.load_state_dict(torch.load(self.dir_results + 'encoder_weights.pth'))
56 | self.Dec.load_state_dict(torch.load(self.dir_results + 'decoder_weights.pth'))
57 |
58 | # Set parameters
59 | self.params = list(self.E.parameters()) + list(self.Dec.parameters())
60 |
61 | # Set losses
62 | if self.loss_reconstruction == 'l2':
63 | self.Lr = torch.nn.MSELoss(reduction='sum')
64 | elif self.loss_reconstruction == 'bce':
65 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum')
66 |
67 | self.Lkl = kl_loss
68 |
69 | # Set optimizers
70 | self.opt = torch.optim.Adam(self.params, lr=self.lr)
71 |
72 | # Init additional variables and objects
73 | self.epochs = 0.
74 | self.iterations = 0.
75 | self.init_time = 0.
76 | self.lr_iteration = 0.
77 | self.lr_epoch = 0.
78 | self.kl_iteration = 0.
79 | self.kl_epoch = 0.
80 | self.H_iteration = 0.
81 | self.H_epoch = 0.
82 | self.i_epoch = 0.
83 | self.train_generator = []
84 | self.dataset_test = []
85 | self.metrics = {}
86 | self.aucroc_lc = []
87 | self.auprc_lc = []
88 | self.auroc_det = []
89 | self.lr_lc = []
90 | self.lkl_lc = []
91 | self.lae_lc = []
92 | self.H_lc = []
93 | self.auroc_det_lc = []
94 | self.refCam = 0.
95 |
96 | def train(self, train_generator, epochs, test_dataset):
97 | self.epochs = epochs
98 | self.init_time = timer()
99 | self.train_generator = train_generator
100 | self.dataset_test = test_dataset
101 | self.iterations = len(self.train_generator)
102 |
103 | # Loop over epochs
104 | for self.i_epoch in range(self.epochs):
105 | # init epoch losses
106 | self.lr_epoch = 0
107 | self.kl_epoch = 0.
108 | self.H_epoch = 0.
109 |
110 | # Loop over training dataset
111 | for self.i_iteration, (x_n, y_n, x_a, y_a) in enumerate(self.train_generator):
112 | #p = q
113 |
114 | # brain mask
115 | if 'BRATS' in train_generator.dataset.dir_datasets or\
116 | 'PhysioNet' in train_generator.dataset.dir_datasets:
117 | x_mask = 1 - np.mean((x_n == 0).astype(np.int), 0)
118 | if 'BRATS' in train_generator.dataset.dir_datasets:
119 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 6, 6))).astype(x_mask.dtype)
120 | elif 'MVTEC' in train_generator.dataset.dir_datasets:
121 | x_mask = np.zeros((1, 224, 224))
122 | x_mask[:, 14:-14, 14:-14] = 1
123 |
124 | # Move tensors to gpu
125 | x_n = torch.tensor(x_n).cuda().float()
126 |
127 | # Obtain latent space from normal sample via encoder
128 | z, z_mu, z_logvar, allF = self.E(x_n)
129 |
130 | # Obtain reconstructed images through decoder
131 | xhat, _ = self.Dec(z)
132 | if self.loss_reconstruction == 'l2':
133 | xhat = torch.sigmoid(xhat)
134 |
135 | # Calculate criterion
136 | self.lr_iteration = self.Lr(xhat, x_n) / (self.train_generator.batch_size) # Reconstruction loss
137 | self.kl_iteration = self.Lkl(mu=z_mu, logvar=z_logvar) / (self.train_generator.batch_size) # kl loss
138 |
139 | # Init overall losses
140 | L = self.lr_iteration + self.alpha_kl * self.kl_iteration
141 |
142 | # ---- Compute Attention Homogeneization loss via Entropy
143 |
144 | am = torch.mean(allF[self.level_cams], 1)
145 |
146 | # Restore original shape
147 | am = torch.nn.functional.interpolate(am.unsqueeze(1),
148 | size=(self.input_shape[-1], self.input_shape[-1]),
149 | mode='bilinear',
150 | align_corners=True)
151 | am = am.view((am.shape[0], -1))
152 |
153 | # Prepare mask with brain
154 | if 'BRATS' in train_generator.dataset.dir_datasets or\
155 | 'MVTEC' in train_generator.dataset.dir_datasets or\
156 | 'PhysioNet' in train_generator.dataset.dir_datasets:
157 |
158 | x_mask = np.ravel(x_mask)
159 | x_mask = torch.tensor(np.array(np.argwhere(x_mask > 0.5))).cuda().squeeze()
160 | am = torch.index_select(am, dim=1, index=x_mask)
161 |
162 | # Probabilities
163 | p = torch.nn.functional.softmax(am.view((am.shape[0], -1)), dim=-1)
164 | # Mean entropy
165 | self.H_iteration = torch.mean(-torch.sum(p * torch.log(p + 1e-12), dim=(-1)))
166 |
167 | if self.i_epoch > self.pre_training_epochs:
168 |
169 | if self.alpha_entropy > 0:
170 | # Entropy Maximization
171 | L += - self.alpha_entropy * self.H_iteration
172 |
173 | # Update weights
174 | L.backward() # Backward
175 | self.opt.step() # Update weights
176 | self.opt.zero_grad() # Clear gradients
177 |
178 | """
179 | ON ITERATION/EPOCH END PROCESS
180 | """
181 |
182 | # Display losses per iteration
183 | self.display_losses(on_epoch_end=False)
184 |
185 | # Update epoch's losses
186 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator)
187 | self.kl_epoch += self.kl_iteration.cpu().detach().numpy() / len(self.train_generator)
188 | self.H_epoch += self.H_iteration.cpu().detach().numpy() / len(self.train_generator)
189 |
190 | # Epoch-end processes
191 | self.on_epoch_end()
192 |
193 | def on_epoch_end(self):
194 |
195 | # Display losses
196 | self.display_losses(on_epoch_end=True)
197 |
198 | # Update learning curves
199 | self.lr_lc.append(self.lr_epoch)
200 | self.lkl_lc.append(self.kl_epoch)
201 | self.H_lc.append(self.H_epoch)
202 |
203 | # Each x epochs, test models and plot learning curves
204 | if (self.i_epoch + 1) % self.epochs_to_test == 0:
205 | # Save weights
206 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth')
207 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth')
208 |
209 | # Evaluate
210 | if self.i_epoch > (self.pre_training_epochs - 50):
211 | # Make predictions
212 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test)
213 |
214 | # Input to dataset
215 | self.dataset_test.Scores = Scores
216 | self.dataset_test.Mhat = Mhat
217 | self.dataset_test.Xhat = Xhat
218 |
219 | # Evaluate anomaly detection
220 | auroc_det, auprc_det, th_det = evaluate_anomaly_detection(self.dataset_test.Y, self.dataset_test.Scores,
221 | dir_out=self.dir_results,
222 | range=[np.min(Scores)-np.std(Scores), np.max(Scores)+np.std(Scores)],
223 | tit='kl')
224 | acc = accuracy_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int'))
225 | fs = f1_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int'))
226 |
227 | metrics_detection = {'auroc_det': auroc_det, 'auprc_det': auprc_det, 'th_det': th_det, 'acc_det': acc,
228 | 'fs_det': fs}
229 | print(metrics_detection)
230 |
231 | # Evaluate anomaly localization
232 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results)
233 | self.metrics = metrics
234 |
235 | # Save metrics as dict
236 | with open(self.dir_results + 'metrics.json', 'w') as fp:
237 | json.dump(metrics, fp)
238 | print(metrics)
239 |
240 | # Plot learning curve
241 | self.plot_learning_curves()
242 |
243 | # Save learning curves as dataframe
244 | self.aucroc_lc.append(metrics['AU_ROC'])
245 | self.auprc_lc.append(metrics['AU_PRC'])
246 | self.auroc_det_lc.append(auroc_det)
247 | history = pd.DataFrame(list(zip(self.lr_lc, self.lkl_lc, self.H_lc, self.aucroc_lc, self.auprc_lc, self.auroc_det_lc)),
248 | columns=['Lrec', 'Lkl', 'H', 'AUCROC', 'AUPRC', 'AUROC_det'])
249 | history.to_csv(self.dir_results + 'lc_on_direct.csv')
250 |
251 | else:
252 | self.aucroc_lc.append(0)
253 | self.auprc_lc.append(0)
254 | self.auroc_det_lc.append(0)
255 |
256 | def predict_score(self, x):
257 | self.E.eval()
258 | self.Dec.eval()
259 |
260 | # brain mask
261 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets:
262 | x_mask = 1 - (x == 0).astype(np.int)
263 | if 'BRATS' in self.train_generator.dataset.dir_datasets:
264 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 6, 6))).astype(x_mask.dtype)
265 | else:
266 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype)
267 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets:
268 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1]))
269 | x_mask[:, 14:-14, 14:-14] = 1
270 |
271 | # Get reconstruction error map
272 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
273 | xhat = torch.sigmoid(self.Dec(z)[0]).squeeze().detach().cpu().numpy()
274 |
275 | am = torch.mean(f[self.level_cams], 1)
276 | # Restore original shape
277 | mhat = torch.nn.functional.interpolate(am.unsqueeze(0), size=(self.input_shape[-1], self.input_shape[-1]),
278 | mode='bilinear', align_corners=True).squeeze().detach().cpu().numpy()
279 |
280 | # brain mask - Keep only brain region
281 | if 'BRATS' in self.train_generator.dataset.dir_datasets or \
282 | 'PhysioNet' in self.train_generator.dataset.dir_datasets or \
283 | 'MVTEC' in self.train_generator.dataset.dir_datasets:
284 | mhat[x_mask[0, :, :] == 0] = 0
285 |
286 | # Get outputs
287 | anomaly_map = mhat
288 | # brain mask - Keep only brain region
289 | if 'BRATS' in self.train_generator.dataset.dir_datasets or \
290 | 'PhysioNet' in self.train_generator.dataset.dir_datasets or \
291 | 'MVTEC' in self.train_generator.dataset.dir_datasets:
292 | score = np.std(anomaly_map[x_mask[0, :, :] == 1])
293 | else:
294 | score = np.std(anomaly_map)
295 |
296 | self.E.train()
297 | self.Dec.train()
298 | return score, anomaly_map, xhat
299 |
300 | def display_losses(self, on_epoch_end=False):
301 |
302 | # Init info display
303 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs,
304 | self.i_iteration + 1, self.iterations)
305 | # Prepare values to show
306 | if on_epoch_end:
307 | lr = self.lr_epoch
308 | lkl = self.kl_epoch
309 | lH = self.H_epoch
310 |
311 | end = '\n'
312 | else:
313 | lr = self.lr_iteration
314 | lkl = self.kl_iteration
315 | lH = self.H_iteration
316 |
317 | end = '\r'
318 |
319 | # Init losses display
320 | info += "Reconstruction={:.4f} || KL={:.4f} || H={:.4f}".format(lr, lkl, lH)
321 | if self.train_generator.dataset.weak_supervision:
322 | info += " || H_a={:.4f}".format(lH_a)
323 | # Print losses
324 | et = str(datetime.timedelta(seconds=timer() - self.init_time))
325 | print(info + ', ET=' + et, end=end)
326 |
327 | def plot_learning_curves(self):
328 | def plot_subplot(axes, x, y, y_axis):
329 | axes.grid()
330 | axes.plot(x, y, 'o-')
331 | axes.set_ylabel(y_axis)
332 |
333 | fig, axes = plt.subplots(2, 2, figsize=(20, 15))
334 | plot_subplot(axes[0, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss")
335 | plot_subplot(axes[0, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.lkl_lc), "KL loss")
336 | plot_subplot(axes[1, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.H_lc), "H")
337 | plt.savefig(self.dir_results + 'learning_curve.png')
338 | plt.close()
339 |
340 |
--------------------------------------------------------------------------------
/code/methods/trainers/ae.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import kornia
3 | import json
4 | import torch
5 |
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 |
9 | from scipy import ndimage
10 | from timeit import default_timer as timer
11 | from datasets.utils import augment_input_batch
12 | from models.models import Encoder, Decoder
13 | from evaluation.utils import *
14 |
15 |
16 | class AnomalyDetectorAE:
17 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25,
18 | load_weigths=False, n_blocks=5, dense=True, context=False, bayesian=False,
19 | loss_reconstruction='bce', restoration=False):
20 |
21 | # Init input variables
22 | self.dir_results = dir_results
23 | self.item = item
24 | self.zdim = zdim
25 | self.lr = lr
26 | self.input_shape = input_shape
27 | self.epochs_to_test = epochs_to_test
28 | self.load_weigths = load_weigths
29 | self.n_blocks = n_blocks
30 | self.dense = dense
31 | self.context = context
32 | self.bayesian = bayesian
33 | self.loss_reconstruction = loss_reconstruction
34 | self.restoration = restoration
35 |
36 | # Init network
37 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks,
38 | spatial_dim=self.input_shape[1]//2**self.n_blocks)
39 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0],
40 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks)
41 |
42 | if torch.cuda.is_available():
43 | self.E.cuda()
44 | self.Dec.cuda()
45 |
46 | if self.load_weigths:
47 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth'))
48 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth'))
49 |
50 | # Set parameters
51 | self.params = list(self.E.parameters()) + list(self.Dec.parameters())
52 |
53 | # Set losses
54 | if self.loss_reconstruction == 'l2':
55 | self.Lr = torch.nn.MSELoss(reduction='sum')
56 | elif self.loss_reconstruction == 'bce':
57 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum')
58 |
59 | # Set optimizers
60 | self.opt = torch.optim.Adam(self.params, lr=self.lr)
61 |
62 | # Init additional variables and objects
63 | self.epochs = 0.
64 | self.i_iteration = 0.
65 | self.iterations = 0.
66 | self.init_time = 0.
67 | self.lr_iteration = 0.
68 | self.lr_epoch = 0.
69 | self.i_epoch = 0.
70 | self.train_generator = []
71 | self.dataset_test = []
72 | self.metrics = {}
73 | self.aucroc_lc = []
74 | self.auprc_lc = []
75 | self.lr_lc = []
76 |
77 | def train(self, train_generator, epochs, dataset_test):
78 | self.epochs = epochs
79 | self.init_time = timer()
80 | self.train_generator = train_generator
81 | self.dataset_test = dataset_test
82 | self.iterations = len(self.train_generator)
83 |
84 | # Loop over epochs
85 | for self.i_epoch in range(self.epochs):
86 | # init epoch losses
87 | self.lr_epoch = 0
88 |
89 | # Loop over training dataset
90 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator):
91 |
92 | if self.context: # if context option, data augmentation to apply context
93 | (x_n_context, _) = augment_input_batch(x_n.copy())
94 | x_n_context = torch.tensor(x_n_context).cuda().float()
95 |
96 | # Move tensors to gpu
97 | x_n = torch.tensor(x_n).cuda().float()
98 |
99 | # Obtain latent space from normal sample via encoder
100 | if not self.context:
101 | z, _, _, _ = self.E(x_n)
102 | else:
103 | z, _, _, _ = self.E(x_n_context)
104 |
105 | # Obtain reconstructed images through decoder
106 | xhat, _ = self.Dec(z)
107 | if self.loss_reconstruction == 'l2':
108 | xhat = torch.sigmoid(xhat)
109 |
110 | # Calculate criterion
111 | self.lr_iteration = self.Lr(xhat, x_n) / (self.train_generator.batch_size * self.input_shape[1] *
112 | self.input_shape[2]) # Reconstruction loss
113 |
114 | # Init overall losses
115 | L = self.lr_iteration
116 |
117 | # Update weights
118 | L.backward() # Backward
119 | self.opt.step() # Update weights
120 | self.opt.zero_grad() # Clear gradients
121 |
122 | """
123 | ON ITERATION/EPOCH END PROCESS
124 | """
125 |
126 | # Display losses per iteration
127 | self.display_losses(on_epoch_end=False)
128 |
129 | # Update epoch's losses
130 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator)
131 |
132 | # Epoch-end processes
133 | self.on_epoch_end()
134 |
135 | def on_epoch_end(self):
136 |
137 | # Display losses
138 | self.display_losses(on_epoch_end=True)
139 |
140 | # Update learning curves
141 | self.lr_lc.append(self.lr_epoch)
142 |
143 | # Each x epochs, test models and plot learning curves
144 | if (self.i_epoch + 1) % self.epochs_to_test == 0:
145 | # Save weights
146 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth')
147 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth')
148 |
149 | # Make predictions
150 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test)
151 |
152 | # Input to dataset
153 | self.dataset_test.Scores = Scores
154 | self.dataset_test.Mhat = Mhat
155 | self.dataset_test.Xhat = Xhat
156 |
157 | # Evaluate
158 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results)
159 | self.metrics = metrics
160 |
161 | # Save metrics as dict
162 | with open(self.dir_results + 'metrics.json', 'w') as fp:
163 | json.dump(metrics, fp)
164 | print(metrics)
165 |
166 | # Plot learning curve
167 | self.plot_learning_curves()
168 |
169 | # Save learning curves as dataframe
170 | self.aucroc_lc.append(metrics['AU_ROC'])
171 | self.auprc_lc.append(metrics['AU_PRC'])
172 | history = pd.DataFrame(list(zip(self.lr_lc, self.aucroc_lc, self.auprc_lc)),
173 | columns=['Lrec', 'AUCROC', 'AUPRC'])
174 | history.to_csv(self.dir_results + 'lc_on_direct.csv')
175 |
176 | else:
177 | self.aucroc_lc.append(0)
178 | self.auprc_lc.append(0)
179 |
180 | def predict_score(self, x):
181 | self.E.eval()
182 | self.Dec.eval()
183 |
184 | # brain mask
185 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets:
186 | x_mask = 1 - (x == 0).astype(np.int)
187 | if 'BRATS' in self.train_generator.dataset.dir_datasets:
188 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype)
189 | else:
190 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype)
191 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets:
192 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1]))
193 | x_mask[:, 14:-14, 14:-14] = 1
194 |
195 | # Get reconstruction error map
196 | if self.restoration: # restoration reconstruction
197 | mhat, xhat = self.restoration_reconstruction(x)
198 | elif self.bayesian: # bayesian reconstruction
199 | mhat, xhat = self.bayesian_reconstruction(x)
200 | else:
201 | # Network forward
202 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
203 | xhat = np.squeeze(torch.sigmoid(self.Dec(z)[0]).cpu().detach().numpy())
204 | # Compute anomaly map
205 | #mhat = np.squeeze(np.abs(x - xhat))
206 | mhat = np.squeeze(x - xhat)
207 |
208 | # Keep only brain region
209 | mhat[x_mask[0, :, :] == 0] = 0
210 |
211 | # Get outputs
212 | anomaly_map = mhat
213 | score = np.mean(anomaly_map)
214 |
215 | self.E.train()
216 | self.Dec.train()
217 | return score, anomaly_map, xhat
218 |
219 | def bayesian_reconstruction(self, x):
220 |
221 | N = 100
222 | p_dropout = 0.20
223 | mhat = np.zeros((self.input_shape[1], self.input_shape[2]))
224 |
225 | # Network forward
226 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
227 | xhat = self.Dec(torch.nn.Dropout(p_dropout)(z))[0].cpu().detach().numpy()
228 |
229 | for i in np.arange(N):
230 | if z_mu is None: # apply dropout to z
231 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid(
232 | self.Dec(torch.nn.Dropout(p_dropout)(z))[0]).cpu().detach().numpy()) - x)) / N
233 | else: # sample z
234 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid(
235 | self.Dec(self.E.reparameterize(z_mu, z_logvar))[0]).cpu().detach().numpy()) - x)) / N
236 | return mhat, xhat
237 |
238 | def restoration_reconstruction(self, x):
239 | N = 300
240 | step = 1 * 1e-3
241 | x_rest = torch.tensor(x).cuda().float().unsqueeze(0)
242 |
243 | for i in np.arange(N):
244 | # Forward
245 | x_rest.requires_grad = True
246 | z, z_mu, z_logvar, f = self.E(x_rest)
247 | xhat = self.Dec(z)[0]
248 |
249 | # Compute loss
250 | lr = kornia.losses.total_variation(torch.tensor(x).cuda().float().unsqueeze(0) - torch.sigmoid(xhat))
251 | L = lr / (self.input[1] * self.input[2])
252 |
253 | # Get gradients
254 | gradients = torch.autograd.grad(L, x_rest, grad_outputs=None, retain_graph=True,
255 | create_graph=True,
256 | only_inputs=True, allow_unused=True)[0]
257 |
258 | x_rest = x_rest - gradients * step
259 | x_rest = x_rest.clone().detach()
260 | xhat = np.squeeze(x_rest.cpu().numpy())
261 |
262 | # Compute difference
263 | mhat = np.squeeze(np.abs(x - xhat))
264 |
265 | return mhat, xhat
266 |
267 | def display_losses(self, on_epoch_end=False):
268 |
269 | # Init info display
270 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs,
271 | self.i_iteration + 1, self.iterations)
272 | # Prepare values to show
273 | if on_epoch_end:
274 | lr = self.lr_epoch
275 | end = '\n'
276 | else:
277 | lr = self.lr_iteration
278 | end = '\r'
279 | # Init losses display
280 | info += "Reconstruction={:.4f}".format(lr)
281 | # Print losses
282 | et = str(datetime.timedelta(seconds=timer() - self.init_time))
283 | print(info + ',ET=' + et, end=end)
284 |
285 | def plot_learning_curves(self):
286 | def plot_subplot(axes, x, y, y_axis):
287 | axes.grid()
288 | axes.plot(x, y, 'o-')
289 | axes.set_ylabel(y_axis)
290 |
291 | fig, axes = plt.subplots(1, 1, figsize=(20, 15))
292 | plot_subplot(axes, np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss")
293 | plt.savefig(self.dir_results + 'learning_curve.png')
294 | plt.close()
--------------------------------------------------------------------------------
/code/methods/trainers/anoVAEGAN.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import kornia
3 | import json
4 | import torch
5 |
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 | import torch.nn.functional as F
9 |
10 | from scipy import ndimage
11 | from timeit import default_timer as timer
12 | from models.models import Encoder, Decoder, Discriminator
13 | from evaluation.utils import *
14 | from methods.losses.losses import kl_loss
15 | from datasets.utils import augment_input_batch
16 |
17 |
18 | class AnomalyDetectorAnoVAEGAN:
19 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25,
20 | load_weigths=False, n_blocks=5, dense=True, context=False, bayesian=False,
21 | loss_reconstruction='bce', restoration=False, alpha_kl=1, alpha_adversial=1):
22 |
23 | # Init input variables
24 | self.dir_results = dir_results
25 | self.item = item
26 | self.zdim = zdim
27 | self.lr = lr
28 | self.input_shape = input_shape
29 | self.epochs_to_test = epochs_to_test
30 | self.load_weigths = load_weigths
31 | self.n_blocks = n_blocks
32 | self.dense = dense
33 | self.context = context
34 | self.bayesian = bayesian
35 | self.loss_reconstruction = loss_reconstruction
36 | self.restoration = restoration
37 | self.alpha_kl = alpha_kl
38 | self.alpha_adversial = alpha_adversial
39 | self.normal_label = 1
40 | self.generated_label = 0
41 |
42 | # Init network
43 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks,
44 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=True, gap=False)
45 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0],
46 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks,
47 | gap=False)
48 | self.Disc = Discriminator(fin=self.E.backbone.nfeats//2**self.n_blocks, n_channels=input_shape[0],
49 | n_blocks=self.n_blocks)
50 |
51 | if torch.cuda.is_available():
52 | self.E.cuda()
53 | self.Dec.cuda()
54 | self.Disc.cuda()
55 |
56 | if self.load_weigths:
57 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth'))
58 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth'))
59 |
60 | # Set parameters
61 | self.params = list(self.E.parameters()) + list(self.Dec.parameters())
62 |
63 | # Set losses
64 | if self.loss_reconstruction == 'l2':
65 | self.Lr = torch.nn.MSELoss(reduction='sum')
66 | elif self.loss_reconstruction == 'bce':
67 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum')
68 | self.DLoss = torch.nn.BCEWithLogitsLoss()
69 | self.Lkl = kl_loss
70 |
71 | # Set optimizers
72 | self.opt = torch.optim.Adam(self.params, lr=self.lr)
73 | self.opt_Disc = torch.optim.Adam(self.Disc.parameters(), lr=self.lr * 10)
74 |
75 | # Init additional variables and objects
76 | self.epochs = 0.
77 | self.iterations = 0.
78 | self.init_time = 0.
79 | self.lr_iteration = 0.
80 | self.lr_epoch = 0.
81 | self.kl_iteration = 0.
82 | self.kl_epoch = 0.
83 | self.i_epoch = 0.
84 | self.train_generator = []
85 | self.dataset_test = []
86 | self.metrics = {}
87 | self.aucroc_lc = []
88 | self.auprc_lc = []
89 | self.lr_lc = []
90 | self.lkl_lc = []
91 | self.ldisc_epoch = 0.
92 | self.ldisc_iteration = 0.
93 | self.ladv_epoch = 0.
94 | self.ladv_iteration = 0.
95 | self.ladv_lc = []
96 | self.ldisc_lc = []
97 |
98 | def train(self, train_generator, epochs, dataset_test):
99 | self.epochs = epochs
100 | self.init_time = timer()
101 | self.train_generator = train_generator
102 | self.dataset_test = dataset_test
103 | self.iterations = len(self.train_generator)
104 |
105 | # Loop over epochs
106 | for self.i_epoch in range(self.epochs):
107 | # init epoch losses
108 | self.lr_epoch = 0
109 | self.kl_epoch = 0.
110 |
111 | # Loop over training dataset
112 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator):
113 |
114 | if self.context: # if context option, data augmentation to apply context
115 | (x_n_context, _) = augment_input_batch(x_n.copy())
116 | x_n_context = torch.tensor(x_n_context).cuda().float()
117 |
118 | # Move tensors to gpu
119 | x_n = torch.tensor(x_n).cuda().float()
120 |
121 | # Obtain latent space from normal sample via encoder
122 | if not self.context:
123 | z, z_mu, z_logvar, _ = self.E(x_n)
124 | else:
125 | z, z_mu, z_logvar, _ = self.E(x_n_context)
126 |
127 | # Obtain reconstructed images through decoder
128 | xhat, _ = self.Dec(z)
129 | if self.loss_reconstruction == 'l2':
130 | xhat = torch.sigmoid(xhat)
131 |
132 | # Forward discriminator
133 | d_x, _ = self.Disc(x_n)
134 | d_xhat, _ = self.Disc(torch.sigmoid(xhat))
135 |
136 | # Discriminator labels
137 | d_x_true = torch.tensor(self.normal_label * np.ones((self.train_generator.batch_size, 1))).cuda().float()
138 | d_f_false = torch.tensor(self.normal_label * np.ones((self.train_generator.batch_size, 1))).cuda().float()
139 | d_xhat_true = torch.tensor(self.generated_label * np.ones((self.train_generator.batch_size, 1))).cuda().float()
140 |
141 | # ------------D training------------------
142 | self.ldisc_iteration = 0.5 * F.binary_cross_entropy(d_x, d_x_true) + 0.5 * F.binary_cross_entropy(d_xhat, d_xhat_true)
143 |
144 | self.opt_Disc.zero_grad()
145 | self.ldisc_iteration.backward(retain_graph=True)
146 | self.opt_Disc.step()
147 |
148 | # ------------Encoder and Decoder training------------------
149 | # Discriminator prediction
150 | d_xhat, _ = self.Disc(torch.sigmoid(xhat))
151 |
152 | # Calculate criterion
153 | self.lr_iteration = self.Lr(xhat, x_n) / self.train_generator.batch_size # Reconstruction loss
154 | self.kl_iteration = self.Lkl(mu=z_mu, logvar=z_logvar) # kl loss (averaged per spatial feature)
155 | self.ladv_iteration = F.binary_cross_entropy(d_xhat, d_f_false) # Adversial loss
156 |
157 | # Init overall losses
158 | L = self.lr_iteration + self.alpha_kl * self.kl_iteration + self.alpha_adversial * self.ladv_iteration
159 |
160 | # Update weights
161 | L.backward() # Backward
162 | self.opt.step() # Update weights
163 | self.opt.zero_grad() # Clear gradients
164 |
165 | """
166 | ON ITERATION/EPOCH END PROCESS
167 | """
168 |
169 | # Display losses per iteration
170 | self.display_losses(on_epoch_end=False)
171 |
172 | # Update epoch's losses
173 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator)
174 | self.kl_epoch += self.kl_iteration.cpu().detach().numpy() / len(self.train_generator)
175 | self.ladv_epoch += self.ladv_iteration.cpu().detach().numpy() / len(self.train_generator)
176 | self.ldisc_epoch += self.ldisc_iteration.cpu().detach().numpy() / len(self.train_generator)
177 |
178 | # Epoch-end processes
179 | self.on_epoch_end()
180 |
181 | def on_epoch_end(self):
182 |
183 | # Display losses
184 | self.display_losses(on_epoch_end=True)
185 |
186 | # Update learning curves
187 | self.lr_lc.append(self.lr_epoch)
188 | self.lkl_lc.append(self.kl_epoch)
189 | self.ladv_lc.append(self.ladv_epoch)
190 | self.ldisc_lc.append(self.ldisc_epoch)
191 |
192 | # Each x epochs, test models and plot learning curves
193 | if (self.i_epoch + 1) % self.epochs_to_test == 0:
194 | # Save weights
195 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth')
196 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth')
197 |
198 | # Make predictions
199 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test)
200 |
201 | # Input to dataset
202 | self.dataset_test.Scores = Scores
203 | self.dataset_test.Mhat = Mhat
204 | self.dataset_test.Xhat = Xhat
205 |
206 | # Evaluate
207 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results)
208 | self.metrics = metrics
209 |
210 | # Save metrics as dict
211 | with open(self.dir_results + 'metrics.json', 'w') as fp:
212 | json.dump(metrics, fp)
213 | print(metrics)
214 |
215 | # Plot learning curve
216 | self.plot_learning_curves()
217 |
218 | # Save learning curves as dataframe
219 | self.aucroc_lc.append(metrics['AU_ROC'])
220 | self.auprc_lc.append(metrics['AU_PRC'])
221 | history = pd.DataFrame(list(zip(self.lr_lc, self.lr_lc, self.aucroc_lc, self.auprc_lc)),
222 | columns=['Lrec', 'Lkl', 'AUCROC', 'AUPRC'])
223 | history.to_csv(self.dir_results + 'lc_on_direct.csv')
224 |
225 | else:
226 | self.aucroc_lc.append(0)
227 | self.auprc_lc.append(0)
228 |
229 | def predict_score(self, x):
230 | self.E.eval()
231 | self.Dec.eval()
232 |
233 | # Prepare brain eroded mask
234 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets:
235 | x_mask = 1 - (x == 0).astype(np.int)
236 | if 'BRATS' in self.train_generator.dataset.dir_datasets:
237 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype)
238 | else:
239 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype)
240 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets:
241 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1]))
242 | x_mask[:, 14:-14, 14:-14] = 1
243 |
244 | # Get reconstruction error map
245 | if self.restoration: # restoration reconstruction
246 | mhat, xhat = self.restoration_reconstruction(x)
247 | elif self.bayesian: # bayesian reconstruction
248 | mhat, xhat = self.bayesian_reconstruction(x)
249 | else:
250 | # Network forward
251 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
252 | xhat = np.squeeze(torch.sigmoid(self.Dec(z)[0]).cpu().detach().numpy())
253 | # Compute anomaly map
254 | mhat = np.squeeze(np.abs(x - xhat))
255 |
256 | # Keep only brain region
257 | mhat[x_mask[0, :, :] == 0] = 0
258 |
259 | # Get outputs
260 | anomaly_map = mhat
261 | score = np.mean(anomaly_map)
262 |
263 | self.E.train()
264 | self.Dec.train()
265 | return score, anomaly_map, xhat
266 |
267 | def bayesian_reconstruction(self, x):
268 |
269 | N = 100
270 | p_dropout = 0.20
271 | mhat = np.zeros((self.input_shape[1], self.input_shape[2]))
272 |
273 | # Network forward
274 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
275 | xhat = self.Dec(torch.nn.Dropout(p_dropout)(z))[0].cpu().detach().numpy()
276 |
277 | for i in np.arange(N):
278 | if z_mu is None: # apply dropout to z
279 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid(
280 | self.Dec(torch.nn.Dropout(p_dropout)(z))[0]).cpu().detach().numpy()) - x)) / N
281 | else: # sample z
282 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid(
283 | self.Dec(self.E.reparameterize(z_mu, z_logvar))[0]).cpu().detach().numpy()) - x)) / N
284 | return mhat, xhat
285 |
286 | def restoration_reconstruction(self, x):
287 | N = 300
288 | step = 1 * 1e-3
289 | x_rest = torch.tensor(x).cuda().float().unsqueeze(0)
290 |
291 | for i in np.arange(N):
292 | # Forward
293 | x_rest.requires_grad = True
294 | z, z_mu, z_logvar, f = self.E(x_rest)
295 | xhat = self.Dec(z)[0]
296 |
297 | # Compute loss
298 | lr = kornia.losses.total_variation(torch.tensor(x).cuda().float().unsqueeze(0) - torch.sigmoid(xhat))
299 | L = lr / (self.input_shape[1] * self.input_shape[2])
300 |
301 | # Get gradients
302 | gradients = torch.autograd.grad(L, x_rest, grad_outputs=None, retain_graph=True,
303 | create_graph=True,
304 | only_inputs=True, allow_unused=True)[0]
305 |
306 | x_rest = x_rest - gradients * step
307 | x_rest = x_rest.clone().detach()
308 | xhat = np.squeeze(x_rest.cpu().numpy())
309 |
310 | # Compute difference
311 | mhat = np.squeeze(np.abs(x - xhat))
312 |
313 | return mhat, xhat
314 |
315 | def display_losses(self, on_epoch_end=False):
316 |
317 | # Init info display
318 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs,
319 | self.i_iteration + 1, self.iterations)
320 | # Prepare values to show
321 | if on_epoch_end:
322 | lr = self.lr_epoch
323 | ladv = self.ladv_epoch
324 | ldisc = self.ldisc_epoch
325 | lkl = self.kl_epoch
326 | end = '\n'
327 | else:
328 | lr = self.lr_iteration
329 | ladv = self.ladv_iteration
330 | ldisc = self.ldisc_iteration
331 | lkl = self.kl_iteration
332 | end = '\r'
333 | # Init losses display
334 | info += "Reconstruction={:.4f} || KL={:.4f} || ladv={:.4f} || ldisc={:.4f}".format(lr, lkl, ladv, ldisc)
335 | # Print losses
336 | et = str(datetime.timedelta(seconds=timer() - self.init_time))
337 | print(info + ', ET=' + et, end=end)
338 |
339 | def plot_learning_curves(self):
340 | def plot_subplot(axes, x, y, y_axis):
341 | axes.grid()
342 | axes.plot(x, y, 'o-')
343 | axes.set_ylabel(y_axis)
344 |
345 | fig, axes = plt.subplots(2, 2, figsize=(20, 15))
346 | plot_subplot(axes[0, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss")
347 | plot_subplot(axes[0, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.lkl_lc), "KL loss")
348 | plot_subplot(axes[1, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.ladv_lc), "Generator loss")
349 | plot_subplot(axes[1, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.ldisc_lc), "Discriminator loss")
350 | plt.savefig(self.dir_results + 'learning_curve.png')
351 | plt.close()
--------------------------------------------------------------------------------
/code/methods/trainers/fanoGAN.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import kornia
3 | import json
4 | import torch
5 |
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 | import torch.nn.functional as F
9 |
10 | from scipy import ndimage
11 | from timeit import default_timer as timer
12 | from models.models import Encoder, Decoder, Discriminator
13 | from evaluation.utils import *
14 | from methods.losses.losses import kl_loss
15 | from datasets.utils import augment_input_batch
16 |
17 |
18 | class AnomalyDetectorFanoGAN:
19 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25,
20 | load_weigths=False, n_blocks=4, dense=True, context=False, bayesian=False,
21 | loss_reconstruction='bce', restoration=False, alpha_adversial=1):
22 |
23 | # Init input variables
24 | self.dir_results = dir_results
25 | self.item = item
26 | self.zdim = zdim
27 | self.lr = lr
28 | self.input_shape = input_shape
29 | self.epochs_to_test = epochs_to_test
30 | self.load_weigths = load_weigths
31 | self.n_blocks = n_blocks
32 | self.dense = dense
33 | self.context = context
34 | self.bayesian = bayesian
35 | self.loss_reconstruction = loss_reconstruction
36 | self.restoration = restoration
37 | self.alpha_adversial = alpha_adversial
38 | self.normal_label = 1
39 | self.generated_label = 0
40 |
41 | # Init network
42 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks,
43 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=False, gap=False)
44 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0],
45 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks,
46 | gap=False)
47 | self.Disc = Discriminator(fin=self.E.backbone.nfeats//2**self.n_blocks, n_channels=input_shape[0],
48 | n_blocks=self.n_blocks)
49 |
50 | if torch.cuda.is_available():
51 | self.E.cuda()
52 | self.Dec.cuda()
53 | self.Disc.cuda()
54 |
55 | if self.load_weigths:
56 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth'))
57 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth'))
58 |
59 | # Set parameters
60 | self.params = list(self.E.parameters()) + list(self.Dec.parameters())
61 |
62 | # Set losses
63 | if self.loss_reconstruction == 'l2':
64 | self.Lr = torch.nn.MSELoss(reduction='sum')
65 | elif self.loss_reconstruction == 'bce':
66 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum')
67 | self.DLoss = torch.nn.BCEWithLogitsLoss()
68 |
69 | # Set optimizers
70 | self.opt = torch.optim.Adam(self.params, lr=self.lr)
71 | self.opt_Disc = torch.optim.Adam(self.Disc.parameters(), lr=self.lr * 10)
72 |
73 | # Init additional variables and objects
74 | self.epochs = 0.
75 | self.iterations = 0.
76 | self.init_time = 0.
77 | self.lr_iteration = 0.
78 | self.lr_epoch = 0.
79 | self.i_epoch = 0.
80 | self.train_generator = []
81 | self.dataset_test = []
82 | self.metrics = {}
83 | self.aucroc_lc = []
84 | self.auprc_lc = []
85 | self.lr_lc = []
86 | self.ldisc_epoch = 0.
87 | self.ldisc_iteration = 0.
88 | self.ladv_epoch = 0.
89 | self.ladv_iteration = 0.
90 | self.ladv_lc = []
91 | self.ldisc_lc = []
92 |
93 | def train(self, train_generator, epochs, dataset_test):
94 | self.epochs = epochs
95 | self.init_time = timer()
96 | self.train_generator = train_generator
97 | self.dataset_test = dataset_test
98 | self.iterations = len(self.train_generator)
99 |
100 | # Loop over epochs
101 | for self.i_epoch in range(self.epochs):
102 | # init epoch losses
103 | self.lr_epoch = 0.
104 | self.ldisc_epoch = 0.
105 | self.ladv_epoch = 0.
106 |
107 | # Loop over training dataset
108 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator):
109 |
110 | if self.context: # if context option, data augmentation to apply context
111 | (x_n_context, _) = augment_input_batch(x_n.copy())
112 | x_n_context = torch.tensor(x_n_context).cuda().float()
113 |
114 | # Move tensors to gpu
115 | x_n = torch.tensor(x_n).cuda().float()
116 |
117 | # Obtain latent space from normal sample via encoder
118 | if not self.context:
119 | z, _, _, _ = self.E(x_n)
120 | else:
121 | z, _, _, _ = self.E(x_n_context)
122 |
123 | # Obtain reconstructed images through decoder
124 | xhat, _ = self.Dec(z)
125 | if self.loss_reconstruction == 'l2':
126 | xhat = torch.sigmoid(xhat)
127 |
128 | # Forward discriminator
129 | d_x, _ = self.Disc(x_n)
130 | d_xhat, _ = self.Disc(torch.sigmoid(xhat))
131 |
132 | # Discriminator labels
133 | d_x_true = torch.tensor(self.normal_label * np.ones((self.train_generator.batch_size, 1))).cuda().float()
134 | d_f_false = torch.tensor(self.normal_label * np.ones((self.train_generator.batch_size, 1))).cuda().float()
135 | d_xhat_true = torch.tensor(self.generated_label * np.ones((self.train_generator.batch_size, 1))).cuda().float()
136 |
137 | # ------------D training------------------
138 | self.ldisc_iteration = 0.5 * F.binary_cross_entropy(d_x, d_x_true) + 0.5 * F.binary_cross_entropy(d_xhat, d_xhat_true)
139 |
140 | self.opt_Disc.zero_grad()
141 | self.ldisc_iteration.backward(retain_graph=True)
142 | self.opt_Disc.step()
143 |
144 | # ------------Encoder and Decoder training------------------
145 | # Discriminator prediction
146 | _, f_x = self.Disc(x_n)
147 | d_xhat, f_xhat = self.Disc(torch.sigmoid(xhat))
148 |
149 | # Calculate criterion
150 | self.lr_iteration = self.Lr(xhat, x_n) / self.train_generator.batch_size # Reconstruction loss
151 | self.ladv_iteration = torch.mean(torch.pow(f_x[-1] - f_xhat[-1], 2)) # Feature matching loss
152 |
153 | # Init overall losses
154 | L = self.lr_iteration + self.alpha_adversial * self.ladv_iteration
155 |
156 | # Update weights
157 | L.backward() # Backward
158 | self.opt.step() # Update weights
159 | self.opt.zero_grad() # Clear gradients
160 |
161 | """
162 | ON ITERATION/EPOCH END PROCESS
163 | """
164 |
165 | # Display losses per iteration
166 | self.display_losses(on_epoch_end=False)
167 |
168 | # Update epoch's losses
169 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator)
170 | self.ladv_epoch += self.ladv_iteration.cpu().detach().numpy() / len(self.train_generator)
171 | self.ldisc_epoch += self.ldisc_iteration.cpu().detach().numpy() / len(self.train_generator)
172 |
173 | # Epoch-end processes
174 | self.on_epoch_end()
175 |
176 | def on_epoch_end(self):
177 |
178 | # Display losses
179 | self.display_losses(on_epoch_end=True)
180 |
181 | # Update learning curves
182 | self.lr_lc.append(self.lr_epoch)
183 | self.ladv_lc.append(self.ladv_epoch)
184 | self.ldisc_lc.append(self.ldisc_epoch)
185 |
186 | # Each x epochs, test models and plot learning curves
187 | if (self.i_epoch + 1) % self.epochs_to_test == 0:
188 | # Save weights
189 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth')
190 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth')
191 |
192 | # Make predictions
193 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test)
194 |
195 | # Input to dataset
196 | self.dataset_test.Scores = Scores
197 | self.dataset_test.Mhat = Mhat
198 | self.dataset_test.Xhat = Xhat
199 |
200 | # Evaluate
201 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results)
202 | self.metrics = metrics
203 |
204 | # Save metrics as dict
205 | with open(self.dir_results + 'metrics.json', 'w') as fp:
206 | json.dump(metrics, fp)
207 | print(metrics)
208 |
209 | # Plot learning curve
210 | self.plot_learning_curves()
211 |
212 | # Save learning curves as dataframe
213 | self.aucroc_lc.append(metrics['AU_ROC'])
214 | self.auprc_lc.append(metrics['AU_PRC'])
215 | history = pd.DataFrame(list(zip(self.lr_lc, self.lr_lc, self.aucroc_lc, self.auprc_lc)),
216 | columns=['Lrec', 'Lkl', 'AUCROC', 'AUPRC'])
217 | history.to_csv(self.dir_results + 'lc_on_direct.csv')
218 |
219 | else:
220 | self.aucroc_lc.append(0)
221 | self.auprc_lc.append(0)
222 |
223 | def predict_score(self, x):
224 | self.E.eval()
225 | self.Dec.eval()
226 |
227 | # Prepare brain eroded mask
228 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets:
229 | x_mask = 1 - (x == 0).astype(np.int)
230 | if 'BRATS' in self.train_generator.dataset.dir_datasets:
231 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype)
232 | else:
233 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype)
234 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets:
235 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1]))
236 | x_mask[:, 14:-14, 14:-14] = 1
237 |
238 | # Get reconstruction error map
239 | if self.restoration: # restoration reconstruction
240 | mhat, xhat = self.restoration_reconstruction(x)
241 | elif self.bayesian: # bayesian reconstruction
242 | mhat, xhat = self.bayesian_reconstruction(x)
243 | else:
244 | # Network forward
245 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
246 | xhat = np.squeeze(torch.sigmoid(self.Dec(z)[0]).cpu().detach().numpy())
247 | # Compute anomaly map
248 | mhat = np.squeeze(np.abs(x - xhat))
249 |
250 | # Keep only brain region
251 | mhat[x_mask[0, :, :] == 0] = 0
252 |
253 | # Get outputs
254 | anomaly_map = mhat
255 | score = np.mean(anomaly_map)
256 |
257 | self.E.train()
258 | self.Dec.train()
259 | return score, anomaly_map, xhat
260 |
261 | def bayesian_reconstruction(self, x):
262 |
263 | N = 100
264 | p_dropout = 0.20
265 | mhat = np.zeros((self.input_shape[1], self.input_shape[2]))
266 |
267 | # Network forward
268 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
269 | xhat = self.Dec(torch.nn.Dropout(p_dropout)(z))[0].cpu().detach().numpy()
270 |
271 | for i in np.arange(N):
272 | if z_mu is None: # apply dropout to z
273 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid(
274 | self.Dec(torch.nn.Dropout(p_dropout)(z))[0]).cpu().detach().numpy()) - x)) / N
275 | else: # sample z
276 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid(
277 | self.Dec(self.E.reparameterize(z_mu, z_logvar))[0]).cpu().detach().numpy()) - x)) / N
278 | return mhat, xhat
279 |
280 | def restoration_reconstruction(self, x):
281 | N = 300
282 | step = 1 * 1e-3
283 | x_rest = torch.tensor(x).cuda().float().unsqueeze(0)
284 |
285 | for i in np.arange(N):
286 | # Forward
287 | x_rest.requires_grad = True
288 | z, z_mu, z_logvar, f = self.E(x_rest)
289 | xhat = self.Dec(z)[0]
290 |
291 | # Compute loss
292 | lr = kornia.losses.total_variation(torch.tensor(x).cuda().float().unsqueeze(0) - torch.sigmoid(xhat))
293 | L = lr / (self.input_shape[1] * self.input_shape[2])
294 |
295 | # Get gradients
296 | gradients = torch.autograd.grad(L, x_rest, grad_outputs=None, retain_graph=True,
297 | create_graph=True,
298 | only_inputs=True, allow_unused=True)[0]
299 |
300 | x_rest = x_rest - gradients * step
301 | x_rest = x_rest.clone().detach()
302 | xhat = np.squeeze(x_rest.cpu().numpy())
303 |
304 | # Compute difference
305 | mhat = np.squeeze(np.abs(x - xhat))
306 |
307 | return mhat, xhat
308 |
309 | def display_losses(self, on_epoch_end=False):
310 |
311 | # Init info display
312 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs,
313 | self.i_iteration + 1, self.iterations)
314 | # Prepare values to show
315 | if on_epoch_end:
316 | lr = self.lr_epoch
317 | ladv = self.ladv_epoch
318 | ldisc = self.ldisc_epoch
319 | end = '\n'
320 | else:
321 | lr = self.lr_iteration
322 | ladv = self.ladv_iteration
323 | ldisc = self.ldisc_iteration
324 | end = '\r'
325 | # Init losses display
326 | info += "Reconstruction={:.4f} || ladv={:.4f} || ldisc={:.4f}".format(lr, ladv, ldisc)
327 | # Print losses
328 | et = str(datetime.timedelta(seconds=timer() - self.init_time))
329 | print(info + ', ET=' + et, end=end)
330 |
331 | def plot_learning_curves(self):
332 | def plot_subplot(axes, x, y, y_axis):
333 | axes.grid()
334 | axes.plot(x, y, 'o-')
335 | axes.set_ylabel(y_axis)
336 |
337 | fig, axes = plt.subplots(2, 2, figsize=(20, 15))
338 | plot_subplot(axes[0, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss")
339 | plot_subplot(axes[1, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.ladv_lc), "Generator loss")
340 | plot_subplot(axes[0, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.ldisc_lc), "Discriminator loss")
341 | plt.savefig(self.dir_results + 'learning_curve.png')
342 | plt.close()
--------------------------------------------------------------------------------
/code/methods/trainers/gradCAMCons.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | import datetime
5 | import kornia
6 | import json
7 | import torch
8 |
9 | import pandas as pd
10 | import matplotlib.pyplot as plt
11 |
12 | from scipy import ndimage
13 | from timeit import default_timer as timer
14 | from models.models import Encoder, Decoder
15 | from evaluation.utils import *
16 | from methods.losses.losses import kl_loss
17 | from methods.losses.losses import log_barrier
18 | from sklearn.metrics import accuracy_score, f1_score
19 |
20 |
21 | class AnomalyDetectorGradCamCons:
22 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1 * 1e-4, input_shape=(1, 224, 224), epochs_to_test=25,
23 | load_weigths=False, n_blocks=5, dense=True, loss_reconstruction='bce', alpha_ae=0, alpha_kl=1,
24 | pre_training_epochs=0, level_cams=-4, t=25, p_activation_cam=0.2,
25 | expansion_loss_penalty='l2', avg_grads=True, gap=False):
26 |
27 | # Init input variables
28 | self.dir_results = dir_results
29 | self.item = item
30 | self.zdim = zdim
31 | self.lr = lr
32 | self.input_shape = input_shape
33 | self.epochs_to_test = epochs_to_test
34 | self.load_weigths = load_weigths
35 | self.n_blocks = n_blocks
36 | self.dense = dense
37 | self.loss_reconstruction = loss_reconstruction
38 | self.alpha_kl = alpha_kl
39 | self.pre_training_epochs = pre_training_epochs
40 | self.level_cams = level_cams
41 | self.t = t
42 | self.p_activation_cam = p_activation_cam
43 | self.expansion_loss_penalty = expansion_loss_penalty
44 | self.alpha_ae = alpha_ae
45 | self.avg_grads = avg_grads
46 | self.gap = gap
47 | self.scheduler = False
48 |
49 | # Init network
50 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks,
51 | spatial_dim=self.input_shape[1] // 2 ** self.n_blocks, variational=True, gap=gap)
52 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats // 2, n_channels=self.input_shape[0],
53 | dense=self.dense, n_blocks=self.n_blocks,
54 | spatial_dim=self.input_shape[1] // 2 ** self.n_blocks,
55 | gap=gap)
56 |
57 | if torch.cuda.is_available():
58 | self.E.cuda()
59 | self.Dec.cuda()
60 |
61 | if self.load_weigths:
62 | self.E.load_state_dict(torch.load(self.dir_results + 'encoder_weights.pth'))
63 | self.Dec.load_state_dict(torch.load(self.dir_results + 'decoder_weights.pth'))
64 |
65 | # Set parameters
66 | self.params = list(self.E.parameters()) + list(self.Dec.parameters())
67 |
68 | # Set losses
69 | if self.loss_reconstruction == 'l2':
70 | self.Lr = torch.nn.MSELoss(reduction='sum')
71 | elif self.loss_reconstruction == 'bce':
72 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum')
73 |
74 | self.Lkl = kl_loss
75 |
76 | # Set optimizers
77 | self.opt = torch.optim.Adam(self.params, lr=self.lr)
78 |
79 | # Init additional variables and objects
80 | self.epochs = 0.
81 | self.iterations = 0.
82 | self.init_time = 0.
83 | self.lr_iteration = 0.
84 | self.lr_epoch = 0.
85 | self.kl_iteration = 0.
86 | self.kl_epoch = 0.
87 | self.i_epoch = 0.
88 | self.lae_iteration = 0.
89 | self.lae_epoch = 0.
90 | self.train_generator = []
91 | self.dataset_test = []
92 | self.metrics = {}
93 | self.aucroc_lc = []
94 | self.auprc_lc = []
95 | self.auroc_det = []
96 | self.lr_lc = []
97 | self.lkl_lc = []
98 | self.lae_lc = []
99 | self.auroc_det_lc = []
100 | self.refCam = 0.
101 |
102 | def train(self, train_generator, epochs, test_dataset):
103 | self.epochs = epochs
104 | self.init_time = timer()
105 | self.train_generator = train_generator
106 | self.dataset_test = test_dataset
107 | self.iterations = len(self.train_generator)
108 |
109 | # Loop over epochs
110 | for self.i_epoch in range(self.epochs):
111 | # init epoch losses
112 | self.lr_epoch = 0
113 | self.kl_epoch = 0.
114 | self.lae_epoch = 0.
115 |
116 | if self.scheduler and self.expansion_loss_penalty == 'log_barrier' and self.i_epoch > self.pre_training_epochs:
117 | self.t = self.t * 1.01
118 | print(self.t, end='\n')
119 |
120 | # Loop over training dataset
121 | for self.i_iteration, (x_n, y_n, x_a, y_a) in enumerate(self.train_generator):
122 | # p = q
123 |
124 | # Move tensors to gpu
125 | x_n = torch.tensor(x_n).cuda().float()
126 |
127 | # Obtain latent space from normal sample via encoder
128 | z, z_mu, z_logvar, allF = self.E(x_n)
129 |
130 | # Obtain reconstructed images through decoder
131 | xhat, _ = self.Dec(z)
132 | if self.loss_reconstruction == 'l2':
133 | xhat = torch.sigmoid(xhat)
134 |
135 | # Calculate criterion
136 | self.lr_iteration = self.Lr(xhat, x_n) / (self.train_generator.batch_size) # Reconstruction loss
137 | self.kl_iteration = self.Lkl(mu=z_mu, logvar=z_logvar) / (self.train_generator.batch_size) # kl loss
138 |
139 | # Init overall losses
140 | L = self.lr_iteration + self.alpha_kl * self.kl_iteration
141 |
142 | # ---- Compute Attention expansion loss
143 |
144 | # Compute grad-cams
145 | gcam = grad_cam(allF[self.level_cams], torch.sum(z_mu), normalization='sigm',
146 | avg_grads=True)
147 | # Restore original shape
148 | gcam = torch.nn.functional.interpolate(gcam.unsqueeze(1),
149 | size=(self.input_shape[-1], self.input_shape[-1]),
150 | mode='bilinear',
151 | align_corners=True).squeeze()
152 | self.lae_iteration = torch.mean(gcam)
153 |
154 | if self.i_epoch > self.pre_training_epochs:
155 | # Compute attention expansion loss
156 | if self.expansion_loss_penalty == 'l1': # L1
157 | lae = torch.mean(torch.abs(-torch.mean(gcam, (-1)) + 1 - self.p_activation_cam))
158 | elif self.expansion_loss_penalty == 'l2': # L2
159 | lae = torch.mean(torch.sqrt(torch.pow(-torch.mean(gcam, (-1)) + 1 - self.p_activation_cam, 2)))
160 | elif self.expansion_loss_penalty == 'log_barrier':
161 | z = -torch.mean(gcam, (1, 2)).unsqueeze(-1) + 1
162 | lae = log_barrier(z - self.p_activation_cam, t=self.t) / self.train_generator.batch_size
163 |
164 | # Update overall losses
165 | L += self.alpha_ae * lae.squeeze()
166 |
167 | # Update weights
168 | L.backward() # Backward
169 | self.opt.step() # Update weights
170 | self.opt.zero_grad() # Clear gradients
171 |
172 | """
173 | ON ITERATION/EPOCH END PROCESS
174 | """
175 |
176 | # Display losses per iteration
177 | self.display_losses(on_epoch_end=False)
178 |
179 | # Update epoch's losses
180 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator)
181 | self.kl_epoch += self.kl_iteration.cpu().detach().numpy() / len(self.train_generator)
182 | self.lae_epoch += self.lae_iteration.cpu().detach().numpy() / len(self.train_generator)
183 |
184 | # Epoch-end processes
185 | self.on_epoch_end()
186 |
187 | def on_epoch_end(self):
188 |
189 | # Display losses
190 | self.display_losses(on_epoch_end=True)
191 |
192 | # Update learning curves
193 | self.lr_lc.append(self.lr_epoch)
194 | self.lkl_lc.append(self.kl_epoch)
195 | self.lae_lc.append(self.lae_epoch)
196 |
197 | # Each x epochs, test models and plot learning curves
198 | if (self.i_epoch + 1) % self.epochs_to_test == 0:
199 | # Save weights
200 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth')
201 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth')
202 |
203 | # Evaluate
204 | if self.i_epoch > (0):
205 | # Make predictions
206 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test)
207 |
208 | # Input to dataset
209 | self.dataset_test.Scores = Scores
210 | self.dataset_test.Mhat = Mhat
211 | self.dataset_test.Xhat = Xhat
212 |
213 | # Evaluate anomaly detection
214 | auroc_det, auprc_det, th_det = evaluate_anomaly_detection(self.dataset_test.Y, self.dataset_test.Scores,
215 | dir_out=self.dir_results,
216 | range=[np.min(Scores) - np.std(Scores),
217 | np.max(Scores) + np.std(Scores)],
218 | tit='kl')
219 | acc = accuracy_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int'))
220 | fs = f1_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int'))
221 |
222 | metrics_detection = {'auroc_det': auroc_det, 'auprc_det': auprc_det, 'th_det': th_det, 'acc_det': acc,
223 | 'fs_det': fs}
224 | print(metrics_detection)
225 |
226 | # Evaluate anomaly localization
227 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results)
228 | self.metrics = metrics
229 |
230 | # Save metrics as dict
231 | with open(self.dir_results + 'metrics.json', 'w') as fp:
232 | json.dump(metrics, fp)
233 | print(metrics)
234 |
235 | # Plot learning curve
236 | self.plot_learning_curves()
237 |
238 | # Save learning curves as dataframe
239 | self.aucroc_lc.append(metrics['AU_ROC'])
240 | self.auprc_lc.append(metrics['AU_PRC'])
241 | self.auroc_det_lc.append(auroc_det)
242 | history = pd.DataFrame(
243 | list(zip(self.lr_lc, self.lkl_lc, self.lae_lc, self.aucroc_lc, self.auprc_lc, self.auroc_det_lc)),
244 | columns=['Lrec', 'Lkl', 'lae', 'AUCROC', 'AUPRC', 'AUROC_det'])
245 | history.to_csv(self.dir_results + 'lc_on_direct.csv')
246 |
247 | else:
248 | self.aucroc_lc.append(0)
249 | self.auprc_lc.append(0)
250 | self.auroc_det_lc.append(0)
251 |
252 | def predict_score(self, x):
253 | self.E.eval()
254 | self.Dec.eval()
255 |
256 | # brain mask
257 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets:
258 | x_mask = 1 - (x == 0).astype(np.int)
259 | if 'BRATS' in self.train_generator.dataset.dir_datasets:
260 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 6, 6))).astype(x_mask.dtype)
261 | else:
262 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype)
263 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets:
264 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1]))
265 | x_mask[:, 14:-14, 14:-14] = 1
266 |
267 | # Get reconstruction error map
268 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
269 | xhat = torch.sigmoid(self.Dec(z)[0]).squeeze().detach().cpu().numpy()
270 |
271 | # Compute gradients-cams
272 | gcam = grad_cam(f[self.level_cams], torch.sum(z_mu), normalization='min_max',
273 | avg_grads=self.avg_grads)
274 |
275 | # Restore original shape
276 | mhat = torch.nn.functional.interpolate(gcam.unsqueeze(0), size=(self.input_shape[-1], self.input_shape[-1]),
277 | mode='bilinear', align_corners=True).squeeze().detach().cpu().numpy()
278 |
279 | # brain mask - Keep only brain region
280 | if 'BRATS' in self.train_generator.dataset.dir_datasets or \
281 | 'PhysioNet' in self.train_generator.dataset.dir_datasets or \
282 | 'MVTEC' in self.train_generator.dataset.dir_datasets:
283 | mhat[x_mask[0, :, :] == 0] = 0
284 |
285 | # Get outputs
286 | anomaly_map = mhat
287 |
288 | # brain mask - Keep only brain region
289 | if 'BRATS' in self.train_generator.dataset.dir_datasets or \
290 | 'PhysioNet' in self.train_generator.dataset.dir_datasets or \
291 | 'MVTEC' in self.train_generator.dataset.dir_datasets:
292 | score = np.std(anomaly_map[x_mask[0, :, :] == 1])
293 | else:
294 | score = np.std(anomaly_map)
295 |
296 | self.E.train()
297 | self.Dec.train()
298 | return score, anomaly_map, xhat
299 |
300 | def display_losses(self, on_epoch_end=False):
301 |
302 | # Init info display
303 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs,
304 | self.i_iteration + 1, self.iterations)
305 | # Prepare values to show
306 | if on_epoch_end:
307 | lr = self.lr_epoch
308 | lkl = self.kl_epoch
309 | lae = self.lae_epoch
310 | end = '\n'
311 | else:
312 | lr = self.lr_iteration
313 | lkl = self.kl_iteration
314 | lae = self.lae_iteration
315 | end = '\r'
316 |
317 | # Init losses display
318 | info += "Reconstruction={:.4f} || KL={:.4f} || Lae={:.8f} ".format(lr, lkl, lae)
319 |
320 | # Print losses
321 | et = str(datetime.timedelta(seconds=timer() - self.init_time))
322 | print(info + ', ET=' + et, end=end)
323 |
324 | def plot_learning_curves(self):
325 | def plot_subplot(axes, x, y, y_axis):
326 | axes.grid()
327 | axes.plot(x, y, 'o-')
328 | axes.set_ylabel(y_axis)
329 |
330 | fig, axes = plt.subplots(2, 2, figsize=(20, 15))
331 | plot_subplot(axes[0, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss")
332 | plot_subplot(axes[0, 1], np.arange(self.i_epoch + 1) + 1, np.array(self.lkl_lc), "KL loss")
333 | plot_subplot(axes[1, 0], np.arange(self.i_epoch + 1) + 1, np.array(self.lae_lc), "AE loss")
334 | plt.savefig(self.dir_results + 'learning_curve.png')
335 | plt.close()
336 |
337 |
338 | def grad_cam(activations, output, normalization='relu_min_max', avg_grads=False, norm_grads=False):
339 | def normalize(grads):
340 | l2_norm = torch.sqrt(torch.mean(torch.pow(grads, 2))) + 1e-5
341 | return grads * torch.pow(l2_norm, -1)
342 |
343 | # Obtain gradients
344 | gradients = torch.autograd.grad(output, activations, grad_outputs=None, retain_graph=True, create_graph=True,
345 | only_inputs=True, allow_unused=True)[0]
346 |
347 | # Normalize gradients
348 | if norm_grads:
349 | gradients = normalize(gradients)
350 |
351 | # pool the gradients across the channels
352 | if avg_grads:
353 | gradients = torch.mean(gradients, dim=[2, 3])
354 | # gradients = torch.nn.functional.softmax(gradients)
355 | gradients = gradients.unsqueeze(-1).unsqueeze(-1)
356 |
357 | # weight activation maps
358 | '''
359 | if 'relu' in normalization:
360 | GCAM = torch.sum(torch.relu(gradients * activations), 1)
361 | else:
362 | GCAM = gradients * activations
363 | if 'abs' in normalization:
364 | GCAM = torch.abs(GCAM)
365 | GCAM = torch.sum(GCAM, 1)
366 | '''
367 | GCAM = torch.mean(activations, 1)
368 |
369 | # Normalize CAM
370 | if 'sigm' in normalization:
371 | GCAM = torch.sigmoid(GCAM)
372 | if 'min' in normalization:
373 | norm_value = torch.min(torch.max(GCAM, -1)[0], -1)[0].unsqueeze(-1).unsqueeze(-1) + 1e-3
374 | GCAM = GCAM - norm_value
375 | if 'max' in normalization:
376 | norm_value = torch.max(torch.max(GCAM, -1)[0], -1)[0].unsqueeze(-1).unsqueeze(-1) + 1e-3
377 | GCAM = GCAM * norm_value.pow(-1)
378 | if 'tanh' in normalization:
379 | GCAM = torch.tanh(GCAM)
380 | if 'clamp' in normalization:
381 | GCAM = GCAM.clamp(max=1)
382 |
383 | return GCAM
384 |
385 |
--------------------------------------------------------------------------------
/code/methods/trainers/gradCons.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import kornia
3 | import json
4 | import torch
5 |
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 |
9 | from scipy import ndimage
10 | from timeit import default_timer as timer
11 | from datasets.utils import augment_input_batch
12 | from models.models import Encoder, Decoder, GradConCAEEncoder, GradConCAEDecoder
13 | from evaluation.utils import *
14 | from sklearn.metrics import accuracy_score, f1_score
15 | from methods.losses.losses import kl_loss
16 |
17 |
18 | class AnomalyDetectorGradCons:
19 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25,
20 | load_weigths=False, n_blocks=5, dense=True, loss_reconstruction='bce', alpha_gradloss=1,
21 | n_target_filters=1, alpha_kl=10, variational=True):
22 |
23 | # Init input variables
24 | self.dir_results = dir_results
25 | self.item = item
26 | self.zdim = zdim
27 | self.lr = lr
28 | self.input_shape = input_shape
29 | self.epochs_to_test = epochs_to_test
30 | self.load_weigths = load_weigths
31 | self.n_blocks = n_blocks
32 | self.dense = dense
33 | self.loss_reconstruction = loss_reconstruction
34 | self.alpha_gradloss = alpha_gradloss
35 | self.n_target_filters = n_target_filters
36 | self.alpha_kl = alpha_kl
37 | self.variational = variational
38 |
39 | # Init network
40 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks,
41 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=self.variational)
42 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0],
43 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks)
44 | '''
45 | # Init network
46 | self.E = GradConCAEEncoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks,
47 | spatial_dim=self.input_shape[1]//2**self.n_blocks)
48 | self.Dec = GradConCAEDecoder(fin=self.zdim, n_channels=self.input_shape[0],
49 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks)
50 | '''
51 |
52 | if torch.cuda.is_available():
53 | self.E.cuda()
54 | self.Dec.cuda()
55 |
56 | if self.load_weigths:
57 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth'))
58 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth'))
59 |
60 | # Set parameters
61 | self.params = list(self.E.parameters()) + list(self.Dec.parameters())
62 |
63 | # Set losses
64 | if self.loss_reconstruction == 'l2':
65 | self.Lr = torch.nn.MSELoss(reduction='sum')
66 | elif self.loss_reconstruction == 'bce':
67 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum')
68 | self.Lkl = kl_loss
69 |
70 | # Set optimizers
71 | self.opt = torch.optim.Adam(self.params, lr=self.lr)
72 |
73 | # Init additional variables and objects
74 | self.epochs = 0.
75 | self.i_iteration = 0.
76 | self.iterations = 0.
77 | self.init_time = 0.
78 | self.lr_iteration = 0.
79 | self.lr_epoch = 0.
80 | self.lgrad_iteration = 0.
81 | self.lgrad_epoch = 0.
82 | self.i_epoch = 0.
83 | self.kl_iteration = 0.
84 | self.kl_epoch = 0.
85 | self.train_generator = []
86 | self.dataset_test = []
87 | self.metrics = {}
88 | self.aucroc_lc = []
89 | self.auprc_lc = []
90 | self.lr_lc = []
91 | self.lgrad_lc = []
92 |
93 | def train(self, train_generator, epochs, dataset_test):
94 | self.epochs = epochs
95 | self.init_time = timer()
96 | self.train_generator = train_generator
97 | self.dataset_test = dataset_test
98 | self.iterations = len(self.train_generator)
99 | # Init gradients module
100 | self.ref_grad = self.initialise()
101 | self.k = 0
102 |
103 | # Loop over epochs
104 | for self.i_epoch in range(self.epochs):
105 | # init epoch losses
106 | self.lr_epoch = 0
107 | self.lgrad_epoch = 0.
108 | self.kl_epoch = 0.
109 |
110 | # Loop over training dataset
111 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator):
112 |
113 | # Move tensors to gpu
114 | x_n = torch.tensor(x_n).cuda().float()
115 |
116 | # Obtain latent space from normal sample via encoder
117 | z, z_mu, z_logvar, _ = self.E(x_n)
118 |
119 | # Obtain reconstructed images through decoder
120 | xhat, _ = self.Dec(z)
121 |
122 | # Calculate criterion
123 | self.lr_iteration = self.Lr(xhat, x_n) / (self.train_generator.batch_size) # Reconstruction loss
124 | self.kl_iteration = self.alpha_kl * self.Lkl(mu=z_mu, logvar=z_logvar) / (self.train_generator.batch_size) # kl loss
125 |
126 | # Calculate gradient loss
127 | # self.kl_iteration.backward(create_graph=True, retain_graph=True)
128 | # self.lr_iteration.backward(create_graph=True, retain_graph=True)
129 |
130 | grad_loss = 0.
131 | i = 0
132 | for module in self.iterlist():
133 | if isinstance(module, torch.nn.Conv2d):
134 | wrt = module.weight
135 | #target_grad = wrt.grad
136 | target_grad = torch.autograd.grad(self.kl_iteration, wrt, create_graph=True, retain_graph=True)[0]
137 | if self.k > 0:
138 | grad_loss += -1 * torch.nn.functional.cosine_similarity(target_grad.view(-1, 1), self.ref_grad[i].view(-1, 1) / self.k, dim=0).squeeze()
139 | self.ref_grad[i] += target_grad.detach()
140 | i += 1
141 | if self.k == 0:
142 | self.lgrad_iteration = torch.tensor(1.).cuda().float()
143 | else:
144 | self.lgrad_iteration = grad_loss / i # Average over layers
145 |
146 | # Get overall losses - we already computed gradients from Lr, so it is not neccesary again
147 | # L = self.lr_iteration + self.lgrad_iteration * self.alpha_gradloss
148 | L = self.lr_iteration + self.lgrad_iteration * self.alpha_gradloss
149 |
150 | L.backward() # Backward
151 |
152 | # Update the reference gradient
153 | i = 0
154 | for module in self.iterlist():
155 | if isinstance(module, torch.nn.Conv2d):
156 | self.ref_grad[i] += module.weight.grad
157 | i += 1
158 |
159 | # Update weights
160 | self.opt.step() # Update weights
161 | self.opt.zero_grad() # Clear gradients
162 |
163 | # Update k counter
164 | self.k += 1
165 |
166 | """
167 | ON ITERATION/EPOCH END PROCESS
168 | """
169 |
170 | # Display losses per iteration
171 | self.display_losses(on_epoch_end=False)
172 |
173 | # Update epoch's losses
174 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator)
175 | self.lgrad_epoch += self.lgrad_iteration.cpu().detach().numpy() / len(self.train_generator)
176 |
177 | # Epoch-end processes
178 | self.on_epoch_end()
179 |
180 | def on_epoch_end(self):
181 |
182 | # Display losses
183 | self.display_losses(on_epoch_end=True)
184 |
185 | # Update learning curves
186 | self.lr_lc.append(self.lr_epoch)
187 | self.lgrad_lc.append(self.lgrad_epoch)
188 |
189 | # Each x epochs, test models and plot learning curves
190 | if (self.i_epoch + 1) % self.epochs_to_test == 0:
191 | # Save weights
192 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth')
193 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth')
194 |
195 | # Make predictions
196 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test)
197 |
198 | # Input to dataset
199 | self.dataset_test.Scores = Scores
200 | self.dataset_test.Mhat = Mhat
201 | self.dataset_test.Xhat = Xhat
202 |
203 | # Evaluate anomaly detection
204 | auroc_det, auprc_det, th_det = evaluate_anomaly_detection(self.dataset_test.Y, self.dataset_test.Scores,
205 | dir_out=self.dir_results)
206 | acc = accuracy_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int'))
207 | fs = f1_score(np.ravel(Y), np.ravel((Scores > th_det)).astype('int'))
208 |
209 | metrics_detection = {'auroc_det': auroc_det, 'auprc_det': auprc_det, 'th_det': th_det, 'acc_det': acc,
210 | 'fs_det': fs}
211 | print(metrics_detection)
212 |
213 | # Evaluate anomaly localization
214 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results)
215 | self.metrics = metrics
216 |
217 | # Save metrics as dict
218 | with open(self.dir_results + 'metrics.json', 'w') as fp:
219 | json.dump(metrics, fp)
220 | print(metrics)
221 |
222 | # Plot learning curve
223 | self.plot_learning_curves()
224 |
225 | # Save learning curves as dataframe
226 | self.aucroc_lc.append(metrics['AU_ROC'])
227 | self.auprc_lc.append(metrics['AU_PRC'])
228 | history = pd.DataFrame(list(zip(self.lr_lc, self.lgrad_lc, self.aucroc_lc, self.auprc_lc)),
229 | columns=['Lrec', 'LgradCons', 'AUCROC', 'AUPRC'])
230 | history.to_csv(self.dir_results + 'lc_on_direct.csv')
231 |
232 | else:
233 | self.aucroc_lc.append(0)
234 | self.auprc_lc.append(0)
235 |
236 | def predict_score(self, x):
237 | self.E.eval()
238 | self.Dec.eval()
239 |
240 | # Prepare brain eroded mask
241 | x_mask = 1 - (x == 0).astype(np.int)
242 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 6, 6))).astype(x_mask.dtype)
243 |
244 | # Get reconstruction error map
245 | x_n = torch.tensor(x).cuda().float().unsqueeze(0)
246 | z, z_mu, z_logvar, _ = self.E(x_n)
247 | xhat = self.Dec(z)[0]
248 |
249 | # Calculate criterion
250 | lr_sample = self.Lr(xhat, x_n) / (self.input_shape[1] * self.input_shape[2]) # Reconstruction loss
251 | kl_sample = self.alpha_kl * self.Lkl(mu=z_mu, logvar=z_logvar) # kl loss
252 |
253 | # Calculate gradient loss for anomaly score
254 | #lr_sample.backward(create_graph=True, retain_graph=True)
255 |
256 | score = 0.
257 | i = 0
258 | for module in self.iterlist():
259 | if isinstance(module, torch.nn.Conv2d):
260 | wrt = module.weight
261 | #target_grad = wrt.grad
262 | target_grad = torch.autograd.grad(kl_sample, wrt, create_graph=True, retain_graph=True)[0]
263 | if self.k > 0:
264 | score += 1 * torch.nn.functional.cosine_similarity(target_grad.view(-1, 1),
265 | self.ref_grad[i].view(-1, 1) / self.k,
266 | dim=0).squeeze()
267 | i += 1
268 |
269 | score = - score.cpu().detach().numpy() / i # Average over layers
270 |
271 | # Compute anomaly map
272 | xhat = torch.sigmoid(xhat).cpu().detach().numpy()
273 | mhat = np.squeeze(np.abs(x - xhat))
274 |
275 | # Keep only brain region
276 | mhat[x_mask[0, :, :] == 0] = 0
277 |
278 | # Get outputs
279 | anomaly_map = mhat
280 |
281 | self.opt.zero_grad() # Clear gradients
282 | self.E.train()
283 | self.Dec.train()
284 | return score, anomaly_map, xhat
285 |
286 | def display_losses(self, on_epoch_end=False):
287 |
288 | # Init info display
289 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs,
290 | self.i_iteration + 1, self.iterations)
291 | # Prepare values to show
292 | if on_epoch_end:
293 | lr = self.lr_epoch
294 | lgradCons = self.lgrad_epoch
295 | end = '\n'
296 | else:
297 | lr = self.lr_iteration
298 | lgradCons = self.lgrad_iteration
299 | end = '\r'
300 | # Init losses display
301 | info += "Reconstruction={:.4f} || gradCons={:.4f}".format(lr, lgradCons)
302 | # Print losses
303 | et = str(datetime.timedelta(seconds=timer() - self.init_time))
304 | print(info + ',ET=' + et, end=end)
305 |
306 | def plot_learning_curves(self):
307 | def plot_subplot(axes, x, y, y_axis):
308 | axes.grid()
309 | axes.plot(x, y, 'o-')
310 | axes.set_ylabel(y_axis)
311 |
312 | fig, axes = plt.subplots(1, 2, figsize=(20, 15))
313 | plot_subplot(axes[0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss")
314 | plot_subplot(axes[1], np.arange(self.i_epoch + 1) + 1, np.array(self.lgrad_lc), "gradCons loss")
315 | plt.savefig(self.dir_results + 'learning_curve.png')
316 | plt.close()
317 |
318 | def initialise(self):
319 |
320 | ref_grad = []
321 | n = 0
322 | for i in np.arange(0, len(list(self.E.children()))): # For each block
323 | for module in list(list(self.E.children())[i].modules()): # For each layer in the block
324 | if isinstance(module, torch.nn.Conv2d):
325 | # Get weight
326 | wrt = module.weight
327 | # Init 0s tensor
328 | grad_init_module = torch.zeros(wrt.shape).cuda().float()
329 | # Incorportate to list
330 | ref_grad.append(grad_init_module)
331 |
332 | n += 1
333 | if n == self.n_target_filters:
334 | break
335 |
336 | return ref_grad
337 |
338 | def iterlist(self):
339 | '''
340 | layers = []
341 | n = 0
342 | for i in np.arange(0, len(list(self.E.children()))): # For each block
343 | for module in list(list(self.E.children())[i].modules()): # For each layer in the block
344 | if isinstance(module, torch.nn.Conv2d):
345 | # Get layer
346 | layers.append(module)
347 |
348 | n += 1
349 | if n == self.n_target_filters:
350 | break
351 | if n == self.n_target_filters:
352 | break
353 |
354 | return layers
355 | '''
356 |
357 | layers = []
358 | n = 0
359 | for i in np.arange(0, len(list(self.E.children()))): # For each block
360 |
361 | for module in list(list(self.E.children())[i].modules()): # For each layer in the block
362 | if isinstance(module, torch.nn.Conv2d):
363 | # Get layer
364 | layers.append(module)
365 |
366 | n += 1
367 | if n == self.n_target_filters:
368 | break
369 | if n == self.n_target_filters:
370 | break
371 |
372 | return layers
--------------------------------------------------------------------------------
/code/methods/trainers/histEqualization.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import kornia
3 | import json
4 | import torch
5 |
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 |
9 | from scipy import ndimage
10 | from timeit import default_timer as timer
11 | from models.models import Encoder, Decoder
12 | from evaluation.utils import *
13 | from methods.losses.losses import kl_loss
14 | from datasets.utils import augment_input_batch
15 | from skimage.exposure import equalize_hist
16 |
17 |
18 | class AnomalyDetectorHistEqualization:
19 | def __init__(self, dir_results, item=['flair']):
20 |
21 | # Init input variables
22 | self.dir_results = dir_results
23 | self.item = item
24 | self.train_generator = []
25 | self.dataset_test = []
26 |
27 | def train(self, train_generator, epochs, dataset_test):
28 | self.train_generator = train_generator
29 | self.dataset_test = dataset_test
30 | print('No training for Histogram matching method.', end='\n')
31 |
32 | # Make predictions
33 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test)
34 |
35 | # Input to dataset
36 | self.dataset_test.Scores = Scores
37 | self.dataset_test.Mhat = Mhat
38 | self.dataset_test.Xhat = Xhat
39 |
40 | # Evaluate
41 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results)
42 | self.metrics = metrics
43 |
44 | # Save metrics as dict
45 | with open(self.dir_results + 'metrics.json', 'w') as fp:
46 | json.dump(metrics, fp)
47 | print(metrics)
48 |
49 | def predict_score(self, x):
50 | def equalize_img(img):
51 | """
52 | Perform histogram equalization on the given image.
53 | """
54 | # Create equalization mask
55 | mask = np.zeros_like(img)
56 | mask[img > 0] = 1
57 |
58 | # Equalize
59 | img = img*255
60 | img = equalize_hist(img.astype(np.int64), nbins=256, mask=mask)
61 |
62 | # Assure that background still is 0
63 | img *= mask
64 | img *= (1/255)
65 |
66 | return img
67 |
68 | # Prepare brain eroded mask
69 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets:
70 | x_mask = 1 - (x == 0).astype(np.int)
71 | if 'BRATS' in self.train_generator.dataset.dir_datasets:
72 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype)
73 | else:
74 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype)
75 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets:
76 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1]))
77 | x_mask[:, 14:-14, 14:-14] = 1
78 |
79 | # Get reconstruction error map
80 | mhat = equalize_img(x)
81 | mhat = mhat[0, :, :]
82 |
83 | # Keep only brain region
84 | mhat[x_mask[0, :, :] == 0] = 0
85 |
86 | # Get outputs
87 | anomaly_map = mhat
88 | score = np.mean(anomaly_map)
89 |
90 | return score, anomaly_map, x
91 |
92 |
93 |
--------------------------------------------------------------------------------
/code/methods/trainers/vae.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import kornia
3 | import json
4 | import torch
5 |
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 |
9 | from scipy import ndimage
10 | from timeit import default_timer as timer
11 | from models.models import Encoder, Decoder
12 | from evaluation.utils import *
13 | from methods.losses.losses import kl_loss
14 | from datasets.utils import augment_input_batch
15 |
16 |
17 | class AnomalyDetectorVAE:
18 | def __init__(self, dir_results, item=['flair'], zdim=32, lr=1*1e-4, input_shape=(1, 224, 224), epochs_to_test=25,
19 | load_weigths=False, n_blocks=5, dense=True, context=False, bayesian=False,
20 | loss_reconstruction='bce', restoration=False, alpha_kl=0.1):
21 |
22 | # Init input variables
23 | self.dir_results = dir_results
24 | self.item = item
25 | self.zdim = zdim
26 | self.lr = lr
27 | self.input_shape = input_shape
28 | self.epochs_to_test = epochs_to_test
29 | self.load_weigths = load_weigths
30 | self.n_blocks = n_blocks
31 | self.dense = dense
32 | self.context = context
33 | self.bayesian = bayesian
34 | self.loss_reconstruction = loss_reconstruction
35 | self.restoration = restoration
36 | self.alpha_kl = alpha_kl
37 |
38 | # Init network
39 | self.E = Encoder(fin=self.input_shape[0], zdim=self.zdim, dense=self.dense, n_blocks=self.n_blocks,
40 | spatial_dim=self.input_shape[1]//2**self.n_blocks, variational=True, gap=False)
41 | self.Dec = Decoder(fin=self.zdim, nf0=self.E.backbone.nfeats//2, n_channels=self.input_shape[0],
42 | dense=self.dense, n_blocks=self.n_blocks, spatial_dim=self.input_shape[1]//2**self.n_blocks,
43 | gap=False)
44 |
45 | if torch.cuda.is_available():
46 | self.E.cuda()
47 | self.Dec.cuda()
48 |
49 | if self.load_weigths:
50 | self.E.load_state_dict(torch.load(self.dir_results + '/encoder_weights.pth'))
51 | self.Dec.load_state_dict(torch.load(self.dir_results + '/decoder_weights.pth'))
52 |
53 | # Set parameters
54 | self.params = list(self.E.parameters()) + list(self.Dec.parameters())
55 |
56 | # Set losses
57 | if self.loss_reconstruction == 'l2':
58 | self.Lr = torch.nn.MSELoss(reduction='sum')
59 | elif self.loss_reconstruction == 'bce':
60 | self.Lr = torch.nn.BCEWithLogitsLoss(reduction='sum')
61 |
62 | self.Lkl = kl_loss
63 |
64 | # Set optimizers
65 | self.opt = torch.optim.Adam(self.params, lr=self.lr)
66 |
67 | # Init additional variables and objects
68 | self.epochs = 0.
69 | self.iterations = 0.
70 | self.init_time = 0.
71 | self.lr_iteration = 0.
72 | self.lr_epoch = 0.
73 | self.kl_iteration = 0.
74 | self.kl_epoch = 0.
75 | self.i_epoch = 0.
76 | self.train_generator = []
77 | self.dataset_test = []
78 | self.metrics = {}
79 | self.aucroc_lc = []
80 | self.auprc_lc = []
81 | self.lr_lc = []
82 | self.lkl_lc = []
83 |
84 | def train(self, train_generator, epochs, dataset_test):
85 | self.epochs = epochs
86 | self.init_time = timer()
87 | self.train_generator = train_generator
88 | self.dataset_test = dataset_test
89 | self.iterations = len(self.train_generator)
90 |
91 | # Loop over epochs
92 | for self.i_epoch in range(self.epochs):
93 | # init epoch losses
94 | self.lr_epoch = 0
95 | self.kl_epoch = 0.
96 |
97 | # Loop over training dataset
98 | for self.i_iteration, (x_n, y_n, _, _) in enumerate(self.train_generator):
99 |
100 | if self.context: # if context option, data augmentation to apply context
101 | (x_n_context, _) = augment_input_batch(x_n.copy())
102 | x_n_context = torch.tensor(x_n_context).cuda().float()
103 |
104 | # Move tensors to gpu
105 | x_n = torch.tensor(x_n).cuda().float()
106 |
107 | # Obtain latent space from normal sample via encoder
108 | if not self.context:
109 | z, z_mu, z_logvar, _ = self.E(x_n)
110 | else:
111 | z, z_mu, z_logvar, _ = self.E(x_n_context)
112 |
113 | # Obtain reconstructed images through decoder
114 | xhat, _ = self.Dec(z)
115 | if self.loss_reconstruction == 'l2':
116 | xhat = torch.sigmoid(xhat)
117 |
118 | # Calculate criterion
119 | self.lr_iteration = self.Lr(xhat, x_n) / self.train_generator.batch_size # Reconstruction loss
120 | self.kl_iteration = self.Lkl(mu=z_mu, logvar=z_logvar) # kl loss (averaged per spatial feature)
121 |
122 | # Init overall losses
123 | L = self.lr_iteration + self.alpha_kl * self.kl_iteration
124 |
125 | # Update weights
126 | L.backward() # Backward
127 | self.opt.step() # Update weights
128 | self.opt.zero_grad() # Clear gradients
129 |
130 | """
131 | ON ITERATION/EPOCH END PROCESS
132 | """
133 |
134 | # Display losses per iteration
135 | self.display_losses(on_epoch_end=False)
136 |
137 | # Update epoch's losses
138 | self.lr_epoch += self.lr_iteration.cpu().detach().numpy() / len(self.train_generator)
139 | self.kl_epoch += self.kl_iteration.cpu().detach().numpy() / len(self.train_generator)
140 |
141 | # Epoch-end processes
142 | self.on_epoch_end()
143 |
144 | def on_epoch_end(self):
145 |
146 | # Display losses
147 | self.display_losses(on_epoch_end=True)
148 |
149 | # Update learning curves
150 | self.lr_lc.append(self.lr_epoch)
151 | self.lkl_lc.append(self.kl_epoch)
152 |
153 | # Each x epochs, test models and plot learning curves
154 | if (self.i_epoch + 1) % self.epochs_to_test == 0:
155 | # Save weights
156 | torch.save(self.E.state_dict(), self.dir_results + 'encoder_weights.pth')
157 | torch.save(self.Dec.state_dict(), self.dir_results + 'decoder_weights.pth')
158 |
159 | # Make predictions
160 | Y, Scores, M, Mhat, X, Xhat = inference_dataset(self, self.dataset_test)
161 |
162 | # Input to dataset
163 | self.dataset_test.Scores = Scores
164 | self.dataset_test.Mhat = Mhat
165 | self.dataset_test.Xhat = Xhat
166 |
167 | # Evaluate
168 | metrics, th = evaluate_anomaly_localization(self.dataset_test, save_maps=True, dir_out=self.dir_results)
169 | self.metrics = metrics
170 |
171 | # Save metrics as dict
172 | with open(self.dir_results + 'metrics.json', 'w') as fp:
173 | json.dump(metrics, fp)
174 | print(metrics)
175 |
176 | # Plot learning curve
177 | self.plot_learning_curves()
178 |
179 | # Save learning curves as dataframe
180 | self.aucroc_lc.append(metrics['AU_ROC'])
181 | self.auprc_lc.append(metrics['AU_PRC'])
182 | history = pd.DataFrame(list(zip(self.lr_lc, self.lr_lc, self.aucroc_lc, self.auprc_lc)),
183 | columns=['Lrec', 'Lkl', 'AUCROC', 'AUPRC'])
184 | history.to_csv(self.dir_results + 'lc_on_direct.csv')
185 |
186 | else:
187 | self.aucroc_lc.append(0)
188 | self.auprc_lc.append(0)
189 |
190 | def predict_score(self, x):
191 | self.E.eval()
192 | self.Dec.eval()
193 |
194 | # Prepare brain eroded mask
195 | if 'BRATS' in self.train_generator.dataset.dir_datasets or 'PhysioNet' in self.train_generator.dataset.dir_datasets:
196 | x_mask = 1 - (x == 0).astype(np.int)
197 | if 'BRATS' in self.train_generator.dataset.dir_datasets:
198 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 10, 10))).astype(x_mask.dtype)
199 | else:
200 | x_mask = ndimage.binary_erosion(x_mask, structure=np.ones((1, 3, 3))).astype(x_mask.dtype)
201 | elif 'MVTEC' in self.train_generator.dataset.dir_datasets:
202 | x_mask = np.zeros((1, x.shape[-1], x.shape[-1]))
203 | x_mask[:, 14:-14, 14:-14] = 1
204 |
205 | # Get reconstruction error map
206 | if self.restoration: # restoration reconstruction
207 | mhat, xhat = self.restoration_reconstruction(x)
208 | elif self.bayesian: # bayesian reconstruction
209 | mhat, xhat = self.bayesian_reconstruction(x)
210 | else:
211 | # Network forward
212 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
213 | xhat = np.squeeze(torch.sigmoid(self.Dec(z)[0]).cpu().detach().numpy())
214 | # Compute anomaly map
215 | mhat = np.squeeze(np.abs(x - xhat))
216 | # mhat = np.squeeze((x - xhat))
217 |
218 | # Keep only brain region
219 | mhat[x_mask[0, :, :] == 0] = 0
220 |
221 | # Get outputs
222 | anomaly_map = mhat
223 | score = np.mean(anomaly_map)
224 |
225 | self.E.train()
226 | self.Dec.train()
227 | return score, anomaly_map, xhat
228 |
229 | def bayesian_reconstruction(self, x):
230 |
231 | N = 100
232 | p_dropout = 0.20
233 | mhat = np.zeros((self.input_shape[1], self.input_shape[2]))
234 |
235 | # Network forward
236 | z, z_mu, z_logvar, f = self.E(torch.tensor(x).cuda().float().unsqueeze(0))
237 | xhat = self.Dec(torch.nn.Dropout(p_dropout)(z))[0].cpu().detach().numpy()
238 |
239 | for i in np.arange(N):
240 | if z_mu is None: # apply dropout to z
241 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid(
242 | self.Dec(torch.nn.Dropout(p_dropout)(z))[0]).cpu().detach().numpy()) - x)) / N
243 | else: # sample z
244 | mhat += np.squeeze(np.abs(np.squeeze(torch.sigmoid(
245 | self.Dec(self.E.reparameterize(z_mu, z_logvar))[0]).cpu().detach().numpy()) - x)) / N
246 | return mhat, xhat
247 |
248 | def restoration_reconstruction(self, x):
249 | N = 300
250 | step = 1 * 1e-3
251 | x_rest = torch.tensor(x).cuda().float().unsqueeze(0)
252 |
253 | for i in np.arange(N):
254 | # Forward
255 | x_rest.requires_grad = True
256 | z, z_mu, z_logvar, f = self.E(x_rest)
257 | xhat = self.Dec(z)[0]
258 |
259 | # Compute loss
260 | lr = kornia.losses.total_variation(torch.tensor(x).cuda().float().unsqueeze(0) - torch.sigmoid(xhat))
261 | L = lr / (self.input_shape[1] * self.input_shape[2])
262 |
263 | # Get gradients
264 | gradients = torch.autograd.grad(L, x_rest, grad_outputs=None, retain_graph=True,
265 | create_graph=True,
266 | only_inputs=True, allow_unused=True)[0]
267 |
268 | x_rest = x_rest - gradients * step
269 | x_rest = x_rest.clone().detach()
270 | xhat = np.squeeze(x_rest.cpu().numpy())
271 |
272 | # Compute difference
273 | mhat = np.squeeze(np.abs(x - xhat))
274 |
275 | return mhat, xhat
276 |
277 | def display_losses(self, on_epoch_end=False):
278 |
279 | # Init info display
280 | info = "[INFO] Epoch {}/{} -- Step {}/{}: ".format(self.i_epoch + 1, self.epochs,
281 | self.i_iteration + 1, self.iterations)
282 | # Prepare values to show
283 | if on_epoch_end:
284 | lr = self.lr_epoch
285 | lkl = self.kl_epoch
286 | end = '\n'
287 | else:
288 | lr = self.lr_iteration
289 | lkl = self.kl_iteration
290 | end = '\r'
291 | # Init losses display
292 | info += "Reconstruction={:.4f} || KL={:.4f}".format(lr, lkl)
293 | # Print losses
294 | et = str(datetime.timedelta(seconds=timer() - self.init_time))
295 | print(info + ', ET=' + et, end=end)
296 |
297 | def plot_learning_curves(self):
298 | def plot_subplot(axes, x, y, y_axis):
299 | axes.grid()
300 | axes.plot(x, y, 'o-')
301 | axes.set_ylabel(y_axis)
302 |
303 | fig, axes = plt.subplots(1, 2, figsize=(20, 15))
304 | plot_subplot(axes[0], np.arange(self.i_epoch + 1) + 1, np.array(self.lr_lc), "Reconstruc loss")
305 | plot_subplot(axes[1], np.arange(self.i_epoch + 1) + 1, np.array(self.lkl_lc), "KL loss")
306 | plt.savefig(self.dir_results + 'learning_curve.png')
307 | plt.close()
--------------------------------------------------------------------------------
/code/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import numpy as np
4 |
5 | # ------------------
6 | # RESIDUAL MODELS
7 |
8 |
9 | class Resnet(torch.nn.Module):
10 | def __init__(self, in_channels, n_blocks=4):
11 | super(Resnet, self).__init__()
12 | self.n_blocks = n_blocks
13 | self.nfeats = 512 // (2**(4-n_blocks))
14 |
15 | self.input = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=(2, 2),
16 | padding=(3, 3), bias=False)
17 | resnet18_model = torchvision.models.resnet18(pretrained=False)
18 | self.resnet = torch.nn.Sequential(*(list(resnet18_model.children())[i+4] for i in range(0, self.n_blocks)))
19 |
20 | # placeholder for the gradients
21 | self.gradients = None
22 |
23 | def forward(self, x):
24 | x = self.input(x)
25 | F = []
26 | for iBlock in range(0, self.n_blocks):
27 | x = list(self.resnet.children())[iBlock](x)
28 | F.append(x)
29 |
30 | return x, F
31 |
32 |
33 | class Encoder(torch.nn.Module):
34 | def __init__(self, fin=1, zdim=128, dense=False, variational=False, n_blocks=4, spatial_dim=7,
35 | gap=False):
36 | super(Encoder, self).__init__()
37 | self.fin = fin
38 | self.zdim = zdim
39 | self.dense = dense
40 | self.n_blocks = n_blocks
41 | self.gap = gap
42 | self.variational = variational
43 |
44 | # 1) Feature extraction
45 | self.backbone = Resnet(in_channels=self.fin, n_blocks=self.n_blocks)
46 | # 2) Latent space (dense or spatial)
47 | if self.dense: # dense
48 | if gap:
49 | if self.variational:
50 | self.mu = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1))
51 | self.log_var = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1))
52 | else:
53 | self.z = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1))
54 | else:
55 | if self.variational:
56 | self.mu = torch.nn.Linear(self.backbone.nfeats*spatial_dim**2, zdim)
57 | self.log_var = torch.nn.Linear(self.backbone.nfeats*spatial_dim**2, zdim)
58 | else:
59 | self.z = torch.nn.Linear(self.backbone.nfeats * spatial_dim ** 2, zdim)
60 | else: # spatial
61 | if self.variational:
62 | self.mu = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1))
63 | self.log_var = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1))
64 | else:
65 | self.z = torch.nn.Conv2d(self.backbone.nfeats, zdim, (1, 1))
66 |
67 | def reparameterize(self, mu, log_var):
68 | std = torch.exp(0.5 * log_var) # standard deviation
69 | eps = torch.randn_like(std)
70 |
71 | sample = mu + (eps * std) # sampling
72 | return sample
73 |
74 | def forward(self, x):
75 |
76 | # 1) Feature extraction
77 | x, allF = self.backbone(x)
78 |
79 | if self.dense and not self.gap:
80 | x = torch.nn.Flatten()(x)
81 |
82 | if self.dense and self.gap:
83 | x = torch.nn.functional.adaptive_avg_pool2d(x, 1)
84 | x = x.view(x.size(0), self.backbone.nfeats, 1, 1)
85 |
86 | # 2) Latent space
87 | if self.variational:
88 | # get `mu` and `log_var`
89 | z_mu = self.mu(x)
90 | z_logvar = self.log_var(x)
91 | # get the latent vector through reparameterization
92 | z = self.reparameterize(z_mu, z_logvar)
93 | else:
94 | z = self.z(x)
95 | z_mu, z_logvar = None, None
96 |
97 | return z, z_mu, z_logvar, allF
98 |
99 |
100 | class Decoder(torch.nn.Module):
101 |
102 | def __init__(self, fin=256, nf0=128, n_channels=1, dense=False, n_blocks=4, spatial_dim=7,
103 | gap=False):
104 | super(Decoder, self).__init__()
105 | self.n_blocks = n_blocks
106 | self.dense = dense
107 | self.spatial_dim = spatial_dim
108 | self.fin = fin
109 | self.gap = gap
110 | self.nf0 = nf0
111 |
112 | if self.dense and not self.gap:
113 | self.dense_layer = torch.nn.Sequential(torch.nn.Linear(fin, nf0*spatial_dim**2))
114 | if not dense:
115 | self.dense_layer = torch.nn.Sequential(torch.nn.Conv2d(fin, nf0, (1, 1)))
116 |
117 | # Set number of input and output channels
118 | n_filters_in = [nf0//2**(i) for i in range(0, self.n_blocks + 1)]
119 | n_filters_out = [nf0//2**(i+1) for i in range(0, self.n_blocks)] + [n_channels]
120 |
121 | self.blocks = torch.nn.ModuleList()
122 | for i in np.arange(0, self.n_blocks):
123 | self.blocks.append(torch.nn.Sequential(BasicBlock(n_filters_in[i], n_filters_out[i], downsample=True),
124 | BasicBlock(n_filters_out[i], n_filters_out[i])))
125 | self.out = torch.nn.Conv2d(n_filters_in[-1], n_filters_out[-1], kernel_size=(3, 3), padding=(1, 1))
126 |
127 | self.n_filters_in = n_filters_in
128 | self.n_filters_out = n_filters_out
129 |
130 | def forward(self, x):
131 |
132 | if self.dense and not self.gap:
133 | x = self.dense_layer(x)
134 | x = torch.nn.Unflatten(-1, (self.nf0, self.spatial_dim, self.spatial_dim))(x)
135 |
136 | if self.dense and self.gap:
137 | x = torch.nn.functional.interpolate(x, scale_factor=self.spatial_dim)
138 |
139 | if not self.dense:
140 | x = self.dense_layer(x)
141 |
142 | for i in np.arange(0, self.n_blocks):
143 | x = self.blocks[i](x)
144 | f = x
145 | out = self.out(f)
146 |
147 | return out, f
148 |
149 |
150 | class ResBlock(torch.nn.Module):
151 |
152 | def __init__(self, fin, fout):
153 | super(ResBlock, self).__init__()
154 | self.conv_straight_1 = torch.nn.Conv2d(fin, fout, kernel_size=(3, 3), padding=(1, 1))
155 | self.bn_1 = torch.nn.BatchNorm2d(fout)
156 | self.conv_straight_2 = torch.nn.Conv2d(fout, fout, kernel_size=(3, 3), padding=(1, 1))
157 | self.bn_2 = torch.nn.BatchNorm2d(fout)
158 | self.conv_skip = torch.nn.Conv2d(fin, fout, kernel_size=(3, 3), padding=(1, 1))
159 | self.upsampling = torch.nn.Upsample(scale_factor=(2, 2))
160 | self.relu = torch.nn.ReLU()
161 |
162 | def forward(self, x):
163 |
164 | x_st = self.upsampling(x)
165 | x_st = self.conv_straight_1(x_st)
166 | x_st = self.relu(x_st)
167 | x_st = self.bn_1(x_st)
168 | x_st = self.conv_straight_2(x_st)
169 | x_st = self.relu(x_st)
170 | x_st = self.bn_2(x_st)
171 |
172 | x_sk = self.upsampling(x)
173 | x_sk = self.conv_skip(x_sk)
174 |
175 | out = x_sk + x_st
176 |
177 | return out
178 |
179 |
180 | class BasicBlock(torch.nn.Module):
181 |
182 | def __init__(self, inplanes=32, planes=64, stride=1, downsample=False, bn=True):
183 | super().__init__()
184 | norm_layer = torch.nn.BatchNorm2d
185 | self.conv1 = torch.nn.Conv2d(inplanes, planes, kernel_size=(3, 3), padding=(1, 1))
186 | self.bn1 = norm_layer(planes)
187 | self.relu = torch.nn.ReLU(inplace=True)
188 | self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=(3, 3), padding=(1, 1))
189 | self.bn2 = norm_layer(planes)
190 | self.downsample = downsample
191 | if self.downsample:
192 | self.downsample_layer_conv = torch.nn.Sequential(torch.nn.Conv2d(inplanes, planes, kernel_size=(1, 1)))
193 | self.downsample_layer = torch.nn.Upsample(scale_factor=(2, 2))
194 | self.downsample_layer_bn = norm_layer(planes)
195 | self.stride = stride
196 | self.bn = bn
197 |
198 | def forward(self, x):
199 | identity = x
200 |
201 | out = self.conv1(x)
202 | if self.downsample:
203 | out = self.downsample_layer(out)
204 | if self.bn:
205 | out = self.bn1(out)
206 | out = self.relu(out)
207 |
208 | out = self.conv2(out)
209 | if self.bn:
210 | out = self.bn2(out)
211 |
212 | if self.downsample:
213 | identity = self.downsample_layer_conv(identity)
214 | identity = self.downsample_layer(identity)
215 | if self.bn:
216 | identity = self.downsample_layer_bn(identity)
217 |
218 | out += identity
219 | out = self.relu(out)
220 |
221 | return out
222 |
223 |
224 | class ConvBlock(torch.nn.Module):
225 |
226 | def __init__(self, fin, fout):
227 | super(ConvBlock, self).__init__()
228 |
229 | self.conv = torch.nn.Conv2d(fin, fout, kernel_size=4, stride=2, padding=1, bias=False)
230 | self.bn = torch.nn.BatchNorm2d(fout)
231 | self.act = torch.nn.LeakyReLU(0.2, inplace=False)
232 |
233 | def forward(self, x):
234 |
235 | out = self.act(self.bn(self.conv(x)))
236 |
237 | return out
238 |
239 |
240 | class DeconvBlock(torch.nn.Module):
241 |
242 | def __init__(self, fin, fout):
243 | super(DeconvBlock, self).__init__()
244 |
245 | self.conv = torch.nn.Conv2d(fin, fout, kernel_size=4, stride=2, padding=1, bias=False)
246 | self.bn = torch.nn.BatchNorm2d(fout)
247 | self.act = torch.nn.LeakyReLU(0.2, inplace=False)
248 |
249 | def forward(self, x):
250 |
251 | out = self.act(self.bn(self.conv(x)))
252 |
253 | return out
254 |
255 |
256 | class Discriminator(torch.nn.Module):
257 | def __init__(self, fin=32, n_channels=1, n_blocks=4, type='DCGAN'):
258 | super(Discriminator, self).__init__()
259 |
260 | # Number of feature extractor blocks
261 | self.n_blocks = n_blocks
262 | # Set number of input and output channels
263 | n_filters_in = [n_channels] + [fin*(2**i) for i in range(0, self.n_blocks)]
264 | n_filters_out = [fin*(2**i) for i in range(0, self.n_blocks+1)]
265 | # Number of output features
266 | self.nFeats = n_filters_out[-1]
267 | # Prepare blocks:
268 | self.blocks = torch.nn.ModuleList()
269 | for i in np.arange(0, self.n_blocks + 1):
270 | if type == 'DCGAN':
271 | self.blocks.append(ConvBlock(n_filters_in[i], n_filters_out[i]))
272 | # Output for binary clasification
273 | self.pool_feats = torch.nn.AdaptiveAvgPool2d((7, 7))
274 | self.out = torch.nn.Sequential(torch.nn.AdaptiveAvgPool2d(1),
275 | torch.nn.Flatten(),
276 | torch.nn.Linear(self.nFeats, 1),
277 | torch.nn.Sigmoid())
278 |
279 | def forward(self, x):
280 |
281 | for i in np.arange(0, self.n_blocks+1):
282 | x = self.blocks[i](x)
283 | f = self.pool_feats(x)
284 |
285 | out = self.out(f)
286 |
287 | return out, f
288 |
289 |
290 | # ------------------
291 | # SEQUENTIAL MODELS
292 |
293 | class GradConCAEEncoder(torch.nn.Module):
294 | def __init__(self, fin=1, zdim=128, dense=False, variational=False, n_blocks=4, spatial_dim=7,
295 | gap=False):
296 | super(GradConCAEEncoder, self).__init__()
297 | self.fin = fin
298 | self.zdim = zdim
299 | self.dense = dense
300 | self.n_blocks = n_blocks
301 | self.gap = gap
302 | self.variational = variational
303 |
304 | # 1) Feature extraction
305 | self.backbone = torch.nn.Sequential(GradConCAEDownBlock(1, 32),
306 | GradConCAEDownBlock(32, 32),
307 | GradConCAEDownBlock(32, 64),
308 | GradConCAEDownBlock(64, 64))
309 | # 2) Latent space (dense or spatial)
310 | if self.dense: # dense
311 | if gap:
312 | if self.variational:
313 | self.mu = torch.nn.Conv2d(64, zdim, (1, 1))
314 | self.log_var = torch.nn.Conv2d(64, zdim, (1, 1))
315 | else:
316 | self.z = torch.nn.Conv2d(64, zdim, (1, 1))
317 | else:
318 | if self.variational:
319 | self.mu = torch.nn.Linear(64*spatial_dim**2, zdim)
320 | self.log_var = torch.nn.Linear(64*spatial_dim**2, zdim)
321 | else:
322 | self.z = torch.nn.Linear(64 * spatial_dim ** 2, zdim)
323 | else: # spatial
324 | if self.variational:
325 | self.mu = torch.nn.Conv2d(64, zdim, (1, 1))
326 | self.log_var = torch.nn.Conv2d(64, zdim, (1, 1))
327 | else:
328 | self.z = torch.nn.Conv2d(64, zdim, (1, 1))
329 |
330 | def reparameterize(self, mu, log_var):
331 | std = torch.exp(0.5 * log_var) # standard deviation
332 | eps = torch.randn_like(std)
333 |
334 | sample = mu + (eps * std) # sampling
335 | return sample
336 |
337 | def forward(self, x):
338 |
339 | # 1) Feature extraction
340 | x = self.backbone(x)
341 |
342 | if self.dense and not self.gap:
343 | x = torch.nn.Flatten()(x)
344 |
345 | if self.dense and self.gap:
346 | x = torch.nn.functional.adaptive_avg_pool2d(x, 1)
347 | x = x.view(x.size(0), self.backbone.nfeats, 1, 1)
348 |
349 | # 2) Latent space
350 | if self.variational:
351 | # get `mu` and `log_var`
352 | z_mu = self.mu(x)
353 | z_logvar = self.log_var(x)
354 | # get the latent vector through reparameterization
355 | z = self.reparameterize(z_mu, z_logvar)
356 | else:
357 | z = self.z(x)
358 | z_mu, z_logvar = None, None
359 |
360 | return z, z_mu, z_logvar, None
361 |
362 |
363 | class GradConCAEDecoder(torch.nn.Module):
364 |
365 | def __init__(self, fin=256, nf0=128, n_channels=1, dense=False, n_blocks=4, spatial_dim=7,
366 | gap=False):
367 | super(GradConCAEDecoder, self).__init__()
368 | self.n_blocks = n_blocks
369 | self.dense = dense
370 | self.spatial_dim = spatial_dim
371 | self.fin = fin
372 | self.gap = gap
373 |
374 | if self.dense and not self.gap:
375 | self.dense_layer = torch.nn.Sequential(torch.nn.Linear(fin, 64*spatial_dim**2))
376 | if not dense:
377 | self.dense_layer = torch.nn.Sequential(torch.nn.Conv2d(fin, 64, (1, 1)))
378 |
379 | # Set number of input and output channels
380 | n_filters_in = [64, 64, 32, 32]
381 | n_filters_out = [64, 32, 32] + [n_channels]
382 |
383 | self.blocks = torch.nn.ModuleList()
384 | for i in np.arange(0, 4):
385 | if i == 0:
386 | self.blocks.append(torch.nn.Sequential(GradConCAEUpBlock(n_filters_in[i], n_filters_out[i], padding=2)))
387 | else:
388 | self.blocks.append(torch.nn.Sequential(GradConCAEUpBlock(n_filters_in[i], n_filters_out[i], padding=1)))
389 |
390 | self.n_filters_in = n_filters_in
391 | self.n_filters_out = n_filters_out
392 |
393 | def forward(self, x):
394 |
395 | if self.dense and not self.gap:
396 | x = self.dense_layer(x)
397 | x = torch.nn.Unflatten(-1, (self.fin, self.spatial_dim, self.spatial_dim))(x)
398 |
399 | if self.dense and self.gap:
400 | x = torch.nn.functional.interpolate(x, scale_factor=self.spatial_dim)
401 |
402 | if not self.dense:
403 | x = self.dense_layer(x)
404 |
405 | for i in np.arange(0, 4):
406 | x = self.blocks[i](x)
407 | f = x
408 | out = x
409 |
410 | return out, f
411 |
412 |
413 | class GradConCAEDownBlock(torch.nn.Module):
414 | def __init__(self, in_channel, out_channel, stride=2, padding=2):
415 | super(GradConCAEDownBlock, self).__init__()
416 |
417 | self.conv1 = torch.nn.Conv2d(in_channel, out_channel, (4, 4), stride=stride, padding=padding)
418 | self.act1 = torch.nn.ReLU()
419 |
420 | def forward(self, x):
421 |
422 | x = self.conv1(x)
423 | x = self.act1(x)
424 |
425 | return x
426 |
427 |
428 | class GradConCAEUpBlock(torch.nn.Module):
429 | def __init__(self, in_channel, out_channel, stride=2, padding=2):
430 | super(GradConCAEUpBlock, self).__init__()
431 |
432 | self.conv1 = torch.nn.ConvTranspose2d(in_channel, out_channel, (4, 4), stride=stride, padding=padding)
433 | self.act1 = torch.nn.ReLU()
434 |
435 | def forward(self, x):
436 | x = self.conv1(x)
437 | x = self.act1(x)
438 |
439 | return x
--------------------------------------------------------------------------------
/data/MICCAI_BraTS_2019_Data_Training/README.md:
--------------------------------------------------------------------------------
1 | ### Experiments on BRATS dataset
2 |
3 | To reproduce the experiments carried out in BRATS dataset, you should download the data from the following link: [BRATS](https://drive.google.com/file/d/1NgHMcIcfVGcoAYWd0ABI6AEZCkpFpvJ8/view?usp=sharing). Then, you can adecuate the MRI volumes and produce data slipts using the following code:
4 |
5 | ```
6 | python adecuate_BRATS.py --dir_datasets ../data/MICCAI_BraTS_2019_Data_Training/ --dir_out ../data/BRATS_5slices/ --scan flair --nSlices 5
7 | ```
8 |
9 | For more information on data description and usage requirements please reach out the original authors in the following [LINK](https://www.med.upenn.edu/cbica/brats2020/data.html).
10 |
--------------------------------------------------------------------------------
/data/PhysioNet-ICH/README.md:
--------------------------------------------------------------------------------
1 | ### Experiments on PhysioNet-ICH dataset
2 |
3 | To reproduce the experiments carried out in BRATS dataset, you should download the data from the following link: [PhysioNet-ICH]([https://drive.google.com/file/d/1NgHMcIcfVGcoAYWd0ABI6AEZCkpFpvJ8/view?usp=sharing](https://physionet.org/content/ct-ich/1.3.1/)). Then, you can adecuate the CT scans and produce data slipts using the following code:
4 |
5 | ```
6 | python adecuate_PhysionNet_ICH.py
7 | ```
8 |
9 | For more information on data description and usage requirements please reach out the original authors in the following [LINK]([https://www.med.upenn.edu/cbica/brats2020/data.html](https://physionet.org/content/ct-ich/1.3.1/)).
10 |
--------------------------------------------------------------------------------
/data/PhysioNet-ICH/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/visualizations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jusiro/constrained_anomaly_segmentation/c5c963c47ffc924b8611824c09fdfabfc92a48e1/images/visualizations.png
--------------------------------------------------------------------------------