├── README.md ├── animated.gif ├── common ├── data.py ├── dto │ ├── CaeDto.py │ ├── Dto.py │ ├── MetricMeasuresDto.py │ └── UnetDto.py ├── inference │ ├── CaeEncInference.py │ ├── CaeInference.py │ ├── Inference.py │ └── UnetInference.py ├── metrics.py ├── model │ ├── Cae3D.py │ └── Unet3D.py └── util.py ├── learner ├── CaePredictionLearner.py ├── CaeReconstructionLearner.py ├── CaeStepLearner.py ├── Learner.py └── UnetSegmentationLearner.py ├── requirements.txt ├── sample_output.png ├── test_sdm_resampling.py ├── test_shape_reconstruction.py ├── test_shape_reconstruction_CurveAnalysis.py ├── test_unet_segmentation.py ├── tester ├── CaeReconstructionTester.py ├── CaeReconstructionTesterCurve.py ├── Tester.py └── UnetSegmentationTester.py ├── train_interpolationstep_after_reconstruction.py ├── train_shape_prediction.py ├── train_shape_reconstruction.py ├── train_shape_reconstruction_with_ctp.py └── train_unet_segmentation.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | ![](animated.gif) 3 | 4 | # stroke-prediction 5 | Stroke infarct growth prediction (3D, PyTorch 0.3) 6 | 7 | ## Objective 8 | Learning to Predict Stroke Infarcted Tissue Outcome based on Multivariate CT Images 9 | 10 | ## Data 11 | The source code is working from within the IMI network at University of Luebeck, as the closed dataset of 29 subjects is only accessable if you are member of the bvstaff group. The filenames have been renamed and cases are represented as a subfolder. CTP modalities CBV and TTD are used as input, corresponding manual segmentations for core and penumbra, as well as follow-up lesion segmentation (FUCTMap). The directory contains more files since the work for the Master's thesis of Linda Aulmann. 12 | 13 | The dataset specified in [data.py](common/data.py) is inherited from [torch.utils.data.Dataset](https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#Dataset), thus can be exchanged with other datasets and loaders (At the moment there are two datasets with different transformations for training and validation). The existing Learners expect 3D pytorch tensors of shape `BxCxDxHxW`, but implementing an own [Learner](learner/Learner.py) will enable the use of 2D data as well. 14 | 15 | ## Setup 16 | Set up a Python 3.5 environment including the packages of [requirements.txt](requirements.txt) file. 17 | 18 | ## Structure of repository 19 | The repository consists of the following subfolders: 20 | - common: contains commonly used files such as [DTOs](common/dto/Dto.py), models to be learned, helper files or [Inference](common/inference/Inference.py) classes 21 | - learner: contains the [Learner](learner/Learner.py) for running the different trainings 22 | - tester: contains the [Tester](tester/Tester.py) for running the different tests (i.e. inference-only) 23 | 24 | Further, there are other important files: 25 | - data.py: defines the dataset as mentioned under section [Data](README.md#data) and contains required transformations 26 | - util.py: contains helper functions 27 | 28 | ## Usage 29 | Activate the above environment under section [Setup](README.md#setup). 30 | 31 | For learning the shape space on the manual segmentations run the following command: 32 | 33 | `train_shape_reconstruction.py ~/tmp/shape_f3.model --lrsteps 200 250 --epochs 300 --outbasepath ~/tmp/shape --channelscae 1 16 24 32 100 200 1 --validsetsize 0.3 --fold 17 6 2 26 11 4 1 21 16 27 24 18 9 22 12 0 3 8 23 25 7 10 19` 34 | 35 | The `--fold` is an arbitrary but fixed list of indices between 0 and 28 to specify a fold out of the 29 dataset subjects, from which a fraction specified by `--validsetsize` will be used as validation data (e.g. for 0.275 and the above fold it means that 17 training and 6 validation cases are used by the Learner). 36 | 37 | Always specify a `--outbasepath` to where files are being saved. This includes the `*.model` file once a new validation minimum has been reached, and `*.png` files that plot the losses, metrics and visualize some samples during the training run: 38 | 39 | ![](sample_output.png) 40 | 41 | Train a Unet with the same fold as specified before, to use the Unet segmentation for further training of an adapted encoder to predict on segmentations of unseen CTP modalities: 42 | 43 | `train_unet_segmentation.py ~/tmp/unet_f3.model --epochs 200 --outbasepath ~/tmp/unet --channels 2 16 32 64 32 16 32 2 --validsetsize 0.275 --fold 17 6 2 26 11 4 1 21 16 27 24 18 9 22 12 0 3 8 23 25 7 10 19` 44 | 45 | The `--channels` arguments specifies the channels used for each layer, incl. input and output. E.g., for the above command the three scales Unet will process the input with 16, downsample and process with 32, and, again, downsample and process with 64 neurons before it is upsampled again with the reverse order of channel numbers per scale. To test the trained Unet on some cases, run: 46 | 47 | `test_unet_segmentation.py ~/tmp/unet_f3.model --outbasepath ~/tmp/tmp --channels 2 16 32 64 32 16 32 2 --fold 5 13 14 15 20 28` 48 | 49 | For comparison pruposes, you can run a shape interpolation via signed distance maps: 50 | 51 | `sdm_resampling.py /share/data_zoe1/lucas/Linda_Segmentations/tmp/tmp_unet_f3.model --fold 22 --downsample 0 --groundtruth 1` 52 | 53 | ## Experimental setup 54 | 55 | The experiments in the article "[Learning to predict ischemic stroke growth on acute CT perfusion data by interpolating low-dimensional shape representations](https://www.frontiersin.org/articles/10.3389/fneur.2018.00989/)" have been conducted with the following parameters (command for fold 5): 56 | 57 | `python train_shape_reconstruction.py --channelscae 1 16 24 32 100 800 1 --outbasepath /tmp/shape_f5 --validsetsize 0.275 --epochs 200 --fold 17 6 2 26 11 4 1 21 16 27 24 18 15 20 28 14 5 13 9 22 12 0 3 8` 58 | 59 | ## Literature 60 | Listed under: https://www.researchgate.net/project/Learning-to-predict-stroke-outcome-on-multivariate-CT-data 61 | -------------------------------------------------------------------------------- /animated.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/stroke-prediction/58da5be2c16637d47587cb09ac87ddebf028e093/animated.gif -------------------------------------------------------------------------------- /common/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import nibabel as nib 4 | import random 5 | import datetime 6 | 7 | import torch 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | 12 | import numpy as np 13 | import scipy.ndimage as ndi 14 | from scipy.ndimage.interpolation import map_coordinates 15 | from scipy.ndimage.filters import gaussian_filter 16 | 17 | 18 | KEY_CASE_ID = 'case_id' 19 | KEY_CLINICAL_IDX = 'clinical_idx' 20 | KEY_IMAGES = 'images' 21 | KEY_LABELS = 'labels' 22 | KEY_GLOBAL = 'clinical' 23 | 24 | DIM_HORIZONTAL_NUMPY_3D = 0 25 | DIM_DEPTH_NUMPY_3D = 2 26 | DIM_CHANNEL_NUMPY_3D = 3 27 | DIM_CHANNEL_TORCH3D_5 = 1 28 | 29 | 30 | class StrokeLindaDataset3D(Dataset): 31 | """Ischemic stroke dataset with CBV, TTD, clinical data, and CBVmap, TTDmap, FUmap, and interpolations.""" 32 | PATH_ROOT = '/share/data_zoe1/lucas/Linda_Segmentations' 33 | PATH_CSV = '/share/data_zoe1/lucas/Linda_Segmentations/clinical_cleaned.csv' 34 | FN_PREFIX = 'train' 35 | FN_PATTERN = '{1}/{0}{1}{2}.nii.gz' 36 | ROW_OFFSET = 1 37 | COL_OFFSET = 1 38 | 39 | def __init__(self, root_dir=PATH_ROOT, modalities=[], labels=[], clinical=PATH_CSV, transform=None, 40 | single_case_id=None): 41 | self._root_dir = root_dir 42 | self._clinical = self._load_clinical_data_from_csv(clinical, row_offset=self.ROW_OFFSET, col_offset=0) 43 | self._transform = transform 44 | self._modalities = modalities 45 | self._labels = labels 46 | 47 | self._item_index_map = [] 48 | for index in range(len(self._clinical)): 49 | case_id = int(self._clinical[index][0]) 50 | if single_case_id is not None and single_case_id != case_id: 51 | continue 52 | self._item_index_map.append({KEY_CASE_ID: case_id, KEY_CLINICAL_IDX: index}) 53 | 54 | def _load_clinical_data_from_csv(self, filename, col_offset=0, row_offset=0): 55 | result = [] 56 | with open(filename, 'r') as f: 57 | rows = csv.reader(f, delimiter=',') 58 | for row in rows: 59 | if row_offset == 0: 60 | result.append(row[col_offset:]) 61 | else: 62 | row_offset -= 1 63 | return result 64 | 65 | def _load_image_data_from_nifti(self, case_id, suffix): 66 | img_name = self.FN_PATTERN.format(self.FN_PREFIX, str(case_id), suffix) 67 | filename = os.path.join(self._root_dir, img_name) 68 | img_data = nib.load(filename).get_data() 69 | return img_data[:, :, :, np.newaxis] 70 | 71 | def __len__(self): 72 | return len(self._item_index_map) 73 | 74 | def __getitem__(self, item): 75 | item_id = self._item_index_map[item] 76 | case_id = item_id[KEY_CASE_ID] 77 | clinical_data = self._clinical[item_id[KEY_CLINICAL_IDX]][1:] 78 | 79 | result = {KEY_CASE_ID: case_id, KEY_IMAGES: [], KEY_LABELS: [], KEY_GLOBAL: []} 80 | 81 | for value in clinical_data: 82 | result[KEY_GLOBAL].append(float(value)) 83 | if result[KEY_GLOBAL]: 84 | result[KEY_GLOBAL] = np.array(result[KEY_GLOBAL]).reshape((1, 1, 1, len(clinical_data))) 85 | 86 | for label in self._labels: 87 | result[KEY_LABELS].append(self._load_image_data_from_nifti(case_id, label)) 88 | if result[KEY_LABELS]: 89 | result[KEY_LABELS] = np.concatenate(result[KEY_LABELS], axis=DIM_CHANNEL_NUMPY_3D) 90 | 91 | for modality in self._modalities: 92 | result[KEY_IMAGES].append(self._load_image_data_from_nifti(case_id, modality)) 93 | if result[KEY_IMAGES]: 94 | result[KEY_IMAGES] = np.concatenate(result[KEY_IMAGES], axis=DIM_CHANNEL_NUMPY_3D) 95 | 96 | if self._transform: 97 | result = self._transform(result) 98 | 99 | return result 100 | 101 | 102 | def emptyCopyFromSample(sample): 103 | result = {KEY_CASE_ID: int(sample[KEY_CASE_ID]), KEY_IMAGES: [], KEY_LABELS: [], KEY_GLOBAL: []} 104 | return result 105 | 106 | 107 | def set_np_seed(workerid): 108 | torch_seed = torch.initial_seed() 109 | numpy_seed = torch_seed % np.iinfo(np.int32).max 110 | np.random.seed(numpy_seed) 111 | 112 | 113 | def split_data_loader3D(modalities, labels, indices, batch_size, random_seed=None, valid_size=0.5, shuffle=True, 114 | num_workers=4, pin_memory=False, train_transform=[], valid_transform=[]): 115 | assert ((valid_size >= 0) and (valid_size <= 1)), "[!] valid_size should be in the range [0, 1]." 116 | assert train_transform, "You must provide at least a numpy-to-torch transformation." 117 | assert valid_transform, "You must provide at least a numpy-to-torch transformation." 118 | 119 | # load the dataset 120 | dataset_train = StrokeLindaDataset3D(modalities=modalities, labels=labels, 121 | transform=transforms.Compose(train_transform)) 122 | dataset_valid = StrokeLindaDataset3D(modalities=modalities, labels=labels, 123 | transform=transforms.Compose(valid_transform)) 124 | 125 | items = list(set(range(len(dataset_train))).intersection(set(indices))) 126 | num_train = len(items) 127 | split = int(np.floor(valid_size * num_train)) 128 | 129 | if shuffle == True: 130 | random_state = np.random.RandomState(random_seed) 131 | random_state.shuffle(items) 132 | 133 | train_idx, valid_idx = items[split:], items[:split] 134 | 135 | train_sampler = SubsetRandomSampler(train_idx) 136 | valid_sampler = SubsetRandomSampler(valid_idx) 137 | 138 | train_loader = DataLoader(dataset_train, 139 | batch_size=batch_size, sampler=train_sampler, 140 | num_workers=num_workers, pin_memory=pin_memory, 141 | worker_init_fn=set_np_seed) 142 | 143 | valid_loader = DataLoader(dataset_valid, 144 | batch_size=batch_size, sampler=valid_sampler, 145 | num_workers=num_workers, pin_memory=pin_memory) 146 | 147 | return (train_loader, valid_loader) 148 | 149 | 150 | def single_data_loader3D(modalities, labels, indices, batch_size, random_seed=None, valid_size=0.5, shuffle=True, 151 | num_workers=4, pin_memory=False, train_transform=[]): 152 | assert ((valid_size >= 0) and (valid_size <= 1)), "[!] valid_size should be in the range [0, 1]." 153 | assert train_transform, "You must provide at least a numpy-to-torch transformation." 154 | 155 | # load the dataset 156 | dataset_train = StrokeLindaDataset3D(modalities=modalities, labels=labels, 157 | transform=transforms.Compose(train_transform)) 158 | 159 | items = list(set(range(len(dataset_train))).intersection(set(indices))) 160 | 161 | if shuffle == True: 162 | random_state = np.random.RandomState(random_seed) 163 | random_state.shuffle(items) 164 | 165 | train_sampler = SubsetRandomSampler(items) 166 | 167 | train_loader = DataLoader(dataset_train, 168 | batch_size=batch_size, sampler=train_sampler, 169 | num_workers=num_workers, pin_memory=pin_memory, 170 | worker_init_fn=set_np_seed) 171 | 172 | return train_loader 173 | 174 | 175 | def get_stroke_shape_training_data(modalities, labels, train_transform, valid_transform, fold_indices, ratio, seed=4, 176 | batchsize=2, split=True): 177 | if split: 178 | return split_data_loader3D(modalities, labels, fold_indices, batchsize, random_seed=seed, 179 | valid_size=ratio, train_transform=train_transform, 180 | valid_transform=valid_transform, num_workers=0) 181 | return single_data_loader3D(modalities, labels, fold_indices, batchsize, random_seed=seed, 182 | valid_size=ratio, train_transform=train_transform, num_workers=0), None 183 | 184 | 185 | def get_stroke_prediction_training_data(modalities, labels, train_transform, valid_transform, fold_indices, ratio, 186 | seed=4, batchsize=2, split=True): 187 | if split: 188 | return split_data_loader3D(modalities, labels, fold_indices, batchsize, random_seed=seed, 189 | valid_size=ratio, train_transform=train_transform, 190 | valid_transform=valid_transform, num_workers=0) 191 | return single_data_loader3D(modalities, labels, fold_indices, batchsize, random_seed=seed, 192 | valid_size=ratio, train_transform=train_transform, num_workers=0), None 193 | 194 | 195 | def get_testdata(modalities, labels, indices, random_seed=None, shuffle=True, num_workers=4, pin_memory=False, 196 | transform=[]): 197 | assert transform, "You must provide at least a numpy-to-torch transformation." 198 | 199 | dataset = StrokeLindaDataset3D(modalities=modalities, labels=labels, transform=transforms.Compose(transform)) 200 | 201 | items = list(set(range(len(dataset))).intersection(set(indices))) 202 | 203 | if shuffle == True: 204 | random_state = np.random.RandomState(random_seed) 205 | random_state.shuffle(items) 206 | 207 | sampler = SubsetRandomSampler(items) 208 | 209 | loader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, 210 | worker_init_fn=set_np_seed) # important to have batchsize=1 because metrics is computed on batch 211 | 212 | return loader 213 | 214 | 215 | class HemisphericFlipFixedToCaseId(object): 216 | """Flip numpy images along X-axis.""" 217 | 218 | def __init__(self, split_id): 219 | self.split_id = split_id 220 | 221 | def __call__(self, sample): 222 | if int(sample[KEY_CASE_ID]) > self.split_id: 223 | result = emptyCopyFromSample(sample) 224 | if sample[KEY_IMAGES] != []: 225 | result[KEY_IMAGES] = np.flip(sample[KEY_IMAGES], DIM_HORIZONTAL_NUMPY_3D).copy() 226 | if sample[KEY_LABELS] != []: 227 | result[KEY_LABELS] = np.flip(sample[KEY_LABELS], DIM_HORIZONTAL_NUMPY_3D).copy() 228 | if sample[KEY_GLOBAL] != []: 229 | result[KEY_GLOBAL] = np.flip(sample[KEY_GLOBAL], DIM_HORIZONTAL_NUMPY_3D).copy() 230 | return result 231 | return sample 232 | 233 | 234 | class HemisphericFlip(object): 235 | """Flip numpy images along X-axis.""" 236 | def __call__(self, sample): 237 | if random.random() > 0.5: 238 | result = emptyCopyFromSample(sample) 239 | if sample[KEY_IMAGES] != []: 240 | result[KEY_IMAGES] = np.flip(sample[KEY_IMAGES], DIM_HORIZONTAL_NUMPY_3D).copy() 241 | if sample[KEY_LABELS] != []: 242 | result[KEY_LABELS] = np.flip(sample[KEY_LABELS], DIM_HORIZONTAL_NUMPY_3D).copy() 243 | if sample[KEY_GLOBAL] != []: 244 | result[KEY_GLOBAL] = np.flip(sample[KEY_GLOBAL], DIM_HORIZONTAL_NUMPY_3D).copy() 245 | return result 246 | return sample 247 | 248 | 249 | class RandomPatch(object): 250 | """Random patches of certain size.""" 251 | def __init__(self, w, h, d, pad_x, pad_y, pad_z): 252 | self._padx = pad_x 253 | self._pady = pad_y 254 | self._padz = pad_z 255 | self._w = w 256 | self._h = h 257 | self._d = d 258 | 259 | def __call__(self, sample): 260 | sx, sy, sz, _ = sample[KEY_IMAGES].shape 261 | 262 | rand_x = random.randint(0, sx - self._w) 263 | rand_y = random.randint(0, sy - self._h) 264 | rand_z = random.randint(0, sz - self._d) 265 | 266 | result = emptyCopyFromSample(sample) 267 | if sample[KEY_IMAGES] != []: 268 | result[KEY_IMAGES] = sample[KEY_IMAGES][rand_x: rand_x + self._w, 269 | rand_y: rand_y + self._h, 270 | rand_z: rand_z + self._d, :] 271 | if sample[KEY_LABELS] != []: 272 | result[KEY_LABELS] = sample[KEY_LABELS][rand_x: rand_x + self._w - 2 * self._padx, 273 | rand_y: rand_y + self._h - 2 * self._pady, 274 | rand_z: rand_z + self._d - 2 * self._padz, :] 275 | result[KEY_GLOBAL] = sample[KEY_GLOBAL] 276 | 277 | return result 278 | 279 | 280 | class PadImages(object): 281 | """Pad images with constant pad_value in all 6 directions (3D).""" 282 | def __init__(self, pad_x, pad_y, pad_z, pad_value=0): 283 | self._padx = pad_x 284 | self._pady = pad_y 285 | self._padz = pad_z 286 | self._pad_value = float(pad_value) 287 | 288 | def __call__(self, sample): 289 | sx, sy, sz, sc = sample[KEY_IMAGES].shape 290 | result = emptyCopyFromSample(sample) 291 | if sample[KEY_IMAGES] != []: 292 | result[KEY_IMAGES] = np.ones((sx + 2 * self._padx, sy + 2 * self._pady, sz + 2 * self._padz, sc), dtype=np.float32) * self._pad_value 293 | result[KEY_IMAGES][self._padx:-self._padx, self._pady:-self._pady, self._padz:-self._padz, :] = sample[KEY_IMAGES] 294 | result[KEY_LABELS] = sample[KEY_LABELS] 295 | result[KEY_GLOBAL] = sample[KEY_GLOBAL] 296 | return result 297 | 298 | 299 | class ToTensor(object): 300 | """Convert ndarrays in sample to Tensors.""" 301 | 302 | def __call__(self, sample): 303 | result = emptyCopyFromSample(sample) 304 | if sample[KEY_IMAGES] != []: 305 | result[KEY_IMAGES] = torch.from_numpy(sample[KEY_IMAGES]).permute(3, 2, 1, 0) 306 | if sample[KEY_LABELS] != []: 307 | result[KEY_LABELS] = torch.from_numpy(sample[KEY_LABELS]).permute(3, 2, 1, 0) 308 | if sample[KEY_GLOBAL] != []: 309 | result[KEY_GLOBAL] = torch.from_numpy(sample[KEY_GLOBAL]).permute(3, 2, 1, 0) 310 | return result 311 | 312 | 313 | class ElasticDeform(object): 314 | """Elastic deformation of images as described in [Simard2003] 315 | Simard, Steinkraus and Platt, "Best Practices for Convolutional 316 | Neural Networks applied to Visual Document Analysis", in Proc. 317 | of the International Conference on Document Analysis and 318 | Recognition, 2003. 319 | """ 320 | 321 | def __init__(self, alpha=100, sigma=4, apply_to_images=False): 322 | self._alpha = alpha 323 | self._sigma = sigma 324 | self._apply_to_images = apply_to_images 325 | 326 | def elastic_transform(self, image, alpha=100, sigma=4, random_state=None): 327 | new_seed = datetime.datetime.now().second + datetime.datetime.now().microsecond 328 | if random_state is None: 329 | random_state = np.random.RandomState(new_seed) 330 | 331 | shape = image.shape 332 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 333 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 334 | dz = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha * 0.22 # 28/128 TODO: correct according to voxel spacing 335 | 336 | x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2])) 337 | indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z + dz, (-1, 1)) 338 | 339 | return map_coordinates(image, indices, order=1).reshape(shape), random_state 340 | 341 | def __call__(self, sample): 342 | sample[KEY_LABELS][:, :, :, 0], random_state = self.elastic_transform(sample[KEY_LABELS][:, :, :, 0], 343 | self._alpha, self._sigma) 344 | for c in range(1, sample[KEY_LABELS].shape[3]): 345 | sample[KEY_LABELS][:, :, :, c], _ = self.elastic_transform(sample[KEY_LABELS][:, :, :, c], self._alpha, 346 | self._sigma, random_state=random_state) 347 | if self._apply_to_images and sample[KEY_IMAGES] != []: 348 | for c in range(sample[KEY_IMAGES].shape[3]): 349 | sample[KEY_IMAGES][:, :, :, c], _ = self.elastic_transform(sample[KEY_IMAGES][:, :, :, c], self._alpha, 350 | self._sigma, random_state=random_state) 351 | return sample 352 | 353 | 354 | class ResamplePlaneXY(object): 355 | """Down- or upsample images.""" 356 | def __init__(self, scale_factor=1, mode='nearest'): 357 | self._scale_factor = scale_factor 358 | if mode == 'bilinear': 359 | self._order = 1 360 | else: 361 | self._order = 0 362 | 363 | def __call__(self, sample): 364 | result = emptyCopyFromSample(sample) 365 | result[KEY_GLOBAL] = sample[KEY_GLOBAL] 366 | 367 | if sample[KEY_IMAGES] != []: 368 | sx, sy = ndi.zoom(sample[KEY_IMAGES][:, :, 0], self._scale_factor, order=0).shape[0:2] # just for init 369 | result[KEY_IMAGES] = sample[KEY_IMAGES][:sx, :sy, :, :] # just for init correctly sized array with random values 370 | for c in range(sample[KEY_IMAGES].shape[DIM_CHANNEL_NUMPY_3D]): 371 | for z in range(sample[KEY_IMAGES].shape[DIM_DEPTH_NUMPY_3D]): 372 | result[KEY_IMAGES][:, :, z, c] = ndi.zoom(sample[KEY_IMAGES][:, :, z, c], self._scale_factor, order=self._order) 373 | 374 | if sample[KEY_LABELS] != []: 375 | sx, sy = ndi.zoom(sample[KEY_LABELS][:, :, 0], self._scale_factor, order=0).shape[0:2] # just for init 376 | result[KEY_LABELS] = sample[KEY_LABELS][:sx, :sy, :, :] # just for init correctly sized array with random values 377 | for c in range(sample[KEY_LABELS].shape[DIM_CHANNEL_NUMPY_3D]): 378 | for z in range(sample[KEY_LABELS].shape[DIM_DEPTH_NUMPY_3D]): 379 | result[KEY_LABELS][:, :, z, c] = ndi.zoom(sample[KEY_LABELS][:, :, z, c], self._scale_factor, order=self._order) 380 | 381 | return result -------------------------------------------------------------------------------- /common/dto/CaeDto.py: -------------------------------------------------------------------------------- 1 | from common.dto.Dto import Dto 2 | 3 | FLAG_DEFAULT = 'default' 4 | FLAG_GTRUTH = 'gtruth' 5 | FLAG_INPUTS = 'inputs' 6 | 7 | class CaeDto(Dto): 8 | """ DTO for CAE usage. 9 | """ 10 | def __init__(self, given_variables: Dto, latents: Dto, reconstructions: Dto): 11 | super().__init__() 12 | self.given_variables = given_variables 13 | self.reconstructions = latents 14 | self.latents = reconstructions 15 | self.flag = FLAG_DEFAULT 16 | 17 | 18 | def init_dto(global_variables, time_to_treatment, type_core, type_penumbra, inputs_core, inputs_penu, 19 | gtruth_core, gtruth_penumbra, gtruth_lesion): 20 | """ 21 | Inits a CaeDto with the given variables. 22 | :param global_variables: global clinical scalar variables, such as age et cetera 23 | :param time_to_treatment: global clinical scalar variable time_to_treatment 24 | :param type_core: aux value to represent core 25 | :param type_penumbra: aux value to represent penumbra 26 | :param inputs_core: input data for core, can be CTP CBV image, or its segmentation from Unet 27 | :param inputs_penu: input data for penu, can be CTP TTD image, or its segmentation from Unet 28 | :param gtruth_core: manual segmentation mask for core 29 | :param gtruth_penumbra: manual segmentation mask for penumbra 30 | :param gtruth_lesion: manual segmentation mask for follow-up lesion 31 | :return: CaeDto 32 | """ 33 | 34 | given_variables = Dto(globals=global_variables, 35 | time_to_treatment=time_to_treatment, 36 | scalar_types=Dto(core=type_core, penu=type_penumbra), 37 | inputs=Dto(core=inputs_core, penu=inputs_penu), 38 | gtruth=Dto(core=gtruth_core, penu=gtruth_penumbra, lesion=gtruth_lesion)) 39 | 40 | latents = Dto(inputs=Dto(core=None, penu=None, interpolation=None), 41 | gtruth=Dto(core=None, penu=None, interpolation=None, lesion=None)) 42 | 43 | reconstructions = Dto(inputs=Dto(core=None, penu=None, interpolation=None), 44 | gtruth=Dto(core=None, penu=None, interpolation=None, lesion=None)) 45 | 46 | return CaeDto(given_variables=given_variables, latents=latents, reconstructions=reconstructions) 47 | -------------------------------------------------------------------------------- /common/dto/Dto.py: -------------------------------------------------------------------------------- 1 | class Dto(): 2 | """ Data Transfer Object. 3 | Usually not required here, but makes it easier for 4 | passing arguments and consistent naming of variables. 5 | Allows to iter through its members. 6 | """ 7 | def __init__(self, **kwargs): 8 | self.__dict__ = kwargs 9 | 10 | def __iter__(self): 11 | for attr, value in self.__dict__.items(): 12 | yield attr, value 13 | 14 | def __str__(self, indent=None): 15 | """ 16 | Indicates the fill level, i.e. which attributes 17 | have non-None values 18 | :param indent: indent when print out the result 19 | :return: str representation of the fill level 20 | """ 21 | result = '' 22 | if indent is None: 23 | result += 'Fill level of ' + super().__str__() + ':\n' 24 | indent = '' 25 | for key in sorted(self.__dict__.keys()): 26 | txt = '[ ]' 27 | val = self.__dict__[key] 28 | if val is not None: 29 | txt = '[x]' 30 | result += indent + txt + ' ' + key + '\n' 31 | if isinstance(val, Dto): 32 | result += val.__str__(indent=(indent + ' ')) 33 | return result 34 | 35 | 36 | def _is_empty(self): 37 | for key in sorted(self.__dict__.keys()): 38 | val = self.__dict__[key] 39 | if val is not None: 40 | if isinstance(val, Dto): 41 | val._is_empty() 42 | else: 43 | return False 44 | return True 45 | -------------------------------------------------------------------------------- /common/dto/MetricMeasuresDto.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from common.dto.Dto import Dto 3 | 4 | 5 | class MeasuresDto(Dto): 6 | def add(self, other): 7 | if isinstance(other, type(self)): 8 | for attr, value in other: 9 | if self.__dict__[attr] is None: 10 | self.__dict__[attr] = value 11 | elif isinstance(value, MeasuresDto): 12 | self.__dict__[attr].add(value) 13 | else: 14 | self.__dict__[attr] += value 15 | else: 16 | raise Exception('A' + str(type(self)) + 'must be added') 17 | 18 | def div(self, divisor): 19 | for attr, value in self: 20 | if value is not None and value != numpy.Inf: 21 | if isinstance(value, MeasuresDto): 22 | self.__dict__[attr].div(divisor) 23 | else: 24 | self.__dict__[attr] = value / divisor 25 | 26 | 27 | class BinaryMeasuresDto(MeasuresDto): 28 | """ DTO for the metric measures on binary images. 29 | """ 30 | def __init__(self, dc, hd, assd, precision, sensitivity, specificity): 31 | super().__init__() 32 | self.dc = dc 33 | self.hd = hd 34 | self.assd = assd 35 | self.precision = precision 36 | self.sensitivity = sensitivity # Recall 37 | self.specificity = specificity 38 | 39 | @property 40 | def prc_euclidean_distance(self): 41 | """Computes the distance to top-right corner (1,1) 42 | supposed to be ideal in the precision recall plot. 43 | Consequently, aim to minimize the distance 44 | :return: euclidean distance to top-right corner 45 | """ 46 | return numpy.sqrt((1-self.precision)**2 + (1-self.sensitivity)**2) 47 | 48 | 49 | class MetricMeasuresDto(MeasuresDto): 50 | """ DTO for all evaluation metric measures. 51 | """ 52 | def __init__(self, loss, core:BinaryMeasuresDto, penu:BinaryMeasuresDto, lesion:BinaryMeasuresDto): 53 | super().__init__() 54 | self.loss = loss 55 | self.core = core 56 | self.penu = penu 57 | self.lesion = lesion 58 | 59 | 60 | def init_dto(loss=None, core_dc=None, core_hd=None, core_assd=None, 61 | penu_dc=None, penu_hd=None, penu_assd=None, 62 | lesion_dc=None, lesion_hd=None, lesion_assd=None, 63 | lesion_precision=None, lesion_sensitivity=None, 64 | lesion_specificity=None): 65 | """ 66 | Inits a MetricMeasuresDto with the evaluation measures. 67 | :return: MetricMeasuresDto 68 | """ 69 | 70 | core = BinaryMeasuresDto(core_dc, core_hd, core_assd, None, None, None) 71 | penu = BinaryMeasuresDto(penu_dc, penu_hd, penu_assd, None, None, None) 72 | lesion = BinaryMeasuresDto(lesion_dc, lesion_hd, lesion_assd, lesion_precision, lesion_sensitivity, 73 | lesion_specificity) 74 | 75 | return MetricMeasuresDto(loss, core, penu, lesion) 76 | -------------------------------------------------------------------------------- /common/dto/UnetDto.py: -------------------------------------------------------------------------------- 1 | from common.dto.Dto import Dto 2 | 3 | 4 | class UnetDto(Dto): 5 | """ DTO for Unet usage. 6 | """ 7 | def __init__(self, given_variables: Dto, outputs: Dto): 8 | super().__init__() 9 | self.given_variables = given_variables 10 | self.outputs = outputs 11 | 12 | 13 | def init_dto(input_modalities, gtruth_core=None, gtruth_penumbra=None, gtruth_lesion=None): 14 | """ 15 | Inits a UnetDto with the given variables. 16 | :param input_modalities: CTP input modalities 17 | :param gtruth_core: manual segmentation mask for core 18 | :param gtruth_penumbra: manual segmentation mask for penumbra 19 | :param gtruth_lesion: manual segmentation mask for follow-up lesion 20 | :return: UnetDto 21 | """ 22 | 23 | given_variables = Dto(input_modalities=input_modalities, core=gtruth_core, penu=gtruth_penumbra, 24 | lesion=gtruth_lesion) 25 | 26 | outputs = Dto(core=None, penu=None, lesion=None) 27 | 28 | return UnetDto(given_variables=given_variables, outputs=outputs) 29 | -------------------------------------------------------------------------------- /common/inference/CaeEncInference.py: -------------------------------------------------------------------------------- 1 | from common.model.Cae3D import Cae3D, Enc3D 2 | from common.inference.CaeInference import CaeInference 3 | from common.dto.CaeDto import CaeDto 4 | import common.dto.CaeDto as CaeDtoUtil 5 | from torch.autograd import Variable 6 | from common import data 7 | 8 | 9 | class CaeEncInference(CaeInference): 10 | """Common inference for training and testing, 11 | i.e. feed-forward of CAE and the previous Encoder 12 | """ 13 | def __init__(self, model:Cae3D, new_enc:Enc3D, normalization_hours_penumbra = 10): 14 | CaeInference.__init__(self, model, normalization_hours_penumbra) 15 | self._new_enc = new_enc 16 | 17 | def infer(self, dto: CaeDto): 18 | pass 19 | 20 | def init_unet_segm_variables(self, batch: dict, dto: CaeDto): 21 | unet_core = Variable(batch[data.KEY_IMAGES][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 22 | unet_penu = Variable(batch[data.KEY_IMAGES][:, 1, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 23 | if self.is_cuda: 24 | unet_core = unet_core.cuda() 25 | unet_penu = unet_penu.cuda() 26 | dto.given_variables.inputs.core = unet_core 27 | dto.given_variables.inputs.penu = unet_penu 28 | return dto 29 | 30 | def inference_step(self, batch: dict, step=None): 31 | dto = self.init_clinical_variables(batch, step) 32 | 33 | dto.mode = CaeDtoUtil.FLAG_INPUTS 34 | dto = self.init_unet_segm_variables(batch, dto) 35 | dto = self._new_enc(dto) 36 | dto = self._model.dec(dto) 37 | 38 | dto.mode = CaeDtoUtil.FLAG_GTRUTH 39 | dto = self.init_gtruth_segm_variables(batch, dto) 40 | dto = self._model(dto) 41 | 42 | return dto 43 | -------------------------------------------------------------------------------- /common/inference/CaeInference.py: -------------------------------------------------------------------------------- 1 | from common.model.Cae3D import Cae3D 2 | from common.inference.Inference import Inference 3 | from torch.autograd import Variable 4 | import common.dto.CaeDto as CaeDtoUtil 5 | import torch 6 | from common import data 7 | from common.dto.CaeDto import CaeDto 8 | 9 | 10 | class CaeInference(Inference): 11 | """Common inference for training and testing, 12 | i.e. feed-forward of CAE 13 | """ 14 | def __init__(self, model:Cae3D, normalization_hours_penumbra = 10): 15 | Inference.__init__(self, model) 16 | self._normalization_hours_penumbra = normalization_hours_penumbra 17 | 18 | def _get_normalization(self, batch): 19 | to_to_ta = batch[data.KEY_GLOBAL][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5).type(torch.FloatTensor) 20 | normalization = torch.ones(to_to_ta.size()[0], 1).type(torch.FloatTensor) * \ 21 | self._normalization_hours_penumbra - to_to_ta.squeeze().unsqueeze(data.DIM_CHANNEL_TORCH3D_5) 22 | return normalization 23 | 24 | def get_time_to_treatment(self, batch, global_variables, step): 25 | normalization = self._get_normalization(batch) 26 | if step is None: 27 | ta_to_tr = batch[data.KEY_GLOBAL][:, 1, :, :, :].squeeze().unsqueeze(data.DIM_CHANNEL_TORCH3D_5) 28 | time_to_treatment = Variable(ta_to_tr.type(torch.FloatTensor) / normalization) 29 | else: 30 | time_to_treatment = Variable((step * torch.ones(global_variables.size()[0], 1)) / normalization) 31 | return time_to_treatment.unsqueeze(2).unsqueeze(3).unsqueeze(4) 32 | 33 | def init_clinical_variables(self, batch: dict, step): 34 | globals_incl_time = Variable(batch[data.KEY_GLOBAL].type(torch.FloatTensor)) 35 | type_core = Variable(torch.zeros(globals_incl_time.size()[0], 1, 1, 1, 1)) 36 | type_penumbra = Variable(torch.ones(globals_incl_time.size()[0], 1, 1, 1, 1)) 37 | time_to_treatment = self.get_time_to_treatment(batch, globals_incl_time, step) 38 | 39 | if self.is_cuda: 40 | if time_to_treatment is not None: 41 | time_to_treatment = time_to_treatment.cuda() 42 | globals_incl_time = globals_incl_time.cuda() 43 | type_core = type_core.cuda() 44 | type_penumbra = type_penumbra.cuda() 45 | 46 | return CaeDtoUtil.init_dto(globals_incl_time, time_to_treatment, 47 | type_core, type_penumbra, None, None, None, None, None) 48 | 49 | def init_gtruth_segm_variables(self, batch: dict, dto: CaeDto): 50 | core_gt = Variable(batch[data.KEY_LABELS][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 51 | penu_gt = Variable(batch[data.KEY_LABELS][:, 1, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 52 | lesion_gt = Variable(batch[data.KEY_LABELS][:, 2, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 53 | if self.is_cuda: 54 | core_gt = core_gt.cuda() 55 | penu_gt = penu_gt.cuda() 56 | lesion_gt = lesion_gt.cuda() 57 | dto.given_variables.gtruth.core = core_gt 58 | dto.given_variables.gtruth.penu = penu_gt 59 | dto.given_variables.gtruth.lesion = lesion_gt 60 | return dto 61 | 62 | def infer(self, dto: CaeDto): 63 | return self._model(dto) 64 | 65 | def inference_step(self, batch: dict, step=None): 66 | dto = self.init_clinical_variables(batch, step) 67 | dto.mode = CaeDtoUtil.FLAG_GTRUTH 68 | dto = self.init_gtruth_segm_variables(batch, dto) 69 | return self.infer(dto) 70 | -------------------------------------------------------------------------------- /common/inference/Inference.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Inference(): 5 | """Base class for all classes that require model inference. 6 | """ 7 | IMSHOW_VMAX_CBV = 12 8 | IMSHOW_VMAX_TTD = 40 9 | FN_VIS_BASE = '_visual_' 10 | INFERENCE_INITALIZED = False 11 | 12 | @abstractmethod 13 | def __init__(self, model): 14 | if not self.INFERENCE_INITALIZED: 15 | self._model = model 16 | self.INFERENCE_INITALIZED = True 17 | 18 | @abstractmethod 19 | def inference_step(self, batch: dict): 20 | pass 21 | 22 | @property 23 | def is_cuda(self) -> bool: 24 | return next(self._model.parameters()).is_cuda 25 | -------------------------------------------------------------------------------- /common/inference/UnetInference.py: -------------------------------------------------------------------------------- 1 | from common.model.Unet3D import Unet3D 2 | from common.inference.Inference import Inference 3 | from torch.autograd import Variable 4 | import common.dto.UnetDto as UnetDtoUtil 5 | from common import data 6 | 7 | 8 | class UnetInference(Inference): 9 | """Common inference for training and testing, 10 | i.e. feed-forward of Unet 11 | """ 12 | def __init__(self, model:Unet3D): 13 | Inference.__init__(self, model) 14 | 15 | def inference_step(self, batch): 16 | input_modalities = Variable(batch[data.KEY_IMAGES]) 17 | core_gt = Variable(batch[data.KEY_LABELS][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 18 | penu_gt = Variable(batch[data.KEY_LABELS][:, 1, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 19 | 20 | if self.is_cuda: 21 | input_modalities = input_modalities.cuda() 22 | core_gt = core_gt.cuda() 23 | penu_gt = penu_gt.cuda() 24 | 25 | dto = UnetDtoUtil.init_dto(input_modalities, core_gt, penu_gt) 26 | 27 | return self._model(dto) -------------------------------------------------------------------------------- /common/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import medpy.metric.binary as mpm 3 | from common.dto.MetricMeasuresDto import BinaryMeasuresDto 4 | from torch.nn.modules.loss import _Loss as LossModule 5 | from torch.autograd import Variable 6 | 7 | 8 | class BatchDiceLoss(LossModule): 9 | def __init__(self, label_weights, epsilon=0.0000001, dim=1): 10 | super(BatchDiceLoss, self).__init__() 11 | self._epsilon = epsilon 12 | self._dim = dim 13 | self._label_weights = label_weights 14 | print("DICE Loss weights classes' output by", label_weights) 15 | 16 | def forward(self, outputs, targets): 17 | assert targets.shape[self._dim] == len(self._label_weights), \ 18 | 'Ground truth number of labels does not match with label weight vector' 19 | loss = 0.0 20 | for label in range(len(self._label_weights)): 21 | oflat = outputs.narrow(self._dim, label, 1).contiguous().view(-1) 22 | tflat = targets.narrow(self._dim, label, 1).contiguous().view(-1) 23 | assert oflat.size() == tflat.size() 24 | intersection = (oflat * tflat).sum() 25 | numerator = 2.*intersection + self._epsilon 26 | denominator = (oflat * oflat).sum() + (tflat * tflat).sum() + self._epsilon 27 | loss += self._label_weights[label] * (numerator / denominator) 28 | return 1.0 - loss 29 | 30 | 31 | def binary_measures_numpy(result, target, binary_threshold=0.5): 32 | result_binary = (result > binary_threshold).astype(numpy.uint8) 33 | target_binary = (target > binary_threshold).astype(numpy.uint8) 34 | 35 | result = BinaryMeasuresDto(mpm.dc(result_binary, target_binary), 36 | numpy.Inf, 37 | numpy.Inf, 38 | mpm.precision(result_binary, target_binary), 39 | mpm.sensitivity(result_binary, target_binary), 40 | mpm.specificity(result_binary, target_binary)) 41 | 42 | if result_binary.any() and target_binary.any(): 43 | result.hd = mpm.hd(result_binary, target_binary) 44 | result.assd = mpm.assd(result_binary, target_binary) 45 | 46 | return result 47 | 48 | 49 | def binary_measures_torch(result, target, cuda, binary_threshold=0.5): 50 | if cuda: 51 | result = result.cpu() 52 | target = target.cpu() 53 | 54 | if isinstance(result, Variable): 55 | result = result.data 56 | if isinstance(target, Variable): 57 | target = target.data 58 | 59 | result = result.numpy() 60 | target = target.numpy() 61 | 62 | return binary_measures_numpy(result, target, binary_threshold=binary_threshold) -------------------------------------------------------------------------------- /common/model/Cae3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import common.data as data 4 | from common.dto.CaeDto import CaeDto 5 | import common.dto.CaeDto as CaeDtoUtil 6 | 7 | 8 | class CaeBase(nn.Module): 9 | 10 | def __init__(self, size_input_xy=128, size_input_z=28, channels=[1, 16, 32, 64, 128, 1024, 128, 1], n_ch_global=2, 11 | alpha=0.01, inner_xy=12, inner_z=3): 12 | super().__init__() 13 | assert size_input_xy % 4 == 0 and size_input_z % 4 == 0 14 | self.n_ch_origin = channels[1] 15 | self.n_ch_down2x = channels[2] 16 | self.n_ch_down4x = channels[3] 17 | self.n_ch_down8x = channels[4] 18 | self.n_ch_fc = channels[5] 19 | 20 | self._inner_ch = self.n_ch_down8x 21 | self._inner_xy = inner_xy 22 | self._inner_z = inner_z 23 | 24 | self.n_ch_global = n_ch_global 25 | self.n_input = channels[0] 26 | self.n_classes = channels[-1] 27 | self.alpha = alpha 28 | 29 | def freeze(self, freeze=False): 30 | requires_grad = not freeze 31 | for param in self.parameters(): 32 | param.requires_grad = requires_grad 33 | 34 | 35 | class Enc3D(CaeBase): 36 | def __init__(self, size_input_xy, size_input_z, channels, n_ch_global, alpha): 37 | super().__init__(size_input_xy, size_input_z, channels, n_ch_global, alpha, inner_xy=10, inner_z=3) 38 | 39 | self.encoder = nn.Sequential( 40 | nn.BatchNorm3d(self.n_input), 41 | nn.Conv3d(self.n_input, self.n_ch_origin, 3, stride=1, padding=(1, 0, 0)), 42 | nn.ELU(self.alpha, True), 43 | nn.BatchNorm3d(self.n_ch_origin), 44 | nn.Conv3d(self.n_ch_origin, self.n_ch_origin, 3, stride=1, padding=(1, 0, 0)), 45 | nn.ELU(self.alpha, True), 46 | 47 | nn.BatchNorm3d(self.n_ch_origin), 48 | nn.Conv3d(self.n_ch_origin, self.n_ch_down2x, 3, stride=2, padding=1), 49 | nn.ELU(self.alpha, True), 50 | 51 | nn.BatchNorm3d(self.n_ch_down2x), 52 | nn.Conv3d(self.n_ch_down2x, self.n_ch_down2x, 3, stride=1, padding=(1, 0, 0)), 53 | nn.ELU(self.alpha, True), 54 | nn.BatchNorm3d(self.n_ch_down2x), 55 | nn.Conv3d(self.n_ch_down2x, self.n_ch_down2x, 3, stride=1, padding=(1, 0, 0)), 56 | nn.ELU(self.alpha, True), 57 | 58 | nn.BatchNorm3d(self.n_ch_down2x), 59 | nn.Conv3d(self.n_ch_down2x, self.n_ch_down4x, 3, stride=2, padding=1), 60 | nn.ELU(self.alpha, True), 61 | 62 | nn.BatchNorm3d(self.n_ch_down4x), 63 | nn.Conv3d(self.n_ch_down4x, self.n_ch_down4x, 3, stride=1, padding=(1, 0, 0)), 64 | nn.ELU(self.alpha, True), 65 | nn.BatchNorm3d(self.n_ch_down4x), 66 | nn.Conv3d(self.n_ch_down4x, self.n_ch_down4x, 3, stride=1, padding=(1, 0, 0)), 67 | nn.ELU(self.alpha, True), 68 | 69 | nn.BatchNorm3d(self.n_ch_down4x), 70 | nn.Conv3d(self.n_ch_down4x, self.n_ch_down8x, 3, stride=2, padding=0), 71 | nn.ELU(self.alpha, True), 72 | 73 | nn.BatchNorm3d(self.n_ch_down8x), 74 | nn.Conv3d(self.n_ch_down8x, self.n_ch_fc, 3, stride=1, padding=0), 75 | nn.ELU(self.alpha, True), 76 | ) 77 | 78 | def _interpolate(self, latent_core, latent_penu, step): 79 | assert step is not None, 'Step must be given for interpolation!' 80 | if latent_core is None or latent_penu is None: 81 | return None 82 | core_to_penumbra = latent_penu - latent_core 83 | results = [] 84 | for batch_sample in range(step.size()[0]): 85 | results.append( 86 | (latent_core[batch_sample, :, :, :, :] + 87 | step[batch_sample, :, :, :, :] * core_to_penumbra[batch_sample, :, :, :, :]).unsqueeze(0) 88 | ) 89 | return torch.cat(results, dim=0) 90 | 91 | def _forward_single(self, input_image): 92 | if input_image is None: 93 | return None 94 | return self.encoder(input_image) 95 | 96 | def _get_step(self, dto: CaeDto): 97 | step = dto.given_variables.time_to_treatment 98 | return step 99 | 100 | def forward(self, dto: CaeDto): 101 | step = self._get_step(dto) 102 | 103 | if dto.flag == CaeDtoUtil.FLAG_GTRUTH or dto.flag == CaeDtoUtil.FLAG_DEFAULT: 104 | assert dto.latents.gtruth._is_empty() # Don't accidentally overwrite other results by code mistakes 105 | dto.latents.gtruth.core = self._forward_single(dto.given_variables.gtruth.core) 106 | dto.latents.gtruth.penu = self._forward_single(dto.given_variables.gtruth.penu) 107 | dto.latents.gtruth.lesion = self._forward_single(dto.given_variables.gtruth.lesion) 108 | dto.latents.gtruth.interpolation = self._interpolate(dto.latents.gtruth.core, 109 | dto.latents.gtruth.penu, 110 | step) 111 | if dto.flag == CaeDtoUtil.FLAG_INPUTS or dto.flag == CaeDtoUtil.FLAG_DEFAULT: 112 | assert dto.latents.inputs._is_empty() # Don't accidentally overwrite other results by code mistakes 113 | dto.latents.inputs.core = self._forward_single(dto.given_variables.inputs.core) 114 | dto.latents.inputs.penu = self._forward_single(dto.given_variables.inputs.penu) 115 | dto.latents.inputs.interpolation = self._interpolate(dto.latents.inputs.core, 116 | dto.latents.inputs.penu, 117 | step) 118 | return dto 119 | 120 | 121 | class Enc3DStep(Enc3D): 122 | def __init__(self, size_input_xy, size_input_z, channels, n_ch_global, alpha): 123 | super().__init__(size_input_xy, size_input_z, channels, n_ch_global, alpha) 124 | 125 | self.reduce = nn.Sequential( 126 | nn.Conv3d(self.n_ch_global, self.n_ch_global, 1), 127 | nn.ELU(self.alpha, True), 128 | nn.Conv3d(self.n_ch_global, self.n_ch_global // 2, 1), 129 | nn.ELU(self.alpha, True), 130 | ) 131 | 132 | self.step = nn.Conv3d(self.n_ch_global // 2, 1, 1) 133 | torch.nn.init.normal(self.step.weight, 0, 0.001) # crucial and important! 134 | torch.nn.init.normal(self.step.bias, 0.5, 0.01) # crucial and important! 135 | 136 | self.sigmoid = nn.Sigmoid() # slows down learning, but ensures [0,1] range and adds another non-linearity 137 | 138 | def _get_step(self, dto: CaeDto): 139 | step = dto.given_variables.time_to_treatment 140 | if step is None: 141 | step = self.sigmoid(self.step(self.reduce(dto.given_variables.globals))) 142 | return step 143 | 144 | 145 | class Enc3DCtp(Enc3D): 146 | def __init__(self, size_input_xy, size_input_z, channels, n_ch_global, alpha, padding): 147 | Enc3D.__init__(self, size_input_xy, size_input_z, channels, n_ch_global, alpha) 148 | assert channels[0] > 2, 'At least 3 channels required to process input' 149 | self._padding = padding 150 | 151 | def forward(self, dto: CaeDto): 152 | step = self._get_step(dto) 153 | cbv = dto.given_variables.inputs.core[:, :, self._padding[0]:-self._padding[0], 154 | self._padding[1]:-self._padding[1], 155 | self._padding[2]:-self._padding[2]] 156 | ttd = dto.given_variables.inputs.penu[:, :, self._padding[0]:-self._padding[0], 157 | self._padding[1]:-self._padding[1], 158 | self._padding[2]:-self._padding[2]] 159 | if dto.flag == CaeDtoUtil.FLAG_GTRUTH or dto.flag == CaeDtoUtil.FLAG_DEFAULT: 160 | cat_core = torch.cat((dto.given_variables.gtruth.core, cbv, ttd), dim=data.DIM_CHANNEL_TORCH3D_5) 161 | cat_penu = torch.cat((dto.given_variables.gtruth.penu, cbv, ttd), dim=data.DIM_CHANNEL_TORCH3D_5) 162 | cat_lesion = torch.cat((dto.given_variables.gtruth.lesion, cbv, ttd), dim=data.DIM_CHANNEL_TORCH3D_5) 163 | dto.latents.gtruth.core = self._forward_single(cat_core) 164 | dto.latents.gtruth.penu = self._forward_single(cat_penu) 165 | dto.latents.gtruth.lesion = self._forward_single(cat_lesion) 166 | dto.latents.gtruth.interpolation = self._interpolate(dto.latents.gtruth.core, 167 | dto.latents.gtruth.penu, 168 | step) 169 | return dto 170 | 171 | 172 | class Dec3D(CaeBase): 173 | def __init__(self, size_input_xy, size_input_z, channels, n_ch_global, alpha): 174 | super().__init__(size_input_xy, size_input_z, channels, n_ch_global, alpha, inner_xy=10, inner_z=3) 175 | 176 | self.decoder = nn.Sequential( 177 | nn.BatchNorm3d(self.n_ch_fc), 178 | nn.ConvTranspose3d(self.n_ch_fc, self.n_ch_down8x, 3, stride=1, padding=0, output_padding=0), 179 | nn.ELU(alpha, True), 180 | 181 | nn.BatchNorm3d(self.n_ch_down8x), 182 | nn.ConvTranspose3d(self.n_ch_down8x, self.n_ch_down4x, 3, stride=2, padding=0, output_padding=0), 183 | nn.ELU(alpha, True), 184 | 185 | nn.BatchNorm3d(self.n_ch_down4x), 186 | nn.Conv3d(self.n_ch_down4x, self.n_ch_down4x, 3, stride=1, padding=(1, 2, 2)), 187 | nn.ELU(alpha, True), 188 | nn.BatchNorm3d(self.n_ch_down4x), 189 | nn.Conv3d(self.n_ch_down4x, self.n_ch_down2x, 3, stride=1, padding=(1, 2, 2)), 190 | nn.ELU(alpha, True), 191 | 192 | nn.BatchNorm3d(self.n_ch_down2x), 193 | nn.ConvTranspose3d(self.n_ch_down2x, self.n_ch_down2x, 2, stride=2, padding=0, output_padding=0), 194 | nn.ELU(alpha, True), 195 | 196 | nn.BatchNorm3d(self.n_ch_down2x), 197 | nn.Conv3d(self.n_ch_down2x, self.n_ch_down2x, 3, stride=1, padding=(1, 2, 2)), 198 | nn.ELU(alpha, True), 199 | nn.BatchNorm3d(self.n_ch_down2x), 200 | nn.Conv3d(self.n_ch_down2x, self.n_ch_origin, 3, stride=1, padding=(1, 2, 2)), 201 | nn.ELU(alpha, True), 202 | 203 | nn.BatchNorm3d(self.n_ch_origin), 204 | nn.ConvTranspose3d(self.n_ch_origin, self.n_ch_origin, 2, stride=2, padding=0, output_padding=0), 205 | nn.ELU(alpha, True), 206 | 207 | nn.BatchNorm3d(self.n_ch_origin), 208 | nn.Conv3d(self.n_ch_origin, self.n_ch_origin, 3, stride=1, padding=(1, 2, 2)), 209 | nn.ELU(alpha, True), 210 | nn.BatchNorm3d(self.n_ch_origin), 211 | nn.Conv3d(self.n_ch_origin, self.n_ch_origin, 3, stride=1, padding=(1, 2, 2)), 212 | nn.ELU(alpha, True), 213 | 214 | nn.BatchNorm3d(self.n_ch_origin), 215 | nn.Conv3d(self.n_ch_origin, self.n_ch_origin, 1, stride=1, padding=0), 216 | nn.ELU(alpha, True), 217 | nn.BatchNorm3d(self.n_ch_origin), 218 | nn.Conv3d(self.n_ch_origin, self.n_classes, 1, stride=1, padding=0), 219 | nn.Sigmoid() 220 | ) 221 | 222 | def _forward_single(self, input_latent): 223 | if input_latent is None: 224 | return None 225 | return self.decoder(input_latent) 226 | 227 | def forward(self, dto: CaeDto): 228 | if dto.flag == CaeDtoUtil.FLAG_GTRUTH or dto.flag == CaeDtoUtil.FLAG_DEFAULT: 229 | assert dto.reconstructions.gtruth._is_empty() # Don't accidentally overwrite other results by code mistakes 230 | dto.reconstructions.gtruth.core = self._forward_single(dto.latents.gtruth.core) 231 | dto.reconstructions.gtruth.penu = self._forward_single(dto.latents.gtruth.penu) 232 | dto.reconstructions.gtruth.lesion = self._forward_single(dto.latents.gtruth.lesion) 233 | dto.reconstructions.gtruth.interpolation = self._forward_single(dto.latents.gtruth.interpolation) 234 | if dto.flag == CaeDtoUtil.FLAG_INPUTS or dto.flag == CaeDtoUtil.FLAG_DEFAULT: 235 | assert dto.reconstructions.inputs._is_empty() # Don't accidentally overwrite other results by code mistakes 236 | dto.reconstructions.inputs.core = self._forward_single(dto.latents.inputs.core) 237 | dto.reconstructions.inputs.penu = self._forward_single(dto.latents.inputs.penu) 238 | dto.reconstructions.inputs.interpolation = self._forward_single(dto.latents.inputs.interpolation) 239 | return dto 240 | 241 | 242 | class Cae3D(nn.Module): 243 | def __init__(self, enc: Enc3D, dec: Dec3D): 244 | super().__init__() 245 | self.enc = enc 246 | self.dec = dec 247 | 248 | def forward(self, dto: CaeDto): 249 | dto = self.enc(dto) 250 | dto = self.dec(dto) 251 | return dto 252 | 253 | def freeze(self, freeze: bool): 254 | self.enc.freeze(freeze) 255 | self.dec.freeze(freeze) 256 | 257 | 258 | class Cae3DCtp(Cae3D): 259 | def __init__(self, enc: Enc3DCtp, dec: Dec3D): 260 | Cae3D.__init__(self, enc, dec) 261 | -------------------------------------------------------------------------------- /common/model/Unet3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from common.dto.UnetDto import UnetDto 4 | 5 | 6 | def crop(tensor_in, crop_as, dims=[]): 7 | assert len(dims) > 0, "Specify dimensions to be cropped" 8 | result = tensor_in 9 | for dim in dims: 10 | result = result.narrow(dim, (tensor_in.size()[dim] - crop_as.size()[dim]) // 2, crop_as.size()[dim]) 11 | return result 12 | 13 | 14 | class Block3x3x3(nn.Module): 15 | def __init__(self, n_input, n_channels): 16 | super(Block3x3x3, self).__init__() 17 | self.bn_conv_relu_2x = nn.Sequential( 18 | nn.BatchNorm3d(n_input), 19 | nn.Conv3d(n_input, n_channels, 3, stride=1, padding=0), 20 | nn.LeakyReLU(0.01, True), 21 | nn.BatchNorm3d(n_channels), 22 | nn.Conv3d(n_channels, n_channels, 3, stride=1, padding=0), 23 | nn.LeakyReLU(0.01, True) 24 | ) 25 | 26 | def forward(self, input_maps): 27 | return self.bn_conv_relu_2x(input_maps) 28 | 29 | 30 | class Unet3D(nn.Module): 31 | def __init__(self, channels=[2, 32, 64, 128, 64, 32, 32, 2], channel_dim=1, channels_crop=[2,3,4]): 32 | super(Unet3D, self).__init__() 33 | n_ch_in, ch_b1, ch_b2, ch_b3, ch_b4, ch_b5, ch_bC, n_classes = channels 34 | 35 | self.channel_dim = channel_dim 36 | self.channels_crop = channels_crop 37 | 38 | self.block1 = Block3x3x3(n_ch_in, ch_b1) 39 | self.pool12 = nn.MaxPool3d(2, 2) 40 | self.block2 = Block3x3x3(ch_b1, ch_b2) 41 | self.pool23 = nn.MaxPool3d(2, 2) 42 | self.block3 = Block3x3x3(ch_b2, ch_b3) 43 | 44 | self.upsa34 = nn.Upsample(scale_factor=2, mode='trilinear') 45 | self.block4 = Block3x3x3(ch_b3 + ch_b2, ch_b4) 46 | self.upsa45 = nn.Upsample(scale_factor=2, mode='trilinear') 47 | self.block5 = Block3x3x3(ch_b4 + ch_b1, ch_b5) 48 | 49 | self.classify = nn.Sequential( 50 | nn.Conv3d(ch_b5, ch_bC, 1, stride=1, padding=0), 51 | nn.LeakyReLU(0.01, True), 52 | nn.Conv3d(ch_bC, n_classes, 1, stride=1, padding=0), 53 | nn.Sigmoid() 54 | ) 55 | 56 | def forward(self, dto: UnetDto): 57 | block1_result = self.block1(dto.given_variables.input_modalities) 58 | 59 | block2_input = self.pool12(block1_result) 60 | block2_result = self.block2(block2_input) 61 | 62 | block3_input = self.pool23(block2_result) 63 | block3_result = self.block3(block3_input) 64 | block3_unpool = self.upsa34(block3_result) 65 | 66 | block2_crop = crop(block2_result, block3_unpool, dims=self.channels_crop) 67 | block4_input = torch.cat((block3_unpool, block2_crop), dim=self.channel_dim) 68 | block4_result = self.block4(block4_input) 69 | block4_unpool = self.upsa45(block4_result) 70 | 71 | block1_crop = crop(block1_result, block4_unpool, dims=self.channels_crop) 72 | block5_input = torch.cat((block4_unpool, block1_crop), dim=self.channel_dim) 73 | block5_result = self.block5(block5_input) 74 | 75 | segmentation = self.classify(block5_result) 76 | dto.outputs.core = segmentation[:, 0, :, :, :].unsqueeze(1) 77 | dto.outputs.penu = segmentation[:, 1, :, :, :].unsqueeze(1) 78 | 79 | return dto 80 | 81 | def freeze(self, freeze=False): 82 | requires_grad = not freeze 83 | for param in self.parameters(): 84 | param.requires_grad = requires_grad 85 | 86 | 87 | class LargeUnet3D(nn.Module): 88 | def __init__(self, channels=[2, 32, 64, 128, 256, 128, 64, 32, 32, 2], channel_dim=1, channels_crop=[2,3,4]): 89 | super(Unet3D, self).__init__() 90 | n_ch_in, ch_b1, ch_b2, ch_b3, ch_b4, ch_b5, ch_b6, ch_b7, ch_bC, n_classes = channels 91 | 92 | self.channel_dim = channel_dim 93 | self.channels_crop = channels_crop 94 | 95 | self.block1 = Block3x3x3(n_ch_in, ch_b1) 96 | self.pool12 = nn.MaxPool3d(2, 2) 97 | self.block2 = Block3x3x3(ch_b1, ch_b2) 98 | self.pool23 = nn.MaxPool3d(2, 2) 99 | self.block3 = Block3x3x3(ch_b2, ch_b3) 100 | self.pool34 = nn.MaxPool3d(2, 2) 101 | self.block4 = Block3x3x3(ch_b3, ch_b4) 102 | 103 | self.upsa45 = nn.Upsample(scale_factor=2, mode='trilinear') 104 | self.block5 = Block3x3x3(ch_b4 + ch_b3, ch_b5) 105 | self.upsa56 = nn.Upsample(scale_factor=2, mode='trilinear') 106 | self.block6 = Block3x3x3(ch_b5 + ch_b2, ch_b6) 107 | self.upsa67 = nn.Upsample(scale_factor=2, mode='trilinear') 108 | self.block7 = Block3x3x3(ch_b6 + ch_b1, ch_b7) 109 | 110 | self.classify = nn.Sequential( 111 | nn.Conv3d(ch_b7, ch_bC, 1, stride=1, padding=0), 112 | nn.LeakyReLU(0.01, True), 113 | nn.Conv3d(ch_bC, n_classes, 1, stride=1, padding=0), 114 | nn.Sigmoid() 115 | ) 116 | 117 | def forward(self, dto: UnetDto): 118 | block1_result = self.block1(dto.given_variables.input_modalities) 119 | 120 | block2_input = self.pool12(block1_result) 121 | block2_result = self.block2(block2_input) 122 | 123 | block3_input = self.pool23(block2_result) 124 | block3_result = self.block3(block3_input) 125 | 126 | block4_input = self.pool34(block3_result) 127 | block4_result = self.block4(block4_input) 128 | block4_unpool = self.upsa45(block4_result) 129 | 130 | block3_crop = crop(block3_result, block4_unpool, dims=self.channels_crop) 131 | block5_input = torch.cat((block4_unpool, block3_crop), dim=self.channel_dim) 132 | block5_result = self.block5(block5_input) 133 | block5_unpool = self.upsa56(block5_result) 134 | 135 | block2_crop = crop(block2_result, block5_unpool, dims=self.channels_crop) 136 | block6_input = torch.cat((block5_unpool, block2_crop), dim=self.channel_dim) 137 | block6_result = self.block6(block6_input) 138 | block6_unpool = self.upsa67(block6_result) 139 | 140 | block1_crop = crop(block1_result, block6_unpool, dims=self.channels_crop) 141 | block7_input = torch.cat((block6_unpool, block1_crop), dim=self.channel_dim) 142 | block7_result = self.block7(block7_input) 143 | 144 | segmentation = self.classify(block7_result) 145 | dto.outputs.core = segmentation[:, 0, :, :, :].unsqueeze(1) 146 | dto.outputs.penu = segmentation[:, 1, :, :, :].unsqueeze(1) 147 | 148 | return dto -------------------------------------------------------------------------------- /common/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from common import data 3 | 4 | 5 | # ======================= DETERMINISTIC DATA =========================== 6 | 7 | 8 | def get_vis_samples(train_loader, valid_loader): 9 | n_vis_samples = 6 10 | visual_samples = [] 11 | visual_times = [] 12 | for i in train_loader.sampler.indices: 13 | sample = train_loader.dataset[i] 14 | sample[data.KEY_IMAGES] = sample[data.KEY_IMAGES].unsqueeze(0) 15 | sample[data.KEY_LABELS] = sample[data.KEY_LABELS].unsqueeze(0) 16 | sample[data.KEY_GLOBAL] = sample[data.KEY_GLOBAL].unsqueeze(0) 17 | visual_samples.append(sample) 18 | tA_to_tR_tmp = sample[data.KEY_GLOBAL][0, 1, :, :, :] 19 | visual_times.append(float(tA_to_tR_tmp)) 20 | if len(visual_samples) > n_vis_samples / 2 - 1: 21 | break; 22 | if valid_loader is not None: 23 | for i in valid_loader.sampler.indices: 24 | sample = valid_loader.dataset[i] 25 | sample[data.KEY_IMAGES] = sample[data.KEY_IMAGES].unsqueeze(0) 26 | sample[data.KEY_LABELS] = sample[data.KEY_LABELS].unsqueeze(0) 27 | sample[data.KEY_GLOBAL] = sample[data.KEY_GLOBAL].unsqueeze(0) 28 | visual_samples.append(sample) 29 | tA_to_tR_tmp = sample[data.KEY_GLOBAL][0, 1, :, :, :] 30 | visual_times.append(float(tA_to_tR_tmp)) 31 | if len(visual_samples) > n_vis_samples - 1: 32 | break; 33 | 34 | return visual_samples, visual_times 35 | 36 | 37 | # =================================== PARSER =========================== 38 | 39 | 40 | class ExpParser(argparse.ArgumentParser): 41 | def __init__(self): 42 | super().__init__() 43 | self.add_argument('--fold', type=int, nargs='+', help='Fold case indices', 44 | default=list(range(29))) # Internal indices, NOT case numbers on disk) 45 | self.add_argument('--hemisflipid', type=float, help='Case id or greater, at which hemispheric flip is applied', 46 | default=15) 47 | self.add_argument('--validsetsize', type=float, help='Fraction of validation set size', default=0.5) 48 | self.add_argument('--seed', type=int, help='Seed for any randomization', default=4) 49 | self.add_argument('--xyoriginal', type=int, help='Original size of slices', default=256) 50 | self.add_argument('--xyresample', type=int, help='Factor for resampling slices', default=0.5) 51 | self.add_argument('--zsize', type=int, help='Number of z slices', default=28) 52 | self.add_argument('--padding', type=int, nargs='+', help='Padding of patches', default=[20, 20, 20]) 53 | self.add_argument('--lrsteps', type=int, nargs='+', help='MultiStepLR epochs', default=[]) 54 | 55 | def parse_args(self, args=None, namespace=None): 56 | args = super().parse_args(args, namespace) 57 | print(args) 58 | return args 59 | 60 | 61 | class CAEParser(ExpParser): 62 | def __init__(self): 63 | super().__init__() 64 | self.add_argument('--epochs', type=int, help='Number of epochs', default=300) 65 | self.add_argument('--batchsize', type=int, help='Batch size', default=4) 66 | self.add_argument('--globals', type=int, help='Number of global variables', default=5) 67 | self.add_argument('--normalize', type=int, help='Normalization corresponding to penumbra (hours)', default=10) 68 | self.add_argument('--inbasepath', type=str, help='Path and filename base for loading', default=None) 69 | self.add_argument('--outbasepath', type=str, help='Path and filename base for saving', default='/tmp/tmp_out') 70 | self.add_argument('--steplearning', action='store_true', help='Also learn interpolation step from clinical data', default=False) 71 | 72 | 73 | class UnetParser(ExpParser): 74 | def __init__(self): 75 | super().__init__() 76 | self.add_argument('unetpath', type=str, help='Path to model of Unet', 77 | default='/share/data_zoe1/Linda_Segmentations/tmp/unet.model') 78 | self.add_argument('--channels', type=int, nargs='+', help='Unet channels', 79 | default=[2, 16, 32, 64, 32, 16, 32, 2]) 80 | self.add_argument('--epochs', type=int, help='Number of epochs', default=200) 81 | self.add_argument('--outbasepath', type=str, help='Path and filename base for outputs', 82 | default='/share/data_zoe1/lucas/Linda_Segmentations/tmp/unet') 83 | 84 | 85 | class SDMParser(ExpParser): 86 | def __init__(self): 87 | super().__init__() 88 | self.add_argument('unet', type=str, help='Path to model of Segmentation Unet', 89 | default='/share/data_zoe1/lucas/unet1dcm.model') 90 | self.add_argument('--channels', type=int, nargs='+', help='Unet channels', 91 | default=[2, 16, 32, 64, 32, 16, 32, 2]) 92 | self.add_argument('--downsample', type=int, help='Downsampling to CAE latent representation size', default=1) 93 | self.add_argument('--groundtruth', type=int, help='Use groundtruth instead of UNet segmentations', default=1) 94 | self.add_argument('--visualinspection', type=int, help='Inspect visually before it is saved', default=0) 95 | self.add_argument('--outbasepath', type=str, help='Path and filename base for outputs', 96 | default='/share/data_zoe1/lucas/Linda_Segmentations/tmp/sdm') 97 | 98 | 99 | def get_args_sdm(): 100 | parser = SDMParser() 101 | args = parser.parse_args() 102 | return args 103 | 104 | 105 | def get_args_shape_training(): 106 | parser = CAEParser() 107 | parser.add_argument('--channelscae', type=int, nargs='+', help='CAE channels', default=[1, 16, 24, 32, 100, 200, 1]) 108 | args = parser.parse_args() 109 | return args 110 | 111 | def get_args_step_training(): 112 | parser = CAEParser() 113 | parser.add_argument('caepath', type=str, help='Path to previously trained cae phase1 model') 114 | parser.add_argument('--channelscae', type=int, nargs='+', help='CAE channels', default=[1, 16, 24, 32, 100, 200, 1]) 115 | args = parser.parse_args() 116 | return args 117 | 118 | 119 | def get_args_shape_prediction_training(): 120 | parser = CAEParser() 121 | parser.add_argument('caepath', type=str, help='Path to previously trained cae phase1 model') 122 | parser.add_argument('--channelsenc', type=int, nargs='+', help='CAE channels', default=[1, 16, 24, 32, 100, 200, 1]) 123 | parser.add_argument('--initbycae', action='store_true', help='Init enc weights by cae\'s enc', default=False) 124 | args = parser.parse_args() 125 | return args 126 | 127 | 128 | def get_args_shape_testing(): 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('--path', action='append', type=str, help='Path to model of Shape CAE') 131 | parser.add_argument('--fold', action='append', type=int, nargs='+', help='Fold case indices') # Internal indices, NOT case numbers on disk) 132 | parser.add_argument('--normalize', type=int, help='Normalization value corresponding to penumbra (hours)', 133 | default=10) 134 | parser.add_argument('--outbasepath', type=str, help='Path and filename base for outputs', 135 | default='/share/data_zoe1/lucas/Linda_Segmentations/tmp/shape') 136 | parser.add_argument('--xyresample', type=int, help='Factor for resampling slices', default=0.5) 137 | parser.add_argument('--padding', type=int, nargs='+', help='Padding of patches', default=[20, 20, 20]) 138 | args = parser.parse_args() 139 | return args 140 | 141 | 142 | def get_args_unet_training(): 143 | parser = UnetParser() 144 | args = parser.parse_args() 145 | return args -------------------------------------------------------------------------------- /learner/CaePredictionLearner.py: -------------------------------------------------------------------------------- 1 | from learner.Learner import Learner 2 | from common.dto.CaeDto import CaeDto 3 | from common.inference.CaeEncInference import CaeEncInference 4 | from common import data, util, metrics 5 | import common.dto.MetricMeasuresDto as MetricMeasuresDtoInit 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | 10 | class CaePredictionLearner(Learner, CaeEncInference): 11 | """ A Learner to train the CAE on the prediction based 12 | on Unet segmentations and compares latent codes with 13 | the previous shape encoder codes. 14 | """ 15 | FN_VIS_BASE = '_cae2_' 16 | FNB_MARKS = '_cae2' 17 | N_EPOCHS_ADAPT_BETA1 = 4 18 | 19 | def __init__(self, dataloader_training, dataloader_validation, cae_model, enc_model, optimizer, scheduler, n_epochs, 20 | path_previous_base, path_outputs_base, criterion, normalization_hours_penumbra=10): 21 | Learner.__init__(self, dataloader_training, dataloader_validation, cae_model, optimizer, scheduler, n_epochs, 22 | path_previous_base, path_outputs_base) 23 | CaeEncInference.__init__(self, cae_model, enc_model, normalization_hours_penumbra) 24 | self._model.freeze(True) 25 | self._criterion = criterion # main loss criterion 26 | 27 | def load_model(self, cuda=True): 28 | Learner.load_model(self, self.is_cuda) 29 | if cuda: 30 | self._new_enc = torch.load(self.path('load', self.FNB_MODEL, '_enc')).cuda() 31 | else: 32 | self._new_enc = torch.load(self.path('load', self.FNB_MODEL, '_enc')) 33 | 34 | def save_model(self, suffix=''): 35 | Learner.save_model(self, suffix) 36 | torch.save(self._new_enc.cpu(), self.path('save', self.FNB_MODEL, '_enc' + suffix)) 37 | self._new_enc.cuda() 38 | 39 | def adapt_betas(self, epoch): 40 | pass 41 | 42 | def loss_step(self, dto: CaeDto, epoch): 43 | loss = 0.0 44 | divd = 6 45 | 46 | diff_penu_fuct = dto.reconstructions.inputs.penu - dto.reconstructions.inputs.interpolation 47 | diff_penu_core = dto.reconstructions.inputs.penu - dto.reconstructions.inputs.core 48 | loss += 1 * torch.mean(torch.abs(diff_penu_fuct) - diff_penu_fuct) 49 | loss += 1 * torch.mean(torch.abs(diff_penu_core) - diff_penu_core) 50 | 51 | loss += 1 * self._criterion(dto.reconstructions.inputs.interpolation, dto.given_variables.gtruth.lesion) 52 | 53 | loss += 1 * torch.mean(torch.abs(dto.latents.gtruth.interpolation - dto.latents.inputs.interpolation)) 54 | loss += 1 * torch.mean(torch.abs(dto.latents.gtruth.core - dto.latents.inputs.core)) 55 | loss += 1 * torch.mean(torch.abs(dto.latents.gtruth.penu - dto.latents.inputs.penu)) 56 | 57 | return loss / divd 58 | 59 | def batch_metrics_step(self, dto: CaeDto, epoch): 60 | batch_metrics = MetricMeasuresDtoInit.init_dto() 61 | batch_metrics.lesion = metrics.binary_measures_torch(dto.reconstructions.gtruth.interpolation, 62 | dto.given_variables.gtruth.lesion, self.is_cuda) 63 | batch_metrics.core = metrics.binary_measures_torch(dto.reconstructions.gtruth.core, 64 | dto.given_variables.gtruth.core, self.is_cuda) 65 | batch_metrics.penu = metrics.binary_measures_torch(dto.reconstructions.gtruth.penu, 66 | dto.given_variables.gtruth.penu, self.is_cuda) 67 | return batch_metrics 68 | 69 | def print_epoch(self, epoch, phase, epoch_metrics): 70 | output = '\nEpoch {}/{} {} loss: {:.3} - DC:{:.3}, HD:{:.3}, ASSD:{:.3}, DC core:{:.3}, DC penu.:{:.3}' 71 | print(output.format(epoch + 1, self._n_epochs, phase, 72 | epoch_metrics.loss, 73 | epoch_metrics.lesion.dc, 74 | epoch_metrics.lesion.hd, 75 | epoch_metrics.lesion.assd, 76 | epoch_metrics.core.dc, 77 | epoch_metrics.penu.dc), end=' ') 78 | 79 | def plot_epoch(self, plot, epochs): 80 | plot.plot(epochs, [dto.loss for dto in self._metric_dtos['training']], 'r-') 81 | plot.plot(epochs, [dto.loss for dto in self._metric_dtos['validate']], 'g-') 82 | plot.plot(epochs, [dto.lesion.dc for dto in self._metric_dtos['validate']], 'k-') 83 | plot.plot(epochs, [dto.core.dc for dto in self._metric_dtos['validate']], 'c+') 84 | plot.plot(epochs, [dto.penu.dc for dto in self._metric_dtos['validate']], 'm+') 85 | plot.set_ylabel('L Train.(red)/Val.(green) | Dice Val. Lesion(b), Core(c), Penu(m)') 86 | plot.set_ylim(0, 1) 87 | ax2 = plot.twinx() 88 | ax2.plot(epochs, [dto.lesion.assd for dto in self._metric_dtos['validate']], 'b-') 89 | ax2.set_ylabel('Validation ASSD (blue)', color='b') 90 | ax2.tick_params('y', colors='b') 91 | 92 | def visualize_epoch(self, epoch): 93 | visual_samples, visual_times = util.get_vis_samples(self._dataloader_training, self._dataloader_validation) 94 | 95 | f, axarr = plt.subplots(len(visual_samples), 15) 96 | inc = 0 97 | for sample, time in zip(visual_samples, visual_times): 98 | 99 | col = 3 100 | for step in [None, -10, -1, 0, 1, 2, 3, 4, 5, 20]: 101 | dto = self.inference_step(sample, step) 102 | axarr[inc, col].imshow(dto.reconstructions.gtruth.interpolation.cpu().data.numpy()[0, 0, 14, :, :], 103 | vmin=0, vmax=1, cmap='gray') 104 | if col == 3: 105 | col += 1 106 | col += 1 107 | 108 | axarr[inc, 0].imshow(sample[data.KEY_IMAGES].numpy()[0, 0, 14, :, :], vmin=0, vmax=1, cmap='gray') 109 | axarr[inc, 1].imshow(sample[data.KEY_IMAGES].numpy()[0, 1, 14, :, :], vmin=0, vmax=1, cmap='gray') 110 | 111 | axarr[inc, 2].imshow(dto.given_variables.gtruth.lesion.cpu().data.numpy()[0, 0, 14, :, :], 112 | vmin=0, vmax=1, cmap='gray') 113 | axarr[inc, 4].imshow(dto.given_variables.gtruth.core.cpu().data.numpy()[0, 0, 14, :, :], 114 | vmin=0, vmax=1, cmap='gray') 115 | axarr[inc, 14].imshow(dto.given_variables.gtruth.penu.cpu().data.numpy()[0, 0, 14, :, :], 116 | vmin=0, vmax=1, cmap='gray') 117 | 118 | del sample 119 | del dto 120 | 121 | titles = ['CBV', 'TTD', 'Lesion', 'p(' + 122 | ('{:03.1f}'.format(float(time))) 123 | + 'h)', 'Core', 'p(-10h)', 'p(-1h)', 'p(0h)', 'p(1h)', 'p(2h)', 'p(3h)', 'p(4h)', 'p(5h)', 124 | 'p(20h)', 125 | 'Penumbra'] 126 | 127 | for ax, title in zip(axarr[inc], titles): 128 | ax.set_title(title) 129 | 130 | inc += 1 131 | 132 | for ax in axarr.flatten(): 133 | ax.title.set_fontsize(3) 134 | ax.xaxis.set_visible(False) 135 | ax.yaxis.set_visible(False) 136 | 137 | f.subplots_adjust(hspace=0.05) 138 | f.savefig(self._path_outputs_base + self.FN_VIS_BASE + str(epoch + 1) + '.png', bbox_inches='tight', dpi=300) 139 | 140 | del f 141 | del axarr 142 | -------------------------------------------------------------------------------- /learner/CaeReconstructionLearner.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | from learner.Learner import Learner 4 | from common.dto.CaeDto import CaeDto 5 | from common.inference.CaeInference import CaeInference 6 | import common.dto.MetricMeasuresDto as MetricMeasuresDtoInit 7 | import matplotlib.pyplot as plt 8 | import torch 9 | from common import data, util, metrics 10 | import numpy 11 | 12 | 13 | class CaeReconstructionLearner(Learner, CaeInference): 14 | """ A Learner to train a CAE on the reconstruction of 15 | shape segmentations. Uses CaeDto data transfer objects. 16 | """ 17 | FN_VIS_BASE = '_cae1_' 18 | FNB_MARKS = '_cae1' 19 | N_EPOCHS_ADAPT_BETA1 = 4 20 | 21 | def __init__(self, dataloader_training, dataloader_validation, cae_model, optimizer, scheduler, n_epochs, 22 | path_previous_base, path_outputs_base, criterion, normalization_hours_penumbra=10): 23 | Learner.__init__(self, dataloader_training, dataloader_validation, cae_model, optimizer, scheduler, n_epochs, 24 | path_previous_base, path_outputs_base) 25 | CaeInference.__init__(self, cae_model, normalization_hours_penumbra) # TODO: refactor double initialization?! 26 | self._criterion = criterion # main loss criterion 27 | 28 | def adapt_betas(self, epoch): 29 | betas = self._optimizer.defaults['betas'] 30 | if epoch < self.N_EPOCHS_ADAPT_BETA1: 31 | betas = list(betas) 32 | betas[0] -= 0.1 * (self.N_EPOCHS_ADAPT_BETA1 - epoch) 33 | betas = tuple(betas) 34 | for param_group in self._optimizer.param_groups: 35 | param_group['betas'] = betas 36 | print('Momentum betas have been set to:', param_group['betas'], end=' ') 37 | elif epoch == self.N_EPOCHS_ADAPT_BETA1: 38 | for param_group in self._optimizer.param_groups: 39 | param_group['betas'] = betas 40 | print('Momentum betas have been set to:', param_group['betas'], end=' ') 41 | 42 | def get_start_epoch(self): 43 | if self._metric_dtos['training']: 44 | return len([dto.loss for dto in self._metric_dtos['training']]) 45 | return 0 46 | 47 | def get_start_min_loss(self): 48 | if self._metric_dtos['validate']: 49 | return min([dto.loss for dto in self._metric_dtos['validate']]) 50 | return numpy.Inf 51 | 52 | def loss_step(self, dto: CaeDto, epoch): 53 | factor = min(0.04 * max(0, epoch - 25), 1) 54 | print(factor, end=' ') 55 | 56 | loss = 0.0 57 | divd = 5 + factor 58 | 59 | diff_penu_fuct = dto.reconstructions.gtruth.penu - dto.reconstructions.gtruth.interpolation 60 | diff_penu_core = dto.reconstructions.gtruth.penu - dto.reconstructions.gtruth.core 61 | loss += 1 * torch.mean(torch.abs(diff_penu_fuct) - diff_penu_fuct) 62 | loss += 1 * torch.mean(torch.abs(diff_penu_core) - diff_penu_core) 63 | 64 | loss += 1 * self._criterion(dto.reconstructions.gtruth.core, dto.given_variables.gtruth.core) 65 | loss += 1 * self._criterion(dto.reconstructions.gtruth.penu, dto.given_variables.gtruth.penu) 66 | loss += 1 * self._criterion(dto.reconstructions.gtruth.lesion, dto.given_variables.gtruth.lesion) 67 | 68 | loss += factor * torch.mean(torch.abs(dto.latents.gtruth.interpolation - dto.latents.gtruth.lesion)) 69 | 70 | return loss / divd 71 | 72 | def batch_metrics_step(self, dto: CaeDto, epoch): 73 | batch_metrics = MetricMeasuresDtoInit.init_dto() 74 | batch_metrics.lesion = metrics.binary_measures_torch(dto.reconstructions.gtruth.interpolation, 75 | dto.given_variables.gtruth.lesion, self.is_cuda) 76 | batch_metrics.core = metrics.binary_measures_torch(dto.reconstructions.gtruth.core, 77 | dto.given_variables.gtruth.core, self.is_cuda) 78 | batch_metrics.penu = metrics.binary_measures_torch(dto.reconstructions.gtruth.penu, 79 | dto.given_variables.gtruth.penu, self.is_cuda) 80 | return batch_metrics 81 | 82 | def print_epoch(self, epoch, phase, epoch_metrics): 83 | output = '\nEpoch {}/{} {} loss: {:.3} - DC:{:.3}, HD:{:.3}, ASSD:{:.3}, DC core:{:.3}, DC penu.:{:.3}' 84 | print(output.format(epoch + 1, self._n_epochs, phase, 85 | epoch_metrics.loss, 86 | epoch_metrics.lesion.dc, 87 | epoch_metrics.lesion.hd, 88 | epoch_metrics.lesion.assd, 89 | epoch_metrics.core.dc, 90 | epoch_metrics.penu.dc), end=' ') 91 | 92 | def plot_epoch(self, plot, epochs): 93 | plot.plot(epochs, [dto.loss for dto in self._metric_dtos['training']], 'r-') 94 | plot.plot(epochs, [dto.loss for dto in self._metric_dtos['validate']], 'g-') 95 | plot.plot(epochs, [dto.lesion.dc for dto in self._metric_dtos['validate']], 'k-') 96 | plot.plot(epochs, [dto.core.dc for dto in self._metric_dtos['validate']], 'c+') 97 | plot.plot(epochs, [dto.penu.dc for dto in self._metric_dtos['validate']], 'm+') 98 | plot.set_ylabel('L Train.(red)/Val.(green) | Dice Val. Lesion(b), Core(c), Penu(m)') 99 | plot.set_ylim(0, 1) 100 | ax2 = plot.twinx() 101 | ax2.plot(epochs, [dto.lesion.assd for dto in self._metric_dtos['validate']], 'b-') 102 | ax2.set_ylabel('Validation ASSD (blue)', color='b') 103 | ax2.tick_params('y', colors='b') 104 | 105 | def visualize_epoch(self, epoch): 106 | visual_samples, visual_times = util.get_vis_samples(self._dataloader_training, self._dataloader_validation) 107 | 108 | f, axarr = plt.subplots(len(visual_samples), 15) 109 | inc = 0 110 | for sample, time in zip(visual_samples, visual_times): 111 | 112 | col = 3 113 | for step in [None, -10, -1, 0, 1, 2, 3, 4, 5, 20]: 114 | dto = self.inference_step(sample, step) 115 | axarr[inc, col].imshow(dto.reconstructions.gtruth.interpolation.cpu().data.numpy()[0, 0, 14, :, :], 116 | vmin=0, vmax=1, cmap='gray') 117 | if col == 3: 118 | col += 1 119 | col += 1 120 | 121 | axarr[inc, 0].imshow(sample[data.KEY_IMAGES].numpy()[0, 0, 14, :, :], 122 | vmin=0, vmax=self.IMSHOW_VMAX_CBV, cmap='jet') 123 | axarr[inc, 1].imshow(sample[data.KEY_IMAGES].numpy()[0, 1, 14, :, :], 124 | vmin=0, vmax=self.IMSHOW_VMAX_TTD, cmap='jet') 125 | axarr[inc, 2].imshow(dto.given_variables.gtruth.lesion.cpu().data.numpy()[0, 0, 14, :, :], 126 | vmin=0, vmax=1, cmap='gray') 127 | axarr[inc, 4].imshow(dto.given_variables.gtruth.core.cpu().data.numpy()[0, 0, 14, :, :], 128 | vmin=0, vmax=1, cmap='gray') 129 | axarr[inc, 14].imshow(dto.given_variables.gtruth.penu.cpu().data.numpy()[0, 0, 14, :, :], 130 | vmin=0, vmax=1, cmap='gray') 131 | 132 | del sample 133 | del dto 134 | 135 | titles = ['CBV', 'TTD', 'Lesion', 'p(' + 136 | ('{:03.1f}'.format(float(time))) 137 | + 'h)', 'Core', 'p(-10h)', 'p(-1h)', 'p(0h)', 'p(1h)', 'p(2h)', 'p(3h)', 'p(4h)', 'p(5h)', 138 | 'p(20h)', 139 | 'Penumbra'] 140 | 141 | for ax, title in zip(axarr[inc], titles): 142 | ax.set_title(title) 143 | 144 | inc += 1 145 | 146 | for ax in axarr.flatten(): 147 | ax.title.set_fontsize(3) 148 | ax.xaxis.set_visible(False) 149 | ax.yaxis.set_visible(False) 150 | 151 | f.subplots_adjust(hspace=0.05) 152 | f.savefig(self._path_outputs_base + self.FN_VIS_BASE + str(epoch + 1) + '.png', bbox_inches='tight', dpi=300) 153 | 154 | del f 155 | del axarr 156 | -------------------------------------------------------------------------------- /learner/CaeStepLearner.py: -------------------------------------------------------------------------------- 1 | from learner.CaeReconstructionLearner import CaeReconstructionLearner 2 | from common.dto.CaeDto import CaeDto 3 | from torch.autograd import Variable 4 | import torch 5 | 6 | 7 | class CaeStepLearner(CaeReconstructionLearner): 8 | """ A Learner to learn best interpolation steps for the 9 | reconstruction shape space. Uses CaeDto data transfer objects. 10 | """ 11 | FN_VIS_BASE = '_cae1step_' 12 | FNB_MARKS = '_cae1step' 13 | N_EPOCHS_ADAPT_BETA1 = 4 14 | 15 | def loss_step(self, dto: CaeDto, epoch): 16 | loss = 0.0 17 | divd = 2 18 | diff_penu_fuct = dto.reconstructions.gtruth.penu - dto.reconstructions.gtruth.interpolation 19 | loss += 1 * torch.mean(torch.abs(diff_penu_fuct) - diff_penu_fuct) 20 | loss += 1 * self._criterion(dto.reconstructions.gtruth.interpolation, dto.given_variables.gtruth.lesion) 21 | return loss / divd 22 | 23 | def get_time_to_treatment(self, batch, global_variables, step): 24 | normalization = self._get_normalization(batch) 25 | if step is None: 26 | time_to_treatment = None 27 | else: 28 | time_to_treatment = Variable((step * torch.ones(global_variables.size()[0], 1)) / normalization).unsqueeze(2).unsqueeze(3).unsqueeze(4) 29 | return time_to_treatment -------------------------------------------------------------------------------- /learner/Learner.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import DataLoader 3 | from common.dto.Dto import Dto 4 | from common.inference.Inference import Inference 5 | from common.dto.MetricMeasuresDto import MetricMeasuresDto 6 | import common.dto.MetricMeasuresDto as MetricMeasuresDtoInit 7 | from torch.optim.optimizer import Optimizer 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | from torch.nn import Module 10 | import matplotlib.pyplot as plt 11 | import torch 12 | import numpy 13 | import jsonpickle 14 | 15 | 16 | class Learner(Inference): 17 | """Base class with a standard routine for 18 | a training procedure. The single steps can 19 | be overridden by subclasses to specify the 20 | procedures required for a specific training. 21 | """ 22 | FNB_MODEL = 'model' # filename base for model data 23 | FNB_OPTIM = 'optimizer' # filename base for optimizer state dict 24 | FNB_TRAIN = 'training' # filename base for training data 25 | FNB_PLOTS = 'plots' # filename base for plots form training 26 | FNB_IMAGE = 'visual' # filename base for sample visualizations 27 | FNB_MARKS = '_learner' # filename base for specific learner 28 | EXT_MODEL = '.model' # filename extension for model data 29 | EXT_OPTIM = '.optim' # filename extension for optimizer state dict 30 | EXT_TRAIN = '.json' # filename extension for training data 31 | EXT_IMAGE = '.png' # filename extension for any image data 32 | 33 | def __init__(self, dataloader_training: DataLoader, dataloader_validation: DataLoader, model: Module, 34 | optimizer: Optimizer, scheduler: _LRScheduler, n_epochs: int, path_previous_base: str = None, 35 | path_outputs_base: str = '/tmp/stroke-prediction'): 36 | # init inference 37 | Inference.__init__(self, model) 38 | 39 | # init learning data and optimizing schedule 40 | assert dataloader_training.batch_size > 1, 'For normalization layers batch_size > 1 is required.' 41 | self._dataloader_training = dataloader_training 42 | self._dataloader_validation = dataloader_validation 43 | self._optimizer = optimizer 44 | self._scheduler = scheduler 45 | self._n_epochs = n_epochs 46 | 47 | self._path_outputs_base = path_outputs_base 48 | self._path_previous_base = path_previous_base 49 | 50 | # load previous training data to continue 51 | if path_previous_base is not None: 52 | self.load_model(self.is_cuda) 53 | self.load_training() # restore training curves from previous training 54 | print('Continue training', path_previous_base, '...') 55 | else: 56 | self._metric_dtos = {'training': [], 'validate': []} 57 | assert len(self._metric_dtos['training']) == len(self._metric_dtos['validate']), 'Incomplete training data!' 58 | 59 | def path(self, mode: str, type: str, suffix: str=''): 60 | if mode == 'load': 61 | base_path = self._path_previous_base 62 | elif mode == 'save': 63 | base_path = self._path_outputs_base 64 | else: 65 | return None 66 | 67 | if type == self.FNB_MODEL: 68 | return base_path + self.FNB_MARKS + suffix + self.EXT_MODEL 69 | elif type == self.FNB_OPTIM: 70 | return base_path + self.FNB_MARKS + suffix + self.EXT_OPTIM 71 | elif type == self.FNB_TRAIN: 72 | return base_path + self.FNB_MARKS + suffix + self.EXT_TRAIN 73 | elif type == self.FNB_PLOTS: 74 | return base_path + self.FNB_MARKS + suffix + self.EXT_IMAGE 75 | elif type == self.FNB_IMAGE: 76 | return base_path + self.FNB_MARKS + suffix + self.EXT_IMAGE 77 | else: 78 | return None 79 | 80 | @abstractmethod 81 | def loss_step(self, dto: Dto, epoch): 82 | pass 83 | 84 | def get_start_epoch(self): 85 | return 0 86 | 87 | def get_start_min_loss(self): 88 | return numpy.Inf 89 | 90 | def load_model(self, cuda=True): 91 | path_model = self.path('load', self.FNB_MODEL) 92 | if cuda: 93 | self._model = torch.load(path_model).cuda() 94 | else: 95 | self._model = torch.load(path_model) 96 | 97 | def load_training(self): 98 | path_training = self.path('load', self.FNB_TRAIN) 99 | path_optimizer = self.path('load', self.FNB_OPTIM) 100 | print('Loading:', path_training, path_optimizer) 101 | self._optimizer.load_state_dict(torch.load(path_optimizer)) 102 | with open(path_training, 'r') as fp: 103 | self._metric_dtos = jsonpickle.decode(fp.read()) 104 | 105 | def save_training(self): 106 | path_training = self.path('save', self.FNB_TRAIN) 107 | path_optimizer = self.path('save', self.FNB_OPTIM) 108 | torch.save(self._optimizer.state_dict(), path_optimizer) 109 | with open(path_training, 'w') as fp: 110 | fp.write(jsonpickle.encode(self._metric_dtos)) 111 | 112 | def save_model(self, suffix=''): 113 | torch.save(self._model.cpu(), self.path('save', self.FNB_MODEL, suffix)) 114 | self._model.cuda() 115 | 116 | def train_batch(self, batch: dict, epoch) -> MetricMeasuresDto: 117 | dto = self.inference_step(batch) 118 | loss = self.loss_step(dto, epoch) 119 | 120 | self._optimizer.zero_grad() 121 | loss.backward() 122 | self._optimizer.step() 123 | 124 | batch_metrics = self.batch_metrics_step(dto, epoch) 125 | batch_metrics.loss = loss.squeeze().cpu().data.numpy()[0] 126 | 127 | del loss 128 | del dto 129 | 130 | return batch_metrics 131 | 132 | def validate_batch(self, batch: dict, epoch) -> MetricMeasuresDto: 133 | dto = self.inference_step(batch) 134 | loss = self.loss_step(dto, epoch) 135 | 136 | batch_metrics = self.batch_metrics_step(dto, epoch) 137 | batch_metrics.loss = loss.squeeze().cpu().data.numpy()[0] 138 | 139 | del loss 140 | del dto 141 | 142 | return batch_metrics 143 | 144 | def batch_metrics_step(self, dto: Dto, epoch) -> MetricMeasuresDto: 145 | return MetricMeasuresDtoInit.init_dto() 146 | 147 | def print_epoch(self, epoch, phase, epoch_metrics: MetricMeasuresDto): 148 | pass 149 | 150 | def plot_epoch(self, plotter, epochs): 151 | pass 152 | 153 | def visualize_epoch(self, epoch): 154 | pass 155 | 156 | def adapt_lr(self, epoch): 157 | if self._scheduler is not None: 158 | self._scheduler.step() 159 | 160 | def adapt_betas(self, epoch): 161 | pass 162 | 163 | def run_training(self): 164 | min_loss = self.get_start_min_loss() 165 | 166 | for epoch in range(self.get_start_epoch(), self._n_epochs): 167 | self.adapt_lr(epoch) 168 | self.adapt_betas(epoch) 169 | 170 | # ---------------------------- (1) TRAINING ---------------------------- # 171 | 172 | self._model.train() 173 | 174 | epoch_metrics = MetricMeasuresDtoInit.init_dto() 175 | for batch in self._dataloader_training: 176 | epoch_metrics.add(self.train_batch(batch, epoch)) 177 | epoch_metrics.div(len(self._dataloader_training)) 178 | del batch 179 | 180 | self.print_epoch(epoch, 'training', epoch_metrics) 181 | self._metric_dtos['training'].append(epoch_metrics) 182 | del epoch_metrics 183 | 184 | # ---------------------------- (2) VALIDATE ---------------------------- # 185 | 186 | self._model.eval() 187 | 188 | if self._dataloader_validation is None: 189 | epoch_metrics = MetricMeasuresDtoInit.init_dto(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190 | 0.0, 0.0) 191 | else: 192 | epoch_metrics = MetricMeasuresDtoInit.init_dto() 193 | for batch in self._dataloader_validation: 194 | epoch_metrics.add(self.validate_batch(batch, epoch)) 195 | epoch_metrics.div(len(self._dataloader_validation)) 196 | del batch 197 | 198 | self.print_epoch(epoch, 'validate', epoch_metrics) 199 | self._metric_dtos['validate'].append(epoch_metrics) 200 | del epoch_metrics 201 | 202 | # ------------ (3) SAVE MODEL / VISUALIZE (if new optimum) ------------ # 203 | 204 | if self._metric_dtos['validate'] and self._metric_dtos['validate'][-1].loss < min_loss: 205 | min_loss = self._metric_dtos['validate'][-1].loss 206 | self.save_model() 207 | self.save_training() # allows to continue if training has been interrupted 208 | print('(New optimum: Training saved)', end=' ') 209 | self.visualize_epoch(epoch) 210 | 211 | if epoch % 50 == 0: 212 | self.visualize_epoch(epoch) 213 | 214 | # ----------------- (4) PLOT / SAVE EVALUATION METRICS ---------------- # 215 | 216 | if epoch > 0: 217 | fig, plot = plt.subplots() 218 | self.plot_epoch(plot, range(1, epoch + 2)) 219 | fig.savefig(self._path_outputs_base + self.FN_VIS_BASE + 'plots.png', bbox_inches='tight', dpi=300) 220 | del plot 221 | del fig 222 | 223 | # ------------ (5) SAVE FINAL MODEL / VISUALIZE ------------- # 224 | 225 | self.save_model('_final') 226 | self.visualize_epoch(epoch) 227 | -------------------------------------------------------------------------------- /learner/UnetSegmentationLearner.py: -------------------------------------------------------------------------------- 1 | from common.inference.UnetInference import UnetInference 2 | import common.dto.MetricMeasuresDto as MetricMeasuresDtoInit 3 | from learner.Learner import Learner 4 | from common.dto.UnetDto import UnetDto 5 | import matplotlib.pyplot as plt 6 | from common import data, metrics, util 7 | import numpy 8 | 9 | 10 | class UnetSegmentationLearner(Learner, UnetInference): 11 | """ A Learner to train a Unet on shape segmentations. 12 | """ 13 | FNB_MARKS = '_unet' 14 | 15 | def __init__(self, dataloader_training, dataloader_validation, unet_model, optimizer, scheduler, n_epochs, 16 | criterion, path_previous_base=None, path_outputs_base='/tmp/unet-segmentation'): 17 | Learner.__init__(dataloader_training, dataloader_validation, unet_model, optimizer, scheduler, n_epochs, 18 | path_previous_base, path_outputs_base) 19 | self._criterion = criterion # main loss criterion 20 | 21 | def loss_step(self, dto: UnetDto, epoch): 22 | loss = 0.0 23 | divd = 2 24 | 25 | loss += 1 * self._criterion(dto.outputs.core, dto.given_variables.core) 26 | loss += 1 * self._criterion(dto.outputs.penu, dto.given_variables.penu) 27 | 28 | return loss / divd 29 | 30 | def batch_metrics_step(self, dto: UnetDto, epoch): 31 | batch_metrics = MetricMeasuresDtoInit.init_dto() 32 | batch_metrics.core = metrics.binary_measures_torch(dto.outputs.core, 33 | dto.given_variables.core, self.is_cuda) 34 | batch_metrics.penu = metrics.binary_measures_torch(dto.outputs.penu, 35 | dto.given_variables.penu, self.is_cuda) 36 | return batch_metrics 37 | 38 | def get_start_epoch(self): 39 | if self._metric_dtos['training']: 40 | return len([dto.loss for dto in self._metric_dtos['training']]) 41 | return 0 42 | 43 | def get_start_min_loss(self): 44 | if self._metric_dtos['validate']: 45 | return min([dto.loss for dto in self._metric_dtos['validate']]) 46 | return numpy.Inf 47 | 48 | def print_epoch(self, epoch, phase, epoch_metrics): 49 | output = '\nEpoch {}/{} {} loss: {:.3} - DC Core:{:.3}, DC Penumbra:{:.3}' 50 | print(output.format(epoch + 1, self._n_epochs, phase, 51 | epoch_metrics.loss, 52 | epoch_metrics.core.dc, 53 | epoch_metrics.penu.dc), end=' ') 54 | 55 | def plot_epoch(self, plot, epochs): 56 | plot.plot(epochs, [dto.loss for dto in self._metric_dtos['training']], 'r-') 57 | plot.plot(epochs, [dto.loss for dto in self._metric_dtos['validate']], 'g-') 58 | plot.plot(epochs, [dto.core.dc for dto in self._metric_dtos['validate']], 'c+') 59 | plot.plot(epochs, [dto.penu.dc for dto in self._metric_dtos['validate']], 'm+') 60 | plot.set_ylabel('L Train.(red)/Val.(green) | Dice Val. Core(c), Penu(m)') 61 | 62 | def visualize_epoch(self, epoch): 63 | visual_samples, visual_times = util.get_vis_samples(self._dataloader_training, self._dataloader_validation) 64 | 65 | pad = [20, 20, 20] 66 | 67 | f, axarr = plt.subplots(len(visual_samples), 6) 68 | inc = 0 69 | for sample in visual_samples: 70 | dto = self.inference_step(sample) 71 | zslice = 34 72 | axarr[inc, 0].imshow(sample[data.KEY_IMAGES].numpy()[0, 0, zslice, pad[1]:-pad[1], pad[2]:-pad[2]], 73 | vmin=0, vmax=self.IMSHOW_VMAX_CBV, cmap='jet') 74 | axarr[inc, 1].imshow(dto.given_variables.core.cpu().data.numpy()[0, 0, 14, :, :], 75 | vmin=0, vmax=1, cmap='gray') 76 | axarr[inc, 2].imshow(dto.outputs.core.cpu().data.numpy()[0, 0, 14, :, :], 77 | vmin=0, vmax=1, cmap='gray') 78 | axarr[inc, 3].imshow(dto.outputs.penu.cpu().data.numpy()[0, 0, 14, :, :], 79 | vmin=0, vmax=1, cmap='gray') 80 | axarr[inc, 4].imshow(dto.given_variables.penu.cpu().data.numpy()[0, 0, 14, :, :], 81 | vmin=0, vmax=1, cmap='gray') 82 | axarr[inc, 5].imshow(sample[data.KEY_IMAGES].numpy()[0, 1, zslice, pad[1]:-pad[1], pad[2]:-pad[2]], 83 | vmin=0, vmax=self.IMSHOW_VMAX_TTD, cmap='jet') 84 | 85 | del sample 86 | 87 | titles = ['CBV', 'Core GT', 'p(Core)', 'p(Penu.)', 'Penu. GT', 'TTD'] 88 | 89 | for ax, title in zip(axarr[inc], titles): 90 | ax.set_title(title) 91 | 92 | inc += 1 93 | 94 | for ax in axarr.flatten(): 95 | ax.title.set_fontsize(3) 96 | ax.xaxis.set_visible(False) 97 | ax.yaxis.set_visible(False) 98 | 99 | f.subplots_adjust(hspace=0.05) 100 | f.savefig(self._path_outputs_base + self.FN_VIS_BASE + str(epoch + 1) + '.png', bbox_inches='tight', dpi=300) 101 | 102 | del f 103 | del axarr -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.1.0 2 | MedPy==0.3.0 3 | nibabel==2.2.1 4 | numpy==1.13.3 5 | scikit-learn==0.19.1 6 | scipy==1.0.0 7 | torch==0.3.1 8 | torchvision==0.2.0 9 | jsonpickle==0.9.6 10 | -------------------------------------------------------------------------------- /sample_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/stroke-prediction/58da5be2c16637d47587cb09ac87ddebf028e093/sample_output.png -------------------------------------------------------------------------------- /test_sdm_resampling.py: -------------------------------------------------------------------------------- 1 | #from common.model.Unet3D import Unet3D TODO Unet live segmentation 2 | #import common.dto.UnetDto as UnetDtoInit TODO Unet live segmentation 3 | from common import data, util, metrics 4 | import torch 5 | import numpy as np 6 | import nibabel as nib 7 | from scipy import ndimage as ndi 8 | import scipy.ndimage.measurements as scim 9 | import scipy.ndimage.morphology as scimorph 10 | import datetime 11 | import matplotlib.pyplot as plt 12 | from torch.autograd import Variable 13 | 14 | 15 | def sdm_interpolate_numpy(core, penu, interpolation, threshold=0.5, zoom=12, dilate=3, resample=True): 16 | penu_bin = penu[0, 0, :, :, :] > threshold 17 | penu_dist = ndi.distance_transform_edt(penu_bin) 18 | penu_dist -= ndi.distance_transform_edt(penu[0, 0, :, :, :] < threshold) 19 | latent_penu = ndi.zoom(penu_dist, (1, 1.0 / zoom, 1.0 / zoom)) 20 | if not resample: 21 | recon_penu = penu_dist # NO DOWNSAMPLING 22 | del penu_dist 23 | del penu 24 | 25 | core_bin = (core[0, 0, :, :, :] > threshold) 26 | if not core_bin.any(): # all signal below threshold, thus missing binary segmentation 27 | cog = [int(v) for v in scim.center_of_mass(penu_bin)] 28 | core_bin[cog[0], cog[1], cog[2]] = 1 29 | core_bin = scimorph.binary_dilation(core_bin, iterations=dilate) 30 | print('------------------------------------> artifical core', cog) 31 | del penu_bin 32 | core_dist = ndi.distance_transform_edt(1 - core_bin) - ndi.distance_transform_edt( 33 | core[0, 0, :, :, :] > threshold) 34 | del core_bin 35 | del core 36 | if not resample: 37 | recon_core = core_dist # NO DOWNSAMPLING 38 | latent_core = ndi.zoom(core_dist, (1, 1.0 / zoom, 1.0 / zoom)) 39 | del core_dist 40 | 41 | if resample: 42 | recon_core = ndi.zoom(latent_core, (1, zoom, zoom))[:, 2:130, 2:130] 43 | recon_penu = ndi.zoom(latent_penu, (1, zoom, zoom))[:, 2:130, 2:130] 44 | 45 | latent_intp = latent_penu * interpolation - latent_core * (1 - interpolation) 46 | 47 | if not resample: 48 | recon_intp = recon_penu * interpolation - recon_core * (1 - interpolation) 49 | else: 50 | recon_intp = ndi.zoom(latent_intp, (1, zoom, zoom))[:, 2:130, 2:130] 51 | 52 | return recon_core, recon_intp, recon_penu, latent_core, latent_intp, latent_penu 53 | 54 | 55 | def get_normalized_time(batch, normalization_hours_penumbra): 56 | to_to_ta = batch[data.KEY_GLOBAL][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5).type(torch.FloatTensor) 57 | normalization = torch.ones(to_to_ta.size()[0], 1).type(torch.FloatTensor) * \ 58 | normalization_hours_penumbra - to_to_ta.squeeze().unsqueeze(data.DIM_CHANNEL_TORCH3D_5) 59 | return to_to_ta, normalization 60 | 61 | 62 | def infer(): 63 | args = util.get_args_sdm() 64 | 65 | print('Evaluate validation set', args.fold) 66 | 67 | # Params / Config 68 | normalization_hours_penumbra = 10 69 | #channels_unet = args.channels TODO Unet live segmentation 70 | #pad = args.padding TODO Unet live segmentation 71 | 72 | transform = [data.ResamplePlaneXY(args.xyresample), 73 | data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid), 74 | #data.PadImages(pad[0], pad[1], pad[2], pad_value=0), TODO Unet live segmentation 75 | data.ToTensor()] 76 | 77 | ds_test = data.get_testdata(modalities=['_unet_core', '_unet_penu'], # modalities=['_CBV_reg1_downsampled', '_TTD_reg1_downsampled'], TODO Unet live segmentation 78 | labels=['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled', 79 | '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'], 80 | transform=transform, 81 | indices=args.fold) 82 | 83 | # Unet 84 | #unet = None TODO Unet live segmentation 85 | #if not args.groundtruth: TODO Unet live segmentation 86 | # unet = Unet3D(channels=channels_unet) TODO Unet live segmentation 87 | # unet.load_state_dict(torch.load(args.unet)) TODO Unet live segmentation 88 | # unet.train(False) # fixate regularization for forward-only! TODO Unet live segmentation 89 | 90 | for sample in ds_test: 91 | case_id = sample[data.KEY_CASE_ID].cpu().numpy()[0] 92 | 93 | nifph = nib.load('/share/data_zoe1/lucas/Linda_Segmentations/' + str(case_id) + '/train' + str(case_id) + 94 | '_CBVmap_reg1_downsampled.nii.gz').affine 95 | 96 | to_to_ta, normalization = get_normalized_time(sample, normalization_hours_penumbra) 97 | 98 | lesion = Variable(sample[data.KEY_LABELS][:, 2, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 99 | if args.groundtruth: 100 | core = Variable(sample[data.KEY_LABELS][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 101 | penu = Variable(sample[data.KEY_LABELS][:, 1, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 102 | else: 103 | #dto = UnetDtoInit.init_dto(Variable(sample[data.KEY_IMAGES]), None, None) TODO Unet live segmentation 104 | #dto = unet(dto) TODO Unet live segmentation 105 | #core = dto.outputs.core TODO Unet live segmentation 106 | #penu = dto.outputs.penu, TODO Unet live segmentation 107 | core = Variable(sample[data.KEY_IMAGES][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 108 | penu = Variable(sample[data.KEY_IMAGES][:, 1, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5)) 109 | 110 | ta_to_tr = sample[data.KEY_GLOBAL][:, 1, :, :, :].squeeze().unsqueeze(data.DIM_CHANNEL_TORCH3D_5) 111 | time_to_treatment = Variable(ta_to_tr.type(torch.FloatTensor) / normalization) 112 | 113 | del to_to_ta 114 | del normalization 115 | 116 | recon_core, recon_intp, recon_penu, latent_core, latent_intp, latent_penu = \ 117 | sdm_interpolate_numpy(core.data.cpu().numpy(), penu.data.cpu().numpy(), threshold=0.5, 118 | interpolation=time_to_treatment.data.cpu().numpy().squeeze(), zoom=12, 119 | resample=args.downsample) 120 | 121 | print(int(sample[data.KEY_CASE_ID]), 'TO-->TR', float(time_to_treatment)) 122 | 123 | if args.visualinspection: 124 | fig, axes = plt.subplots(3, 4) 125 | 126 | axes[0, 0].imshow(core.cpu().data.numpy()[0, 0, 16, :, :], cmap='gray', vmin=0, vmax=1) 127 | axes[1, 0].imshow(lesion.cpu().data.numpy()[0, 0, 16, :, :], cmap='gray', vmin=0, vmax=1) 128 | axes[2, 0].imshow(penu.cpu().data.numpy()[0, 0, 16, :, :], cmap='gray', vmin=0, vmax=1) 129 | 130 | axes[0, 1].imshow(latent_core[16, :, :], cmap='gray') 131 | axes[1, 1].imshow(latent_intp[16, :, :], cmap='gray') 132 | axes[2, 1].imshow(latent_penu[16, :, :], cmap='gray') 133 | 134 | axes[0, 2].imshow(recon_core[16, :, :], cmap='gray') 135 | axes[1, 2].imshow(recon_intp[16, :, :], cmap='gray') 136 | axes[2, 2].imshow(recon_penu[16, :, :], cmap='gray') 137 | 138 | axes[0, 3].imshow(recon_core[16, :, :] < 0, cmap='gray', vmin=0, vmax=1) 139 | axes[1, 3].imshow(recon_intp[16, :, :] > 0, cmap='gray', vmin=0, vmax=1) 140 | axes[2, 3].imshow(recon_penu[16, :, :] > 0, cmap='gray', vmin=0, vmax=1) 141 | plt.show() 142 | 143 | results = metrics.binary_measures_numpy((recon_intp > 0).astype(np.float), 144 | lesion.cpu().data.numpy()[0, 0, :, :, :], binary_threshold=0.5) 145 | 146 | c_res = metrics.binary_measures_numpy((recon_core < 0).astype(np.float), 147 | core.cpu().data.numpy()[0, 0, :, :, :], binary_threshold=0.5) 148 | 149 | p_res = metrics.binary_measures_numpy((recon_penu > 0).astype(np.float), 150 | penu.cpu().data.numpy()[0, 0, :, :, :], binary_threshold=0.5) 151 | 152 | with open('/data_zoe1/lucas/Linda_Segmentations/tmp/sdm_results.txt', 'a') as f: 153 | print('Evaluate case: {} - DC:{:.3}, HD:{:.3}, ASSD:{:.3}, Core recon DC:{:.3}, Penu recon DC:{:.3}'.format(case_id, 154 | results.dc, results.hd, results.assd, c_res.dc, p_res.dc), file=f) 155 | 156 | zoomed = ndi.interpolation.zoom(recon_intp.transpose((2, 1, 0)), zoom=(2, 2, 1)) 157 | nib.save(nib.Nifti1Image((zoomed > 0).astype(np.float32), nifph), args.outbasepath + '_' + str(case_id) + '_lesion.nii.gz') 158 | del zoomed 159 | 160 | zoomed = ndi.interpolation.zoom(lesion.cpu().data.numpy().astype(np.int8).transpose((4, 3, 2, 1, 0))[:, :, :, 0, 0], zoom=(2, 2, 1)) 161 | nib.save(nib.Nifti1Image(zoomed, nifph), args.outbasepath + '_' + str(case_id) + '_fuctgt.nii.gz') 162 | del zoomed 163 | 164 | zoomed = ndi.interpolation.zoom(recon_core.transpose((2, 1, 0)), zoom=(2, 2, 1)) 165 | nib.save(nib.Nifti1Image((zoomed < 0).astype(np.float32), nifph), args.outbasepath + '_' + str(case_id) + '_core.nii.gz') 166 | del zoomed 167 | 168 | zoomed = ndi.interpolation.zoom(recon_penu.transpose((2, 1, 0)), zoom=(2, 2, 1)) 169 | nib.save(nib.Nifti1Image((zoomed > 0).astype(np.float32), nifph), args.outbasepath + '_' + str(case_id) + '_penu.nii.gz') 170 | 171 | del nifph 172 | 173 | del sample 174 | 175 | 176 | if __name__ == '__main__': 177 | print(datetime.datetime.now()) 178 | infer() 179 | print(datetime.datetime.now()) 180 | -------------------------------------------------------------------------------- /test_shape_reconstruction.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from tester.CaeReconstructionTester import CaeReconstructionTester 3 | from common import data, util 4 | 5 | 6 | def test(args): 7 | # Params / Config 8 | modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled'] 9 | labels = ['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled', 10 | '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'] 11 | normalization_hours_penumbra = args.normalize 12 | pad = args.padding 13 | pad_value = 0 14 | 15 | for idx in range(len(args.path)): 16 | # Data 17 | transform = [data.ResamplePlaneXY(args.xyresample), 18 | data.PadImages(pad[0], pad[1], pad[2], pad_value=pad_value), 19 | data.ToTensor()] 20 | ds_test = data.get_testdata(modalities=modalities, labels=labels, transform=transform, indices=args.fold[idx]) 21 | 22 | print('Size test set:', len(ds_test.sampler.indices), '| # batches:', len(ds_test)) 23 | 24 | # Single case evaluation 25 | tester = CaeReconstructionTester(ds_test, args.path[idx], args.outbasepath, normalization_hours_penumbra) 26 | tester.run_inference() 27 | 28 | 29 | if __name__ == '__main__': 30 | print(datetime.datetime.now()) 31 | test(util.get_args_shape_testing()) 32 | print(datetime.datetime.now()) 33 | -------------------------------------------------------------------------------- /test_shape_reconstruction_CurveAnalysis.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from tester.CaeReconstructionTesterCurve import CaeReconstructionTesterCurve 3 | from common import data, util 4 | 5 | 6 | def test(): 7 | args = util.get_args_shape_testing() 8 | 9 | assert len(args.fold) == len(args.path), 'You must provide as many --fold arguments as caepath model arguments\ 10 | in the exact same order!' 11 | 12 | # Params / Config 13 | modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled'] 14 | labels = ['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled', 15 | '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'] 16 | normalization_hours_penumbra = args.normalize 17 | steps = range(6) # fixed steps for tAdmission-->tReca: 0-5 hrs 18 | pad = args.padding 19 | pad_value = 0 20 | 21 | # Data 22 | transform = [data.ResamplePlaneXY(args.xyresample), 23 | data.PadImages(pad[0], pad[1], pad[2], pad_value=pad_value), 24 | data.ToTensor()] 25 | 26 | # Fold-wise evaluation according to fold indices and fold model for all folds and model path provided as arguments: 27 | for i, path in enumerate(args.path): 28 | print('Model ' + path + ' of fold ' + str(i+1) + '/' + str(len(args.fold)) + ' with indices: ' + str(args.fold[i])) 29 | ds_test = data.get_testdata(modalities=modalities, labels=labels, transform=transform, indices=args.fold[i]) 30 | print('Size test set:', len(ds_test.sampler.indices), '| # batches:', len(ds_test)) 31 | # Single case evaluation for all cases in fold 32 | tester = CaeReconstructionTesterCurve(ds_test, path, args.outbasepath, normalization_hours_penumbra, steps) 33 | tester.run_inference() 34 | 35 | 36 | if __name__ == '__main__': 37 | print(datetime.datetime.now()) 38 | test() 39 | print(datetime.datetime.now()) 40 | -------------------------------------------------------------------------------- /test_unet_segmentation.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from tester.UnetSegmentationTester import UnetSegmentationTester 3 | from common.model.Unet3D import Unet3D 4 | from common import data, util 5 | 6 | 7 | def test(args): 8 | 9 | # Params / Config 10 | modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled'] 11 | labels = ['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled'] 12 | path_saved_model = args.unetpath 13 | pad = args.padding 14 | pad_value = 0 15 | 16 | # Data 17 | # Trained on patches, but fully convolutional approach let us apply on bigger image (thus, omit patch transform) 18 | transform = [data.ResamplePlaneXY(args.xyresample), 19 | data.PadImages(pad[0], pad[1], pad[2], pad_value=pad_value), 20 | data.ToTensor()] 21 | ds_test = data.get_testdata(modalities=modalities, labels=labels, transform=transform, indices=args.fold) 22 | 23 | print('Size test set:', len(ds_test.sampler.indices), '| # batches:', len(ds_test)) 24 | 25 | # Single case evaluation 26 | tester = UnetSegmentationTester(ds_test, path_saved_model, args.outbasepath, None) 27 | tester.run_inference() 28 | 29 | 30 | if __name__ == '__main__': 31 | print(datetime.datetime.now()) 32 | args = util.get_args_unet_training() 33 | test(args) 34 | print(datetime.datetime.now()) 35 | -------------------------------------------------------------------------------- /tester/CaeReconstructionTester.py: -------------------------------------------------------------------------------- 1 | from common.inference.CaeInference import CaeInference 2 | from common.dto.CaeDto import CaeDto 3 | from common.dto.MetricMeasuresDto import MetricMeasuresDto 4 | import common.dto.MetricMeasuresDto as MetricMeasuresDtoInit 5 | from common import metrics, data 6 | from tester.Tester import Tester 7 | import scipy.ndimage.interpolation as ndi 8 | import nibabel as nib 9 | import numpy as np 10 | 11 | 12 | class CaeReconstructionTester(Tester, CaeInference): 13 | def __init__(self, dataloader, path_model, path_outputs_base='/tmp/', normalization_hours_penumbra=10): 14 | Tester.__init__(self, dataloader, path_model, path_outputs_base=path_outputs_base) 15 | CaeInference.__init__(self, self._model, normalization_hours_penumbra) 16 | # TODO: This needs some refactoring (double initialization of model, path etc) 17 | 18 | def batch_metrics_step(self, dto: CaeDto): 19 | batch_metrics = MetricMeasuresDtoInit.init_dto() 20 | batch_metrics.lesion = metrics.binary_measures_torch(dto.reconstructions.gtruth.interpolation, 21 | dto.given_variables.gtruth.lesion, self.is_cuda) 22 | batch_metrics.core = metrics.binary_measures_torch(dto.reconstructions.gtruth.core, 23 | dto.given_variables.gtruth.core, self.is_cuda) 24 | batch_metrics.penu = metrics.binary_measures_torch(dto.reconstructions.gtruth.penu, 25 | dto.given_variables.gtruth.penu, self.is_cuda) 26 | return batch_metrics 27 | 28 | def save_inference(self, dto: CaeDto, batch: dict, suffix=''): 29 | case_id = int(batch[data.KEY_CASE_ID]) 30 | # Output results on which metrics have been computed 31 | nifph = nib.load('/share/data_zoe1/lucas/Linda_Segmentations/' + str(case_id) + '/train' + str(case_id) + 32 | '_CBVmap_reg1_downsampled.nii.gz').affine 33 | image = np.transpose(dto.reconstructions.gtruth.core.cpu().data.numpy(), (4, 3, 2, 1, 0))[:, :, :, 0, 0] 34 | nib.save(nib.Nifti1Image(ndi.zoom(image, zoom=(2, 2, 1)), nifph), self._fn(case_id, '_core', suffix)) 35 | 36 | nifph = nib.load('/share/data_zoe1/lucas/Linda_Segmentations/' + str(case_id) + '/train' + str(case_id) + 37 | '_FUCT_MAP_T_Samplespace_reg1_downsampled.nii.gz').affine 38 | image = np.transpose(dto.reconstructions.gtruth.interpolation.cpu().data.numpy(), (4, 3, 2, 1, 0))[:, :, :, 0, 0] 39 | nib.save(nib.Nifti1Image(ndi.zoom(image, zoom=(2, 2, 1)), nifph), self._fn(case_id, '_pred', suffix)) 40 | 41 | nifph = nib.load('/share/data_zoe1/lucas/Linda_Segmentations/' + str(case_id) + '/train' + str(case_id) + 42 | '_TTDmap_reg1_downsampled.nii.gz').affine 43 | image = np.transpose(dto.reconstructions.gtruth.penu.cpu().data.numpy(), (4, 3, 2, 1, 0))[:, :, :, 0, 0] 44 | nib.save(nib.Nifti1Image(ndi.zoom(image, zoom=(2, 2, 1)), nifph), self._fn(case_id, '_penu', suffix)) 45 | 46 | def print_inference(self, batch: dict, batch_metrics: MetricMeasuresDto, dto: CaeDto, note=''): 47 | output = 'Case Id={}\ttA-tO={:.3f}\ttR-tA={:.3f}\tnormalized_time_to_treatment={:.3f}\t-->\ 48 | \tDC={:.3f}\tHD={:.3f}\tASSD={:.3f}\tDC Core={:.3f}\tDC Penumbra={:.3f}\t\ 49 | Precision={:.3}\tRecall/Sensitivity={:.3}\tSpecificity={:.3}\tDistToCornerPRC={:.3}\t{}' 50 | print(output.format(int(batch[data.KEY_CASE_ID]), 51 | float(batch[data.KEY_GLOBAL][:, 0, :, :, :]), 52 | float(batch[data.KEY_GLOBAL][:, 1, :, :, :]), 53 | float(dto.given_variables.time_to_treatment), 54 | batch_metrics.lesion.dc, 55 | batch_metrics.lesion.hd, 56 | batch_metrics.lesion.assd, 57 | batch_metrics.core.dc, 58 | batch_metrics.penu.dc, 59 | batch_metrics.lesion.precision, 60 | batch_metrics.lesion.sensitivity, 61 | batch_metrics.lesion.specificity, 62 | batch_metrics.lesion.prc_euclidean_distance, 63 | note)) -------------------------------------------------------------------------------- /tester/CaeReconstructionTesterCurve.py: -------------------------------------------------------------------------------- 1 | import common.data as data 2 | from tester.CaeReconstructionTester import CaeReconstructionTester 3 | 4 | 5 | class CaeReconstructionTesterCurve(CaeReconstructionTester): 6 | def __init__(self, dataloader, path_model, path_outputs_base='/tmp/', normalization_hours_penumbra=10, 7 | ta_to_tr_fixed_hours=range(11), ta_to_tr_relative_steps=[0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]): 8 | CaeReconstructionTester.__init__(self, dataloader, path_model, path_outputs_base=path_outputs_base, 9 | normalization_hours_penumbra=normalization_hours_penumbra) 10 | self._steps_fixed = ta_to_tr_fixed_hours 11 | self._steps_relative = ta_to_tr_relative_steps 12 | 13 | def infer_batch(self, batch: dict, step: float): 14 | dto = self.inference_step(batch, step) 15 | batch_metrics = self.batch_metrics_step(dto) 16 | return batch_metrics, dto 17 | 18 | def run_inference(self): 19 | for batch in self._dataloader: 20 | 21 | # 1) Evaluate on ground truth tA-->tR 22 | batch_metrics, dto = self.infer_batch(batch, None) 23 | self.print_inference(batch, batch_metrics, dto) 24 | self.save_inference(dto, batch) 25 | 26 | # 2) Evaluate metrics curve on fixed tA-->tR: 0 .. 5 hrs 27 | for step in self._steps_fixed: 28 | batch_metrics, dto = self.infer_batch(batch, step) 29 | self.print_inference(batch, batch_metrics, dto, 'ta_to_tr fixed=' + str(step)) 30 | 31 | # 3) Evaluate metrics curve on relative tA-->tR: 32 | ta_to_tr = float(batch[data.KEY_GLOBAL][:, 1, :, :, :]) 33 | for step in self._steps_relative: 34 | batch_metrics, dto = self.infer_batch(batch, step * ta_to_tr) 35 | self.print_inference(batch, batch_metrics, dto, 'ta_to_tr ratio=' + str(step) + '\t(' + str(step * ta_to_tr) + ')') 36 | 37 | # 4) Evaluate metrics curve on uniform interval [0,1] between core/penumbra 38 | to_to_ta = float(batch[data.KEY_GLOBAL][:, 0, :, :, :]) 39 | tr_to_penu = self._normalization_hours_penumbra - to_to_ta 40 | for step in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: 41 | batch_metrics, dto = self.infer_batch(batch, step * tr_to_penu) 42 | self.print_inference(batch, batch_metrics, dto, 'tr_to_penumbra=' + str(step) + '\t(' + str(step * tr_to_penu) + ')') -------------------------------------------------------------------------------- /tester/Tester.py: -------------------------------------------------------------------------------- 1 | from common.dto.Dto import Dto 2 | from common.inference.Inference import Inference 3 | from common.dto.MetricMeasuresDto import MetricMeasuresDto 4 | import common.dto.MetricMeasuresDto as MetricMeasuresDtoInit 5 | from torch.utils.data import DataLoader 6 | import torch 7 | 8 | 9 | class Tester(Inference): 10 | """Base class with a standard routine for 11 | a testing procedure. The single steps can 12 | be overridden by subclasses to specify the 13 | procedures required for a specific test run. 14 | """ 15 | 16 | def __init__(self, dataloader: DataLoader, path_model: str, path_outputs_base: str='/tmp/'): 17 | Inference.__init__(self, torch.load(path_model)) 18 | assert dataloader.batch_size == 1, "You must ensure a batch size of 1 for correct case metric measures." 19 | self._dataloader = dataloader 20 | self._path_outputs_base = path_outputs_base 21 | self._model.freeze(True) 22 | self._model.eval() 23 | 24 | def infer_batch(self, batch: dict): 25 | dto = self.inference_step(batch) 26 | batch_metrics = self.batch_metrics_step(dto) 27 | self.save_inference(dto, batch) 28 | return batch_metrics, dto 29 | 30 | def batch_metrics_step(self, dto: Dto): 31 | return MetricMeasuresDtoInit.init_dto() 32 | 33 | def _fn(self, case_id, type, suffix): 34 | return self._path_outputs_base + '_' + str(case_id) + str(type) + str(suffix) + '.nii.gz' 35 | 36 | def save_inference(self, dto: Dto, batch: dict): 37 | pass 38 | 39 | def print_inference(self, batch: dict, metrics: MetricMeasuresDto, dto: Dto = None): 40 | pass 41 | 42 | def run_inference(self): 43 | for batch in self._dataloader: 44 | batch_metrics, dto = self.infer_batch(batch) 45 | self.print_inference(batch, batch_metrics, dto) 46 | -------------------------------------------------------------------------------- /tester/UnetSegmentationTester.py: -------------------------------------------------------------------------------- 1 | from tester.Tester import Tester 2 | from common.inference.UnetInference import UnetInference 3 | from common.dto.UnetDto import UnetDto 4 | from common.dto.MetricMeasuresDto import MetricMeasuresDto 5 | import common.dto.MetricMeasuresDto as MetricMeasuresDtoInit 6 | from common import data, metrics 7 | import nibabel as nib 8 | import numpy as np 9 | import scipy.ndimage.interpolation as ndi 10 | 11 | 12 | class UnetSegmentationTester(Tester, UnetInference): 13 | def __init__(self, dataloader, path_model, path_outputs_base='/tmp/', padding=None): 14 | Tester.__init__(self, dataloader, path_model, path_outputs_base=path_outputs_base) 15 | self._pad = padding 16 | 17 | def batch_metrics_step(self, dto: UnetDto): 18 | batch_metrics = MetricMeasuresDtoInit.init_dto() 19 | batch_metrics.core = metrics.binary_measures_torch(dto.outputs.core, 20 | dto.given_variables.core, self.is_cuda) 21 | batch_metrics.penu = metrics.binary_measures_torch(dto.outputs.penu, 22 | dto.given_variables.penu, self.is_cuda) 23 | return batch_metrics 24 | 25 | def _transpose_unpad_zoom(self, image): 26 | image = np.transpose(image, (4, 3, 2, 1, 0)) 27 | if self._pad is not None: 28 | image = image[self._pad[0]:-self._pad[0], self._pad[1]:-self._pad[1], self._pad[2]:-self._pad[2], :, :] 29 | return ndi.zoom(image[:, :, :, 0, 0], zoom=(2, 2, 1)) 30 | 31 | def save_inference(self, dto: UnetDto, batch: dict, suffix=''): 32 | case_id = int(batch[data.KEY_CASE_ID]) 33 | # Output the results on which metrics have been computed 34 | nifph = nib.load('/share/data_zoe1/lucas/Linda_Segmentations/' + str(case_id) + '/train' + str(case_id) + 35 | '_TTDmap_reg1_downsampled.nii.gz').affine 36 | core = self._transpose_unpad_zoom(dto.outputs.core.cpu().data.numpy()) 37 | nib.save(nib.Nifti1Image(core, nifph), self._fn(case_id, '_core', suffix)) 38 | penu = self._transpose_unpad_zoom(dto.outputs.penu.cpu().data.numpy()) 39 | nib.save(nib.Nifti1Image(penu, nifph), self._fn(case_id, '_penu', suffix)) 40 | 41 | def print_inference(self, batch: dict, batch_metrics: MetricMeasuresDto, dto: UnetDto): 42 | output = 'Case Id {}:\t DC Core:{:.3},\tDC Penumbra:{:.3}' 43 | print(output.format(int(batch[data.KEY_CASE_ID]), 44 | batch_metrics.core.dc, 45 | batch_metrics.penu.dc)) 46 | -------------------------------------------------------------------------------- /train_interpolationstep_after_reconstruction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | from learner.CaeStepLearner import CaeStepLearner 4 | from common.model.Cae3D import Cae3D, Enc3DStep 5 | from common import data, util, metrics 6 | 7 | 8 | def train(args): 9 | # Params / Config 10 | learning_rate = 1e-3 11 | momentums_cae = (0.9, 0.999) 12 | weight_decay = 1e-5 13 | criterion = metrics.BatchDiceLoss([1.0]) # nn.BCELoss() 14 | channels_cae = args.channelscae 15 | n_globals = args.globals # type(core/penu), tO_to_tA, NHISS, sex, age 16 | resample_size = int(args.xyoriginal * args.xyresample) 17 | alpha = 1.0 18 | cuda = True 19 | 20 | # CAE model 21 | cae = torch.load(args.caepath) 22 | cae.freeze(True) 23 | enc = Enc3DStep(size_input_xy=resample_size, size_input_z=args.zsize, 24 | channels=channels_cae, n_ch_global=n_globals, alpha=alpha) 25 | enc.encoder = cae.enc.encoder # enc.step will be trained from scratch for given shape representations 26 | dec = cae.dec 27 | cae = Cae3D(enc, dec) 28 | 29 | if cuda: 30 | cae = cae.cuda() 31 | 32 | # Model params 33 | params = [p for p in cae.parameters() if p.requires_grad] 34 | print('# optimizing params', sum([p.nelement() * p.requires_grad for p in params]), 35 | '/ total: cae', sum([p.nelement() for p in cae.parameters()])) 36 | 37 | # Optimizer with scheduler 38 | optimizer = torch.optim.Adam(params, lr=learning_rate, weight_decay=weight_decay, betas=momentums_cae) 39 | if args.lrsteps: 40 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lrsteps) 41 | else: 42 | scheduler = None 43 | 44 | # Data 45 | common_transform = [data.ResamplePlaneXY(args.xyresample)] # before: FixedToCaseId(split_id=args.hemisflipid)] 46 | train_transform = common_transform + [data.HemisphericFlip(), data.ElasticDeform(), data.ToTensor()] 47 | valid_transform = common_transform + [data.ToTensor()] 48 | 49 | modalities = ['_CBV_reg1_downsampled', 50 | '_TTD_reg1_downsampled'] # dummy data only needed for visualization 51 | labels = ['_CBVmap_subset_reg1_downsampled', 52 | '_TTDmap_subset_reg1_downsampled', 53 | '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'] 54 | 55 | ds_train, ds_valid = data.get_stroke_shape_training_data(modalities, labels, train_transform, valid_transform, 56 | args.fold, args.validsetsize, batchsize=args.batchsize) 57 | print('Size training set:', len(ds_train.sampler.indices), 58 | 'samples | Size validation set:', len(ds_valid.sampler.indices), 59 | 'samples | Capacity batch:', args.batchsize, 'samples') 60 | print('# training batches:', len(ds_train), 61 | '| # validation batches:', len(ds_valid)) 62 | 63 | # Training 64 | learner = CaeStepLearner(ds_train, ds_valid, cae, optimizer, scheduler, 65 | n_epochs=args.epochs, 66 | path_previous_base=args.inbasepath, 67 | path_outputs_base=args.outbasepath, 68 | criterion=criterion) 69 | learner.run_training() 70 | 71 | 72 | if __name__ == '__main__': 73 | print(datetime.datetime.now()) 74 | args = util.get_args_step_training() 75 | train(args) 76 | print(datetime.datetime.now()) 77 | -------------------------------------------------------------------------------- /train_shape_prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | from learner.CaePredictionLearner import CaePredictionLearner 4 | from common import data, util, metrics 5 | from common.model.Cae3D import Enc3D 6 | 7 | 8 | def train(args): 9 | # Params / Config 10 | learning_rate = 1e-3 11 | momentums_cae = (0.9, 0.999) 12 | weight_decay = 1e-5 13 | criterion = metrics.BatchDiceLoss([1.0]) # nn.BCELoss() 14 | resample_size = int(args.xyoriginal * args.xyresample) 15 | n_globals = args.globals # type(core/penu), tO_to_tA, NHISS, sex, age 16 | channels_enc = args.channelsenc 17 | alpha = 1.0 18 | cuda = True 19 | 20 | # TODO assert initbycae XOR channels_enc 21 | 22 | # CAE model 23 | path_saved_model = args.caepath 24 | cae = torch.load(path_saved_model) 25 | cae.freeze(True) 26 | if args.initbycae: 27 | enc = torch.load(path_saved_model).enc 28 | else: 29 | enc = Enc3D(size_input_xy=resample_size, size_input_z=args.zsize, 30 | channels=channels_enc, n_ch_global=n_globals, alpha=alpha) 31 | 32 | if cuda: 33 | cae = cae.cuda() 34 | enc = enc.cuda() 35 | 36 | # Model params 37 | params = [p for p in enc.parameters() if p.requires_grad] 38 | print('# optimizing params', sum([p.nelement() * p.requires_grad for p in params]), 39 | '/ total new enc + old cae', sum([p for p in enc.parameters()] + [p.nelement() for p in cae.parameters()])) 40 | 41 | # Optimizer with scheduler 42 | optimizer = torch.optim.Adam(params, lr=learning_rate, weight_decay=weight_decay, betas=momentums_cae) 43 | if args.lrsteps: 44 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lrsteps) 45 | else: 46 | scheduler = None 47 | 48 | # Data 49 | common_transform = [data.ResamplePlaneXY(args.xyresample), 50 | data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid)] 51 | train_transform = common_transform + [data.ElasticDeform(apply_to_images=True), data.ToTensor()] 52 | valid_transform = common_transform + [data.ToTensor()] 53 | modalities = ['_unet_core', '_unet_penu'] 54 | labels = ['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled', 55 | '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'] 56 | ds_train, ds_valid = data.get_stroke_prediction_training_data(modalities, labels, train_transform, valid_transform, 57 | args.fold, args.validsetsize, batchsize=args.batchsize) 58 | print('Size training set:', len(ds_train.sampler.indices), 'samples | Size validation set:', len(ds_valid.sampler.indices), 59 | 'samples | Capacity batch:', args.batchsize, 'samples') 60 | print('# training batches:', len(ds_train), '| # validation batches:', len(ds_valid)) 61 | 62 | # Training 63 | learner = CaePredictionLearner(ds_train, ds_valid, cae, enc, optimizer, scheduler, 64 | n_epochs=args.epochs, 65 | path_previous_base=args.inbasepath, 66 | path_outputs_base=args.outbasepath, 67 | criterion=criterion) 68 | learner.run_training() 69 | 70 | 71 | if __name__ == '__main__': 72 | print(datetime.datetime.now()) 73 | args = util.get_args_shape_prediction_training() 74 | train(args) 75 | print(datetime.datetime.now()) 76 | -------------------------------------------------------------------------------- /train_shape_reconstruction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | from learner.CaeReconstructionLearner import CaeReconstructionLearner 4 | from common.model.Cae3D import Cae3D, Enc3D, Enc3DStep, Dec3D 5 | from common import data, util, metrics 6 | 7 | 8 | def train(args): 9 | # Params / Config 10 | use_validation = not args.steplearning 11 | learning_rate = 1e-3 12 | momentums_cae = (0.9, 0.999) 13 | weight_decay = 1e-5 14 | criterion = metrics.BatchDiceLoss([1.0]) # nn.BCELoss() 15 | channels_cae = args.channelscae 16 | n_globals = args.globals # type(core/penu), tO_to_tA, NHISS, sex, age 17 | resample_size = int(args.xyoriginal * args.xyresample) 18 | alpha = 1.0 19 | cuda = True 20 | 21 | # CAE model 22 | if args.steplearning: 23 | enc = Enc3DStep(size_input_xy=resample_size, size_input_z=args.zsize, 24 | channels=channels_cae, n_ch_global=n_globals, alpha=alpha) 25 | else: 26 | enc = Enc3D(size_input_xy=resample_size, size_input_z=args.zsize, 27 | channels=channels_cae, n_ch_global=n_globals, alpha=alpha) 28 | dec = Dec3D(size_input_xy=resample_size, size_input_z=args.zsize, 29 | channels=channels_cae, n_ch_global=n_globals, alpha=alpha) 30 | cae = Cae3D(enc, dec) 31 | if cuda: 32 | cae = cae.cuda() 33 | 34 | # Model params 35 | params = [p for p in cae.parameters() if p.requires_grad] 36 | print('# optimizing params', sum([p.nelement() * p.requires_grad for p in params]), 37 | '/ total: cae', sum([p.nelement() for p in cae.parameters()])) 38 | 39 | # Optimizer with scheduler 40 | optimizer = torch.optim.Adam(params, lr=learning_rate, weight_decay=weight_decay, betas=momentums_cae) 41 | if args.lrsteps: 42 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lrsteps) 43 | else: 44 | scheduler = None 45 | 46 | # Data 47 | common_transform = [data.ResamplePlaneXY(args.xyresample)] # before: FixedToCaseId(split_id=args.hemisflipid)] 48 | train_transform = common_transform + [data.HemisphericFlip(), data.ElasticDeform(), data.ToTensor()] 49 | valid_transform = common_transform + [data.ToTensor()] 50 | 51 | modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled'] # dummy data only needed for visualization 52 | labels = ['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled', 53 | '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'] 54 | 55 | ds_train, ds_valid = data.get_stroke_shape_training_data(modalities, labels, train_transform, valid_transform, 56 | args.fold, args.validsetsize, batchsize=args.batchsize, split=use_validation) 57 | if use_validation: 58 | print('Size training set:', len(ds_train.sampler.indices), 'samples | Size validation set:', len(ds_valid.sampler.indices), 59 | 'samples | Capacity batch:', args.batchsize, 'samples') 60 | print('# training batches:', len(ds_train), '| # validation batches:', len(ds_valid)) 61 | else: 62 | print('Size training set:', len(ds_train.sampler.indices), 63 | 'samples | Size validation set: 0 samples | Capacity batch:', args.batchsize, 'samples') 64 | print('# training batches:', len(ds_train), '| # validation batches:', 0) 65 | 66 | # Training 67 | learner = CaeReconstructionLearner(ds_train, ds_valid, cae, optimizer, scheduler, 68 | n_epochs=args.epochs, 69 | path_previous_base=args.inbasepath, 70 | path_outputs_base=args.outbasepath, 71 | criterion=criterion) 72 | learner.run_training() 73 | 74 | 75 | if __name__ == '__main__': 76 | print(datetime.datetime.now()) 77 | args = util.get_args_shape_training() 78 | train(args) 79 | print(datetime.datetime.now()) 80 | -------------------------------------------------------------------------------- /train_shape_reconstruction_with_ctp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | from learner.CaeReconstructionLearner import CaeReconstructionLearner 4 | from common.model.Cae3D import Cae3DCtp, Enc3DCtp, Dec3D 5 | from common import data, util, metrics 6 | 7 | 8 | def train(): 9 | args = util.get_args_shape_training() 10 | 11 | # Params / Config 12 | learning_rate = 1e-3 13 | momentums_cae = (0.99, 0.999) 14 | criterion = metrics.BatchDiceLoss([1.0]) # nn.BCELoss() 15 | path_training_metrics = args.continuetraining # --continuetraining /share/data_zoe1/lucas/Linda_Segmentations/tmp/tmp_shape_f3.json 16 | path_saved_model = args.caepath 17 | channels_cae = args.channelscae 18 | n_globals = args.globals # type(core/penu), tO_to_tA, NHISS, sex, age 19 | resample_size = int(args.xyoriginal * args.xyresample) 20 | pad = args.padding 21 | pad_value = 0 22 | leakage = 0.01 23 | cuda = True 24 | 25 | # CAE model 26 | enc = Enc3DCtp(size_input_xy=resample_size, size_input_z=args.zsize, 27 | channels=channels_cae, n_ch_global=n_globals, leakage=leakage, padding=pad) 28 | dec = Dec3D(size_input_xy=resample_size, size_input_z=args.zsize, 29 | channels=channels_cae, n_ch_global=n_globals, leakage=leakage) 30 | cae = Cae3DCtp(enc, dec) 31 | if cuda: 32 | cae = cae.cuda() 33 | 34 | # Model params 35 | params = [p for p in cae.parameters() if p.requires_grad] 36 | print('# optimizing params', sum([p.nelement() * p.requires_grad for p in params]), 37 | '/ total: cae', sum([p.nelement() for p in cae.parameters()])) 38 | 39 | # Optimizer with scheduler 40 | optimizer = torch.optim.Adam(params, lr=learning_rate, weight_decay=1e-5, betas=momentums_cae) 41 | if args.lrsteps: 42 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lrsteps) 43 | else: 44 | scheduler = None 45 | 46 | # Data 47 | common_transform = [data.ResamplePlaneXY(args.xyresample), 48 | data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid), 49 | data.PadImages(pad[0], pad[1], pad[2], pad_value=pad_value)] 50 | train_transform = common_transform + [data.ElasticDeform(), data.ToTensor()] 51 | valid_transform = common_transform + [data.ToTensor()] 52 | modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled'] 53 | labels = ['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled', 54 | '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'] 55 | ds_train, ds_valid = data.get_stroke_shape_training_data(modalities, labels, train_transform, valid_transform, 56 | args.fold, args.validsetsize, batchsize=args.batchsize) 57 | print('Size training set:', len(ds_train.sampler.indices), 'samples | Size validation set:', len(ds_valid.sampler.indices), 58 | 'samples | Capacity batch:', args.batchsize, 'samples') 59 | print('# training batches:', len(ds_train), '| # validation batches:', len(ds_valid)) 60 | 61 | # Training 62 | learner = CaeReconstructionLearner(ds_train, ds_valid, cae, path_saved_model, optimizer, scheduler, 63 | path_outputs_base=args.outbasepath) 64 | learner.run_training() 65 | 66 | 67 | if __name__ == '__main__': 68 | print(datetime.datetime.now()) 69 | train() 70 | print(datetime.datetime.now()) 71 | -------------------------------------------------------------------------------- /train_unet_segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | from learner.UnetSegmentationLearner import UnetSegmentationLearner 4 | from common.model.Unet3D import Unet3D 5 | from common import data, util, metrics 6 | 7 | 8 | def train(): 9 | args = util.get_args_unet_training() 10 | 11 | # Params / Config 12 | batchsize = 6 # 17 training, 6 validation 13 | learning_rate = 1e-3 14 | momentums_cae = (0.99, 0.999) 15 | criterion = metrics.BatchDiceLoss([1.0]) # nn.BCELoss() 16 | path_saved_model = args.unetpath 17 | channels = args.channels 18 | pad = args.padding 19 | cuda = True 20 | 21 | # Unet model 22 | unet = Unet3D(channels) 23 | if cuda: 24 | unet = unet.cuda() 25 | 26 | # Model params 27 | params = [p for p in unet.parameters() if p.requires_grad] 28 | print('# optimizing params', sum([p.nelement() * p.requires_grad for p in params]), 29 | '/ total: unet', sum([p.nelement() for p in unet.parameters()])) 30 | 31 | # Optimizer with scheduler 32 | optimizer = torch.optim.Adam(params, lr=learning_rate, weight_decay=1e-5, betas=momentums_cae) 33 | if args.lrsteps: 34 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lrsteps) 35 | else: 36 | scheduler = None 37 | 38 | # Data 39 | train_transform = [data.ResamplePlaneXY(args.xyresample), 40 | data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid), 41 | data.PadImages(pad[0], pad[1], pad[2], pad_value=0), 42 | data.RandomPatch(104, 104, 68, pad[0], pad[1], pad[2]), 43 | data.ToTensor()] 44 | valid_transform = [data.ResamplePlaneXY(args.xyresample), 45 | data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid), 46 | data.PadImages(pad[0], pad[1], pad[2], pad_value=0), 47 | data.RandomPatch(104, 104, 68, pad[0], pad[1], pad[2]), 48 | data.ToTensor()] 49 | ds_train, ds_valid = data.get_stroke_shape_training_data(train_transform, valid_transform, args.fold, 50 | args.validsetsize, batchsize=batchsize) 51 | print('Size training set:', len(ds_train.sampler.indices), 'samples | Size validation set:', len(ds_valid.sampler.indices), 52 | 'samples | Capacity batch:', batchsize, 'samples') 53 | print('# training batches:', len(ds_train), '| # validation batches:', len(ds_valid)) 54 | 55 | # Training 56 | learner = UnetSegmentationLearner(ds_train, ds_valid, unet, path_saved_model, optimizer, scheduler, criterion, 57 | path_previous_base=args.inbasepath, path_outputs_base=args.outbasepath) 58 | learner.run_training() 59 | 60 | 61 | if __name__ == '__main__': 62 | print(datetime.datetime.now()) 63 | train() 64 | print(datetime.datetime.now()) 65 | --------------------------------------------------------------------------------