├── src └── ovseg │ ├── HPO │ └── __init__.py │ ├── run │ ├── __init__.py │ └── run_inference.py │ ├── data │ ├── __init__.py │ ├── SegmentationDataV2.py │ ├── SegmentationData.py │ ├── DataBase.py │ ├── utils.py │ ├── SegmentationDataloader.py │ └── SegmentationDoubleBiasDataloader.py │ ├── model │ ├── __init__.py │ ├── SegmentationEnsembleV2.py │ ├── SegmentationModelV2.py │ ├── ClaraWrappers.py │ └── SegmentationEnsemble.py │ ├── training │ ├── __init__.py │ ├── TrainingBase.py │ └── loss_functions_combined.py │ ├── utils │ ├── __init__.py │ ├── path_utils.py │ ├── torch_np_utils.py │ ├── download_pretrained_utils.py │ ├── torch_morph.py │ ├── dict_equal.py │ └── label_utils.py │ ├── augmentation │ ├── __init__.py │ ├── SegmentationAugmentation.py │ ├── ConcatenatedAugmentation.py │ ├── AffineAugmentation.py │ ├── GridAugmentation.py │ ├── MaskAugmentation.py │ └── myRandAugment.py │ ├── prediction │ ├── __init__.py │ └── SlidingWindowPrediction.py │ ├── preprocessing │ └── __init__.py │ ├── postprocessing │ └── __init__.py │ ├── networks │ ├── __init__.py │ └── custom_normalization.py │ └── __init__.py ├── ovseg_manual.pdf ├── nifti_in_itk_snap_example.png ├── example_scripts ├── convert_kits21_to_raw_data.py ├── example_kits21_kidneys_low.py ├── my_preprocessing.py ├── example_kits21_kidneys_cascade_full.py ├── example_kits21_masses_deep_supervision.py ├── my_training.py └── train_my_cascade.py ├── LICENSE ├── setup.py ├── preprocess_ovarian_data.py ├── .gitignore ├── test_global_metrics.py ├── run_training.py ├── plan_and_preprocess.py └── README.md /src/ovseg/HPO/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /src/ovseg/run/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /ovseg_manual.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasBudd/ovseg/HEAD/ovseg_manual.pdf -------------------------------------------------------------------------------- /src/ovseg/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /src/ovseg/model/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /src/ovseg/training/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /src/ovseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /src/ovseg/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /src/ovseg/prediction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /src/ovseg/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /src/ovseg/postprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /nifti_in_itk_snap_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasBudd/ovseg/HEAD/nifti_in_itk_snap_example.png -------------------------------------------------------------------------------- /src/ovseg/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | from ovseg.networks.UNet import UNet 4 | from ovseg.networks.resUNet import UNetResEncoder, UNetResEncoderV2, UNetResDecoder, UResNet, UNetResStemEncoder, UNetResShuffleEncoder 5 | -------------------------------------------------------------------------------- /src/ovseg/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | import os 4 | 5 | if 'OV_DATA_BASE' not in os.environ: 6 | OV_DATA_BASE = os.path.join(os.path.dirname(__file__), 'ov_data_base') 7 | print('No environment variable OV_DATA_BASE specified. Model weights, ' 8 | 'the resulting segmentations etc. will be stored at ' 9 | f'{OV_DATA_BASE}.') 10 | os.environ['OV_DATA_BASE'] = OV_DATA_BASE 11 | else: 12 | OV_DATA_BASE = os.environ['OV_DATA_BASE'] 13 | 14 | os.makedirs(OV_DATA_BASE, exist_ok=True) 15 | 16 | if 'OV_PREPROCESSED' in os.environ: 17 | OV_PREPROCESSED = os.environ['OV_PREPROCESSED'] 18 | else: 19 | OV_PREPROCESSED = os.path.join(OV_DATA_BASE, 'preprocessed') 20 | -------------------------------------------------------------------------------- /example_scripts/convert_kits21_to_raw_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nibabel as nib 3 | import matplotlib.pyplot as plt 4 | import shutil 5 | from tqdm import tqdm 6 | 7 | dp = 'PATH_TO_KITS_DATA_FOLDER' # should be .../kits21/kits21/data 8 | tp = os.path.join(os.environ['OV_DATA_BASE'], 9 | 'raw_data', 'kits21') 10 | 11 | timp = os.path.join(tp, 'images') 12 | tlbp = os.path.join(tp, 'labels') 13 | 14 | for p in [timp, tlbp]: 15 | if not os.path.exists(p): 16 | os.makedirs(p) 17 | 18 | for case in tqdm(os.listdir(dp)): 19 | 20 | if os.path.exists(os.path.join(dp, case, 'aggregated_MAJ_seg.nii.gz')): 21 | shutil.copy(os.path.join(dp, case, 'imaging.nii.gz'), 22 | os.path.join(timp, case+'.nii.gz')) 23 | shutil.copy(os.path.join(dp, case, 'aggregated_MAJ_seg.nii.gz'), 24 | os.path.join(tlbp, case+'.nii.gz')) 25 | -------------------------------------------------------------------------------- /src/ovseg/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | from os.path import join, exists, split, isdir 2 | from os import listdir, mkdir, sep 3 | 4 | 5 | def maybe_create_path(path): 6 | 7 | if path: 8 | counter = 0 9 | subfs = [] 10 | bp = path 11 | 12 | while (not exists(bp)) and counter < 100: 13 | if bp.find(sep) >= 0: 14 | bp, subf = split(bp) 15 | subfs.append(subf) 16 | else: 17 | break 18 | 19 | if len(subfs) > 0: 20 | 21 | for subf in subfs[::-1]: 22 | mkdir(join(bp, subf)) 23 | bp = join(bp, subf) 24 | else: 25 | if not exists(bp): 26 | mkdir(bp) 27 | 28 | 29 | def my_listdir(path, return_pathes=False): 30 | content = listdir(path) 31 | content.sort() 32 | if return_pathes: 33 | return [join(path, cont) for cont in content] 34 | else: 35 | return content 36 | -------------------------------------------------------------------------------- /src/ovseg/augmentation/SegmentationAugmentation.py: -------------------------------------------------------------------------------- 1 | from ovseg.augmentation.ConcatenatedAugmentation import torch_concatenated_augmentation, \ 2 | np_concatenated_augmentation 3 | 4 | 5 | class SegmentationAugmentation(object): 6 | ''' 7 | SegmentationAugmentation(...) 8 | 9 | Performs spatial and gray value augmentations 10 | ''' 11 | 12 | def __init__(self, torch_params={}, np_params={}): 13 | 14 | self.torch_params = torch_params 15 | self.np_params = np_params 16 | 17 | self.torch_augmentation = torch_concatenated_augmentation(self.torch_params) 18 | if self.np_params == {}: 19 | self.np_augmentation = None 20 | else: 21 | self.np_augmentation = np_concatenated_augmentation(self.np_params) 22 | 23 | def update_prg_trn(self, param_dict, h): 24 | 25 | self.torch_augmentation.update_prg_trn(param_dict, h) 26 | if self.np_augmentation is not None: 27 | self.np_augmentation.update_prg_trn(param_dict, h) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ThomasBudd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | setup( 5 | # Needed to silence warnings (and to be a worthwhile package) 6 | name='ovseg', 7 | url='https://https://github.com/ThomasBudd/ovseg', 8 | author='Thomas Buddenkotte', 9 | author_email='thomasbuddenkotte@googlemail.com', 10 | # Needed to actually package something 11 | packages=find_packages('src'), 12 | package_dir={'': 'src'}, 13 | # Needed for dependencies 14 | install_requires=[ 15 | "torch>=1.7.0", 16 | "tqdm", 17 | "scikit-image>=0.14", 18 | "scipy", 19 | "numpy", 20 | "nibabel", 21 | "rt_utils" 22 | ], 23 | entry_points={'console_scripts': ['ovseg_inference = ovseg.run.run_inference:main']}, 24 | # *strongly* suggested for sharing 25 | version='1.0', 26 | # The license can be anything you like 27 | license='MIT', 28 | description='A deep learning based libary for ovarian cancer segmentation', 29 | # We will also need a readme eventually (there will be a warning) 30 | # long_description=open('README.txt').read(), 31 | ) 32 | -------------------------------------------------------------------------------- /src/ovseg/utils/torch_np_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def stack(items, axis=0): 6 | # wrapps np.stack and torch.stack 7 | if isinstance(items[0], np.ndarray): 8 | # numpy interpolation 9 | return np.stack(items, axis=axis) 10 | elif torch.is_tensor(items[0]): 11 | # torch interpolation 12 | return torch.stack(items, dim=axis) 13 | else: 14 | # error 15 | raise ValueError('Input of stack must be np.ndarray or torch.tensor.') 16 | 17 | 18 | def check_type(inpt): 19 | is_np = isinstance(inpt, np.ndarray) 20 | is_torch = torch.is_tensor(inpt) 21 | if not is_np and not is_torch: 22 | raise TypeError('Expected input to be np.ndarray or torch.tensor. ' 23 | 'Got {}'.format(type(inpt))) 24 | return is_np, is_torch 25 | 26 | 27 | def maybe_add_channel_dim(inpt): 28 | 29 | is_np, _ = check_type(inpt) 30 | if len(inpt.shape) == 3: 31 | if is_np: 32 | return inpt[np.newaxis] 33 | else: 34 | return inpt.unsqueeze(0) 35 | elif len(inpt.shape) == 4: 36 | return inpt 37 | else: 38 | raise ValueError('Expected input to be 3d or 4d, got {}d'.format(len(inpt.shape))) 39 | -------------------------------------------------------------------------------- /preprocess_ovarian_data.py: -------------------------------------------------------------------------------- 1 | from ovseg.preprocessing.SegmentationPreprocessing import SegmentationPreprocessing 2 | from ovseg import OV_PREPROCESSED 3 | import os 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("raw_data", default='OV04', nargs='+',) 8 | args = parser.parse_args() 9 | 10 | raw_data = args.raw_data 11 | data_name = '_'.join(sorted(raw_data)) 12 | 13 | lb_classes_list = [[1, 9], [1, 2, 3, 5, 6, 7], [13, 14, 15, 17]] 14 | p_name_list = ['pod_om', 'abdominal_lesions', 'lymph_nodes'] 15 | target_spacing_list = [[5.0, 0.8, 0.8], [5.0, 0.8, 0.8], [5.0, 0.67, 0.67]] 16 | 17 | for lb_classes, p_name, target_spacing in zip(lb_classes_list, 18 | p_name_list, 19 | target_spacing_list): 20 | prep = SegmentationPreprocessing(apply_resizing=True, 21 | apply_pooling=False, 22 | apply_windowing=True, 23 | lb_classes=lb_classes, 24 | target_spacing=target_spacing, 25 | save_only_fg_scans=False) 26 | 27 | prep.plan_preprocessing_raw_data(raw_data) 28 | prep.preprocess_raw_data(raw_data, p_name) 29 | 30 | print('Converted the following datasets:') 31 | print(raw_data) 32 | print('The preprocessed data is stored under the name '+data_name) 33 | print(data_name) -------------------------------------------------------------------------------- /src/ovseg/data/SegmentationDataV2.py: -------------------------------------------------------------------------------- 1 | from ovseg.data.DataBase import DataBase 2 | from ovseg.data.SegmentationDataloaderV2 import SegmentationDataloaderV2 3 | 4 | 5 | class SegmentationDataV2(DataBase): 6 | 7 | def __init__(self, augmentation=None, use_double_bias=False, *args, **kwargs): 8 | self.augmentation = augmentation 9 | self.use_double_bias = use_double_bias 10 | super().__init__(*args, **kwargs) 11 | 12 | def initialise_dataloader(self, is_train): 13 | if is_train: 14 | print('Initialise training dataloader') 15 | 16 | self.trn_dl = SegmentationDataloaderV2(self.trn_ds, 17 | augmentation=self.augmentation, 18 | **self.trn_dl_params) 19 | else: 20 | print('Initialise validation dataloader') 21 | try: 22 | self.val_dl = SegmentationDataloaderV2(self.val_ds, 23 | augmentation=self.augmentation, 24 | **self.val_dl_params) 25 | 26 | except (AttributeError, TypeError): 27 | print('No validatation dataloader initialised') 28 | self.val_dl = None 29 | 30 | def clean(self): 31 | self.trn_dl.dataset._maybe_clean_stored_data() 32 | if self.val_dl is not None: 33 | self.val_dl.dataset._maybe_clean_stored_data() 34 | -------------------------------------------------------------------------------- /src/ovseg/networks/custom_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class no_z_InstNorm(nn.Module): 6 | 7 | def __init__(self, n_channels, **kwargs): 8 | super().__init__() 9 | self.norm = nn.InstanceNorm2d(n_channels, **kwargs) 10 | 11 | def forward(self, xb): 12 | 13 | nb, nc, nz, nx, ny = xb.shape 14 | 15 | # put the z slices in the batch dimension 16 | xb = xb.permute((0, 2, 1, 3, 4)).reshape((nb*nz, nc, nx, ny)) 17 | xb = self.norm(xb) 18 | # undo 19 | xb = xb.reshape((nb, nz, nc, nx, ny)).permute((0, 2, 1, 3, 4)) 20 | 21 | return xb 22 | 23 | class my_LayerNorm(nn.Module): 24 | 25 | def __init__(self, n_channels, affine=True, eps=1e-5): 26 | super().__init__() 27 | 28 | self.n_channels = n_channels 29 | self.affine = affine 30 | self.eps = eps 31 | 32 | self.gamma = nn.Parameter(torch.ones((1, self.n_channels, 1, 1, 1))) 33 | 34 | if self.affine: 35 | self.beta = nn.Parameter(torch.zeros((1, self.n_channels, 1, 1, 1))) 36 | 37 | 38 | def forward(self, xb): 39 | 40 | # normalize 41 | xb = (xb - torch.mean(xb, 1, keepdim=True))/(torch.std(xb, 1, unbiased=False, keepdim=True) + self.eps) 42 | # affine trafo 43 | xb = xb * self.gamma 44 | 45 | if self.affine: 46 | xb = xb + self.beta 47 | 48 | return xb 49 | -------------------------------------------------------------------------------- /src/ovseg/utils/download_pretrained_utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | import zipfile 4 | from tqdm import tqdm 5 | 6 | def download_and_install(url): 7 | # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests 8 | local_filename = os.path.join(os.environ['OV_DATA_BASE'], 'temp.zip') 9 | # NOTE the stream=True parameter below 10 | with requests.get(url, stream=True) as r: 11 | r.raise_for_status() 12 | with open(local_filename, 'wb') as f: 13 | for chunk in tqdm(r.iter_content(chunk_size=8192 * 16)): 14 | # If you have chunk encoded response uncomment if 15 | # and set chunk_size parameter to None. 16 | # if chunk: 17 | f.write(chunk) 18 | 19 | # extracting the zip and removing zip file 20 | # borrowed from nnUNet: 21 | # https://github.com/MIC-DKFZ/nnUNet/blob/6d02b5a4e2a7eae14361cde9599bbf4ccde2cd37/nnunet/inference/pretrained_models/download_pretrained_model.py#L294 22 | with zipfile.ZipFile(local_filename, 'r') as zip_ref: 23 | zip_ref.extractall(os.environ['OV_DATA_BASE']) 24 | 25 | if os.path.isfile(local_filename): 26 | os.remove(local_filename) 27 | 28 | 29 | def maybe_download_clara_models(): 30 | 31 | if 'OV_DATA_BASE' not in os.environ: 32 | raise FileNotFoundError('Environment variable \'OV_DATA_BASE\' was not set.' 33 | 'Please do so to specify where pretrained models, raw data' 34 | 'and predictions should be stored.') 35 | 36 | if not os.path.exists(os.path.join(os.environ['OV_DATA_BASE'], 'clara_models')): 37 | 38 | print('Downloading pretrained models (4080 chunks)...') 39 | 40 | # url = "https://sandbox.zenodo.org/record/1071186/files/clara_models.zip?download=1" 41 | url = "https://sandbox.zenodo.org/record/33549/files/clara_models.zip?download=1" 42 | download_and_install(url) 43 | print('Done!') 44 | 45 | -------------------------------------------------------------------------------- /src/ovseg/utils/torch_morph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def _2d_morph_conv(seg_oh, selem=None): 5 | 6 | assert torch.is_tensor(seg_oh), 'input seg must be a torch tensor' 7 | 8 | assert len(seg_oh.shape) == 4, 'seg must be 4d' 9 | 10 | nz = seg_oh.shape[1] 11 | 12 | if selem is None: 13 | selem = torch.tensor([[0, 1/5, 0], [1/5, 1/5, 1/5], [0, 1/5, 0]]) 14 | else: 15 | if isinstance(selem, np.ndarray): 16 | selem = torch.from_numpy(selem) 17 | 18 | selem = selem / selem.sum() 19 | 20 | if len(selem.shape) == 2: 21 | selem = selem.unsqueeze(0) 22 | 23 | selem = selem.to(seg_oh.device).type(torch.float) 24 | selem = torch.stack(nz * [selem]) 25 | 26 | padding = (selem.shape[2]//2, selem.shape[3]//2 ) 27 | 28 | return torch.nn.functional.conv2d(seg_oh.type(torch.float), 29 | selem, 30 | padding=padding, 31 | groups=nz) 32 | 33 | def dial_2d(seg_oh, selem=None): 34 | 35 | seg_conv = _2d_morph_conv(seg_oh, selem) 36 | 37 | return (seg_conv > 0).type(seg_oh.dtype) 38 | 39 | def eros_2d(seg_oh, selem=None): 40 | 41 | seg_conv = _2d_morph_conv(seg_oh, selem) 42 | 43 | return (seg_conv >= 1-1e-5).type(seg_oh.dtype) 44 | 45 | def opening_2d(seg_oh, selem=None): 46 | 47 | return dial_2d(eros_2d(seg_oh, selem), selem) 48 | 49 | def closing_2d(seg_oh, selem=None): 50 | 51 | return eros_2d(dial_2d(seg_oh, selem), selem) 52 | 53 | def morph_cleaning(seg, selem=None): 54 | 55 | assert torch.is_tensor(seg) 56 | # to one hot encoding (in batch dimension) 57 | n_cl = int(seg.max()) 58 | 59 | if n_cl == 0: 60 | return seg 61 | 62 | seg_oh = torch.stack([seg == cl for cl in range(1, n_cl+1)]).type(torch.float) 63 | 64 | # cleaning with opening and closing 65 | seg_oh_clean = opening_2d(closing_2d(seg_oh, selem), selem) 66 | 67 | # now back to labels 68 | seg_clean = torch.zeros_like(seg) 69 | for cl in range(n_cl): 70 | seg_clean += (cl + 1)*seg_oh_clean[cl] 71 | 72 | return seg_clean -------------------------------------------------------------------------------- /src/ovseg/data/SegmentationData.py: -------------------------------------------------------------------------------- 1 | from ovseg.data.DataBase import DataBase 2 | from ovseg.data.SegmentationDataloader import SegmentationDataloader 3 | from ovseg.data.SegmentationDoubleBiasDataloader import SegmentationDoubleBiasDataloader 4 | from os import listdir 5 | from os.path import join 6 | 7 | 8 | class SegmentationData(DataBase): 9 | 10 | def __init__(self, augmentation=None, use_double_bias=False, *args, **kwargs): 11 | self.augmentation = augmentation 12 | self.use_double_bias = use_double_bias 13 | super().__init__(*args, **kwargs) 14 | 15 | def initialise_dataloader(self, is_train): 16 | if is_train: 17 | print('Initialise training dataloader') 18 | 19 | if self.use_double_bias: 20 | 21 | self.trn_dl = SegmentationDoubleBiasDataloader(self.trn_ds, 22 | augmentation=self.augmentation, 23 | **self.trn_dl_params) 24 | else: 25 | self.trn_dl = SegmentationDataloader(self.trn_ds, 26 | augmentation=self.augmentation, 27 | **self.trn_dl_params) 28 | else: 29 | print('Initialise validation dataloader') 30 | try: 31 | if self.use_double_bias: 32 | self.val_dl = SegmentationDoubleBiasDataloader(self.val_ds, 33 | augmentation=self.augmentation, 34 | **self.val_dl_params) 35 | else: 36 | self.val_dl = SegmentationDataloader(self.val_ds, 37 | augmentation=self.augmentation, 38 | **self.val_dl_params) 39 | 40 | except (AttributeError, TypeError): 41 | print('No validatation dataloader initialised') 42 | self.val_dl = None 43 | 44 | def clean(self): 45 | self.trn_dl.dataset._maybe_clean_stored_data() 46 | if self.val_dl is not None: 47 | self.val_dl.dataset._maybe_clean_stored_data() 48 | -------------------------------------------------------------------------------- /src/ovseg/augmentation/ConcatenatedAugmentation.py: -------------------------------------------------------------------------------- 1 | from ovseg.augmentation.myRandAugment import torch_myRandAugment 2 | from ovseg.augmentation.MaskAugmentation import MaskAugmentation 3 | from ovseg.augmentation.GridAugmentation import torch_inplane_grid_augmentations 4 | from ovseg.augmentation.GrayValueAugmentation import torch_gray_value_augmentation 5 | import torch.nn as nn 6 | 7 | 8 | # %% 9 | class torch_concatenated_augmentation(nn.Module): 10 | 11 | def __init__(self, torch_params={}): 12 | 13 | super().__init__() 14 | 15 | for key in torch_params: 16 | assert key in ['grid_inplane', 'grayvalue', 'myRandAugment'], \ 17 | 'got unrecognised augmentation ' + key 18 | 19 | self.aug_list = [] 20 | if 'grid_inplane' in torch_params: 21 | self.aug_list.append(torch_inplane_grid_augmentations(**torch_params['grid_inplane'])) 22 | 23 | if 'grayvalue' in torch_params: 24 | self.aug_list.append(torch_gray_value_augmentation(**torch_params['grayvalue'])) 25 | 26 | if 'myRandAugment' in torch_params: 27 | self.aug_list.append(torch_myRandAugment(**torch_params['myRandAugment'])) 28 | 29 | if len(self.aug_list) > 0: 30 | self.module = nn.Sequential(*self.aug_list) 31 | else: 32 | self.module = nn.Identity 33 | 34 | def forward(self, xb): 35 | return self.module(xb) 36 | 37 | def update_prg_trn(self, param_dict, h, indx=None): 38 | 39 | for aug in self.aug_list: 40 | aug.update_prg_trn(param_dict, h, indx) 41 | 42 | 43 | # %% 44 | class np_concatenated_augmentation(): 45 | 46 | def __init__(self, np_params={}): 47 | 48 | if 'grayvalue' in np_params.keys(): 49 | raise NotImplementedError('gray value augmentations not implemented for np yet...') 50 | 51 | for key in np_params: 52 | assert key in ['mask'], 'got unrecognised augmentation ' + key 53 | 54 | self.ops_list = [] 55 | if 'mask' in np_params: 56 | self.ops_list.append(MaskAugmentation(**np_params['mask'])) 57 | 58 | def __call__(self, xb): 59 | for op in self.ops_list: 60 | xb = op(xb) 61 | return xb 62 | 63 | def update_prg_trn(self, param_dict, h, indx=None): 64 | 65 | for aug in self.ops_list: 66 | aug.update_prg_trn(param_dict, h, indx) 67 | 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /src/ovseg/augmentation/AffineAugmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | # %% 8 | class torch_grid_augmentation(nn.Module): 9 | 10 | def __init__(self, 11 | p_rot=0.2, 12 | p_zoom=0.2, 13 | p_transl=0, 14 | p_shear=0, 15 | mm_zoom=[0.7, 1.4], 16 | mm_rot=[-15, 15], 17 | mm_transl=[-0.25, 0.25], 18 | mm_shear=[-0.1, 0.1], 19 | apply_flipping=True, 20 | threeD_affine=False, 21 | mode='only_one_axis', 22 | out_shape=None 23 | ): 24 | super().__init__() 25 | self.p_rot = p_rot 26 | self.p_zoom = p_zoom 27 | self.p_transl = p_transl 28 | self.p_shear = p_shear 29 | self.mm_zoom = mm_zoom 30 | self.mm_rot = mm_rot 31 | self.mm_transl = mm_transl 32 | self.mm_shear = mm_shear 33 | self.apply_flipping = apply_flipping 34 | self.threeD_affine = threeD_affine 35 | self.mode = mode 36 | self.out_shape = out_shape 37 | 38 | if self.threeD_affine: 39 | raise NotImplementedError('Only standard 2d and 2.5 affine transformations ' 40 | 'implemented at the moment.') 41 | 42 | def _rot_z(self, theta, angle): 43 | 44 | rot_m = torch.zeros_like(theta[:, :-1]) 45 | cos, sin = torch.cos(angle), torch.sin(angle) 46 | rot_m[0, 0] = cos 47 | rot_m[0, 1] = sin 48 | rot_m[1, 0] = -1 * sin 49 | rot_m[1, 1] = cos 50 | 51 | return torch.mm(rot_m, theta) 52 | 53 | def _zoom(self, theta, fac): 54 | return theta * fac 55 | 56 | def _get_theta(self, xb): 57 | 58 | bs = xb.shape[0] 59 | dims = len(xb.shape) - 2 60 | theta = torch.zeros((bs, dims, dims+1), device=xb.device, dtype=xb.dtype) 61 | do_aug = False 62 | for i in range(bs): 63 | if np.random.rand() < self.p_rot: 64 | angle = np.random.uniform(*self.mm_rot) 65 | theta[i] = self._rot_z(theta[i], angle) 66 | do_aug = True 67 | if np.random.rand() < self.p_zoom: 68 | fac = np.random.uniform(*self.mm_zoom) 69 | theta[i] = self._zoom(theta, fac) 70 | do_aug = True 71 | 72 | if do_aug: 73 | grid = F.affine_grid(theta, imt.size()).cuda() 74 | im_trsf = F.grid_sample(imt, grid).cpu().numpy() 75 | xb_aug = -------------------------------------------------------------------------------- /example_scripts/example_kits21_kidneys_low.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationModelV2 import SegmentationModelV2 2 | from ovseg.preprocessing.SegmentationPreprocessingV2 import SegmentationPreprocessingV2 3 | from ovseg.model.SegmentationEnsembleV2 import SegmentationEnsembleV2 4 | from ovseg.model.model_parameters_segmentation import get_model_params_3d_UNet 5 | from ovseg.utils.io import load_pkl 6 | from time import sleep 7 | import os 8 | from ovseg import OV_PREPROCESSED, OV_DATA_BASE 9 | 10 | ''' 11 | Script for preprocessing and training lowres kidney models 12 | ''' 13 | 14 | # name of your raw dataset 15 | data_name = 'kits21' 16 | # name of preprocessed data 17 | preprocessed_name = 'kidneys_lowres' 18 | 19 | # give each model a unique name. This way the code will be able to identify them 20 | # both models (lowres and fullres) will have the same name and be differentiated 21 | # by the name of preprocessed data 22 | 23 | model_name = 'U-Net32' 24 | val_fold = 0 25 | # %% preprocess lowres data if it hasn't been done yet 26 | if not os.path.exists(os.path.join(OV_PREPROCESSED, data_name, preprocessed_name)): 27 | 28 | # ADD SOME PREPROCESSING PARAMETERS HERE 29 | prep = SegmentationPreprocessingV2(apply_resizing=True, 30 | apply_pooling=False, 31 | apply_windowing=True, 32 | target_spacing=[4,4,4], # downsample to 4mm^3 33 | reduce_lb_to_single_class=True) # in this first stage segment kidneys plus masses 34 | prep.plan_preprocessing_raw_data(data_name) 35 | 36 | prep.preprocess_raw_data(raw_data=data_name, 37 | preprocessed_name=preprocessed_name) 38 | 39 | 40 | # %% now get hyper-parameters for low resolution and train 41 | patch_size = [64, 64, 64] 42 | n_2d_convs = 0 43 | use_prg_trn = False # on low resolution prg trn can harm the performance 44 | n_fg_classes = 1 45 | use_fp32 = False 46 | model_params = get_model_params_3d_UNet(patch_size=patch_size, 47 | n_2d_convs=n_2d_convs, 48 | use_prg_trn=use_prg_trn, 49 | n_fg_classes=n_fg_classes, 50 | fp32=use_fp32) 51 | 52 | # CHANGE YOUR HYPER-PARAMETERS FOR LOWRES STAGE HERE! 53 | 54 | model = SegmentationModelV2(val_fold=val_fold, 55 | data_name=data_name, 56 | model_name=model_name, 57 | preprocessed_name=preprocessed_name, 58 | model_parameters=model_params) 59 | 60 | # execute the trainig, simple as that! 61 | # It will check for previous checkpoints and load them 62 | model.training.train() 63 | 64 | if val_fold < model_params['data']['n_folds']: 65 | model.eval_validation_set() -------------------------------------------------------------------------------- /test_global_metrics.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationModelV2 import SegmentationModelV2 2 | from ovseg.preprocessing.SegmentationPreprocessingV2 import SegmentationPreprocessingV2 3 | from ovseg.model.SegmentationEnsembleV2 import SegmentationEnsembleV2 4 | from ovseg.model.model_parameters_segmentation import get_model_params_3d_UNet 5 | from ovseg.utils.io import load_pkl 6 | from time import sleep 7 | import os 8 | from ovseg import OV_PREPROCESSED, OV_DATA_BASE 9 | 10 | ''' 11 | Script for preprocessing and training lowres kidney models 12 | ''' 13 | 14 | # name of your raw dataset 15 | data_name = 'kits21' 16 | # name of preprocessed data 17 | preprocessed_name = 'kidneys_lowres' 18 | 19 | # give each model a unique name. This way the code will be able to identify them 20 | # both models (lowres and fullres) will have the same name and be differentiated 21 | # by the name of preprocessed data 22 | 23 | model_name = 'U-Net32' 24 | val_fold = 0 25 | # %% preprocess lowres data if it hasn't been done yet 26 | if not os.path.exists(os.path.join(OV_PREPROCESSED, data_name, preprocessed_name)): 27 | 28 | # ADD SOME PREPROCESSING PARAMETERS HERE 29 | prep = SegmentationPreprocessingV2(apply_resizing=True, 30 | apply_pooling=False, 31 | apply_windowing=True, 32 | target_spacing=[4,4,4], # downsample to 4mm^3 33 | reduce_lb_to_single_class=True) # in this first stage segment kidneys plus masses 34 | prep.plan_preprocessing_raw_data(data_name) 35 | 36 | prep.preprocess_raw_data(raw_data=data_name, 37 | preprocessed_name=preprocessed_name) 38 | 39 | 40 | # %% now get hyper-parameters for low resolution and train 41 | patch_size = [64, 64, 64] 42 | n_2d_convs = 0 43 | use_prg_trn = False # on low resolution prg trn can harm the performance 44 | n_fg_classes = 1 45 | use_fp32 = False 46 | model_params = get_model_params_3d_UNet(patch_size=patch_size, 47 | n_2d_convs=n_2d_convs, 48 | use_prg_trn=use_prg_trn, 49 | n_fg_classes=n_fg_classes, 50 | fp32=use_fp32) 51 | 52 | # CHANGE YOUR HYPER-PARAMETERS FOR LOWRES STAGE HERE! 53 | model_params['training']['num_epochs'] = 100 54 | model_params['network']['filters'] = 8 55 | model_params['data']['n_folds'] = 2 56 | 57 | 58 | for val_fold in [0,1]: 59 | model = SegmentationModelV2(val_fold=val_fold, 60 | data_name=data_name, 61 | model_name=model_name, 62 | preprocessed_name=preprocessed_name, 63 | model_parameters=model_params) 64 | 65 | # execute the trainig, simple as that! 66 | # It will check for previous checkpoints and load them 67 | model.training.train() 68 | 69 | if val_fold < model_params['data']['n_folds']: 70 | model.eval_validation_set() -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationModel import SegmentationModel 2 | from ovseg.model.SegmentationEnsemble import SegmentationEnsemble 3 | from ovseg.model.model_parameters_segmentation import get_model_params_3d_res_encoder_U_Net 4 | import argparse 5 | import numpy as np 6 | 7 | parser = argparse.ArgumentParser() 8 | # vf should be 5,6,7. VF stands for validation folds. In this case 9 | # the training is done on 100% of the data using 3 random seeds 10 | parser.add_argument("vf", type=int) 11 | # this should be either of ['pod_om', 'abdominal_lesions','lymph_nodes'] 12 | # to indicate which of the three segmentation models is trained 13 | parser.add_argument("--model", default='pod_om') 14 | # add all the names of the labled training data sets as trn_data 15 | parser.add_argument("--trn_data", default=['OV04', 'BARTS', 'ApolloTCGA'], nargs='+') 16 | args = parser.parse_args() 17 | 18 | vf = args.vf 19 | trn_data = args.trn_data 20 | data_name = '_'.join(sorted(trn_data)) 21 | preprocessed_name = args.model 22 | 23 | assert preprocessed_name in ['pod_om', 'abdominal_lesions','lymph_nodes'], 'Unkown model' 24 | 25 | # hyper-paramteres for the model 26 | wd = 1e-4 27 | use_prg_trn = True 28 | 29 | if preprocessed_name == 'pod_om': 30 | patch_size = [32, 216, 216] 31 | out_shape = [[20, 128, 128], 32 | [22, 152, 152], 33 | [30, 192, 192], 34 | [32, 216, 216]] 35 | z_to_xy_ratio = 5.0/0.8 36 | larger_res_encoder = False 37 | n_fg_classes = 2 38 | elif preprocessed_name == 'abdominal_lesions': 39 | patch_size = [32, 216, 216] 40 | out_shape = [[20, 128, 128], 41 | [22, 152, 152], 42 | [30, 192, 192], 43 | [32, 216, 216]] 44 | z_to_xy_ratio = 5.0/0.8 45 | larger_res_encoder = False 46 | n_fg_classes = 6 47 | elif preprocessed_name == 'lymph_nodes': 48 | patch_size = [32, 256, 256] 49 | use_prg_trn = True 50 | out_shape = [[20, 160, 160], 51 | [24, 192, 192], 52 | [28, 224, 224], 53 | [32, 256, 256]] 54 | z_to_xy_ratio = 5.0/0.67 55 | larger_res_encoder = True 56 | n_fg_classes = 4 57 | 58 | model_params = get_model_params_3d_res_encoder_U_Net(patch_size, 59 | z_to_xy_ratio=z_to_xy_ratio, 60 | out_shape=out_shape, 61 | n_fg_classes=n_fg_classes, 62 | use_prg_trn=use_prg_trn) 63 | model_params['data']['trn_dl_params']['batch_size'] = 4 64 | model_params['data']['val_dl_params']['batch_size'] = 4 65 | model_params['training']['opt_params']['momentum'] = 0.98 66 | model_params['training']['opt_params']['weight_decay'] = wd 67 | 68 | # change the model name when using other hyper-paramters 69 | model_name = 'clara_model' 70 | 71 | model = SegmentationModel(val_fold=vf, 72 | data_name=data_name, 73 | model_name=model_name, 74 | preprocessed_name=preprocessed_name, 75 | model_parameters=model_params) 76 | model.training.train() 77 | if vf < model_params['data']['n_folds']: 78 | model.eval_validation_set() -------------------------------------------------------------------------------- /example_scripts/my_preprocessing.py: -------------------------------------------------------------------------------- 1 | from ovseg.preprocessing.SegmentationPreprocessingV2 import SegmentationPreprocessingV2 2 | from ovseg import OV_PREPROCESSED 3 | import os 4 | 5 | # name of your raw dataset 6 | data_name = 'MY_DATASET_NAME' 7 | # name after preprocessing 8 | preprocessed_name = 'MY_PREPROCESSED_NAME' 9 | 10 | # whether to apply resizing, pooling and windowing during preprocessing 11 | apply_resizing = True 12 | apply_pooing = False 13 | apply_windowing = True 14 | 15 | # if apply_resizing all scans are resized to this spacing (before potential pooling) 16 | # default: inferre median voxel spacing from dataset and use this 17 | target_spacing = None 18 | 19 | # stride used for mean pooling of images/max pooling of labels e.g. (1,2,2) 20 | # will only complain about pooling_stride = None if apply_pooling 21 | pooling_stride = None 22 | 23 | # clipping of gray values e.g. (-150, 250) for abdominal CT 24 | # default: inferre 0.5 and 99.5 gray value of foreground voxel 25 | window = None 26 | 27 | # gray value scaling applied after pooling e.g. 28 | # default: inferre Z normalization from data (recommended) 29 | scaling = None 30 | 31 | # if you have many classes in your segmentation problem and you only want to 32 | # segment some of them at a time, use e.g. lb_classes = [1, 3, 4] 33 | # default: use all classes 34 | lb_classes = None 35 | 36 | # removed class information and reduces segmentaiton problem to be binary 37 | reduce_lb_to_single_class = False 38 | 39 | # number of image channels, only tested for n_im_channels=1!! 40 | n_im_channels = 1 41 | 42 | # if true saves only scans that contain at least one foreground voxel 43 | save_only_fg_scans = False 44 | 45 | # set this variable if you want to input the segmentaiton masks from a previous 46 | # model, e.g. in a cascade 47 | # prev_stage_for_input = {'data_name': MY_DATA, 'preprocessed_name': MY_PREPROCESSED_NAME, 'model_name': MY_MODEL_NAME} 48 | prev_stage_for_input = {} 49 | 50 | # similarly to mask the segmentation, e.g. in deep supervision when segmenting first and organ and then lesions 51 | prev_stage_for_mask = {} 52 | 53 | # if you want to increase the size of the mask you can dialate it as a preprocessing setp 54 | r_dial_mask = 0 55 | 56 | 57 | # creat preprocessing object 58 | prep = SegmentationPreprocessingV2(apply_resizing=apply_resizing, 59 | apply_pooling=apply_pooing, 60 | apply_windowing=apply_windowing, 61 | target_spacing=target_spacing, 62 | pooling_stride=pooling_stride, 63 | window=window, 64 | scaling=scaling, 65 | lb_classes=lb_classes, 66 | reduce_lb_to_single_class=reduce_lb_to_single_class, 67 | n_im_channels=n_im_channels, 68 | save_only_fg_scans=save_only_fg_scans, 69 | prev_stage_for_input=prev_stage_for_input, 70 | prev_stage_for_mask=prev_stage_for_mask, 71 | r_dial_mask=r_dial_mask) 72 | 73 | # inferre preprocessing parameters (window, target_spacing, scaling) 74 | prep.plan_preprocessing_raw_data(data_name) 75 | 76 | # execute preprocessing 77 | prep.preprocess_raw_data(data_name, preprocessed_name) 78 | 79 | print('Preprocessing done!') -------------------------------------------------------------------------------- /example_scripts/example_kits21_kidneys_cascade_full.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationModelV2 import SegmentationModelV2 2 | from ovseg.preprocessing.SegmentationPreprocessingV2 import SegmentationPreprocessingV2 3 | from ovseg.model.SegmentationEnsembleV2 import SegmentationEnsembleV2 4 | from ovseg.model.model_parameters_segmentation import get_model_params_3d_UNet 5 | from ovseg.utils.io import load_pkl 6 | from time import sleep 7 | import os 8 | from ovseg import OV_PREPROCESSED, OV_DATA_BASE 9 | 10 | ''' 11 | Script for preprocessing and training fullres kidney models 12 | ''' 13 | 14 | # name of your raw dataset 15 | data_name = 'kits21' 16 | # name of preprocessed data 17 | preprocessed_name = 'kidneys_fullres' 18 | 19 | # give each model a unique name. This way the code will be able to identify them 20 | # both models (lowres and fullres) will have the same name and be differentiated 21 | # by the name of preprocessed data 22 | 23 | model_name = 'U-Net32' 24 | val_fold = 0 25 | # %% preprocess lowres data if it hasn't been done yet 26 | if not os.path.exists(os.path.join(OV_PREPROCESSED, data_name, preprocessed_name)): 27 | 28 | # doing a cascade inputting the previous stage 29 | prev_stage = {'data_name': data_name, 30 | 'preprocessed_name': 'kidneys_lowres', 31 | 'model_name': 'U-Net32'} 32 | 33 | # ADD SOME PREPROCESSING PARAMETERS HERE 34 | prep = SegmentationPreprocessingV2(apply_resizing=True, 35 | apply_pooling=False, 36 | apply_windowing=True, 37 | target_spacing=[2,1,1], # 2mm in z direction 1 in xy 38 | reduce_lb_to_single_class=True, # in this first stage segment kidneys plus masses 39 | prev_stage_for_input=prev_stage) 40 | prep.plan_preprocessing_raw_data(data_name) 41 | 42 | prep.preprocess_raw_data(raw_data=data_name, 43 | preprocessed_name=preprocessed_name) 44 | 45 | 46 | # %% now get hyper-parameters for low resolution and train 47 | patch_size = [48, 96, 96] 48 | n_2d_convs = 1 49 | use_prg_trn = True # on low resolution prg trn can harm the performance 50 | n_fg_classes = 1 51 | use_fp32 = False 52 | out_shape = [[24, 64, 64], [32, 64, 64], [32, 80, 80], [48, 96, 96]] 53 | model_params = get_model_params_3d_UNet(patch_size=patch_size, 54 | n_2d_convs=n_2d_convs, 55 | use_prg_trn=use_prg_trn, 56 | n_fg_classes=n_fg_classes, 57 | fp32=use_fp32, 58 | out_shape=out_shape) 59 | 60 | # for the cascade we input the masks of the previous stage 61 | model_params['network']['in_channels'] = 2 62 | model_params['data']['folders'] = ['images', 'labels', 'prev_preds'] 63 | model_params['data']['keys'] = ['image', 'label', 'prev_pred'] 64 | 65 | model = SegmentationModelV2(val_fold=val_fold, 66 | data_name=data_name, 67 | model_name=model_name, 68 | preprocessed_name=preprocessed_name, 69 | model_parameters=model_params) 70 | 71 | # execute the trainig, simple as that! 72 | # It will check for previous checkpoints and load them 73 | model.training.train() 74 | 75 | if val_fold < model_params['data']['n_folds']: 76 | model.eval_validation_set() -------------------------------------------------------------------------------- /src/ovseg/utils/dict_equal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def dict_equal(dict1, dict2): 6 | 7 | if not dict1.keys() == dict2.keys(): 8 | return False 9 | 10 | keys = dict1.keys() 11 | 12 | for key in keys: 13 | item1, item2 = dict1[key], dict2[key] 14 | if not isinstance(item1, type(item2)): 15 | return False 16 | 17 | try: 18 | if isinstance(item1, np.ndarray): 19 | if not np.all(item1 == item2): 20 | return False 21 | elif torch.is_tensor(item1): 22 | if not np.all(item1.detach().cpu().numpy() == item2.detach().cpu().numpy()): 23 | return False 24 | elif isinstance(item1, (list, tuple)): 25 | if not np.all(np.array(item1) == np.array(item2)): 26 | return False 27 | 28 | elif isinstance(item1, dict): 29 | if not dict_equal(item1, item2): 30 | return False 31 | else: 32 | if not item1 == item2: 33 | return False 34 | except ValueError: 35 | print('Value Error when compating {} and {}.'.format(item1, item2)) 36 | return False 37 | 38 | return True 39 | 40 | 41 | def print_dict_diff(dict1, dict2, dict1_name='input dict', dict2_name='loaded dict', pref=''): 42 | 43 | # check if keys are missing 44 | common_keys = [] 45 | for key in dict1: 46 | if key not in dict2: 47 | print(pref+ key+' missing in '+dict2_name) 48 | else: 49 | common_keys.append(key) 50 | 51 | for key in dict2: 52 | if key not in dict1: 53 | print(pref+key+' missing in '+dict1_name) 54 | 55 | # type checking 56 | remaining_keys = [] 57 | for key in common_keys: 58 | if type(dict1[key]) != type(dict2[key]): 59 | print(pref+key+' type missmatch, got {} for {} and {} for {}'.format(type(dict1[key]), 60 | dict1_name, 61 | type(dict2[key]), 62 | dict2_name)) 63 | print('Values: {}, {}'.format(dict1[key], dict2[key])) 64 | else: 65 | remaining_keys.append(key) 66 | 67 | # now we can check the content 68 | for key in remaining_keys: 69 | item1, item2 = dict1[key], dict2[key] 70 | try: 71 | if isinstance(item1, np.ndarray): 72 | if not np.all(item1 == item2): 73 | print(pref+key+' missmatch: {}, {}'.format(item1, item2)) 74 | elif torch.is_tensor(item1): 75 | if not np.all(item1.detach().cpu().numpy() == item2.detach().cpu().numpy()): 76 | print(pref+key+' missmatch: {}, {}'.format(item1, item2)) 77 | elif isinstance(item1, (list, tuple)): 78 | if not np.all(np.array(item1) == np.array(item2)): 79 | print(pref+key+' missmatch: {}, {}'.format(item1, item2)) 80 | 81 | elif isinstance(item1, dict): 82 | print_dict_diff(item1, item2, dict1_name, dict2_name, pref+key+' -> ') 83 | else: 84 | if not item1 == item2: 85 | print(pref+key+' missmatch: {}, {}'.format(item1, item2)) 86 | except ValueError: 87 | print('Value Error when compating {} and {}.'.format(item1, item2)) 88 | 89 | -------------------------------------------------------------------------------- /src/ovseg/run/run_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ovseg.utils.io import read_nii, save_nii 3 | from ovseg.utils.download_pretrained_utils import maybe_download_clara_models 4 | from ovseg.model.InferenceWrapper import InferenceWrapper 5 | import argparse 6 | 7 | 8 | def is_nii_file(path_to_file): 9 | return path_to_file.endswith('.nii') or path_to_file.endswith('.nii.gz') 10 | 11 | 12 | def run_inference(path_to_data, 13 | models=['pod_om'], 14 | fast=False): 15 | 16 | if is_nii_file(path_to_data): 17 | path_to_data, nii_file = os.path.split(path_to_data) 18 | nii_files = [nii_file] 19 | else: 20 | nii_files = [f for f in os.listdir(path_to_data) if is_nii_file(f)] 21 | if len(nii_files) == 0: 22 | raise FileNotFoundError(f"No nifti images were found at {path_to_data}") 23 | 24 | # if the pretrained models were not downloaded yet, we're doing it now here 25 | maybe_download_clara_models() 26 | 27 | # some variables we need for saving the predictions 28 | pred_folder_name = "ovseg_predictions" 29 | if 'pod_om' in models: 30 | pred_folder_name += "_pod_om" 31 | if 'abdominal_lesions' in models: 32 | pred_folder_name += "_abdominal_lesions" 33 | if 'lymph_nodes' in models: 34 | pred_folder_name += "_lymph_nodes" 35 | 36 | # iterate over the dataset and save predictions 37 | out_folder = os.path.join(path_to_data, pred_folder_name) 38 | os.makedirs(out_folder, exist_ok=True) 39 | 40 | for i, nii_file in enumerate(nii_files): 41 | 42 | print(f"Evaluate image {i} out of {len(nii_files)}") 43 | 44 | im, sp = read_nii(os.path.join(path_to_data, nii_file)) 45 | 46 | pred = InferenceWrapper(im, sp, models, fast=fast) 47 | 48 | save_nii(pred, os.path.join(out_folder, nii_file), os.path.join(path_to_data, nii_file)) 49 | 50 | 51 | def main(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("path_to_data", 54 | help='Either (i) path to a single nifti file like PATH/TO/IMAGE.nii(.gz),\n ' 55 | '(ii) path to a folder containing mulitple nifti files.') 56 | # add all the names of the labled training data sets as trn_data 57 | parser.add_argument("--models", 58 | default=['pod_om'], nargs='+', 59 | help='Name(s) of models used during inference. Options are ' 60 | 'the following.\n' 61 | '(i) pod_om: model for main disease sites in the pelvis/ovaries' 62 | ' and the omentum. The two sites are encoded as 9 and 1.\n' 63 | '(ii) abdominal_lesions: model for various lesions between ' 64 | 'the pelvis and diaphram. The model considers lesions in the ' 65 | 'omentum (1), right upper quadrant (2), left upper quadrant (3), ' 66 | 'mesenterium (5), left paracolic gutter (6) and right ' 67 | 'paracolic gutter (7).\n' 68 | '(iii) lymph_nodes: segments disease in the lymph nodes ' 69 | 'namely infrarenal lymph nodes (13), suprarenal lymph nodes ' 70 | '(14), supradiaphragmatic lymph nodes (15) and inguinal ' 71 | 'lymph nodes (17).\n' 72 | 'Any combination of the three are viable options.') 73 | parser.add_argument("--fast", action='store_true', 74 | default=False, 75 | help='Increases inference speed by disabling dynamic z spacing, ' 76 | 'model ensembling and test-time augmentations.') 77 | 78 | args = parser.parse_args() 79 | 80 | path_to_data = args.path_to_data 81 | models = args.models 82 | fast = args.fast 83 | 84 | run_inference(path_to_data, models, fast) -------------------------------------------------------------------------------- /example_scripts/example_kits21_masses_deep_supervision.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationModelV2 import SegmentationModelV2 2 | from ovseg.preprocessing.SegmentationPreprocessingV2 import SegmentationPreprocessingV2 3 | from ovseg.model.SegmentationEnsembleV2 import SegmentationEnsembleV2 4 | from ovseg.model.model_parameters_segmentation import get_model_params_3d_UNet 5 | from ovseg.utils.io import load_pkl 6 | from time import sleep 7 | import os 8 | from ovseg import OV_PREPROCESSED, OV_DATA_BASE 9 | 10 | ''' 11 | Script for preprocessing and training of kidney masses using deep supervision 12 | e.g. only training and validating where the previous stage (kidney model) found foreground 13 | ''' 14 | 15 | # name of your raw dataset 16 | data_name = 'kits21' 17 | # name of preprocessed data 18 | preprocessed_name = 'kidneys_masses' 19 | 20 | # give each model a unique name. This way the code will be able to identify them 21 | # both models (lowres and fullres) will have the same name and be differentiated 22 | # by the name of preprocessed data 23 | 24 | model_name = 'U-Net32' 25 | val_fold = 0 26 | # %% preprocess lowres data if it hasn't been done yet 27 | if not os.path.exists(os.path.join(OV_PREPROCESSED, data_name, preprocessed_name)): 28 | 29 | # doing a cascade inputting the previous stage 30 | prev_stage = {'data_name': data_name, 31 | 'preprocessed_name': 'kidneys_fullres', 32 | 'model_name': 'U-Net32'} 33 | 34 | # ADD SOME PREPROCESSING PARAMETERS HERE 35 | prep = SegmentationPreprocessingV2(apply_resizing=True, 36 | apply_pooling=False, 37 | apply_windowing=True, 38 | target_spacing=[2,1,1], # downsample to 4mm^3 39 | lb_classes=[2,3], # exclude class 1 (kidneys) and only segment 2 (tumors) and 3 (cysts) 40 | prev_stage_for_mask=prev_stage) 41 | prep.plan_preprocessing_raw_data(data_name) 42 | 43 | prep.preprocess_raw_data(raw_data=data_name, 44 | preprocessed_name=preprocessed_name) 45 | 46 | 47 | # %% now get hyper-parameters for low resolution and train 48 | patch_size = [48, 96, 96] 49 | n_2d_convs = 1 50 | use_prg_trn = True # on low resolution prg trn can harm the performance 51 | n_fg_classes = 2 52 | use_fp32 = False 53 | out_shapes = [[24, 64, 64], [32, 64, 64], [32, 80, 80], [48, 96, 96]] 54 | model_params = get_model_params_3d_UNet(patch_size=patch_size, 55 | n_2d_convs=n_2d_convs, 56 | use_prg_trn=use_prg_trn, 57 | n_fg_classes=n_fg_classes, 58 | fp32=use_fp32, 59 | out_shape=out_shapes) 60 | 61 | # tell data object it should also load the masks 62 | model_params['data']['folders'] = ['images', 'labels', 'masks'] 63 | model_params['data']['keys'] = ['image', 'label', 'mask'] 64 | 65 | for s in ['trn_dl_params', 'val_dl_params']: 66 | # tell dataloaders to use the bias where first a class and then a 67 | # foreground voxel is chosen 68 | model_params['data'][s]['bias'] = 'cl_fg' 69 | # number of foreground classes 70 | model_params['data'][s]['n_fg_classes'] = 2 71 | 72 | # tell the training object to use the masks for the loss function 73 | model_params['training']['batches_have_masks'] = True 74 | # apply the mask during post-processing 75 | model_params['postprocessing'] = {'mask_with_reg': True} 76 | 77 | model = SegmentationModelV2(val_fold=val_fold, 78 | data_name=data_name, 79 | model_name=model_name, 80 | preprocessed_name=preprocessed_name, 81 | model_parameters=model_params) 82 | 83 | # execute the trainig, simple as that! 84 | # It will check for previous checkpoints and load them 85 | model.training.train() 86 | 87 | if val_fold < model_params['data']['n_folds']: 88 | model.eval_validation_set() -------------------------------------------------------------------------------- /example_scripts/my_training.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationModelV2 import SegmentationModelV2 2 | from ovseg.model.SegmentationEnsembleV2 import SegmentationEnsembleV2 3 | from ovseg.model.model_parameters_segmentation import get_model_params_3d_UNet 4 | 5 | # name of your raw dataset 6 | data_name = 'MY_DATASET_NAME' 7 | # same name as in the preprocessing script 8 | preprocessed_name = 'MY_PREPROCESSED_NAME' 9 | # give each model a unique name. This way the code will be able to identify them 10 | model_name = 'MY_MODEL_NAME' 11 | # which fold of the training is performed? 12 | # Example 5-fold cross-vadliation: CV folds are 0,1,...,4. 13 | # For each val_fold > 4 no CV is applied and 14 | # 100% of the training data is used 15 | val_fold = 0 16 | 17 | # now get hyper-parameters 18 | # patch size used during (last stage of) training and inference 19 | # z axis first, then xy 20 | patch_size = [32, 256, 256] 21 | # for standard UNet the number of inplane convolutions 22 | n_2d_convs = 3 23 | # wheter to use progressive learning or not. I often found it to have no 24 | # effect on the performance, but reduces training time by up to 40% 25 | use_prg_trn = True 26 | # number of different foreground classes you want to segment 27 | n_fg_classes = 1 28 | # it is recommended to perform the training with mixed precision (fp16) 29 | # instead of full precision (fp32) 30 | use_fp32 = False 31 | # shapes introduced to the network during progressive learning 32 | # rule of thumb: reduce total number of voxels by a factor of 4,3,2 in the 33 | # first three stages and train last stage as usual 34 | # be careful that the patch size is still executable for your U-Net 35 | # e.g. a U-Net that downsamples 4 times inplane should have a patch size 36 | # where the inplane size is divisible by 2**4 37 | out_shape = [[24, 128, 128], [24, 192, 192], [24, 256, 256], [32, 256, 256]] 38 | 39 | 40 | model_params = get_model_params_3d_UNet(patch_size=patch_size, 41 | n_2d_convs=n_2d_convs, 42 | use_prg_trn=use_prg_trn, 43 | n_fg_classes=n_fg_classes, 44 | fp32=use_fp32, 45 | out_shape=out_shape) 46 | 47 | # CHANGE YOUR HYPER-PARAMETERS HERE! For example 48 | # change batch size to 4 49 | #model_params['data']['trn_dl_params']['batch_size'] = 4 50 | #model_params['data']['val_dl_params']['batch_size'] = 4 51 | # change momentum 52 | #model_params['training']['opt_params']['momentum'] = 0.98 53 | # change weight decay 54 | #model_params['training']['opt_params']['weight_decay'] = wd 55 | 56 | # creat model object. 57 | # this object holds all objects that define a deep neural network model 58 | # - preprocessing 59 | # - augmentation 60 | # - training 61 | # - slinding window evaluation 62 | # - postprocessing 63 | # - data and data sampling 64 | # - functions to iterate over datasets 65 | # - I'm sure I forgot something 66 | 67 | model = SegmentationModelV2(val_fold=val_fold, 68 | data_name=data_name, 69 | model_name=model_name, 70 | preprocessed_name=preprocessed_name, 71 | model_parameters=model_params) 72 | # execute the trainig, simple as that! 73 | # It will check for previous checkpoints and load them 74 | model.training.train() 75 | 76 | # if cross-validation is applied you can evaluate the validation scans like this 77 | # as stated above, val_fold > n_folds means using 100% training data e.g. no validation data 78 | if val_fold < model_params['data']['n_folds']: 79 | model.eval_validation_set() 80 | 81 | # uncomment to evaluate raw (test) dataset with the model 82 | # model.eval_raw_dataset('MY_TEST_DATA') 83 | 84 | 85 | # uncomment to evaluate ensemble e.g. of cross-validation models 86 | # ens = SegmentationEnsembleV2(val_fold=list(range(model_params['data']['n_folds'])), 87 | # model_name=model_name, 88 | # data_name=data_name, 89 | # preprocessed_name=preprocessed_name) 90 | # typically I train all folds on different GPUs in parallel, this let's you wait 91 | # until all trainings are done 92 | # ens.wait_until_all_folds_complete() 93 | # evaluate ensemble on test data 94 | # ens.eval_raw_dataset('MY_TEST_DATA') 95 | 96 | 97 | -------------------------------------------------------------------------------- /src/ovseg/data/DataBase.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os.path import exists, join, basename 3 | from ovseg.utils import io 4 | from os import environ, sep, listdir 5 | import pickle 6 | from ovseg.data.utils import split_scans_random_uniform, split_scans_by_patient_id 7 | from ovseg.data.Dataset import Dataset 8 | 9 | 10 | class DataBase(): 11 | 12 | def __init__(self, val_fold, preprocessed_path, keys, folders, n_folds=5, 13 | fixed_shuffle=True, trn_dl_params={}, ds_params={}, 14 | val_dl_params={}): 15 | ''' 16 | DataBase(val_fold, preprocessed_path, n_folds=5, fixed_shuffle=True, 17 | trn_dl_params={}, ds_params={}, val_dl_params={}) 18 | 19 | Basic class that splits data and creates train and validation datasets 20 | and dataloaders. Can be initialised gived fixed folds \'folds\' or 21 | with the scans to all datatuples. In the latter case this class does 22 | the splitting automatically. 23 | ''' 24 | # set number of validation fold 25 | self.val_fold = val_fold 26 | self.preprocessed_path = preprocessed_path 27 | self.keys = keys 28 | self.folders = folders 29 | # other arguments. we don't need them when finding exisiting splits 30 | # but anyways... 31 | self.n_folds = n_folds 32 | self.fixed_shuffle = fixed_shuffle 33 | # now save the additional arguments for creating the datasets 34 | # and dataloaders 35 | self.ds_params = ds_params 36 | self.trn_dl_params = trn_dl_params 37 | self.val_dl_params = val_dl_params 38 | 39 | # let the important bit start: The splitting of the data 40 | # check if there is alreay some split 41 | path_to_splits = join(self.preprocessed_path, 'splits.pkl') 42 | if exists(path_to_splits): 43 | # in this case a split of data is given 44 | print('Found existing data split') 45 | self.splits = io.load_pkl(path_to_splits) 46 | self.n_folds = len(self.splits) 47 | else: 48 | print('No data split found.') 49 | print('Computing new one..') 50 | 51 | self.scans = listdir(join(self.preprocessed_path, self.folders[0])) 52 | patient_ids = {} 53 | for scan in self.scans: 54 | path_to_fingerprint = join(self.preprocessed_path, 'fingerprints', scan) 55 | if exists(path_to_fingerprint): 56 | fngprnt = np.load(path_to_fingerprint, 57 | allow_pickle=True).item() 58 | patient_ids[scan] = fngprnt['dataset'] + '_' + fngprnt['pat_id'] 59 | else: 60 | patient_ids[scan] = scan[:-4] 61 | 62 | self.splits = split_scans_by_patient_id(self.scans, 63 | patient_ids, 64 | self.n_folds, 65 | self.fixed_shuffle) 66 | # we add an additional fold with 100% of the data being used as training data 67 | # this is usefull for hyperparameter tuning where we can train on this 68 | # fold instead of doing a full CV 69 | self.splits.append({'train': self.scans, 'val': []}) 70 | io.save_pkl(self.splits, path_to_splits) 71 | print('New split saved.\n') 72 | 73 | if self.val_fold >= len(self.splits): 74 | print('WARNING! More val_fold > len(splits)! Picking the last fold. Unless you have ' 75 | 'created a custom split this will be the 100% training, no validation data fold.') 76 | self.split = self.splits[-1] 77 | else: 78 | self.split = self.splits[self.val_fold] 79 | self.trn_scans = self.split['train'] 80 | self.val_scans = self.split['val'] 81 | 82 | # now create datasets 83 | self.initialise_dataset(is_train=True) 84 | self.initialise_dataset(is_train=False) 85 | 86 | # and the dataloaders 87 | self.initialise_dataloader(is_train=True) 88 | self.initialise_dataloader(is_train=False) 89 | 90 | def initialise_dataset(self, is_train): 91 | if is_train: 92 | self.trn_ds = Dataset(self.trn_scans, self.preprocessed_path, 93 | self.keys, self.folders, **self.ds_params) 94 | elif len(self.val_scans) > 0: 95 | self.val_ds = Dataset(self.val_scans, self.preprocessed_path, 96 | self.keys, self.folders, **self.ds_params) 97 | 98 | def initialise_dataloader(self, is_train): 99 | raise NotImplementedError('function \'initialise_dataloader\' was not ' 100 | ' overloaded in child class. This function ' 101 | 'need to create the attributes trn_dl and ' 102 | 'val_dl.') 103 | -------------------------------------------------------------------------------- /src/ovseg/utils/label_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | try: 3 | from skimage.measure import label 4 | except ImportError: 5 | print('Caught Import Error while importing some function from scipy or skimage. ' 6 | 'Please use a newer version of gcc.') 7 | 8 | 9 | def remove_small_connected_components(lb, min_vol, spacing=None): 10 | ''' 11 | 12 | Parameters 13 | ---------- 14 | lb : np.ndarray 15 | integer valued label array 16 | min_vol : scalar or list 17 | smallest volume 18 | spacing : TYPE, optional 19 | DESCRIPTION. The default is None. 20 | 21 | Returns 22 | ------- 23 | lb : TYPE 24 | DESCRIPTION. 25 | 26 | ''' 27 | 28 | n_classes = int(lb.max()) 29 | if np.isscalar(min_vol): 30 | min_vol = [min_vol for _ in range(n_classes)] 31 | else: 32 | if len(min_vol) < n_classes: 33 | raise ValueError('min_vol was given as a list, but with less volumes then classes. ' 34 | 'Choose either one univsersal number for all classes or give ' 35 | 'at least as many volumes as classes.') 36 | 37 | if spacing is None: 38 | spacing = [1, 1, 1] 39 | 40 | fac = np.prod(spacing) 41 | 42 | for i, mvol in enumerate(min_vol): 43 | bin_label = (lb == i + 1) > 0 44 | 45 | conn_comps = label(bin_label) 46 | n_comps = conn_comps.max() 47 | 48 | for c in range(1, n_comps+1): 49 | conn_comp = (conn_comps == c) 50 | 51 | if np.sum(conn_comp.astype(float)) * fac < mvol: 52 | lb[conn_comp] = 0 53 | 54 | return lb 55 | 56 | def remove_connected_components_by_volume(lb, min_vol=None, max_vol=None, spacing=None): 57 | ''' 58 | 59 | Parameters 60 | ---------- 61 | lb : np.ndarray 62 | integer valued label array 63 | min_vol : scalar or list 64 | smallest volume 65 | max_vol : scalar or list 66 | largest volume 67 | spacing : TYPE, optional 68 | DESCRIPTION. The default is None. 69 | 70 | Returns 71 | ------- 72 | lb : TYPE 73 | DESCRIPTION. 74 | 75 | ''' 76 | 77 | n_classes = int(lb.max()) 78 | 79 | if min_vol is None: 80 | min_vol = 0 81 | if max_vol is None: 82 | max_vol = np.inf 83 | 84 | if np.isscalar(min_vol): 85 | min_vol = [min_vol for _ in range(n_classes)] 86 | else: 87 | if len(min_vol) < n_classes: 88 | raise ValueError('min_vol was given as a list, but with less volumes then classes. ' 89 | 'Choose either one univsersal number for all classes or give ' 90 | 'at least as many volumes as classes.') 91 | if np.isscalar(max_vol): 92 | max_vol = [max_vol for _ in range(n_classes)] 93 | else: 94 | if len(max_vol) < n_classes: 95 | raise ValueError('min_vol was given as a list, but with less volumes then classes. ' 96 | 'Choose either one univsersal number for all classes or give ' 97 | 'at least as many volumes as classes.') 98 | 99 | if spacing is None: 100 | spacing = [1, 1, 1] 101 | 102 | fac = np.prod(spacing) 103 | 104 | for i, (mnvol, mxvol) in enumerate(zip(min_vol, max_vol)): 105 | bin_label = (lb == i + 1) > 0 106 | 107 | conn_comps = label(bin_label) 108 | n_comps = conn_comps.max() 109 | 110 | for c in range(1, n_comps+1): 111 | conn_comp = (conn_comps == c) 112 | 113 | conn_comp_vol = np.sum(conn_comp.astype(float)) * fac 114 | if conn_comp_vol < mnvol: 115 | lb[conn_comp] = 0 116 | elif conn_comp_vol > mxvol: 117 | lb[conn_comp] = 0 118 | 119 | return lb 120 | 121 | 122 | def remove_small_connected_components_from_batch(lbb, min_vol, spacing=None): 123 | 124 | batch_list = [] 125 | for b in range(lbb.shape[0]): 126 | channel_list = [] 127 | for c in range(lbb.shape[1]): 128 | lb = remove_small_connected_components(lbb[b, c], min_vol, spacing) 129 | channel_list.append(lb) 130 | batch_list.append(np.stack(channel_list)) 131 | return np.stack(batch_list) 132 | 133 | def remove_connected_components_by_volume_from_batch(lbb, min_vol=None, max_vol=None, spacing=None): 134 | 135 | batch_list = [] 136 | for b in range(lbb.shape[0]): 137 | channel_list = [] 138 | for c in range(lbb.shape[1]): 139 | lb = remove_connected_components_by_volume(lbb[b, c], min_vol, max_vol, spacing) 140 | channel_list.append(lb) 141 | batch_list.append(np.stack(channel_list)) 142 | return np.stack(batch_list) 143 | 144 | 145 | def reduce_classes(lb, classes, to_single_class=False): 146 | 147 | lb_new = np.zeros_like(lb) 148 | 149 | for i, c in enumerate(classes): 150 | lb_new[lb == c] = i + 1 151 | 152 | if to_single_class: 153 | lb_new = (lb_new > 0).astype(lb.dtype) 154 | 155 | return lb_new 156 | -------------------------------------------------------------------------------- /src/ovseg/data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def crop_and_pad_image(volume, coord, patch_size, padded_patch_size, mode='minimum'): 5 | ''' 6 | crop_and_pad_image(volume, coord, patch_size, padded_patch_size) 7 | crops from a volume with 8 | Parameters 9 | ---------- 10 | volume : 4d tensor 11 | coord : len 3 12 | upper left coordinate of the patch 13 | patch_size : len 3 14 | size of the patch before padding 15 | patch_size : len 3 16 | size of the padded patch 17 | Returns 18 | ------- 19 | None. 20 | 21 | ''' 22 | assert len(volume.shape) == 4 23 | shape = np.array(volume.shape)[1:] 24 | 25 | # global coordinates, possible outside volume 26 | cmn_in = coord - (padded_patch_size - patch_size)//2 27 | cmx_in = cmn_in + padded_patch_size 28 | 29 | # clip the coordinates to not violate the arrays axes 30 | cmn_vol = np.maximum(cmn_in, 0) 31 | cmx_vol = np.minimum(cmx_in, shape) 32 | 33 | # let's cut out of the volume as much as we can 34 | crop = volume[:, cmn_vol[0]:cmx_vol[0], cmn_vol[1]:cmx_vol[1], 35 | cmn_vol[2]:cmx_vol[2]] 36 | 37 | # now the padding 38 | pad_low = -1 * np.minimum(0, cmn_in) 39 | # pad_up = np.maximum(0, cmn_in - cmx_vol) 40 | pad_up = np.maximum(0, cmx_in - cmx_vol) 41 | pad_width = [(int(pl), int(pu)) for pl, pu in zip(pad_low, pad_up)] 42 | 43 | # we apply the mode for each channel, e.g. when the mode is minium we pad with the minmal value 44 | # in each channel 45 | return np.stack([np.pad(crop[c], pad_width, mode=mode) for c in range(crop.shape[0])]) 46 | 47 | 48 | def folds_to_splits(folds): 49 | splits = [] 50 | for i in range(len(folds)): 51 | train = [] 52 | for j in range(len(folds)): 53 | if i == j: 54 | val = folds[j] 55 | else: 56 | train.extend(folds[j]) 57 | splits.append({'train': train, 'val': val}) 58 | return splits 59 | 60 | 61 | def split_scans_random_uniform(scans, n_folds=5, fixed_shuffle=True): 62 | 63 | if fixed_shuffle: 64 | # fix the splitting of the data 65 | scans = sorted(scans) 66 | np.random.seed(12345) 67 | np.random.shuffle(scans) 68 | # number of items in all but the last fold 69 | size_fold = int(len(scans) / n_folds) 70 | # folds 0,1,...,n_folds-2 71 | folds = [scans[i * size_fold: (i + 1) * size_fold] for i in 72 | range(n_folds - 1)] 73 | # fold n_fold -1 74 | folds.append(scans[(n_folds - 1) * size_fold:]) 75 | return folds_to_splits(folds) 76 | 77 | 78 | def split_scans_by_patient_id(scans, patient_ids, n_folds=4, 79 | fixed_shuffle=True): 80 | 81 | # first we check if the patient ID dict is like we want it 82 | if not isinstance(patient_ids, dict): 83 | raise TypeError('patient_ids must be dict. The keys must be the names ' 84 | 'of the files and the items the patient ids.') 85 | 86 | # let's see if all the scans are in the dict 87 | not_in_patient_ids = [] 88 | for scan in scans: 89 | if scan not in patient_ids: 90 | not_in_patient_ids.append(scan) 91 | if len(not_in_patient_ids) > 0: 92 | raise ValueError('Some names of scans were not found in the ' 93 | 'patient_ids: \n' + str(not_in_patient_ids)) 94 | 95 | # shuffle the scans either randomly of fixed at random 96 | if fixed_shuffle: 97 | scans = sorted(scans) 98 | np.random.seed(12345) 99 | np.random.shuffle(scans) 100 | 101 | # now we put the images in tuples of matching patient ids 102 | scans_used = [] 103 | scan_tuples = [] 104 | patient_id_list = [patient_ids[scan] for scan in scans] 105 | for scan in scans: 106 | if scan not in scans_used: 107 | pat_id = patient_ids[scan] 108 | # find all scans with same patient id 109 | scan_tuple = [s for pid, s in zip(patient_id_list, scans) if 110 | pid == pat_id] 111 | 112 | # save this tuple 113 | scan_tuples.append(scan_tuple) 114 | 115 | # and mark all the scans as used 116 | for s in scan_tuple: 117 | scans_used.append(s) 118 | 119 | # now we sort the tuples again to have the largest ones at first 120 | # this way we're trying to make the folds equally large. If a very large 121 | # tuple would be the last one in scan_tuples this might lead to very 122 | # unequally sized folds 123 | scan_tuples_len = [len(tpl) for tpl in scan_tuples] 124 | scan_tuples = [tpl for n, tpl in sorted(zip(scan_tuples_len, scan_tuples))] 125 | 126 | # now we unpack the guys and put them in folds 127 | folds = [[] for _ in range(n_folds)] 128 | for tpl in scan_tuples: 129 | # find out the fold with the currently lowest amount of scans 130 | ind = np.argmin([len(fold) for fold in folds]) 131 | folds[ind].extend(tpl) 132 | 133 | # now we turn the folds into splits and return 134 | print('Smart splitting successfull length of folds:') 135 | print([len(fold) for fold in folds]) 136 | print() 137 | 138 | return folds_to_splits(folds) 139 | -------------------------------------------------------------------------------- /src/ovseg/model/SegmentationEnsembleV2.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationEnsemble import SegmentationEnsemble 2 | from ovseg.model.SegmentationModelV2 import SegmentationModelV2 3 | from ovseg.data.Dataset import raw_Dataset 4 | from os.path import join, exists 5 | from os import environ 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class SegmentationEnsembleV2(SegmentationEnsemble): 11 | 12 | def create_model(self, fold): 13 | model = SegmentationModelV2(val_fold=fold, 14 | data_name=self.data_name, 15 | model_name=self.model_name, 16 | model_parameters=self.model_parameters, 17 | preprocessed_name=self.preprocessed_name, 18 | network_name=self.network_name, 19 | is_inference_only=True, 20 | fmt_write=self.fmt_write, 21 | model_parameters_name=self.model_parameters_name 22 | ) 23 | return model 24 | 25 | def __call__(self, data_tpl): 26 | if not self.all_folds_complete(): 27 | print('WARNING: Ensemble is used without all training folds being completed!!') 28 | 29 | if not self.models_initialised: 30 | print('Models were not initialised. Trying to do it now...') 31 | self.wait_until_all_folds_complete() 32 | 33 | scan = data_tpl['scan'] 34 | 35 | # also the path where we will look for already executed npz prediction 36 | pred_npz_path = join(environ['OV_DATA_BASE'], 'npz_predictions', self.data_name, 37 | self.preprocessed_name, self.model_name) 38 | 39 | # the preprocessing will only do something if the image is not preprocessed yet 40 | if not self.preprocessing.is_preprocessed_data_tpl(data_tpl): 41 | for model in self.models: 42 | # try find the npz file if there was already a prediction. 43 | path_to_npz = join(pred_npz_path, model.val_fold_str, scan+'.npz') 44 | path_to_npy = join(pred_npz_path, model.val_fold_str, scan+'.npy') 45 | 46 | if exists(path_to_npy) or exists(path_to_npz): 47 | im, mask = None, None 48 | continue 49 | else: 50 | im = self.preprocessing(data_tpl, preprocess_only_im=True) 51 | if self.preprocessing.has_ps_mask: 52 | im, mask = im[:-1], im[-1:] 53 | else: 54 | mask = None 55 | break 56 | 57 | # now the importat part: the actual enembling of sliding window evaluations 58 | preds = [] 59 | with torch.no_grad(): 60 | for model in self.models: 61 | # try find the npz file if there was already a prediction. 62 | path_to_npz = join(pred_npz_path, model.val_fold_str, scan+'.npz') 63 | path_to_npy = join(pred_npz_path, model.val_fold_str, scan+'.npy') 64 | if exists(path_to_npy): 65 | try: 66 | pred = np.load(path_to_npy) 67 | except ValueError: 68 | 69 | if im is None: 70 | im = self.preprocessing(data_tpl, preprocess_only_im=True) 71 | if self.preprocessing.has_ps_mask: 72 | im, mask = im[:-1], im[-1:] 73 | else: 74 | mask = None 75 | pred = model.prediction(im).cpu().numpy() 76 | elif exists(path_to_npz): 77 | try: 78 | pred = np.load(path_to_npz)['arr_0'] 79 | except ValueError: 80 | if im is None: 81 | im = self.preprocessing(data_tpl, preprocess_only_im=True) 82 | if self.preprocessing.has_ps_mask: 83 | im, mask = im[:-1], im[-1:] 84 | else: 85 | mask = None 86 | pred = model.prediction(im).cpu().numpy() 87 | 88 | else: 89 | pred = model.prediction(im).cpu().numpy() 90 | preds.append(pred) 91 | 92 | ens_pred = np.stack(preds).mean(0) 93 | 94 | data_tpl[self.pred_key] = ens_pred 95 | 96 | # inside the postprocessing the result will be attached to the data_tpl 97 | self.postprocessing.postprocess_data_tpl(data_tpl, self.pred_key, mask) 98 | 99 | torch.cuda.empty_cache() 100 | return data_tpl[self.pred_key] 101 | 102 | def eval_raw_dataset(self, data_name, save_preds=True, save_plots=False, 103 | force_evaluation=False, scans=None, image_folder=None, dcm_revers=True, 104 | dcm_names_dict=None): 105 | 106 | prev_stages = {**self.preprocessing.prev_stage_for_input, 107 | **self.preprocessing.prev_stage_for_mask} 108 | if len(prev_stages) == 0: 109 | prev_stages = None 110 | 111 | ds = raw_Dataset(join(environ['OV_DATA_BASE'], 'raw_data', data_name), 112 | scans=scans, 113 | image_folder=image_folder, 114 | dcm_revers=dcm_revers, 115 | dcm_names_dict=dcm_names_dict, 116 | prev_stages=prev_stages) 117 | self.eval_ds(ds, ds_name=data_name, save_preds=save_preds, save_plots=save_plots, 118 | force_evaluation=force_evaluation) -------------------------------------------------------------------------------- /src/ovseg/model/SegmentationModelV2.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationModel import SegmentationModel 2 | from ovseg.preprocessing.SegmentationPreprocessingV2 import SegmentationPreprocessingV2 3 | from ovseg.data.SegmentationDataV2 import SegmentationDataV2 4 | from ovseg.utils.io import save_nii_from_data_tpl, save_npy_from_data_tpl, load_pkl, read_nii, save_dcmrt_from_data_tpl, is_dcm_path 5 | from ovseg.utils.torch_np_utils import maybe_add_channel_dim 6 | from ovseg.utils.dict_equal import dict_equal, print_dict_diff 7 | from os.path import join 8 | import numpy as np 9 | 10 | class SegmentationModelV2(SegmentationModel): 11 | 12 | def initialise_preprocessing(self): 13 | if 'preprocessing' not in self.model_parameters: 14 | print('No preprocessing parameters found in model_parameters. ' 15 | 'Trying to load from preprocessed_folder...') 16 | if not hasattr(self, 'preprocessed_path'): 17 | raise AttributeError('preprocessed_path wasn\'t initialiased. ' 18 | 'Make sure to either pass the ' 19 | 'preprocessing parameters or the path ' 20 | 'to the preprocessed folder were an ' 21 | 'extra copy is stored.') 22 | else: 23 | prep_params = load_pkl(join(self.preprocessed_path, 24 | 'preprocessing_parameters.pkl')) 25 | self.model_parameters['preprocessing'] = prep_params 26 | 27 | 28 | params = self.model_parameters['preprocessing'].copy() 29 | 30 | self.preprocessing = SegmentationPreprocessingV2(**params) 31 | 32 | # now for the computation of loss metrics we need the number of prevalent fg classes 33 | if self.preprocessing.reduce_lb_to_single_class: 34 | self.n_fg_classes = 1 35 | elif self.preprocessing.lb_classes is not None: 36 | self.n_fg_classes = len(self.preprocessing.lb_classes) 37 | elif self.model_parameters['network']['out_channels'] is not None: 38 | self.n_fg_classes = self.model_parameters['network']['out_channels'] - 1 39 | elif hasattr(self.preprocessing, 'dataset_properties'): 40 | print('Using all foreground classes for computing the DSCS') 41 | self.n_fg_classes = self.preprocessing.dataset_properties['n_fg_classes'] 42 | else: 43 | raise AttributeError('Something seems to be wrong. Could not figure out the number ' 44 | 'of foreground classes in the problem...') 45 | if self.preprocessing.lb_classes is None and hasattr(self.preprocessing, 'dataset_properties'): 46 | 47 | if self.preprocessing.reduce_lb_to_single_class: 48 | self.lb_classes = [1] 49 | else: 50 | self.lb_classes = list(range(1, self.n_fg_classes+1)) 51 | if self.n_fg_classes != self.preprocessing.dataset_properties['n_fg_classes']: 52 | print('There seems to be a missmatch between the number of forground ' 53 | 'classes in the preprocessed data and the number of network ' 54 | 'output channels....') 55 | else: 56 | self.lb_classes = self.preprocessing.lb_classes 57 | 58 | def initialise_data(self): 59 | # the data object holds the preprocessed data (training and validation) 60 | # for each it has both a dataset returning the data tuples and the dataloaders 61 | # returning the batches 62 | if 'data' not in self.model_parameters: 63 | raise AttributeError('model_parameters must have key ' 64 | '\'data\'. These must contain the ' 65 | 'dict of training paramters.') 66 | 67 | # Let's get the parameters and add the cpu augmentation 68 | params = self.model_parameters['data'].copy() 69 | 70 | # if we don't want to store our data in ram... 71 | if self.dont_store_data_in_ram: 72 | for key in ['trn_dl_params', 'val_dl_params']: 73 | params[key]['store_data_in_ram'] = False 74 | params[key]['store_coords_in_ram'] = False 75 | self.data = SegmentationDataV2(val_fold=self.val_fold, 76 | preprocessed_path=self.preprocessed_path, 77 | augmentation= self.augmentation.np_augmentation, 78 | **params) 79 | print('Data initialised') 80 | 81 | 82 | def __call__(self, data_tpl, do_postprocessing=True): 83 | ''' 84 | This function just predict the segmentation for the given data tpl 85 | There are a lot of differnt ways to do prediction. Some do require direct preprocessing 86 | some don't need the postprocessing imidiately (e.g. when ensembling) 87 | Same holds for the resizing to original shape. In the validation case we wan't to apply 88 | some postprocessing (argmax and removing of small lesions) but not the resizing. 89 | ''' 90 | self.network = self.network.eval() 91 | 92 | # first let's get the image and maybe the bin_pred as well 93 | # the preprocessing will only do something if the image is not preprocessed yet 94 | if not self.preprocessing.is_preprocessed_data_tpl(data_tpl): 95 | # the image already contains the binary prediction as additional channel 96 | im = self.preprocessing(data_tpl, preprocess_only_im=True) 97 | if self.preprocessing.has_ps_mask: 98 | im, mask = im[:-1], im[-1:] 99 | else: 100 | mask = None 101 | else: 102 | # the data_tpl is already preprocessed, let's just get the arrays 103 | im = data_tpl['image'] 104 | im = maybe_add_channel_dim(im) 105 | if self.preprocessing.has_ps_input: 106 | 107 | pred = maybe_add_channel_dim(data_tpl['prev_pred']) 108 | if pred.max() > 1: 109 | raise NotImplementedError('Didn\'t implement the casacde for multiclass' 110 | 'prev stages. Add one hot encoding.') 111 | im = np.concatenate([im, pred]) 112 | 113 | if self.preprocessing.has_ps_mask: 114 | mask = maybe_add_channel_dim(data_tpl['mask']) 115 | else: 116 | mask = None 117 | 118 | 119 | # now the importat part: the sliding window evaluation (or derivatives of it) 120 | pred = self.prediction(im) 121 | data_tpl[self.pred_key] = pred 122 | 123 | # inside the postprocessing the result will be attached to the data_tpl 124 | if do_postprocessing: 125 | self.postprocessing.postprocess_data_tpl(data_tpl, self.pred_key, mask) 126 | 127 | return data_tpl[self.pred_key] -------------------------------------------------------------------------------- /src/ovseg/augmentation/GridAugmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | # %% 8 | class torch_inplane_grid_augmentations(nn.Module): 9 | 10 | def __init__(self, 11 | p_rot=0.2, 12 | p_zoom=0.2, 13 | p_scale_if_zoom=0, 14 | p_transl=0, 15 | p_shear=0, 16 | mm_zoom=[0.7, 1.4], 17 | mm_rot=[-15, 15], 18 | mm_transl=[-0.25, 0.25], 19 | mm_shear=[-0.2, 0.2], 20 | apply_flipping=True, 21 | n_im_channels: int = 1, 22 | out_shape=None 23 | ): 24 | super().__init__() 25 | self.p_rot = p_rot 26 | self.p_zoom = p_zoom 27 | self.p_scale_if_zoom = p_scale_if_zoom 28 | self.p_transl = p_transl 29 | self.p_shear = p_shear 30 | self.mm_zoom = mm_zoom 31 | self.mm_rot = mm_rot 32 | self.mm_transl = mm_transl 33 | self.mm_shear = mm_shear 34 | self.apply_flipping = apply_flipping 35 | self.n_im_channels = n_im_channels 36 | self.out_shape = out_shape 37 | if out_shape is not None: 38 | self.out_shape = np.array(self.out_shape) 39 | 40 | def _rot(self, theta): 41 | 42 | angle = np.random.uniform(*self.mm_rot) 43 | rot_m = torch.zeros_like(theta[:, :-1]) 44 | cos, sin = np.cos(np.deg2rad(angle)), np.sin(np.deg2rad(angle)) 45 | rot_m[0, 0] = cos 46 | rot_m[0, 1] = sin 47 | rot_m[1, 0] = -1 * sin 48 | rot_m[1, 1] = cos 49 | if theta.shape[0] == 3: 50 | rot_m[2, 2] = 1 51 | 52 | return torch.mm(rot_m, theta) 53 | 54 | def _zoom(self, theta): 55 | fac1 = np.random.uniform(*self.mm_zoom) 56 | if np.random.rand() < self.p_scale_if_zoom: 57 | fac2 = np.random.uniform(*self.mm_zoom) 58 | else: 59 | fac2 = fac1 60 | theta[0, 0] *= fac1 61 | theta[1, 1] *= fac2 62 | theta[0, -1] *= fac1 63 | theta[1, -1] *= fac2 64 | return theta 65 | 66 | def _translate(self, theta): 67 | theta[0, -1] = np.random.uniform(*self.mm_transl) 68 | theta[1, -1] = np.random.uniform(*self.mm_transl) 69 | return theta 70 | 71 | def _shear(self, theta): 72 | s = np.random.uniform(*self.mm_shear) 73 | shear_m = torch.zeros_like(theta[:, :-1]) 74 | for i in range(theta.shape[0]): 75 | shear_m[i, i] = 1 76 | if np.random.rand() < 0.5: 77 | shear_m[0, 1] = s 78 | else: 79 | shear_m[1, 0] = s 80 | return torch.mm(shear_m, theta) 81 | 82 | def _get_ops_list(self): 83 | ops_list = [] 84 | if np.random.rand() < self.p_rot: 85 | ops_list.append(self._rot) 86 | if np.random.rand() < self.p_zoom: 87 | ops_list.append(self._zoom) 88 | if np.random.rand() < self.p_transl: 89 | ops_list.append(self._translate) 90 | if np.random.rand() < self.p_shear: 91 | ops_list.append(self._shear) 92 | np.random.shuffle(ops_list) 93 | 94 | return ops_list 95 | 96 | def _flip(self, xb): 97 | 98 | bs = xb.shape[0] 99 | img_dims = len(xb.shape) - 2 100 | flp_list = [np.random.rand(img_dims) < 0.5 for _ in range(bs)] 101 | 102 | for b, flp in enumerate(flp_list): 103 | dims = [i + 1 for i, f in enumerate(flp) if f] 104 | if len(dims) > 0: 105 | xb[b] = torch.flip(xb[b], dims) 106 | return xb 107 | 108 | def forward(self, xb): 109 | 110 | bs, n_ch = xb.shape[0:2] 111 | img_dims = len(xb.shape) - 2 112 | theta = torch.zeros((bs, img_dims, img_dims+1), device=xb.device, dtype=xb.dtype) 113 | for j in range(img_dims): 114 | theta[:, j, j] = 1 115 | for i in range(bs): 116 | ops_list = self._get_ops_list() 117 | for op in ops_list: 118 | theta[i] = op(theta[i]) 119 | 120 | grid = F.affine_grid(theta, xb.size()).cuda().type(xb.dtype) 121 | if self.out_shape is not None: 122 | # crop from the grid 123 | crp_l = (np.array(xb.shape[2:]) - self.out_shape) // 2 124 | crp_u = (crp_l + self.out_shape) 125 | grid = grid[:, crp_l[0]:crp_u[0], crp_l[1]:crp_u[1], crp_l[2]:crp_u[2]] 126 | xb = torch.cat([F.grid_sample(xb[:, :self.n_im_channels], grid, mode='bilinear'), 127 | F.grid_sample(xb[:, self.n_im_channels:], grid, mode='nearest')], dim=1) 128 | 129 | # now flipping 130 | if self.apply_flipping: 131 | xb = self._flip(xb) 132 | return xb 133 | 134 | def update_prg_trn(self, param_dict, h, indx=None): 135 | 136 | attr_list = ['p_rot', 'p_zoom', 'p_transl', 'p_shear', 'mm_zoom', 'mm_rot', 137 | 'mm_transl', 'mm_shear'] 138 | 139 | for attr in attr_list: 140 | if attr in param_dict: 141 | self.__setattr__(attr, (1 - h) * param_dict[attr][0] + h * param_dict[attr][1]) 142 | 143 | if 'out_shape' in param_dict: 144 | self.out_shape = param_dict['out_shape'][indx] 145 | 146 | # %% 147 | if __name__ == '__main__': 148 | import matplotlib.pyplot as plt 149 | from time import perf_counter 150 | plt.close('all') 151 | 152 | im_full = np.load('D:\\PhD\\Data\\ov_data_base\\preprocessed\\OV04_test\\default\\images' 153 | '\\OV04_034_20091014.npy') 154 | lb_full = np.load('D:\\PhD\\Data\\ov_data_base\\preprocessed\\OV04_test\\default\\labels' 155 | '\\OV04_034_20091014.npy') > 0 156 | 157 | im_crop = im_full[30:78, 100:292, 100:292].astype(np.float32) 158 | imt = torch.from_numpy(im_crop).cuda().unsqueeze(0).unsqueeze(0).type(torch.float) 159 | lb_crop = lb_full[30:78, 100:292, 100:292].astype(np.float32) 160 | lbt = torch.from_numpy(lb_crop).cuda().unsqueeze(0).unsqueeze(0).type(torch.float) 161 | xb = torch.cat([imt, lbt], 1).cuda() 162 | xb = torch.cat([xb, xb], 0) 163 | aug = torch_inplane_grid_augmentations(p_rot=0.5, p_zoom=0.5, p_scale_if_zoom=0.5, 164 | p_transl=0.0, p_shear=0.5, 165 | mm_zoom=[0.8,1.2], mm_rot=[-20, 20], 166 | apply_flipping=False) 167 | 168 | # %% 169 | xb_aug = aug(xb).cpu().numpy() 170 | 171 | z = np.argmax(np.sum(lb_crop > 0, (1, 2))) 172 | plt.subplot(1, 3, 1) 173 | plt.imshow(xb_aug[0, 0, z], cmap='gray') 174 | plt.contour(xb_aug[0, 1, z]) 175 | plt.subplot(1, 3, 2) 176 | plt.imshow(im_crop[z], cmap='gray') 177 | plt.contour(lb_crop[z]) 178 | plt.subplot(1, 3, 3) 179 | plt.imshow(xb_aug[1, 0, z], cmap='gray') 180 | plt.contour(xb_aug[1, 1, z]) 181 | 182 | # %% 183 | 184 | st = perf_counter() 185 | for _ in range(50): 186 | xb_aug = aug(xb) 187 | torch.cuda.synchronize() 188 | et = perf_counter() 189 | print('It took {:.7f}s for augmenting with batch size 2'.format((et-st)/50)) -------------------------------------------------------------------------------- /plan_and_preprocess.py: -------------------------------------------------------------------------------- 1 | from ovseg.preprocessing.SegmentationPreprocessing import SegmentationPreprocessing 2 | from os.path import join 3 | from os import environ 4 | import argparse 5 | import warnings 6 | warnings.filterwarnings("ignore", category=UserWarning) 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | # first some arguments we need for the planning and saving 11 | parser.add_argument("raw_data", nargs='+', 12 | help='Name or names of folders in \'raw_data\' that are used for planning ' 13 | 'and that are preprocessed.') 14 | parser.add_argument("preprocessed_name", 15 | help='Name that the folder in \'preprocessed\' will have.') 16 | parser.add_argument("--data_name", required=False, default=None, 17 | help='set this in case you want to give the parent folder of the preprocessed ' 18 | 'data a special name. If unset this name will be the name given in raw_data.') 19 | parser.add_argument("--save_as_fp32", required=False, default=False, action='store_true', 20 | help='Usually the preprocessed images are stored as fp16 to save disk space. ' 21 | 'Set this if you are sure that you need them to be stored as fp32 instead.') 22 | parser.add_argument("--save_scans_without_fg", required=False, default=False, action='store_true', 23 | help='By default images that contain foreground are preprocessed and ' 24 | 'used for training. Set this flag to undo so.') 25 | 26 | # now the arguments that determine the preprocessing 27 | parser.add_argument("--dont_apply_resizing", required=False, default=False, action='store_true') 28 | parser.add_argument("--dont_apply_windowing", required=False, default=False, action='store_true') 29 | parser.add_argument("--target_spacing", required=False, default=None, nargs='+') 30 | parser.add_argument("--pooling_stride", required=False, default=None, nargs='+') 31 | parser.add_argument("--window", required=False, default=None, nargs='+') 32 | parser.add_argument("--scaling", required=False, default=None, nargs='+') 33 | parser.add_argument("--lb_classes", required=False, default=None, nargs='+', 34 | help='If you have a multiclass problem with a lot of classes and each model ' 35 | 'is only supposed to use a subset of these, determine here which classes ' 36 | 'you want to consider for this batch of preprocessed data.') 37 | parser.add_argument("--reduce_lb_to_single_class", required=False, default=False, 38 | action='store_true', 39 | help='If you are in a multiclass setting and you wishes to reduce your ' 40 | 'problem to a simply binary foreground vs. background setting, set this flag.') 41 | parser.add_argument("--lb_min_vol", required=False, default=None, nargs='+', 42 | help='You can give one value or a list with as many values as there are ' 43 | 'foreground classes. Connected components with a volume of less then this/these' 44 | ' constant(s) are being discraded as a part of the preprocessing. Choose ' 45 | 'constants in the same unit as the images spacing, typically mm^3.') 46 | parser.add_argument('--prev_stages', required=False, default=[], nargs='+', 47 | help='Name the data_name, preprocessed_name and model_name of arbritraily ' 48 | 'many previous stages here to use them as an input for model cascades.') 49 | args = parser.parse_args() 50 | 51 | # input arguments 52 | apply_resizing = not args.dont_apply_resizing 53 | if apply_resizing and args.target_spacing is not None: 54 | target_spacing = [float(ts) for ts in args.target_spacing] 55 | assert len(target_spacing) == 3, 'you must give exactly 3 floats as a target spacing' 56 | else: 57 | target_spacing = None 58 | 59 | if args.pooling_stride is not None: 60 | pooling_stride = [int(ps) for ps in args.pooling_stride] 61 | assert len(pooling_stride) == 3, 'you must give exactly 3 integers as a pooling stride' 62 | apply_pooling = True 63 | else: 64 | pooling_stride = None 65 | apply_pooling = False 66 | 67 | apply_windowing = not args.dont_apply_windowing 68 | if apply_windowing and args.window is not None: 69 | window = [float(w) for w in args.window] 70 | assert len(window) == 2, 'you must give exactyle 2 floats as a window' 71 | else: 72 | window = None 73 | 74 | if args.scaling is not None: 75 | scaling = [float(s) for s in args.scaling] 76 | assert len(scaling) == 2, 'you must give exactly 2 floats for the input scaling' 77 | else: 78 | scaling = None 79 | 80 | if args.lb_classes is not None: 81 | lb_classes = [int(lbc) for lbc in args.lb_classes] 82 | else: 83 | lb_classes = None 84 | 85 | if args.lb_min_vol is not None: 86 | lb_min_vol = [float(lmv) for lmv in args.lb_min_vol] 87 | else: 88 | lb_min_vol = None 89 | 90 | save_only_fg_scans = not args.save_scans_without_fg 91 | 92 | if len(args.prev_stages) % 3 != 0: 93 | raise ValueError('The arguments given in previous stages must be divisible by three.' 94 | 'The Input shold be likedata_name1, preprocessed_name1, model_name1, ...., ' 95 | 'data_namek, preprocessed_namek, model_namek') 96 | 97 | n_stages = len(args.prev_stages) // 3 98 | prev_stages = [] 99 | for i in range(n_stages): 100 | prev_stages.append({'data_name': args.prev_stages[3*i], 101 | 'preprocessed_name': args.prev_stages[3*i+1], 102 | 'model_name': args.prev_stages[3*i+2]}) 103 | 104 | preprocessing = SegmentationPreprocessing(apply_resizing=apply_resizing, 105 | apply_pooling=apply_pooling, 106 | apply_windowing=apply_windowing, 107 | target_spacing=target_spacing, 108 | pooling_stride=pooling_stride, 109 | window=window, 110 | scaling=scaling, 111 | lb_classes=lb_classes, 112 | reduce_lb_to_single_class=args.reduce_lb_to_single_class, 113 | lb_min_vol=lb_min_vol, 114 | prev_stages=prev_stages, 115 | save_only_fg_scans=save_only_fg_scans) 116 | 117 | preprocessing.plan_preprocessing_raw_data(args.raw_data, 118 | force_planning=True) 119 | 120 | preprocessing.preprocess_raw_data(raw_data=args.raw_data, 121 | preprocessed_name=args.preprocessed_name, 122 | data_name=args.data_name, 123 | save_as_fp16=not args.save_as_fp32) 124 | 125 | print('Done! Here are the preprocessing paramters:') 126 | if args.data_name is None: 127 | data_name = '_'.join(sorted(args.raw_data)) 128 | 129 | # root folder of all saved preprocessed data 130 | path_to_file = join(environ['OV_DATA_BASE'], 'preprocessed', data_name, args.preprocessed_name, 131 | 'preprocessing_parameters.txt') 132 | 133 | with open(path_to_file, 'r') as file: 134 | print(file.read()) 135 | -------------------------------------------------------------------------------- /src/ovseg/training/TrainingBase.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from os.path import join, exists 4 | import time 5 | import pickle 6 | import sys 7 | from ovseg.utils.path_utils import maybe_create_path 8 | 9 | 10 | class TrainingBase(): 11 | ''' 12 | Basic class for Trainer. Inherit this one for your training needs. 13 | Overload all classes like 14 | 15 | def function(...): 16 | super().function(...) 17 | 18 | ''' 19 | 20 | def __init__(self, trn_dl, num_epochs, model_path): 21 | 22 | # overwrite defaults 23 | self.trn_dl = trn_dl 24 | self.num_epochs = num_epochs 25 | self.model_path = model_path 26 | 27 | # some training stuff 28 | self.epochs_done = 0 29 | self.trn_start_time = -1 30 | self.trn_end_time = -1 31 | self.total_train_time = 0 32 | 33 | # these attributes will be stored and recovered, append this list 34 | # with the attributes you want to save 35 | self.checkpoint_attributes = ['epochs_done', 'trn_start_time', 'trn_end_time', 36 | 'total_train_time', 'num_epochs'] 37 | self.print_attributes = ['model_path', 'num_epochs'] 38 | # make model_path and training_log 39 | maybe_create_path(self.model_path) 40 | self.training_log = join(self.model_path, 'training_log.txt') 41 | 42 | def print_and_log(self, s, n_newlines=0): 43 | ''' 44 | prints, flushes and writes in the training_log 45 | ''' 46 | if len(s) > 0: 47 | print(s) 48 | for _ in range(n_newlines): 49 | print('') 50 | sys.stdout.flush() 51 | t = time.localtime() 52 | if len(s) > 0: 53 | ts = time.strftime('%Y-%m-%d %H:%M:%S: ', t) 54 | s = ts + s + '\n' 55 | else: 56 | s = '\n' 57 | mode = 'a' if exists(self.training_log) else 'w' 58 | with open(self.training_log, mode) as log_file: 59 | log_file.write(s) 60 | for _ in range(n_newlines): 61 | log_file.write('\n') 62 | 63 | def train(self): 64 | ''' 65 | Basic training function where everything is happening 66 | ''' 67 | 68 | if self.epochs_done >= self.num_epochs: 69 | # do nothing if the traing was already finished 70 | print('Training was already completed. Doing nothing here.') 71 | return 72 | else: 73 | # if we've stopped the training before by setting stop_training = True 74 | # this resumes it 75 | self.stop_training = False 76 | 77 | self.on_training_start() 78 | 79 | # we're using the stop_training flag to easily allow early stopping in the training 80 | while not self.stop_training: 81 | 82 | self.on_epoch_start() 83 | 84 | for step, batch in enumerate(self.trn_dl): 85 | 86 | self.do_trn_step(batch, step) 87 | 88 | self.on_epoch_end() 89 | # we save the checkpoint after calling on_epoch_end so that 90 | # computations added to on_epoch_end will be saved as well 91 | self.save_checkpoint() 92 | self.print_and_log('', 1) 93 | 94 | if self.epochs_done >= self.num_epochs: 95 | self.on_training_end() 96 | 97 | def on_training_start(self): 98 | 99 | self.print_and_log('Start training.', 2) 100 | sys.stdout.flush() 101 | 102 | if self.trn_start_time == -1: 103 | # keep date when training started 104 | self.trn_start_time = time.asctime() 105 | 106 | if self.epochs_done == 0: 107 | # print some training infos 108 | self.print_and_log('Training parameters:', 1) 109 | for key in self.print_attributes: 110 | item = self 111 | try: 112 | for k in key.split('.'): 113 | item = item.__getattribute__(k) 114 | self.print_and_log(str(key)+': '+str(item)) 115 | except AttributeError: 116 | self.print_and_log(str(key)+': ERROR, item not found.') 117 | self.print_and_log('', 1) 118 | 119 | def on_epoch_start(self): 120 | 121 | self.print_and_log('Epoch:%d' % self.epochs_done) 122 | self.epoch_start_time = time.time() 123 | 124 | def do_trn_step(self, data_tpl, step): 125 | 126 | raise NotImplementedError('\'do_trn_step\' must be overloaded in the child class.\n It has ' 127 | 'a data tpl as input and performs one optimisation step.') 128 | 129 | def on_epoch_end(self): 130 | ''' 131 | Basic function on what we\'re doing after each epoch. 132 | Add e.g. printing of training error or computations of validation error here 133 | ''' 134 | epoch_time = time.time() - self.epoch_start_time 135 | self.total_train_time += epoch_time 136 | 137 | self.print_and_log('Epoch training {} done after {:.2f} seconds' 138 | .format(self.epochs_done, epoch_time)) 139 | self.epochs_done += 1 140 | self.print_and_log('Average epoch training time: {:.2f} seconds' 141 | .format(self.total_train_time/self.epochs_done)) 142 | if self.epochs_done >= self.num_epochs: 143 | self.stop_training = True 144 | 145 | def on_training_end(self): 146 | 147 | self.print_and_log('Training finished!') 148 | if self.trn_end_time == -1: 149 | self.trn_end_time = time.asctime() 150 | self.print_and_log('Training time: {} - {} ({:.3f}h)'.format(self.trn_start_time, 151 | self.trn_end_time, 152 | self.total_train_time/3600)) 153 | self.save_checkpoint() 154 | 155 | def save_checkpoint(self, path=None): 156 | ''' 157 | Saves attributes of this trainer class as .pkl file for 158 | later restoring 159 | ''' 160 | if path is None: 161 | path = self.model_path 162 | attribute_dict = {} 163 | 164 | for key in self.checkpoint_attributes: 165 | 166 | item = self.__getattribute__(key) 167 | attribute_dict.update({key: item}) 168 | 169 | with open(os.path.join(path, 'attribute_checkpoint.pkl'), 'wb') as outfile: 170 | pickle.dump(attribute_dict, outfile) 171 | 172 | self.print_and_log('Training attributes saved') 173 | 174 | def load_last_checkpoint(self, path=None): 175 | ''' 176 | Loads trainers checkpoint, if added any custom attributes that are not 177 | of type scalar, tuple, list or np.ndarray. 178 | Overload for attributes of other types. 179 | ''' 180 | if path is None: 181 | path = self.model_path 182 | path_to_trainer_checkpoint = join(path, 'attribute_checkpoint.pkl') 183 | if exists(path_to_trainer_checkpoint): 184 | with open(path_to_trainer_checkpoint, 'rb') as pickle_file: 185 | attribute_dict = pickle.load(pickle_file) 186 | 187 | for key in self.checkpoint_attributes: 188 | try: 189 | item = attribute_dict[key] 190 | self.__setattr__(key, item) 191 | except KeyError: 192 | print('key {} was not found in loaded checkpoint. Skipping!'.format(key)) 193 | return True 194 | else: 195 | return False 196 | -------------------------------------------------------------------------------- /example_scripts/train_my_cascade.py: -------------------------------------------------------------------------------- 1 | from ovseg.model.SegmentationModelV2 import SegmentationModelV2 2 | from ovseg.preprocessing.SegmentationPreprocessingV2 import SegmentationPreprocessingV2 3 | from ovseg.model.SegmentationEnsembleV2 import SegmentationEnsembleV2 4 | from ovseg.model.model_parameters_segmentation import get_model_params_3d_UNet 5 | from ovseg.utils.io import load_pkl 6 | from time import sleep 7 | import os 8 | from ovseg import OV_PREPROCESSED, OV_DATA_BASE 9 | 10 | ''' 11 | Script for preprocessing and training a cascade in one go. 12 | ''' 13 | 14 | # name of your raw dataset 15 | data_name = 'MY_DATASET_NAME' 16 | # name of lowres preprocessed data 17 | preprocessed_name_lowres = 'MY_PREPROCESSED_NAME_lowres' 18 | # name of fullres preprocessed data 19 | # WARNING: If you change the lowres model the preprocessing have to be done 20 | # again to with a new name 21 | preprocessed_name_fullres = 'MY_PREPROCESSED_NAME_fullres' 22 | 23 | # give each model a unique name. This way the code will be able to identify them 24 | # both models (lowres and fullres) will have the same name and be differentiated 25 | # by the name of preprocessed data 26 | model_name = 'MY_MODEL_NAME' 27 | val_fold = 0 28 | # %% preprocess lowres data if it hasn't been done yet 29 | if not os.path.exists(os.path.join(OV_PREPROCESSED, data_name, preprocessed_name_lowres)): 30 | 31 | # downsample in xy plane by factor 2 for lowres model 32 | pooling_stride = (1,2,2) 33 | 34 | # ADD SOME PREPROCESSING PARAMETERS HERE 35 | prep = SegmentationPreprocessingV2(apply_resizing=True, 36 | apply_pooling=True, 37 | apply_windowing=True, 38 | pooling_stride=pooling_stride) 39 | prep.plan_preprocessing_raw_data(data_name) 40 | 41 | prep.preprocess_raw_data(raw_data=data_name, 42 | preprocessed_name=preprocessed_name_lowres) 43 | 44 | 45 | # %% now get hyper-parameters for low resolution and train 46 | patch_size = [32, 256, 256] 47 | n_2d_convs = 3 48 | use_prg_trn = False # on low resolution prg trn can harm the performance 49 | n_fg_classes = 1 50 | use_fp32 = False 51 | model_params = get_model_params_3d_UNet(patch_size=patch_size, 52 | n_2d_convs=n_2d_convs, 53 | use_prg_trn=use_prg_trn, 54 | n_fg_classes=n_fg_classes, 55 | fp32=use_fp32) 56 | 57 | # CHANGE YOUR HYPER-PARAMETERS FOR LOWRES STAGE HERE! 58 | 59 | model = SegmentationModelV2(val_fold=val_fold, 60 | data_name=data_name, 61 | model_name=model_name, 62 | preprocessed_name=preprocessed_name_lowres, 63 | model_parameters=model_params) 64 | # execute the trainig, simple as that! 65 | # It will check for previous checkpoints and load them 66 | model.training.train() 67 | 68 | if val_fold < model_params['data']['n_folds']: 69 | model.eval_validation_set() 70 | 71 | # %% now we have to wait until all models have finished their predicitons 72 | # we need the predictions from the low resolution before we can start the preprocessing 73 | # of the next stage 74 | wait = True 75 | while wait: 76 | num_epochs = model_params['training']['num_epochs'] 77 | not_finished_folds = [] 78 | for fold in range(model_params['data']['n_folds']): 79 | # path to training checkpoints 80 | path_to_attr = os.path.join(OV_DATA_BASE, 81 | 'trained_models', 82 | data_name, 83 | preprocessed_name_lowres, 84 | model_name, 85 | f'fold_{fold}', 86 | 'attribute_checkpoint.pkl') 87 | if not os.path.exists(path_to_attr): 88 | print(f'No checkpoint found for fold {fold}. Training not started?') 89 | not_finished_folds.append(fold) 90 | continue 91 | 92 | attr = load_pkl(path_to_attr) 93 | 94 | if attr['epochs_done'] < attr['num_epochs']: 95 | not_finished_folds.append(fold) 96 | 97 | if len(not_finished_folds) > 0: 98 | print(f'Waiting for folds {not_finished_folds}') 99 | sleep(60) 100 | else: 101 | wait = False 102 | 103 | # uncomment to evaluate ensemble e.g. of cross-validation models 104 | # ens = SegmentationEnsembleV2(val_fold=list(range(model_params['data']['n_folds'])), 105 | # model_name=model_name, 106 | # data_name=data_name, 107 | # preprocessed_name=preprocessed_name_lowres) 108 | # ens.eval_raw_dataset('MY_TEST_DATA') 109 | 110 | # %% preprocess fullres data if it hasn't been done yet 111 | if not os.path.exists(os.path.join(OV_PREPROCESSED, data_name, preprocessed_name_fullres)): 112 | 113 | prev_stage = {'data_name': data_name, 114 | 'preprocessed_name': preprocessed_name_lowres, 115 | 'model_name': model_name} 116 | 117 | # ADD SOME PREPROCESSING PARAMETERS HERE 118 | prep = SegmentationPreprocessingV2(apply_resizing=True, 119 | apply_pooling=False, 120 | apply_windowing=True, 121 | prev_stage_for_input=prev_stage) 122 | prep.plan_preprocessing_raw_data(data_name) 123 | 124 | prep.preprocess_raw_data(raw_data=data_name, 125 | preprocessed_name=preprocessed_name_fullres) 126 | 127 | # %% now get hyper-parameters for full resolution and train 128 | patch_size = [32, 256, 256] 129 | n_2d_convs = 3 130 | use_prg_trn = True 131 | n_fg_classes = 1 132 | use_fp32 = False 133 | out_shape = [[24, 128, 128], [24, 192, 192], [24, 256, 256], [32, 256, 256]] 134 | model_params = get_model_params_3d_UNet(patch_size=patch_size, 135 | n_2d_convs=n_2d_convs, 136 | use_prg_trn=use_prg_trn, 137 | n_fg_classes=n_fg_classes, 138 | fp32=use_fp32, 139 | out_shape=out_shape) 140 | 141 | # for the cascade we input the masks of the previous stage 142 | model_params['network']['in_channels'] = 1 + n_fg_classes 143 | model_params['data']['folders'] = ['images', 'labels', 'prev_preds'] 144 | model_params['data']['keys'] = ['image', 'label', 'prev_pred'] 145 | 146 | # CHANGE YOUR HYPER-PARAMETERS FOR LOWRES STAGE HERE! 147 | model = SegmentationModelV2(val_fold=val_fold, 148 | data_name=data_name, 149 | model_name=model_name, 150 | preprocessed_name=preprocessed_name_fullres, 151 | model_parameters=model_params) 152 | 153 | 154 | # execute the trainig, simple as that! 155 | # It will check for previous checkpoints and load them 156 | model.training.train() 157 | 158 | if val_fold < model_params['data']['n_folds']: 159 | model.eval_validation_set() 160 | 161 | # uncomment to evaluate ensemble e.g. of cross-validation models 162 | # ens = SegmentationEnsembleV2(val_fold=list(range(model_params['data']['n_folds'])), 163 | # model_name=model_name, 164 | # data_name=data_name, 165 | # preprocessed_name=preprocessed_name_fullres) 166 | # typically I train all folds on different GPUs in parallel, this let's you wait 167 | # until all trainings are done 168 | # ens.wait_until_all_folds_complete() 169 | # evaluate ensemble on test data 170 | # ens.eval_raw_dataset('MY_TEST_DATA') -------------------------------------------------------------------------------- /src/ovseg/augmentation/MaskAugmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ovseg.utils.torch_np_utils import check_type, stack 4 | try: 5 | from scipy.ndimage import morphology 6 | from skimage.measure import label 7 | except ImportError: 8 | print('Caught Import Error while importing some function from scipy or skimage. ' 9 | 'Please use a newer version of gcc.') 10 | 11 | 12 | TORCH_WARNING_PRINTED = False 13 | 14 | 15 | class MaskAugmentation(object): 16 | ''' 17 | MaksAugmentation(PARAMETERS!!) 18 | Performs the following augmentations: 19 | - morphological changes of segmentation masks 20 | - removing of small lesions 21 | 22 | Parameter: 23 | ---------------- 24 | p_xxx : 25 | - probability with which xxx is applied to the image 26 | xxx_mm : 27 | - min and max of the uniform distribution which is used to draw the 28 | parameters for xxx 29 | vol_percentage_removal/vol_threshold_removal: 30 | if vol_threshold_removal and spacing is given the leions removal 31 | threshold is computed in real world units, else the threshold is 32 | computed as the percentage of the patch size 33 | ''' 34 | 35 | def __init__(self, spacing=None, p_morph=0.4, radius_mm=[1, 8], p_removal=0.2, 36 | vol_percentage_removal=0.15, vol_threshold_removal=None, 37 | threeD_morph_ops=False, aug_channels=[1]): 38 | 39 | # morphological operations 40 | self.p_morph = p_morph 41 | self.radius_mm = radius_mm 42 | self.threeD_morph_ops = threeD_morph_ops 43 | 44 | # removal of small components 45 | self.p_removal = p_removal 46 | self.vol_threshold_removal = vol_threshold_removal 47 | self.vol_percentage_removal = vol_percentage_removal 48 | self.spacing = spacing 49 | 50 | # determins which channels are being augmented 51 | self.aug_channels = aug_channels 52 | 53 | self.morph_operations = [morphology.binary_closing, 54 | morphology.binary_dilation, 55 | morphology.binary_opening, 56 | morphology.binary_erosion] 57 | 58 | if spacing is not None: 59 | self.spacing = np.array(spacing) 60 | 61 | def _morphological_augmentation(self, img): 62 | 63 | # should be 2 or 3 64 | img_dim = len(img.shape) 65 | assert img_dim in [2, 3] 66 | 67 | if img_dim == 3: 68 | spacing = self.spacing if self.spacing is not None else \ 69 | np.mean(img.shape) / np.array(img.shape) 70 | if img.shape[0] * 2 < np.min(img.shape[1:]): 71 | spacing = spacing[1:] 72 | else: 73 | spacing = self.spacing[1:] if self.spacing is not None else np.array([1, 1]) 74 | 75 | classes = list(range(1, int(img.max())+1)) 76 | # turn integer in one hot encoding boolean 77 | img_one_hot = np.stack([img == c for c in classes]) 78 | 79 | # perform the operations on a random order of the classes 80 | np.random.shuffle(classes) 81 | 82 | # radius in mm, e.g. real world units 83 | r_mm = np.random.uniform(self.radius_mm[0], self.radius_mm[1]) 84 | # radius in amount of pixel 85 | r_pixel = (r_mm / spacing).astype(int) 86 | 87 | # zero centered axes in mm 88 | axes = [np.linspace(-1 * sp * rp, sp * rp, 2 * rp + 1) for sp, rp in zip(spacing, r_pixel)] 89 | grid = np.stack(np.meshgrid(*axes, indexing='ij')) 90 | 91 | # the structure is a L2 ball with radius r_mm 92 | structure = np.sum(grid**2, 0) < r_mm**2 93 | if len(spacing) == 2 and img_dim == 3: 94 | structure = structure[np.newaxis] 95 | 96 | # binary operation 97 | operation = np.random.choice(self.morph_operations) 98 | 99 | for class_idx in classes: 100 | # change only this one class 101 | # we have the -1 since the background is not in the one hot vector 102 | class_aug = operation(img_one_hot[class_idx - 1], structure) 103 | img_one_hot[class_idx - 1] = class_aug 104 | # and for all other classes we remove the fg in this region 105 | # in case we get intersections from 106 | for other_class_idx in classes: 107 | if other_class_idx != class_idx: 108 | img_one_hot[other_class_idx - 1][class_aug] = False 109 | 110 | # now we add the background channel again 111 | img_one_hot = np.concatenate([np.zeros((1, *img.shape)), img_one_hot]) 112 | 113 | # from one hot back to interget encoding. 114 | return np.argmax(img_one_hot, 0) 115 | 116 | def _removal_augmentation(self, img): 117 | 118 | img_dim = len(img.shape) 119 | mask = np.ones_like(img) 120 | 121 | components = label(img > 0) 122 | n_components = components.max() 123 | 124 | if self.vol_threshold_removal is not None and self.spacing is not None: 125 | vol_threshold = self.vol_threshold_removal \ 126 | / np.prod(self.spacing[:img_dim]) 127 | else: 128 | vol_threshold = self.vol_percentage_removal * np.prod(img.shape) 129 | 130 | for c in range(1, n_components + 1): 131 | comp = components == c 132 | if np.sum(comp) < vol_threshold: 133 | mask[comp] = 0 134 | 135 | return img * mask 136 | 137 | def augment_image(self, img): 138 | ''' 139 | augment_img(img) 140 | (nx, ny) 141 | ''' 142 | global TORCH_WARNING_PRINTED 143 | 144 | is_np, _ = check_type(img) 145 | if not is_np: 146 | img = img.cpu().numpy() 147 | if not TORCH_WARNING_PRINTED: 148 | print('Warning: Maks augmentations can only be done in ' 149 | 'numpy. Still got a torch tensor as input. Transferring ' 150 | ' it to the CPU, this kills gradients and might be ' 151 | 'slow.\n') 152 | 153 | if img.max() == 0: 154 | # no foreground nothing to do! 155 | return img 156 | # first collect what we want to do 157 | self.do_morph = np.random.rand() < self.p_morph 158 | self.do_removal = np.random.rand() < self.p_removal 159 | 160 | # Let's-a go! 161 | if self.do_morph: 162 | img = self._morphological_augmentation(img) 163 | if self.do_removal: 164 | img = self._removal_augmentation(img) 165 | 166 | if not is_np: 167 | img = torch.from_numpy(img).cuda() 168 | 169 | return img 170 | 171 | def augment_sample(self, sample): 172 | ''' 173 | augment_sample(sample) 174 | augments only the first image of the sample as we assume single channel 175 | images like CT 176 | ''' 177 | for c in self.aug_channels: 178 | sample[c] = self.augment_image(sample[c]) 179 | return sample 180 | 181 | def augment_batch(self, batch): 182 | ''' 183 | augment_batch(batch) 184 | augments every sample of the batch, in each sample only the image in 185 | the first channel will be augmented as we assume single channel images 186 | like CT 187 | ''' 188 | return stack([self.augment_sample(batch[i]) 189 | for i in range(len(batch))]) 190 | 191 | def __call__(self, batch): 192 | return self.augment_batch(batch) 193 | 194 | def augment_volume(self, volume, is_inverse: bool = False): 195 | if not is_inverse: 196 | if len(volume.shape) == 3: 197 | volume = self.augment_image(volume) 198 | else: 199 | volume = self.augment_sample(volume) 200 | return volume 201 | 202 | def update_prg_trn(self, param_dict, h, indx=None): 203 | 204 | for attr in ['p_morph', 'p_removal', 'radius_mm', 'vol_percentage_removal']: 205 | if attr in param_dict: 206 | self.__setattr__(attr, (1 - h) * param_dict[attr][0] + h * param_dict[attr][1]) 207 | -------------------------------------------------------------------------------- /src/ovseg/training/loss_functions_combined.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from ovseg.training.loss_functions import cross_entropy, dice_loss 6 | from ovseg.training.loss_functions import __dict__ as loss_functions_dict 7 | 8 | def to_one_hot_encoding(yb, n_ch): 9 | 10 | yb = yb.long() 11 | yb_oh = torch.cat([(yb == c) for c in range(n_ch)], 1).float() 12 | return yb_oh 13 | 14 | 15 | class CE_dice_loss(nn.Module): 16 | # weighted sum of the two losses 17 | # this functions is just here for historic reason 18 | def __init__(self, eps=1e-5, ce_weight=1.0, dice_weight=1.0): 19 | super().__init__() 20 | self.ce_loss = cross_entropy() 21 | self.dice_loss = dice_loss(eps) 22 | self.dice_weight = dice_weight 23 | self.ce_weight = ce_weight 24 | 25 | def forward(self, logs, yb, mask=None): 26 | if yb.shape[1] == 1: 27 | # turn yb to one hot encoding 28 | yb = to_one_hot_encoding(yb, logs.shape[1]) 29 | ce = self.ce_loss(logs, yb, mask) * self.ce_weight 30 | dice = self.dice_loss(logs, yb, mask) * self.dice_weight 31 | loss = ce + dice 32 | return loss 33 | 34 | class weighted_combined_loss(nn.Module): 35 | # arbritray loss functions weighted and summed up 36 | def __init__(self, loss_names, loss_weights=None, loss_kwargs=None): 37 | super().__init__() 38 | self.loss_names = loss_names 39 | # if no weights are given the losses are just summed without any weight 40 | self.loss_weights = loss_weights if loss_weights is not None else [1] * len(self.loss_names) 41 | # if no kwargs are given we just use blanks 42 | 43 | if loss_kwargs is None: 44 | 45 | self.loss_kwargs = [{}] * len(self.loss_names) 46 | 47 | elif len(loss_names) == 1 and isinstance(loss_kwargs, dict): 48 | self.loss_kwargs = [loss_kwargs] 49 | else: 50 | self.loss_kwargs = loss_kwargs 51 | 52 | 53 | assert len(loss_names) > 0, 'no names for losses given.' 54 | assert len(loss_names) == len(self.loss_weights), 'Got different amount of loss names and weights' 55 | assert len(loss_names) == len(self.loss_kwargs), 'Got different amount of loss names and kwargs' 56 | 57 | 58 | self.losses = [] 59 | for name, kwargs in zip(self.loss_names, self.loss_kwargs): 60 | if name not in loss_functions_dict: 61 | losses_found = [key for key in loss_functions_dict 62 | if not key.startswith('_') and key not in ['torch', 'nn', 'np']] 63 | raise ValueError('Name {} for a loss functions was not found in loss_functions.py. ' 64 | ' Got the modules {}'.format(name, losses_found)) 65 | self.losses.append(loss_functions_dict[name](**kwargs)) 66 | 67 | self.losses = nn.ModuleList(self.losses) 68 | 69 | def forward(self, logs, yb, mask=None): 70 | if yb.shape[1] == 1: 71 | # turn yb to one hot encoding 72 | yb = to_one_hot_encoding(yb, logs.shape[1]) 73 | 74 | l = self.losses[0](logs, yb, mask) * self.loss_weights[0] 75 | for loss, weight in zip(self.losses[1:], self.loss_weights[1:]): 76 | l += loss(logs, yb, mask) * weight 77 | 78 | return l 79 | 80 | def downsample_yb(logs_list, yb): 81 | 82 | # get pytorch 2d or 3d adaptive max pooling function 83 | f = F.adaptive_max_pool3d if len(yb.shape) == 5 else F.adaptive_max_pool2d 84 | 85 | # target downsampled to same size as logits 86 | return [f(yb, logs.shape[2:]) for logs in logs_list] 87 | 88 | 89 | def downsample_yb_old(logs_list, yb): 90 | # NOT IN USAGE ANYMORE 91 | # ugly implementation of maxpooling, replaced by function 'downsample_yb' 92 | 93 | # this function downsamples the target (or masks) to the same shapes as the outputs 94 | # from the different resolutions of the decoder path of the U-Net. 95 | yb_list = [yb] 96 | is_3d = len(yb.shape) == 5 97 | for logs in logs_list[1:]: 98 | if is_3d: 99 | # maybe downsample in first spatial direction 100 | if logs.shape[2] == yb.shape[2] // 2: 101 | yb = torch.maximum(yb[:, :, ::2], yb[:, :, 1::2]) 102 | elif not logs.shape[2] == yb.shape[2]: 103 | raise ValueError('shapes of logs and labels aren\'t machting for ' 104 | 'downsampling. got {} and {}' 105 | .format(logs.shape, yb.shape)) 106 | # maybe downsample in second spatial direction 107 | if logs.shape[3] == yb.shape[3] // 2: 108 | yb = torch.maximum(yb[:, :, :, ::2], yb[:, :, :, 1::2]) 109 | elif not logs.shape[3] == yb.shape[3]: 110 | raise ValueError('shapes of logs and labels aren\'t machting for ' 111 | 'downsampling. got {} and {}' 112 | .format(logs.shape, yb.shape)) 113 | # maybe downsample in third direction 114 | if logs.shape[4] == yb.shape[4] // 2: 115 | yb = torch.maximum(yb[:, :, :, :, ::2], yb[:, :, :, :, 1::2]) 116 | elif not logs.shape[4] == yb.shape[4]: 117 | raise ValueError('shapes of logs and labels aren\'t machting for ' 118 | 'downsampling. got {} and {}' 119 | .format(logs.shape, yb.shape)) 120 | else: 121 | # maybe downsample in first spatial direction 122 | if logs.shape[2] == yb.shape[2] // 2: 123 | yb = yb[:, :, ::2] + yb[:, :, 1::2] 124 | elif not logs.shape[2] == yb.shape[2]: 125 | raise ValueError('shapes of logs and labels aren\'t machting for ' 126 | 'downsampling. got {} and {}' 127 | .format(logs.shape, yb.shape)) 128 | # maybe downsample in second spatial direction 129 | if logs.shape[3] == yb.shape[3] // 2: 130 | yb = yb[:, :, :, ::2] + yb[:, :, :, 1::2] 131 | elif not logs.shape[3] == yb.shape[3]: 132 | raise ValueError('shapes of logs and labels aren\'t machting for ' 133 | 'downsampling. got {} and {}' 134 | .format(logs.shape, yb.shape)) 135 | # now append 136 | yb_list.append(yb) 137 | return yb_list 138 | 139 | 140 | class CE_dice_pyramid_loss(nn.Module): 141 | 142 | def __init__(self, eps=1e-5, ce_weight=1.0, dice_weight=1.0, 143 | pyramid_weight=0.5): 144 | super().__init__() 145 | self.ce_dice_loss = CE_dice_loss(eps, ce_weight, dice_weight) 146 | self.pyramid_weight = pyramid_weight 147 | 148 | def forward(self, logs_list, yb, mask=None): 149 | if yb.shape[1] == 1: 150 | yb = to_one_hot_encoding(yb, logs_list[0].shape[1]) 151 | # compute the weights to be powers of pyramid_weight 152 | scale_weights = self.pyramid_weight ** np.arange(len(logs_list)) 153 | # let them sum to one 154 | scale_weights = scale_weights / np.sum(scale_weights) 155 | # turn labels into one hot encoding and downsample to same resolutions 156 | # as the logits 157 | yb_list = downsample_yb(logs_list, yb) 158 | if torch.is_tensor(mask): 159 | mask_list = downsample_yb(logs_list, mask) 160 | else: 161 | mask_list = [None] * len(yb_list) 162 | 163 | # now let's compute the loss for each scale 164 | loss = 0 165 | for logs, yb, m, w in zip(logs_list, yb_list, mask_list, scale_weights): 166 | loss += w * self.ce_dice_loss(logs, yb, m) 167 | 168 | return loss 169 | 170 | class weighted_combined_pyramid_loss(nn.Module): 171 | 172 | def __init__(self, loss_names, loss_weights=None, loss_kwargs=None, pyramid_weight=0.5): 173 | super().__init__() 174 | self.loss = weighted_combined_loss(loss_names, loss_weights, loss_kwargs) 175 | self.pyramid_weight = pyramid_weight 176 | 177 | 178 | def forward(self, logs_list, yb, mask=None): 179 | if yb.shape[1] == 1: 180 | yb = to_one_hot_encoding(yb, logs_list[0].shape[1]) 181 | # compute the weights to be powers of pyramid_weight 182 | scale_weights = self.pyramid_weight ** np.arange(len(logs_list)) 183 | # let them sum to one 184 | scale_weights = scale_weights / np.sum(scale_weights) 185 | # turn labels into one hot encoding and downsample to same resolutions 186 | # as the logits 187 | yb_list = downsample_yb(logs_list, yb) 188 | if torch.is_tensor(mask): 189 | mask_list = downsample_yb(logs_list, mask) 190 | else: 191 | mask_list = [None] * len(yb_list) 192 | 193 | # now let's compute the loss for each scale 194 | loss = 0 195 | for logs, yb, m, w in zip(logs_list, yb_list, mask_list, scale_weights): 196 | loss += w * self.loss(logs, yb, m) 197 | return loss -------------------------------------------------------------------------------- /src/ovseg/model/ClaraWrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os 5 | from ovseg.utils.io import load_pkl 6 | from ovseg.networks.resUNet import UNetResEncoder 7 | from ovseg.prediction.SlidingWindowPrediction import SlidingWindowPrediction 8 | 9 | FLIP_AND_ROTATE_IMAGE = True 10 | 11 | 12 | def ClaraWrapperOvarian(data_tpl, 13 | models, 14 | path_to_clara_models='/aiaa_workspace/aiaa-1/lib/ovseg_zxy/clara_models'): 15 | ''' 16 | General wrapper for HGSOC segmentation. 17 | Can run the segmentation for different and multiple locations 18 | 19 | Parameters 20 | ---------- 21 | data_tpl : dict 22 | contains 'image', 3D or 4D np array z first, and 'spacing' of len 3 23 | models : str or list of strings 24 | name of the folder in which the model parameters and weights are stored 25 | path_to_clara_models : str, optional 26 | location of the models on the server. The default is '/aiaa_workspace/aiaa-1/lib/ovseg_zxy/clara_models'. 27 | 28 | Returns 29 | ------- 30 | TYPE 31 | DESCRIPTION. 32 | 33 | ''' 34 | 35 | if isinstance(models, str): 36 | 37 | models = [models] 38 | 39 | pred = evaluate_segmentation_ensemble(data_tpl, 40 | models[0], 41 | path_to_clara_models) 42 | 43 | torch.cuda.empty_cache() 44 | 45 | for model in models[1:]: 46 | 47 | arr = evaluate_segmentation_ensemble(data_tpl, 48 | model, 49 | path_to_clara_models) 50 | 51 | # fill in new prediction and overwrite previous one 52 | pred = pred * (arr == 0).type(torch.float) + arr 53 | torch.cuda.empty_cache() 54 | 55 | # back to numpy 56 | pred = pred.cpu().numpy() 57 | 58 | # put z axis back 59 | # numpy does this faster than torch 60 | return np.moveaxis(pred, 0, -1) 61 | 62 | # %% 63 | def preprocess_dynamic_z_spacing(data_tpl, 64 | prep_params): 65 | ''' 66 | This function implements the dynamic z resizing used during inference. 67 | When given an image with a low slice thickness, it is resized to multiple 68 | images with non overlapping slices of target spacing to provide 69 | high resolution predictions. 70 | 71 | Parameters 72 | ---------- 73 | data_tpl : dict 74 | Need to contain 'image', 3D or 4D np.ndarray and 'spacing' of len 3. 75 | prep_params : dict 76 | preprocessing parameters as used for the network training 77 | 78 | Returns 79 | ------- 80 | im_list : list 81 | contains 5D torch cuda tensors of all images 82 | ''' 83 | 84 | print('*** PREPROCESSING ***') 85 | im = data_tpl['image'] 86 | if len(im.shape) == 3: 87 | im = im[np.newaxis] 88 | 89 | rif = data_tpl['raw_image_file'] 90 | is_nifti = rif.endswith('.nii.gz') or rif.endswith('.nii') 91 | 92 | if FLIP_AND_ROTATE_IMAGE and is_nifti: 93 | # this corrects for differences in how dcms are read in ov_seg 94 | # and how clara creates nifti files from dcms 95 | im = np.rot90(im[:, ::-1, :, ::-1], -1, (1,2)) 96 | 97 | # now the image should be 5d 98 | im = torch.from_numpy(im.copy()).type(torch.float).unsqueeze(0).cuda() 99 | 100 | # %% resizing, the funny part 101 | z_sp = data_tpl['spacing'][0] 102 | 103 | target_spacing = prep_params['target_spacing'] 104 | target_z_spacing = target_spacing[0] 105 | 106 | # %% dynamic z spacing 107 | n_ims = int(np.max([np.floor(target_z_spacing/z_sp), 1])) 108 | print(f'Creating {n_ims} images with z spacing {target_z_spacing}') 109 | dynamic_z_spacing = target_z_spacing / n_ims 110 | 111 | scale_factor = [data_tpl['spacing'][0] / dynamic_z_spacing, 112 | data_tpl['spacing'][1] / target_spacing[1], 113 | data_tpl['spacing'][2] / target_spacing[2]] 114 | 115 | # resizing 116 | im = F.interpolate(im, 117 | scale_factor=scale_factor, 118 | mode='trilinear') 119 | 120 | # apply windowing 121 | if prep_params['apply_windowing']: 122 | im = im.clamp(*prep_params['window']) 123 | 124 | # now rescaling 125 | scaling = prep_params['scaling'] 126 | im = (im - scaling[1]) / scaling[0] 127 | 128 | # split images 129 | im_list = [im[:, :, i::n_ims] for i in range(n_ims)] 130 | 131 | # %% finally pooling 132 | if prep_params['apply_pooling']: 133 | stride = prep_params['pooling_stride'] 134 | im_list = [F.avg_pool3d(im, kernel_size=stride, stride=stride) for im in im_list] 135 | 136 | # remove batch dimension 137 | im_list = [im[0] for im in im_list] 138 | 139 | return im_list 140 | 141 | # %% 142 | def evaluate_segmentation_ensemble(data_tpl, 143 | model, 144 | path_to_clara_models='/aiaa_workspace/aiaa-1/lib/ovseg_zxy/clara_models'): 145 | 146 | print(f'*** EVALUATING {model} ***') 147 | # At this path the model parameters and networks weights should be 148 | # stored 149 | path_to_model = os.path.join(path_to_clara_models, model) 150 | # Read model parameters 151 | path_to_model_params = os.path.join(path_to_model, 'model_parameters.pkl') 152 | model_params = load_pkl(path_to_model_params) 153 | 154 | im_list = preprocess_dynamic_z_spacing(data_tpl, 155 | model_params['preprocessing']) 156 | 157 | # dimensions of target tensor 158 | nz = np.sum([im.shape[1] for im in im_list]) 159 | nx, ny = im_list[0].shape[2], im_list[0].shape[3] 160 | 161 | print('*** RUNNING THE MODEL ***') 162 | print('the fun starts...') 163 | 164 | # this needs updating to allow general architecture 165 | if not model_params['architecture'] == 'unetresencoder': 166 | raise NotImplementedError('Only implemented for ResEncoder so far...') 167 | 168 | network = UNetResEncoder(**model_params['network']).cuda() 169 | 170 | n_ch = model_params['network']['out_channels'] 171 | 172 | # %% Sliding window prediction time! 173 | prediction = SlidingWindowPrediction(network=network, 174 | **model_params['prediction']) 175 | 176 | # collect all weights we got from the ensemble 177 | weight_files = [os.path.join(path_to_model, file) for file in os.listdir(path_to_model) 178 | if file.startswith('network_weights')] 179 | 180 | # list of predictions from each weight 181 | pred_list = [] 182 | # iterate over all weights used in the ensemble 183 | for j, weight_file in enumerate(weight_files): 184 | print(f'Evaluate network {j+1} out of {len(weight_files)}') 185 | # load weights 186 | prediction.network.load_state_dict(torch.load(weight_file, 187 | map_location=torch.device('cuda'))) 188 | 189 | # full tensor of softmax outputs 190 | # pred = torch.zeros((n_ch, nz, nx, ny), device='cuda', dtype=torch.float) 191 | # we're using numpy arrays here to prevent OOM errors 192 | pred = np.zeros((n_ch, nz, nx, ny), dtype=np.float32) 193 | 194 | # for each image in the list, evaluate sliding window and fill in 195 | for i, im in enumerate(im_list): 196 | pred[:, i::len(im_list)] = prediction(im).detach().cpu().numpy() 197 | 198 | pred_list.append(pred) 199 | 200 | # this solution is ugly, but otherwise there might be OOM errors 201 | pred = np.stack(pred_list).mean(0) 202 | pred = torch.from_numpy(pred).cuda() 203 | torch.cuda.empty_cache() 204 | # %% we do the postprocessing manually here to save some moving to the 205 | # GPU back and fourth 206 | print('*** POSTPROCESSING ***') 207 | if 'postprocessing' in model_params: 208 | 209 | print('WARNING: Only resizing and argmax is performed here') 210 | 211 | # first trilinear resizing 212 | size = [int(s) for s in data_tpl['image'].shape[-3:]] 213 | 214 | try: 215 | pred = F.interpolate(pred.unsqueeze(0), 216 | size=size, 217 | mode='trilinear')[0] 218 | except RuntimeError: 219 | print('Went out of memory. Resizing again on the CPU, but this can be slow...') 220 | 221 | pred = F.interpolate(pred.unsqueeze(0).cpu(), 222 | size=size, 223 | mode='trilinear')[0] 224 | 225 | 226 | # now applying argmax 227 | pred = torch.argmax(pred, 0).type(torch.float) 228 | 229 | # now convert labels back to their orig. classes 230 | pred_lb = torch.zeros_like(pred) 231 | for i, lb in enumerate(model_params['preprocessing']['lb_classes']): 232 | # this should be the fastest way on the GPU to get the job done 233 | pred_lb = pred_lb + lb * (pred == i+1).type(torch.float) 234 | 235 | return pred_lb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A deep learning based library for segmentation of high grade serous ovarian cancer on CT images. 2 | The library contains the code and the finals models created during my [PhD Thesis](https://doi.org/10.17863/CAM.87940) some of which have been described in our [paper](https://eurradiolexp.springeropen.com/articles/10.1186/s41747-023-00388-z). 3 | 4 | While the code was mostly used for ovarian cancer segmentation, the library is general purpose and can be used to train other kind of segmentation models. In case you have questions you can send a mail to me: thomasbuddenkotte@googlemail.com 5 | 6 | The code design is in some ways simliar to the nnU-Net library (https://github.com/MIC-DKFZ/nnUNet). Many thanks to the authors for sharing their code and letting me learn from it. 7 | 8 | # Installation 9 | 10 | Before you install the library make sure that your machine has a CUDA compatible GPU. 11 | To install ovseg simply clone the repo and install via pip: 12 | 13 | ``` 14 | git clone https://github.com/ThomasBudd/ovseg 15 | cd ovseg 16 | pip install . 17 | ``` 18 | 19 | # Inference 20 | 21 | We've recently updated the library to make the inference usage of it more convenient. To run inference you need to provide the data in the nifti format (as either .nii or .nii.gz file). We recommend to check your nifti images with the [ITK-SNAP](http://www.itksnap.org/pmwiki/pmwiki.php) viewer to see if the images have the correct orientation. The results will contain errors in case the axial, sagittal or coronal views are swapped. The correct orientation looks like this: 22 | 23 | ![plot](./nifti_in_itk_snap_example.png) 24 | 25 | In case your image data is stored as dicom files, you can use a tool like [this](https://github.com/icometrix/dicom2nifti) one for conversion and check the results with ITK-SNAP as described above. 26 | 27 | Now you can call the inference from the terminal/command line by providing the path to the folder in which all nifti images are stored: 28 | 29 | > ovseg_inference path_to_data 30 | 31 | If you prefer write python code directly you can also do this with a small script: 32 | 33 | ``` 34 | from ovseg.run.run_inference import run_inference 35 | run_inference("\PATH\TO\DATA") 36 | ``` 37 | 38 | In case you want to run the inference only on a single image you can also provide the path to that image instead. By default the code will run the inference for the model segmenting the pelvic/ovarian and omental disease. If you want to use other models specific them in the command above with the --models specifier on the command line or the "models=[...]" argument of the python function. Available models are 39 | 40 | - pod_om: model for main disease sites in the pelvis/ovaries and the omentum. The two sites are encoded as 9 and 1 in the predictions. 41 | - abdominal_lesions: model for various lesions between the pelvis and diaphram. The model considers lesions in the omentum (1), right upper quadrant (2), left upper quadrant (3), mesenterium (5), left paracolic gutter (6) and right paracolic gutter (7). 42 | - lymph_nodes: segments disease in the lymph nodes namely infrarenal lymph nodes (13), suprarenal lymph nodes (14), supradiaphragmatic lymph nodes (15) and inguinal 43 | lymph nodes (17). 44 | 45 | you can also combine the models to run two or all three at the same time. Lastly, you have the option to run the inference in fast mode using the --fast option of the command line tool or the argument "fast=True". This drastically decreases the computation time at the cost of accuracy. This option was mainly build to try and run the code on a local machine that does not have a GPU. On my laptop this took ~13 mins per image. The results will be stored as nifti files, you can use ITK-SNAP to take a look at them. 46 | 47 | The code will download the network weights automatically and store them in the clone git repository. In case you want to store them somewhere else you have to set up an environment variable called OV_DATA_BASE and specify the desired location. 48 | 49 | # Training 50 | 51 | Before you can run training we recommend to set up an environment variable called OV_DATA_BASE. 52 | All predictions, (pre-)trained models, raw data, etc. will be stored in this location. By default the code will set it to the folder of the cloned git repository of this code. 53 | If you're planning to run training on a multi-server system it is advised to set up the OV_DATA_BASE at a central location all servers can access (see run training). 54 | 55 | # Data management 56 | 57 | To run inference or training you first need to store the datasets in a particular way to make it accessible for the library. All datasets should be stored at $OV_DATA_BASE\raw_data and should be given a unique name. Currently the library supports datasets in which images (and segmentations) are stored as nifti or dicom files. In this current version only single channel images were tested. 58 | 59 | If you're using **nifti files** create a folder called 'images' in $OV_DATA_BASE\raw_data\DATASET_NAME and simply put all images in there. In the case of training create a second folder called 'labels' with the corresponding segmentations. The segmentation files should have the same names as the image files or follow the Medical Decathlon naming convention (image: case_xyz_0000.nii.gz, seg: case_xyz.nii.gz). For example 60 | 61 | OV_DATA_BASE/raw_data/DATASET_NAME/ 62 | ├── images 63 | │ ├── case_001_0000.nii.gz 64 | │ ├── case_002_0000.nii.gz 65 | │ ├── case_003_0000.nii.gz 66 | │ ├── ... 67 | ├── labels 68 | │ ├── case_001.nii.gz 69 | | |── case_002.nii.gz 70 | │ ├── case_003.nii.gz 71 | │ ├── ... 72 | 73 | 74 | For **dicom images** any type of folder structure is allowed. Make sure that only axial reconstructions are contained in your dataset, the code won't remove other types of reconstructions such as topograms or sagital slices by itself. The code also assumes that all dicoms found in one folder belong to the same reconstruction, make sure that each reconstruction is contained in a seperate folder. If you're performing training, include the segmentations as dicomrt files. Each folder with reconstruction dicoms should have exactly one additional dicomrt file with the corresponding segmentation. Missing segmentations are interpreted as empty segmentations masks (only backgorund). 75 | 76 | Examples are 77 | 78 | OV_DATA_BASE/raw_data/DATASET_NAME/ 79 | ├── patient1 80 | │ ├── segmentation.dcm 81 | │ ├── slice1.dcm 82 | │ ├── slice2.dcm 83 | │ ├── slice3.dcm 84 | │ ├── ... 85 | ├── patient2 86 | │ ├── ... 87 | ├── patient3 88 | │ ├── ... 89 | ├── ... 90 | 91 | 92 | Or 93 | 94 | OV_DATA_BASE/raw_data/DATASET_NAME/ 95 | ├── patient1 96 | │ ├── timepoint1 97 | | │ ├── segmentation.dcm 98 | | │ ├── slice1.dcm 99 | │ | ├── slice2.dcm 100 | │ | ├── slice3.dcm 101 | │ | ├── ... 102 | │ ├── timepoint2 103 | │ | ├── ... 104 | ├── patient2 105 | │ ├── ... 106 | ├── patient3 107 | │ ├── ... 108 | ├── ... 109 | 110 | Or a mixture of the above. Note that it is not necessary to rename your dcm files to "segmentation.dcm" or "sliceX.dcm", the library will recognise it automatically. 111 | 112 | # Rerun ovarian cancer segmentation training 113 | 114 | Repeating ovarian cancer segmentation can be done via command line without changing any pthon code. Before the training can be started the raw data has to be preprocessed and stored. If you're running the training on a multi-sever system it is advised to place the OV_DATA_BASE in a central storage. However, this is not a good place for preprocessed data. The preprocessed data should be kept on a fast local disk to ensure that loading times do not become a bottleneck of the training. In this case create a second environment variable called OV_PREPROCESSED that is located on such fast local disk. If this variable is not created, the preprocessed data will be simply stored at $OV_DATA_BASE/preprocessed. 115 | 116 | To perform preprocessing call the script 'preprocess_ovaraian_data.py' with the name of all datasets you want to use for training as arguments. For example 117 | > python preprocess_ovarian_data.py DATANAME1 DATANAME2 118 | 119 | Next the training can be started by running 'run_training.py'. The first input needed is the number of the validation fold. By default the library will split the preprocessed data using a fivefold cross-validation scheme. For an input 0,1,...,4 the training will be launched using 80% of the available data for training and 20% for validation. For inputs 5,6,... the training will use 100% of the preprocessed data for training. The type of model trained is specified via the --model input. The models have the same naming as in inference (pod_om, abdominal_lesions, lymph_nodes). The training datasets used are specified via the --trn_data input. 120 | 121 | For example, training on 100% of the data (no validation) the model for the main two disease sites on datasets called DATANAME1 and DATANAME2 run 122 | > python run_training.py 5 --model pod_om --trn_data DATANAME1 DATANAME2 123 | 124 | # Running training for new segmentaiton problems 125 | 126 | One advantage of ovseg is that it is very simple to run training, inference and modify hyper-parameters for you own data. 127 | For this it is necessary to write you own preprocessing and training scripts such as the previously mentioned 'preprocess_ovaraian_data.py' and 'run_training.py'. 128 | Check out the example_scripts folder for helpful code templates. The templates demonstrates how to train models e.g. for the kits21 dataset including cascade models (full resolution models refining low resolution models), or deep supervision models (first segmenting the kidney, then searching tumors only inside the kidney). 129 | 130 | Make sure to give a unique model_name to each model once you change hyper-parameters. The trained models can be found at OV_DATA_BASE/trained_models/data_name/preprocessed_name/model_name. 131 | This folder will contain information on the hyper-parameters used, network weights (for each fold trained), training checkpoints and result files on the model evaluation on validation or test data. 132 | Similarly, the predictions can be found at OV_DATA_BASE/predictions/data_name/preprocessed_name/model_name. In case your raw_data was given in DICOM format, the predictions will be stored as nifti and DICOM files. If the raw data was given in nifti format, the predictions will be stored only in nifti format. 133 | 134 | For more explanation on how the library works and details on the model hyper-parameters, please see the manual (ovseg_manual.pdf). Please note that a previous version of the library required a different more custom orientation of the nifti images. This has now changed and all nifti images should have the orientation described in the inference section. 135 | -------------------------------------------------------------------------------- /src/ovseg/prediction/SlidingWindowPrediction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | from scipy.ndimage.filters import gaussian_filter 5 | from ovseg.utils.torch_np_utils import check_type, maybe_add_channel_dim 6 | 7 | 8 | class SlidingWindowPrediction(object): 9 | 10 | def __init__(self, network, patch_size, batch_size=1, overlap=0.5, fp32=False, 11 | patch_weight_type='gaussian', sigma_gaussian_weight=1/8, linear_min=0.1, 12 | mode='flip'): 13 | 14 | self.dev = 'cuda' if torch.cuda.is_available() else 'cpu' 15 | self.network = network.to(self.dev) 16 | self.batch_size = batch_size 17 | self.overlap = overlap 18 | self.fp32 = fp32 19 | self.patch_weight_type = patch_weight_type 20 | self.sigma_gaussian_weight = sigma_gaussian_weight 21 | self.linear_min = linear_min 22 | self.mode = mode 23 | 24 | assert self.patch_weight_type.lower() in ['constant', 'gaussian', 'linear'] 25 | assert self.mode.lower() in ['simple', 'flip'] 26 | 27 | self._set_patch_size_and_weight(patch_size) 28 | 29 | 30 | def _set_patch_size_and_weight(self, patch_size): 31 | # check and build up the patch weight 32 | # we can use a gaussian weighting since the predictions on the edge of the patch are less 33 | # reliable than the ones in the middle 34 | self.patch_size = np.array(patch_size).astype(int) 35 | if self.patch_weight_type.lower() == 'constant': 36 | self.patch_weight = np.ones(self.patch_size) 37 | elif self.patch_weight_type.lower() == 'gaussian': 38 | # we distrust the edge voxel the same in each direction regardless of the 39 | # patch size in that dimension 40 | 41 | # thanks to Fabian Isensee! I took this from his code: 42 | # https://github.com/MIC-DKFZ/nnUNet/blob/14992342919e63e4916c038b6dc2b050e2c62e3c/nnunet/network_architecture/neural_network.py#L250 43 | tmp = np.zeros(self.patch_size) 44 | center_coords = [i // 2 for i in self.patch_size] 45 | sigmas = [i * self.sigma_gaussian_weight for i in self.patch_size] 46 | tmp[tuple(center_coords)] = 1 47 | self.patch_weight = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) 48 | self.patch_weight = self.patch_weight / np.max(self.patch_weight) * 1 49 | self.patch_weight = self.patch_weight.astype(np.float32) 50 | 51 | # self.patch_weight cannot be 0, otherwise we may end up with nans! 52 | self.patch_weight[self.patch_weight == 0] = np.min( 53 | self.patch_weight[self.patch_weight != 0]) 54 | 55 | elif self.patch_weight_type.lower() == 'linear': 56 | lin_slopes = [np.linspace(self.linear_min, 1, s//2) for s in self.patch_size] 57 | hats = [np.concatenate([lin_slope, lin_slope[::-1]]) for lin_slope in lin_slopes] 58 | hats = [np.expand_dims(hat, [j for j in range(len(self.patch_size)) if j != i]) 59 | for i, hat in enumerate(hats)] 60 | 61 | self.patch_weight = np.ones(self.patch_size) 62 | for hat in hats: 63 | self.patch_weight *= hat 64 | 65 | self.patch_weight = self.patch_weight[np.newaxis] 66 | self.patch_weight = torch.from_numpy(self.patch_weight).to(self.dev).type(torch.float) 67 | 68 | # add an axis to the patch size and set is_2d 69 | if len(self.patch_size) == 2: 70 | self.is_2d = True 71 | self.patch_size = np.concatenate([[1], self.patch_size]) 72 | elif len(self.patch_size) == 3: 73 | self.is_2d = False 74 | else: 75 | raise ValueError('patch_size must be of len 2 or 3 (for 2d and 3d networks).') 76 | 77 | 78 | 79 | def _get_xyz_list(self, shape, ROI=None): 80 | 81 | nz, nx, ny = shape 82 | 83 | if ROI is None: 84 | # not ROI is given take all coordinates 85 | ROI = torch.ones((1, nz, nx, ny)) > 0 86 | 87 | n_patches = np.ceil((np.array([nz, nx, ny]) - self.patch_size) / 88 | (self.overlap * self.patch_size)).astype(int) + 1 89 | 90 | # upper left corners of all patches 91 | if self.is_2d: 92 | z_list = np.arange(nz).astype(int).tolist() 93 | else: 94 | z_list = np.linspace(0, nz - self.patch_size[0], n_patches[0]).astype(int).tolist() 95 | x_list = np.linspace(0, nx - self.patch_size[1], n_patches[1]).astype(int).tolist() 96 | y_list = np.linspace(0, ny - self.patch_size[2], n_patches[2]).astype(int).tolist() 97 | 98 | zxy_list = [] 99 | for z in z_list: 100 | for x in x_list: 101 | for y in y_list: 102 | # we only predict the patch if the middle cube with half side length 103 | # intersects the ROI 104 | if self.is_2d: 105 | z1, z2 = z, z+1 106 | else: 107 | z1, z2 = z+self.patch_size[0]//4, z+self.patch_size[0]*3//4 108 | x1, x2 = x+self.patch_size[1]//4, x+self.patch_size[1]*3//4 109 | y1, y2 = y+self.patch_size[2]//4, y+self.patch_size[2]*3//4 110 | if ROI[0, z1:z2, x1:x2, y1:y2].any().item(): 111 | zxy_list.append((z, x, y)) 112 | 113 | return zxy_list 114 | 115 | def _sliding_window(self, volume, ROI=None): 116 | 117 | if not torch.is_tensor(volume): 118 | raise TypeError('Input must be torch tensor') 119 | if not len(volume.shape) == 4: 120 | raise ValueError('Volume must be a 4d tensor (incl channel axis)') 121 | 122 | # in case the volume is smaller than the patch size we pad it 123 | # and save the input size to crop again before returning 124 | shape_in = np.array(volume.shape) 125 | 126 | # %% possible padding of too small volumes 127 | pad = [0, self.patch_size[2] - shape_in[3], 0, self.patch_size[1] - shape_in[2], 128 | 0, self.patch_size[0] - shape_in[1]] 129 | pad = np.maximum(pad, 0).tolist() 130 | volume = F.pad(volume, pad).type(torch.float) 131 | shape = volume.shape[1:] 132 | 133 | # %% reserve storage 134 | pred = torch.zeros((self.network.out_channels, *shape), 135 | device=self.dev, 136 | dtype=torch.float) 137 | # this is for the voxel where we have no prediction in the end 138 | # for each of those the method will return the (1,0,..,0) vector 139 | # pred[0] = 1 140 | ovlp = torch.zeros((1, *shape), 141 | device=self.dev, 142 | dtype=torch.float) 143 | 144 | # %% get all top left coordinates of patches 145 | zxy_list = self._get_xyz_list(shape, ROI) 146 | 147 | # introduce batch size 148 | # some people say that introducing a batch size at inference time makes it faster 149 | # I couldn't see that so far 150 | n_full_batches = len(zxy_list) // self.batch_size 151 | zxy_batched = [zxy_list[i * self.batch_size: (i + 1) * self.batch_size] 152 | for i in range(n_full_batches)] 153 | 154 | if n_full_batches * self.batch_size < len(zxy_list): 155 | zxy_batched.append(zxy_list[n_full_batches * self.batch_size:]) 156 | 157 | # %% now the magic! 158 | with torch.no_grad(): 159 | for zxy_batch in zxy_batched: 160 | # crop 161 | batch = torch.stack([volume[:, 162 | z:z+self.patch_size[0], 163 | x:x+self.patch_size[1], 164 | y:y+self.patch_size[2]] for z, x, y in zxy_batch]) 165 | 166 | # remove z axis if we have 2d prediction 167 | batch = batch[:, :, 0] if self.is_2d else batch 168 | # remember that the network is outputting a list of predictions for each scale 169 | if not self.fp32 and torch.cuda.is_available(): 170 | with torch.cuda.amp.autocast(): 171 | out = self.network(batch)[0] 172 | else: 173 | out = self.network(batch)[0] 174 | 175 | # add z axis again maybe 176 | out = out.unsqueeze(2) if self.is_2d else out 177 | 178 | # update pred and overlap 179 | for i, (z, x, y) in enumerate(zxy_batch): 180 | pred[:, z:z+self.patch_size[0], x:x+self.patch_size[1], 181 | y:y+self.patch_size[2]] += F.softmax(out[i], 0) * self.patch_weight 182 | ovlp[:, z:z+self.patch_size[0], x:x+self.patch_size[1], 183 | y:y+self.patch_size[2]] += self.patch_weight 184 | 185 | # %% bring maybe back to old shape 186 | pred = pred[:, :shape_in[1], :shape_in[2], :shape_in[3]] 187 | ovlp = ovlp[:, :shape_in[1], :shape_in[2], :shape_in[3]] 188 | 189 | # set the prediction to background and prevent zero division where 190 | # we did not evaluate the network 191 | pred[0, ovlp[0] == 0] = 1 192 | ovlp[ovlp == 0] = 1 193 | 194 | pred /= ovlp 195 | 196 | # just to be sure 197 | if torch.cuda.is_available(): 198 | torch.cuda.empty_cache() 199 | 200 | return pred 201 | 202 | def predict_volume(self, volume, ROI=None, mode=None): 203 | # evaluates the siliding window on this volume 204 | # predictions are returned as soft segmentations 205 | if mode is None: 206 | mode = self.mode 207 | 208 | if ROI is not None: 209 | ROI = maybe_add_channel_dim(ROI) 210 | 211 | self.network.eval() 212 | 213 | # check the type and bring to device 214 | is_np, _ = check_type(volume) 215 | if is_np: 216 | volume = torch.from_numpy(volume).to(self.dev) 217 | 218 | # check if inpt is 3d or 4d for the output 219 | volume = maybe_add_channel_dim(volume) 220 | 221 | if mode.lower() == 'simple': 222 | pred = self._predict_volume_simple(volume, ROI) 223 | elif mode.lower() == 'flip': 224 | pred = self._predict_volume_flip(volume, ROI) 225 | 226 | if is_np: 227 | pred = pred.cpu().numpy() 228 | 229 | return pred 230 | 231 | def __call__(self, volume, ROI=None, mode=None): 232 | return self.predict_volume(volume, ROI, mode) 233 | 234 | def _predict_volume_simple(self, volume, ROI=None): 235 | return self._sliding_window(volume, ROI) 236 | 237 | def _predict_volume_flip(self, volume, ROI=None): 238 | 239 | flip_z_list = [False] if self.is_2d else [False, True] 240 | 241 | if ROI is not None and isinstance(ROI, np.ndarray): 242 | ROI = torch.from_numpy(ROI) 243 | 244 | # collect all combinations of flipping 245 | flip_list = [] 246 | for fz in flip_z_list: 247 | for fx in [False, True]: 248 | for fy in [False, True]: 249 | flip_list.append((fz, fx, fy)) 250 | 251 | # do the first one outside the loop for initialisation 252 | pred = self._sliding_window(volume, ROI=ROI) 253 | 254 | # now some flippings! 255 | for f in flip_list[1:]: 256 | volume = self._flip_volume(volume, f) 257 | if ROI is not None: 258 | ROI = self._flip_volume(ROI, f) 259 | 260 | # predict flipped volume 261 | pred_flipped = self._sliding_window(volume, ROI) 262 | 263 | # flip back and update 264 | pred += self._flip_volume(pred_flipped, f) 265 | volume = self._flip_volume(volume, f) 266 | if ROI is not None: 267 | ROI = self._flip_volume(ROI, f) 268 | 269 | return pred / len(flip_list) 270 | 271 | def _flip_volume(self, volume, f): 272 | for i in range(3): 273 | if f[i]: 274 | volume = volume.flip(i+1) 275 | return volume 276 | -------------------------------------------------------------------------------- /src/ovseg/augmentation/myRandAugment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.functional import interpolate 5 | import numpy as np 6 | import random 7 | 8 | 9 | # %% 10 | class torch_myRandAugment(torch.nn.Module): 11 | ''' 12 | This is really just the nnU-Net gray value Augmentation but parametrised differently 13 | ''' 14 | 15 | def __init__(self, 16 | P=0.15, 17 | M=15, 18 | n_im_channels: int = 1 19 | ): 20 | super().__init__() 21 | self.P = P 22 | self.M = M 23 | self.n_im_channels = n_im_channels 24 | 25 | def _uniform(self, mm, device='cpu'): 26 | return (mm[1] - mm[0]) * torch.rand([], device=device) + mm[0] 27 | 28 | def _sign(self): 29 | return np.random.choice([-1, 1]) 30 | 31 | def _noise(self, img, m): 32 | var = self._uniform([0, 0.1*m/15], device=img.device) 33 | sigma = torch.sqrt(var) 34 | return img + sigma * torch.randn_like(img) 35 | 36 | def _blur(self, img, m): 37 | sigma = self._uniform([0.5 * m/15, 0.5 + m/15], device=img.device) 38 | var = sigma ** 2 39 | axes = torch.arange(-5, 6, device=img.device) 40 | grid = torch.stack(torch.meshgrid([axes for _ in range(2)])) 41 | gkernel = torch.exp(-1*torch.sum(grid**2, dim=0)/2.0/var) 42 | gkernel = gkernel/gkernel.sum() 43 | if len(img.shape) == 4: 44 | # 2d case 45 | gkernel = gkernel.view(1, 1, 11, 11).to(img.device).type(img.dtype) 46 | return torch.nn.functional.conv2d(img, gkernel, padding=5) 47 | else: 48 | gkernel = gkernel.view(1, 1, 1, 11, 11).to(img.device).type(img.dtype) 49 | return torch.nn.functional.conv3d(img, gkernel, padding=(0, 5, 5)) 50 | 51 | def _brightness(self, img, m): 52 | fac = self._uniform([1 - 0.3 * m/15, 1 + 0.3 * m/15], device=img.device) 53 | return img * fac 54 | 55 | def _contrast(self, img, m): 56 | fac = self._uniform([1 - 0.45 * m/15, 1 + 0.5 * m/15], device=img.device) 57 | mean = img.mean() 58 | mn = img.min().item() 59 | mx = img.max().item() 60 | img = (img - mean) * fac + mean 61 | return img.clamp(mn, mx) 62 | 63 | def _low_res(self, img, m): 64 | size = img.size()[2:] 65 | mode = 'bilinear' if len(size) == 2 else 'trilinear' 66 | fac = np.random.uniform(*[1, 1 + m/15]) 67 | img = interpolate(img, scale_factor=1/fac) 68 | return interpolate(img, size=size, mode=mode) 69 | 70 | def _gamma(self, img, m): 71 | with torch.cuda.amp.autocast(enabled=False): 72 | mn, mx = img.min(), img.max() 73 | img = (img - mn)/(mx - mn) 74 | gamma = np.random.uniform(*[1 - 0.3 * m/15, 1 + 0.5 * m/15]) 75 | if np.random.rand() < self.P: 76 | img = 1 - (1 - img) ** gamma 77 | else: 78 | img = img ** gamma 79 | 80 | return (mx - mn) * img + mn 81 | 82 | def _get_ops_mag_list(self): 83 | ops_mag_list = [] 84 | if np.random.rand() < self.P: 85 | ops_mag_list.append((self._noise, np.random.rand() * self.M)) 86 | if np.random.rand() < self.P: 87 | ops_mag_list.append((self._blur, np.random.rand() * self.M)) 88 | if np.random.rand() < self.P: 89 | ops_mag_list.append((self._brightness, np.random.rand() * self.M)) 90 | if np.random.rand() < self.P: 91 | ops_mag_list.append((self._contrast, np.random.rand() * self.M)) 92 | if np.random.rand() < self.P: 93 | ops_mag_list.append((self._low_res, np.random.rand() * self.M)) 94 | np.random.shuffle(ops_mag_list) 95 | 96 | return ops_mag_list 97 | 98 | def forward(self, xb): 99 | 100 | c = self.n_im_channels 101 | 102 | for b in range(xb.shape[0]): 103 | for op, m in self._get_ops_mag_list(): 104 | xb[b:b+1, :c] = op(xb[b:b+1, :c], m) 105 | return xb 106 | 107 | def update_prg_trn(self, param_dict, h, indx=None): 108 | 109 | if 'M' in param_dict: 110 | self.M = (1 - h) * param_dict['M'][0] + h * param_dict['M'][1] 111 | 112 | if 'P' in param_dict: 113 | self.P = (1 - h) * param_dict['P'][0] + h * param_dict['P'][1] 114 | 115 | 116 | # %% 117 | class torch_myRandAugment_old(nn.Module): 118 | 119 | def __init__(self, n, m, n_im_channels=1, use_3d_spatials=False): 120 | super().__init__() 121 | self.n = n 122 | self.m = m 123 | self.n_im_channels = n_im_channels 124 | self.use_3d_spatials = use_3d_spatials 125 | 126 | # smooth_kernel = [[1, 1, 1], [1, 15, 1], [1, 1, 1]] 127 | smooth_kernel = [[1, 1, 1, 1, 1], [1, 5, 5, 5, 1], [1, 5, 44, 5, 1], 128 | [1, 5, 5, 5, 1], [1, 1, 1, 1, 1]] 129 | smooth_kernel = torch.tensor(smooth_kernel).type(torch.float) 130 | smooth_kernel = smooth_kernel / smooth_kernel.sum() 131 | self.smooth_kernel_2d = smooth_kernel.unsqueeze(0).unsqueeze(0) 132 | self.smooth_kernel_3d = self.smooth_kernel_2d.unsqueeze(0) 133 | self.padding_2d = (2, 2) 134 | self.padding_3d = (0, 2, 2) 135 | 136 | self.all_ops = [(self._identity, 0, 1), 137 | (self._translate_x, 0, 0.33), 138 | (self._translate_y, 0, 0.33), 139 | (self._shear_x, 0, 0.3), 140 | (self._shear_y, 0, 0.3), 141 | (self._contrast, 0, 0.9), 142 | (self._brightness, 0, 0.9), 143 | (self._darkness, 0, 0.9), 144 | #self._narrow_window, 145 | (self._sharpness, 0, 0.9), 146 | (self._noise, 0, 1.0)] 147 | 148 | def _get_theta_id(self, img): 149 | # helper to create the identity matrix for spatial operations 150 | bs, n_ch = img.shape[0:2] 151 | img_dims = len(img.shape) - 2 152 | theta = torch.zeros((bs, img_dims, img_dims+1), device=img.device, dtype=img.dtype) 153 | for j in range(img_dims): 154 | theta[:, j, j] = 1 155 | return theta 156 | 157 | def _interp_img(self, img, theta): 158 | # performs spatial operations by interpolation 159 | grid = F.affine_grid(theta, img.size()).to(img.device).type(img.dtype) 160 | img = torch.cat([F.grid_sample(img[:, :self.n_im_channels], grid, mode='bilinear'), 161 | F.grid_sample(img[:, self.n_im_channels:], grid, mode='nearest')], dim=1) 162 | return img 163 | 164 | def _sign(self): 165 | return np.random.choice([-1, 1]) 166 | 167 | # list of all transformations we take into account 168 | def _identity(self, img, val): 169 | return img 170 | 171 | def _translate_x(self, img, val): 172 | theta = self._get_theta_id(img) 173 | theta[:, 1, -1] = self._sign() * val 174 | return self._interp_img(img, theta) 175 | 176 | def _translate_y(self, img, val): 177 | theta = self._get_theta_id(img) 178 | theta[:, 0, -1] = self._sign() * val 179 | return self._interp_img(img, theta) 180 | 181 | def _shear_x(self, img, val): 182 | theta = self._get_theta_id(img) 183 | theta[:, 0, 1] = val * self._sign() 184 | return self._interp_img(img, theta) 185 | 186 | def _shear_y(self, img, val): 187 | theta = self._get_theta_id(img) 188 | theta[:, 1, 0] = val * self._sign() 189 | return self._interp_img(img, theta) 190 | 191 | def _contrast(self, img, val): 192 | val = val * self._sign() 193 | for ch in range(self.n_im_channels): 194 | img[:, ch] = (1 - val) * img[:, ch] + val * img[:, ch].mean() 195 | return img 196 | 197 | def _brightness(self, img, val): 198 | val = val * self._sign() 199 | for ch in range(self.n_im_channels): 200 | img[:, ch] = (1 - val) * img[:, ch] + val * img[:, ch].min() 201 | return img 202 | 203 | def _darkness(self, img, val): 204 | val = val * self._sign() 205 | for ch in range(self.n_im_channels): 206 | img[:, ch] = (1 - val) * img[:, ch] + val * img[:, ch].max() 207 | return img 208 | 209 | def _narrow_window(self, img, val): 210 | 211 | for ch in range(self.n_im_channels): 212 | mn, mx = img[:, ch].min(), img[:, ch].max() 213 | mn_new = mn * (1 - val) + val * mx 214 | mx_new = mn * val + (1 - val) * mx 215 | img[:, ch] = img[:, ch].clip(mn_new, mx_new) 216 | return img 217 | 218 | def _sharpness(self, img, val): 219 | val = val * self._sign() 220 | if len(img.shape) == 4: 221 | img_smooth = [F.conv2d(img[:, ch:ch+1], 222 | self.smooth_kernel_2d.to(img.device).type(img.dtype), 223 | padding=self.padding_2d) for ch in range(self.n_im_channels)] 224 | else: 225 | img_smooth = [F.conv3d(img[:, ch:ch+1], 226 | self.smooth_kernel_3d.to(img.device).type(img.dtype), 227 | padding=self.padding_3d) for ch in range(self.n_im_channels)] 228 | img_smooth = torch.cat(img_smooth, 1) 229 | for ch in range(self.n_im_channels): 230 | img[:, ch] = img_smooth[:, ch] * val + (1 - val) * img[:, ch] 231 | return img 232 | 233 | def _noise(self, img, val): 234 | for ch in range(self.n_im_channels): 235 | img[:, ch] = img[:, ch] + val * torch.randn_like(img[:, ch]) 236 | return img 237 | 238 | def forward(self, xb): 239 | 240 | for b in range(xb.shape[0]): 241 | ops_list = random.choices(self.all_ops, k=self.n) 242 | for op, mn, mx in ops_list: 243 | val = mn * (1 - self.m/30) + self.m/30 * mx 244 | xb[b:b+1] = op(xb[b:b+1], val) 245 | 246 | return xb 247 | 248 | def update_prg_trn(self, param_dict, h, indx=None): 249 | 250 | if 'm' in param_dict: 251 | self.m = (1 - h) * param_dict['m'][0] + h * param_dict['m'][1] 252 | 253 | if 'n' in param_dict: 254 | self.n = (1 - h) * param_dict['n'][0] + h * param_dict['n'][1] 255 | self.n = int(self.n + 0.5) 256 | 257 | 258 | # %% 259 | if __name__ == '__main__': 260 | import matplotlib.pyplot as plt 261 | plt.close() 262 | im = np.load('D:\\PhD\\Data\\ov_data_base\\preprocessed\\OV04\\pod_half\\images\\case_000.npy') 263 | lb = np.load('D:\\PhD\\Data\\ov_data_base\\preprocessed\\OV04\\pod_half\\labels\\case_000.npy') 264 | volume = np.stack([im, lb]).astype(np.float32) 265 | xb = torch.from_numpy(volume[np.newaxis, :, 37:69, 64:192, 64:192]).cuda() 266 | img = xb[:1] 267 | aug = torch_myRandAugment(n=1, m=5) 268 | # %% 269 | vmin, vmax = img[0, 0].min(), img[0, 0].max() 270 | img_aug = aug(torch.clone(img)) 271 | plt.subplot(1, 3, 1) 272 | plt.imshow(img[0, 0, -1].cpu().numpy(), cmap='gray') 273 | plt.contour(img[0, 1, -1].cpu().numpy(), colors='red') 274 | plt.subplot(1, 3, 2) 275 | plt.imshow(img_aug[0, 0, -1].cpu().numpy(), cmap='gray', vmin=vmin, vmax=vmax) 276 | plt.contour(img_aug[0, 1, -1].cpu().numpy(), colors='red') 277 | plt.subplot(1, 3, 3) 278 | plt.imshow((img[0, 0, -1] - img_aug[0, 0, -1]).cpu().numpy(), cmap='gray', vmin=vmin, vmax=vmax) 279 | plt.contour(img_aug[0, 1, -1].cpu().numpy(), colors='red') 280 | 281 | # %% 282 | m = 5 283 | vmin, vmax = img[0, 0].min(), img[0, 0].max() 284 | op, mn, mx = aug.all_ops[9] 285 | print(op) 286 | val = mn * (1 - m/30) + m/30 * mx 287 | img_aug = op(torch.clone(img), val) 288 | plt.subplot(1, 3, 1) 289 | plt.imshow(img[0, 0, -1].cpu().numpy(), cmap='gray') 290 | plt.contour(img[0, 1, -1].cpu().numpy(), colors='red') 291 | plt.subplot(1, 3, 2) 292 | plt.imshow(img_aug[0, 0, -1].cpu().numpy(), cmap='gray', vmin=vmin, vmax=vmax) 293 | plt.contour(img_aug[0, 1, -1].cpu().numpy(), colors='red') 294 | plt.subplot(1, 3, 3) 295 | plt.imshow((img[0, 0, -1] - img_aug[0, 0, -1]).cpu().numpy(), cmap='gray', vmin=vmin, vmax=vmax) 296 | plt.contour(img_aug[0, 1, -1].cpu().numpy(), colors='red') 297 | -------------------------------------------------------------------------------- /src/ovseg/data/SegmentationDataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from ovseg.data.utils import crop_and_pad_image 4 | from ovseg.utils.torch_np_utils import maybe_add_channel_dim 5 | import os 6 | from time import sleep 7 | try: 8 | from tqdm import tqdm 9 | except ModuleNotFoundError: 10 | print('No tqdm found, using no pretty progressing bars') 11 | tqdm = lambda x: x 12 | 13 | 14 | class SegmentationBatchDataset(object): 15 | 16 | def __init__(self, vol_ds, patch_size, batch_size, epoch_len=250, p_bias_sampling=0, 17 | min_biased_samples=1, augmentation=None, padded_patch_size=None, 18 | n_im_channels: int = 1, store_coords_in_ram=True, memmap='r', image_key='image', 19 | label_key='label', store_data_in_ram=False, return_fp16=True, n_max_volumes=None, 20 | bias='fg', n_fg_classes=None, *args, **kwargs): 21 | self.vol_ds = vol_ds 22 | self.patch_size = np.array(patch_size) 23 | self.batch_size = batch_size 24 | self.epoch_len = epoch_len 25 | self.p_bias_sampling = p_bias_sampling 26 | self.min_biased_samples = min_biased_samples 27 | self.augmentation = augmentation 28 | self.store_coords_in_ram = store_coords_in_ram 29 | self.memmap = memmap 30 | self.image_key = image_key 31 | self.label_key = label_key 32 | self.store_data_in_ram = store_data_in_ram 33 | self.n_im_channels = n_im_channels 34 | self.return_fp16 = return_fp16 35 | self.bias = bias 36 | self.n_fg_classes = n_fg_classes 37 | 38 | if self.bias == 'cl_fg': 39 | assert isinstance(self.n_fg_classes, int) 40 | assert self.n_fg_classes > 0 41 | else: 42 | # does not need to be true, but makes our life easier 43 | self.n_fg_classes = 1 44 | 45 | self.dtype = np.float16 if self.return_fp16 else np.float32 46 | if n_max_volumes is None: 47 | self.n_volumes = len(self.vol_ds) 48 | else: 49 | self.n_volumes = np.min([n_max_volumes, len(self.vol_ds)]) 50 | 51 | if len(self.patch_size) == 2: 52 | self.twoD_patches = True 53 | self.patch_size = np.concatenate([[1], self.patch_size]) 54 | else: 55 | self.twoD_patches = False 56 | 57 | # overwrite default in case we're not using padding here 58 | if padded_patch_size is None: 59 | self.padded_patch_size = self.patch_size 60 | else: 61 | self.padded_patch_size = np.array(padded_patch_size) 62 | 63 | self._maybe_store_data_in_ram() 64 | 65 | if len(args) > 0: 66 | print('Warning, got unused args: {}'.format(args)) 67 | if len(kwargs) > 0: 68 | print('Warning, got unused kwargs: {}'.format(kwargs)) 69 | 70 | def _get_bias_coords(self, volume): 71 | 72 | if self.bias == 'fg': 73 | return [np.stack(np.where(volume[-1] > 0)).astype(np.int16)] 74 | elif self.bias == 'cl_fg': 75 | return [np.stack(np.where(volume[-1] == cl)).astype(np.int16) 76 | for cl in range(1, self.n_fg_classes + 1)] 77 | elif self.bias == 'mask': 78 | return [np.stack(np.where(volume[-2] > 0)).astype(np.int16)] 79 | 80 | def _maybe_store_data_in_ram(self): 81 | # maybe cleaning first, just to be sure 82 | self._maybe_clean_stored_data() 83 | 84 | if self.store_data_in_ram: 85 | print('Store data in RAM.\n') 86 | self.data = [] 87 | sleep(1) 88 | for ind in tqdm(range(self.n_volumes)): 89 | path_dict = self.vol_ds.path_dicts[ind] 90 | labels = np.load(path_dict[self.label_key]).astype(np.uint8) 91 | 92 | labels = maybe_add_channel_dim(labels) 93 | 94 | im = np.load(path_dict[self.image_key]).astype(self.dtype) 95 | 96 | im = maybe_add_channel_dim(im) 97 | 98 | self.data.append((im, labels)) 99 | 100 | # store coords in ram 101 | if self.store_coords_in_ram: 102 | print('Precomputing bias coordinates to store them in RAM.\n') 103 | self.coords_list = [] 104 | self.contains_fg_list = [[] for _ in range(self.n_fg_classes)] 105 | sleep(1) 106 | for ind in tqdm(range(self.n_volumes)): 107 | if self.store_data_in_ram: 108 | labels = self.data[ind][1] 109 | else: 110 | labels = np.load(self.vol_ds.path_dicts[ind][self.label_key]) 111 | 112 | # ensure 4d array 113 | labels = maybe_add_channel_dim(labels) 114 | coords = self._get_bias_coords(labels) 115 | self.coords_list.append(coords) 116 | 117 | # save which index has which fg class 118 | for i in range(self.n_fg_classes): 119 | if coords[i].shape[1] > 0: 120 | self.contains_fg_list[i].append(ind) 121 | print('Done') 122 | else: 123 | # if we don't store them in ram we will compute them and store them as .npy files 124 | # in the preprocessed path 125 | self.contains_fg_list = [[] for _ in range(self.n_fg_classes)] 126 | self.bias_coords_fol = os.path.join(self.vol_ds.preprocessed_path, 127 | 'bias_coordinates_'+self.bias) 128 | if not os.path.exists(self.bias_coords_fol): 129 | os.mkdir(self.bias_coords_fol) 130 | 131 | # now we check if come cases are missing in the folder 132 | print('Checking if all bias coordinates are stored in '+self.bias_coords_fol) 133 | for ind, d in enumerate(self.vol_ds.path_dicts): 134 | case = os.path.basename(d[self.label_key]) 135 | if case not in os.listdir(self.bias_coords_fol): 136 | labels = np.load(d[self.label_key]) 137 | coords = self._get_bias_coords(labels) 138 | np.save(os.path.join(self.bias_coords_fol, case), coords) 139 | else: 140 | coords = np.load(os.path.join(self.bias_coords_fol, case)) 141 | 142 | # save which index has which fg class 143 | for i in range(self.n_fg_classes): 144 | if coords[i].shape[1] > 0: 145 | self.contains_fg_list[i].append(ind) 146 | 147 | # print how many scans we have with which class 148 | for c in range(self.n_fg_classes): 149 | print('Found {} scans with fg {}'.format(len(self.contains_fg_list[c]), c)) 150 | 151 | # available classes start from 0 152 | self.availble_classes = [i for i, l in enumerate(self.contains_fg_list) if len(l) > 0] 153 | 154 | if len(self.availble_classes) < self.n_fg_classes: 155 | missing_classes = [i+1 for i, l in enumerate(self.contains_fg_list) if len(l) == 0] 156 | print('Warning! Some fg classes were not found in this dataset. ' 157 | 'Missing classes: {}'.format(missing_classes)) 158 | 159 | sleep(1) 160 | 161 | def _maybe_clean_stored_data(self): 162 | # delte stuff we stored in RAM 163 | # first for the full volumes 164 | if hasattr(self, 'data'): 165 | for tpl in self.data: 166 | for arr in tpl: 167 | del arr 168 | del tpl 169 | del self.data 170 | 171 | # now for the bias coordinates 172 | if hasattr(self, 'coords_list'): 173 | for coord in self.coords_list: 174 | for crds in coord: 175 | del crds 176 | del coord 177 | del self.coords_list 178 | 179 | def change_folders_and_keys(self, new_folders, new_keys): 180 | # for progressive training, we might change the folder of image and label data during 181 | # training if we've stored the rescaled volumes on the hard drive. 182 | print('Dataloader: chaning keys and folders') 183 | print('new keys: ', *new_keys) 184 | print('new folders: ', *new_folders) 185 | print() 186 | self.vol_ds.change_folders_and_keys(new_folders, new_keys) 187 | self._maybe_store_data_in_ram() 188 | 189 | def _get_volume_tuple(self, ind=None): 190 | 191 | if ind is None: 192 | ind = np.random.randint(self.n_volumes) 193 | 194 | load_from_ram = hasattr(self, 'data') 195 | if load_from_ram: 196 | load_from_ram = ind < len(self.data) 197 | 198 | if load_from_ram: 199 | volumes = self.data[ind] 200 | else: 201 | path_dict = self.vol_ds.path_dicts[ind] 202 | im = np.load(path_dict[self.image_key], 'r') 203 | labels = np.load(path_dict[self.label_key], 'r') 204 | volumes = [im, labels] 205 | 206 | # maybe add an additional axis 207 | volumes = [maybe_add_channel_dim(vol) for vol in volumes] 208 | 209 | return volumes 210 | 211 | def _get_random_volume_ind(self, biased_sampling): 212 | if biased_sampling: 213 | # when we do biased sampling we have to make sure that the 214 | # volume we're sampling actually has fg 215 | if len(self.availble_classes) > 0: 216 | cl = np.random.choice(self.availble_classes) 217 | return np.random.choice(self.contains_fg_list[cl]), cl 218 | else: 219 | return np.random.randint(self.n_volumes), -1 220 | else: 221 | return np.random.randint(self.n_volumes), -1 222 | 223 | def __len__(self): 224 | return self.epoch_len * self.batch_size 225 | 226 | def __getitem__(self, index): 227 | 228 | if index % self.batch_size < self.min_biased_samples: 229 | biased_sampling = True 230 | else: 231 | biased_sampling = np.random.rand() < self.p_bias_sampling 232 | 233 | ind, cl = self._get_random_volume_ind(biased_sampling) 234 | volumes = self._get_volume_tuple(ind) 235 | shape = np.array(volumes[0].shape)[1:] 236 | 237 | if biased_sampling and cl >= 0: 238 | # let's get the list of bias coordinates 239 | if self.store_coords_in_ram: 240 | # loading from RAM 241 | coords = self.coords_list[ind][cl] 242 | else: 243 | # or hard drive 244 | case = os.path.basename(self.vol_ds.path_dicts[ind][self.label_key]) 245 | coords = np.load(os.path.join(self.bias_coords_fol, case))[cl] 246 | 247 | # pick a random item from the list and compute the upper left corner of the patch 248 | n_coords = coords.shape[1] 249 | coord = coords[:, np.random.randint(n_coords)] - self.patch_size//2 250 | else: 251 | # random coordinate uniform from the whole volume 252 | coord = np.random.randint(np.maximum(shape - self.patch_size+1, 1)) 253 | coord = np.minimum(np.maximum(coord, 0), shape - self.patch_size) 254 | # now get the cropped and padded sample 255 | 256 | volume = np.concatenate([crop_and_pad_image(vol, 257 | coord, 258 | self.patch_size, 259 | self.padded_patch_size) for vol in volumes]) 260 | 261 | if self.twoD_patches: 262 | # remove z axis 263 | volume = volume[:, 0] 264 | 265 | if self.augmentation is not None: 266 | # in augmentation we need batch style arrays with an additional 267 | # axes for the batch size 268 | # the label augmentation expects integer valued predictions as input 269 | volume = self.augmentation(volume[np.newaxis])[0] 270 | 271 | return volume.astype(self.dtype) 272 | 273 | 274 | 275 | def SegmentationDataloader(vol_ds, patch_size, batch_size, num_workers=None, 276 | pin_memory=True, epoch_len=250, *args, **kwargs): 277 | dataset = SegmentationBatchDataset(vol_ds=vol_ds, patch_size=patch_size, batch_size=batch_size, 278 | epoch_len=epoch_len, *args, **kwargs) 279 | if num_workers is None: 280 | num_workers = 0 if os.name == 'nt' else 5 281 | worker_init_fn = lambda _: np.random.seed() 282 | sampler = torch.utils.data.SequentialSampler(range(batch_size * epoch_len)) 283 | return torch.utils.data.DataLoader(dataset, 284 | sampler=sampler, 285 | batch_size=batch_size, 286 | pin_memory=pin_memory, 287 | num_workers=num_workers, 288 | worker_init_fn=worker_init_fn) 289 | -------------------------------------------------------------------------------- /src/ovseg/model/SegmentationEnsemble.py: -------------------------------------------------------------------------------- 1 | from ovseg.utils.io import load_pkl 2 | from ovseg.model.SegmentationModel import SegmentationModel 3 | from ovseg.model.ModelBase import ModelBase 4 | from ovseg.data.Dataset import raw_Dataset 5 | from os import environ, listdir 6 | from os.path import join, isdir, exists 7 | import torch 8 | from ovseg.utils.torch_np_utils import check_type 9 | import numpy as np 10 | from tqdm import tqdm 11 | from time import sleep 12 | 13 | 14 | class SegmentationEnsemble(ModelBase): 15 | ''' 16 | Ensembling Model that is used to add over softmax outputs before applying the argmax 17 | It is always called in inference mode! 18 | ''' 19 | 20 | def __init__(self, data_name: str, model_name: str, preprocessed_name: str, val_fold=None, 21 | network_name='network', fmt_write='{:.4f}', 22 | model_parameters_name='model_parameters'): 23 | self.model_cv_path = join(environ['OV_DATA_BASE'], 24 | 'trained_models', 25 | data_name, 26 | preprocessed_name, 27 | model_name) 28 | if val_fold is None: 29 | fold_folders = [f for f in listdir(self.model_cv_path) 30 | if isdir(join(self.model_cv_path, f)) and f.startswith('fold')] 31 | val_fold = [int(f.split('_')[-1]) for f in fold_folders] 32 | super().__init__(val_fold=val_fold, data_name=data_name, model_name=model_name, 33 | preprocessed_name=preprocessed_name, 34 | network_name=network_name, is_inference_only=True, 35 | fmt_write=fmt_write, model_parameters_name=model_parameters_name) 36 | 37 | # create all models 38 | self.models = [] 39 | 40 | 41 | self.models_initialised = False 42 | if self.all_folds_complete(): 43 | self.initialise_models() 44 | 45 | def create_model(self, fold): 46 | model = SegmentationModel(val_fold=fold, 47 | data_name=self.data_name, 48 | model_name=self.model_name, 49 | model_parameters=self.model_parameters, 50 | preprocessed_name=self.preprocessed_name, 51 | network_name=self.network_name, 52 | is_inference_only=True, 53 | fmt_write=self.fmt_write, 54 | model_parameters_name=self.model_parameters_name 55 | ) 56 | return model 57 | 58 | def initialise_models(self): 59 | 60 | if self.models_initialised: 61 | print('Models were already initialised') 62 | return 63 | 64 | not_finished_folds = self._find_incomplete_folds() 65 | for fold in self.val_fold: 66 | if fold in not_finished_folds: 67 | print('Skipping fold {}. Training was not finished.'.format(fold)) 68 | continue 69 | print('Creating model from fold: '+str(fold)) 70 | model = self.create_model(fold) 71 | self.models.append(model) 72 | 73 | # change in evaluation mode 74 | for model in self.models: 75 | model.network.eval() 76 | 77 | self.models_initialised = True 78 | 79 | # now we do a hack by initialising the two objects like this... 80 | self.preprocessing = self.models[0].preprocessing 81 | self.postprocessing = self.models[0].postprocessing 82 | 83 | self.n_fg_classes = self.models[0].n_fg_classes 84 | if self.is_cascade(): 85 | self.prev_stages = self.model_parameters['prev_stages'] 86 | self.prev_stages_keys = [] 87 | for prev_stage in self.prev_stages: 88 | key = '_'.join(['prediction', 89 | prev_stage['data_name'], 90 | prev_stage['preprocessed_name'], 91 | prev_stage['model_name']]) 92 | self.prev_stages_keys.append(key) 93 | 94 | def is_cascade(self): 95 | return 'prev_stages' in self.model_parameters 96 | 97 | def _find_incomplete_folds(self): 98 | num_epochs = self.model_parameters['training']['num_epochs'] 99 | not_finished_folds = [] 100 | for fold in self.val_fold: 101 | path_to_attr = join(self.model_cv_path, 102 | 'fold_'+str(fold), 103 | 'attribute_checkpoint.pkl') 104 | if not exists(path_to_attr): 105 | print('Trying to check if the training is done for all folds,' 106 | ' but not checkpoint was found for fold '+str(fold)+'.') 107 | not_finished_folds.append(fold) 108 | continue 109 | 110 | attr = load_pkl(path_to_attr) 111 | 112 | if attr['epochs_done'] < attr['num_epochs']: 113 | not_finished_folds.append(fold) 114 | return not_finished_folds 115 | 116 | def all_folds_complete(self): 117 | not_finished_folds = self._find_incomplete_folds() 118 | if len(not_finished_folds) == 0: 119 | return True 120 | 121 | else: 122 | print("It seems like the folds " + str(not_finished_folds) + 123 | " have not finished training.") 124 | return False 125 | 126 | def wait_until_all_folds_complete(self): 127 | 128 | waited = 0 129 | while not self.all_folds_complete(): 130 | sleep(60) 131 | waited += 60 132 | 133 | if waited % 600 == 0: 134 | print('Waited {} seconds'.format(waited)) 135 | 136 | self.initialise_models() 137 | 138 | def initialise_preprocessing(self): 139 | return 140 | 141 | def initialise_augmentation(self): 142 | return 143 | 144 | def initialise_network(self): 145 | return 146 | 147 | def initialise_postprocessing(self): 148 | return 149 | 150 | def initialise_data(self): 151 | return 152 | 153 | def initialise_training(self): 154 | return 155 | 156 | def __call__(self, data_tpl): 157 | if not self.all_folds_complete(): 158 | print('WARNING: Ensemble is used without all training folds being completed!!') 159 | 160 | if not self.models_initialised: 161 | print('Models were not initialised. Trying to do it now...') 162 | self.wait_until_all_folds_complete() 163 | 164 | scan = data_tpl['scan'] 165 | 166 | # also the path where we will look for already executed npz prediction 167 | pred_npz_path = join(environ['OV_DATA_BASE'], 'npz_predictions', self.data_name, 168 | self.preprocessed_name, self.model_name) 169 | 170 | # the preprocessing will only do something if the image is not preprocessed yet 171 | if not self.preprocessing.is_preprocessed_data_tpl(data_tpl): 172 | for model in self.models: 173 | # try find the npz file if there was already a prediction. 174 | path_to_npz = join(pred_npz_path, model.val_fold_str, scan+'.npz') 175 | path_to_npy = join(pred_npz_path, model.val_fold_str, scan+'.npy') 176 | 177 | if exists(path_to_npy) or exists(path_to_npz): 178 | im = None 179 | continue 180 | else: 181 | im = self.preprocessing(data_tpl, preprocess_only_im=True) 182 | break 183 | 184 | # now the importat part: the actual enembling of sliding window evaluations 185 | preds = [] 186 | with torch.no_grad(): 187 | for model in self.models: 188 | # try find the npz file if there was already a prediction. 189 | path_to_npz = join(pred_npz_path, model.val_fold_str, scan+'.npz') 190 | path_to_npy = join(pred_npz_path, model.val_fold_str, scan+'.npy') 191 | if exists(path_to_npy): 192 | try: 193 | pred = np.load(path_to_npy) 194 | except ValueError: 195 | 196 | if im is None: 197 | im = self.preprocessing(data_tpl, preprocess_only_im=True) 198 | pred = model.prediction(im).cpu().numpy() 199 | elif exists(path_to_npz): 200 | try: 201 | pred = np.load(path_to_npz)['arr_0'] 202 | except ValueError: 203 | if im is None: 204 | im = self.preprocessing(data_tpl, preprocess_only_im=True) 205 | pred = model.prediction(im).cpu().numpy() 206 | 207 | else: 208 | pred = model.prediction(im).cpu().numpy() 209 | preds.append(pred) 210 | 211 | ens_pred = np.stack(preds).mean(0) 212 | 213 | data_tpl[self.pred_key] = ens_pred 214 | 215 | # inside the postprocessing the result will be attached to the data_tpl 216 | self.postprocessing.postprocess_data_tpl(data_tpl, self.pred_key) 217 | 218 | torch.cuda.empty_cache() 219 | return data_tpl[self.pred_key] 220 | 221 | def save_prediction(self, data_tpl, folder_name, filename=None): 222 | 223 | self.models[0].save_prediction(data_tpl, folder_name, filename) 224 | 225 | def plot_prediction(self, data_tpl, folder_name, filename=None, image_key='image'): 226 | 227 | self.models[0].plot_prediction(data_tpl, folder_name, filename, image_key) 228 | 229 | def compute_error_metrics(self, data_tpl): 230 | return self.models[0].compute_error_metrics(data_tpl) 231 | 232 | def _init_global_metrics(self): 233 | self.global_metrics_helper = {} 234 | self.global_metrics = {} 235 | for c in self.models[0].lb_classes: 236 | self.global_metrics_helper.update({s+str(c): 0 for s in ['overlap_', 237 | 'gt_volume_', 238 | 'pred_volume_']}) 239 | self.global_metrics.update({'dice_'+str(c): -1, 240 | 'recall_'+str(c): -1, 241 | 'precision_'+str(c): -1}) 242 | 243 | def _update_global_metrics(self, data_tpl): 244 | 245 | if 'label' not in data_tpl: 246 | return 247 | 248 | if self.models[0].preprocessing.is_preprocessed_data_tpl(data_tpl): 249 | raise NotImplementedError('Ensemble prediction only implemented ' 250 | 'for raw data not for preprocessed data.') 251 | 252 | label = data_tpl['label'] 253 | if self.models[0].preprocessing.reduce_lb_to_single_class: 254 | label = (label > 0).astype(label.dtype) 255 | pred = data_tpl[self.pred_key] 256 | 257 | # volume of one voxel 258 | fac = np.prod(data_tpl['spacing']) 259 | for c in self.models[0].lb_classes: 260 | lb_c = (label == c).astype(float) 261 | pred_c = (pred == c).astype(float) 262 | ovlp = self.global_metrics_helper['overlap_'+str(c)] + np.sum(lb_c * pred_c) * fac 263 | gt_vol = self.global_metrics_helper['gt_volume_'+str(c)] + np.sum(lb_c) * fac 264 | pred_vol = self.global_metrics_helper['pred_volume_'+str(c)] + np.sum(pred_c) * fac 265 | # update global dice, recall and precision 266 | if gt_vol + pred_vol > 0: 267 | self.global_metrics['dice_'+str(c)] = 200 * ovlp / (gt_vol + pred_vol) 268 | else: 269 | self.global_metrics['dice_'+str(c)] = 100 270 | if gt_vol > 0: 271 | self.global_metrics['recall_'+str(c)] = 100 * ovlp / gt_vol 272 | else: 273 | self.global_metrics['recall_'+str(c)] = 100 if pred_vol == 0 else 0 274 | if pred_vol > 0: 275 | self.global_metrics['precision_'+str(c)] = 100 * ovlp / pred_vol 276 | else: 277 | self.global_metrics['precision_'+str(c)] = 100 if gt_vol == 0 else 0 278 | 279 | # now update global metrics helper 280 | self.global_metrics_helper['overlap_'+str(c)] = ovlp 281 | self.global_metrics_helper['gt_volume_'+str(c)] = gt_vol 282 | self.global_metrics_helper['pred_volume_'+str(c)] = pred_vol 283 | 284 | def clean(self): 285 | for model in self.models: 286 | model.clean() 287 | 288 | def fill_cross_validation(self): 289 | 290 | ds = raw_Dataset(join(environ['OV_DATA_BASE'], 'raw_data', self.data_name), 291 | prev_stages=self.prev_stages if hasattr(self, 'prev_stages') else None) 292 | pred_folder = join(environ['OV_DATA_BASE'], 'predictions', self.data_name, 293 | self.preprocessed_name, self.model_name, 'cross_validation') 294 | for i in tqdm(range(len(ds))): 295 | data_tpl = ds[i] 296 | filename = data_tpl['scan'] + '.nii.gz' 297 | if filename not in listdir(pred_folder): 298 | self.__call__(data_tpl) 299 | self.save_prediction(data_tpl, folder_name='cross_validation') 300 | -------------------------------------------------------------------------------- /src/ovseg/data/SegmentationDoubleBiasDataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from ovseg.data.utils import crop_and_pad_image 4 | from ovseg.utils.torch_np_utils import maybe_add_channel_dim 5 | from ovseg.utils.io import read_nii 6 | import os 7 | import nibabel as nib 8 | from time import sleep 9 | try: 10 | from tqdm import tqdm 11 | except ModuleNotFoundError: 12 | print('No tqdm found, using no pretty progressing bars') 13 | tqdm = lambda x: x 14 | 15 | 16 | def torch_resize(label, pred): 17 | 18 | pred_gpu = torch.from_numpy(pred[np.newaxis,np.newaxis]).cuda() 19 | 20 | size = label.shape 21 | if len(size) == 4: 22 | size = size[0] 23 | 24 | pred_rsz = torch.nn.functional.interpolate(pred_gpu, size) 25 | 26 | return pred_rsz[0,0].cpu().numpy() 27 | 28 | 29 | class SegmentationDoubleBiasBatchDataset(object): 30 | 31 | def __init__(self, vol_ds, patch_size, batch_size, epoch_len=250, 32 | n_bias1=1,n_bias2=1, prev_preds:list = [], 33 | augmentation=None, padded_patch_size=None, 34 | n_im_channels: int = 1, memmap='r', image_key='image', 35 | label_key='label', store_data_in_ram=False, return_fp16=True, n_max_volumes=None, 36 | bias1='fg', n_fg_classes=None, lb_classes=None, *args, **kwargs): 37 | self.vol_ds = vol_ds 38 | self.patch_size = np.array(patch_size) 39 | self.batch_size = batch_size 40 | self.epoch_len = epoch_len 41 | self.n_bias1 = n_bias1 42 | self.n_bias2 = n_bias2 43 | self.prev_preds = prev_preds 44 | self.augmentation = augmentation 45 | self.memmap = memmap 46 | self.image_key = image_key 47 | self.label_key = label_key 48 | self.store_data_in_ram = store_data_in_ram 49 | self.n_im_channels = n_im_channels 50 | self.return_fp16 = return_fp16 51 | self.bias1 = bias1 52 | self.n_fg_classes = n_fg_classes 53 | self.lb_classes = lb_classes 54 | 55 | if self.bias1 == 'cl_fg': 56 | assert isinstance(self.n_fg_classes, int) 57 | assert self.n_fg_classes > 0 58 | else: 59 | # does not need to be true, but makes our life easier 60 | self.n_fg_classes = 1 61 | 62 | self.dtype = np.float16 if self.return_fp16 else np.float32 63 | if n_max_volumes is None: 64 | self.n_volumes = len(self.vol_ds) 65 | else: 66 | self.n_volumes = np.min([n_max_volumes, len(self.vol_ds)]) 67 | 68 | if len(self.patch_size) == 2: 69 | self.twoD_patches = True 70 | self.patch_size = np.concatenate([[1], self.patch_size]) 71 | else: 72 | self.twoD_patches = False 73 | 74 | # overwrite default in case we're not using padding here 75 | if padded_patch_size is None: 76 | self.padded_patch_size = self.patch_size 77 | else: 78 | self.padded_patch_size = np.array(padded_patch_size) 79 | 80 | 81 | 82 | assert len(self.prev_preds) > 0, 'Need infos for previous predictions' 83 | self.path_to_previous_preds = os.path.join(os.environ['OV_DATA_BASE'], 84 | 'predictions', 85 | *self.prev_preds) 86 | 87 | self._maybe_store_data_in_ram() 88 | 89 | if len(args) > 0: 90 | print('Warning, got unused args: {}'.format(args)) 91 | if len(kwargs) > 0: 92 | print('Warning, got unused kwargs: {}'.format(kwargs)) 93 | 94 | 95 | 96 | def _get_bias_coords(self, labels, pred): 97 | 98 | if self.bias1 == 'fg': 99 | coords1 = [np.stack(np.where(labels[-1] > 0)).astype(np.int16)] 100 | elif self.bias1 == 'cl_fg': 101 | coords1 = [np.stack(np.where(labels[-1] == cl)).astype(np.int16) 102 | for cl in range(1, self.n_fg_classes + 1)] 103 | 104 | bin_lb = (labels[-1] > 0).astype(float) 105 | bin_pred = (pred > 0).astype(float) 106 | coords2 = np.stack(np.where(np.abs(bin_lb-bin_pred) > 0)).astype(np.int16) 107 | 108 | return [coords1, coords2] 109 | 110 | def _get_bias2_weight(self, labels, pred): 111 | 112 | lb = labels[-1] 113 | 114 | w = 0 115 | for i, cl in enumerate(self.lb_classes): 116 | bin_lb = (lb == i+1).astype(float) 117 | bin_pred = (pred ==cl).astype(float) 118 | w += 1 - (2*np.sum(bin_lb*bin_pred) + 1) / (np.sum(bin_lb + bin_pred) + 1) 119 | 120 | return w 121 | 122 | 123 | def _get_prev_pred(self, d): 124 | case = os.path.basename(d[self.label_key]).split('.')[0] 125 | pred, _, _ = read_nii(os.path.join(self.path_to_previous_preds, 126 | case+'.nii.gz')) 127 | return pred 128 | 129 | def _maybe_store_data_in_ram(self): 130 | # maybe cleaning first, just to be sure 131 | self._maybe_clean_stored_data() 132 | 133 | if self.store_data_in_ram: 134 | print('Store data in RAM.\n') 135 | self.data = [] 136 | sleep(1) 137 | for ind in tqdm(range(self.n_volumes)): 138 | path_dict = self.vol_ds.path_dicts[ind] 139 | labels = np.load(path_dict[self.label_key]).astype(np.uint8) 140 | 141 | labels = maybe_add_channel_dim(labels) 142 | 143 | im = np.load(path_dict[self.image_key]).astype(self.dtype) 144 | 145 | im = maybe_add_channel_dim(im) 146 | 147 | self.data.append((im, labels)) 148 | 149 | # store coords in ram 150 | print('Precomputing bias coordinates to store them in RAM.\n') 151 | self.coords_list = [] 152 | self.bias2_weights = [] 153 | self.contains_fg_list = [[] for _ in range(self.n_fg_classes)] 154 | sleep(1) 155 | for ind in tqdm(range(self.n_volumes)): 156 | if self.store_data_in_ram: 157 | labels = self.data[ind][1] 158 | else: 159 | labels = np.load(self.vol_ds.path_dicts[ind][self.label_key]) 160 | # ensure 4d array 161 | labels = maybe_add_channel_dim(labels) 162 | 163 | # get prev prediction in right shape 164 | pred = self._get_prev_pred(self.vol_ds.path_dicts[ind]) 165 | pred = torch_resize(labels, pred) 166 | 167 | coords = self._get_bias_coords(labels, pred) 168 | self.coords_list.append(coords) 169 | 170 | self.bias2_weights.append(self._get_bias2_weight(labels, pred)) 171 | 172 | # save which index has which fg class 173 | for i in range(self.n_fg_classes): 174 | if coords[0][i].shape[1] > 0: 175 | self.contains_fg_list[i].append(ind) 176 | print('Done') 177 | 178 | # print how many scans we have with which class 179 | for c in range(self.n_fg_classes): 180 | print('Found {} scans with fg {}'.format(len(self.contains_fg_list[c]), c)) 181 | 182 | # available classes start from 0 183 | self.availble_classes = [i for i, l in enumerate(self.contains_fg_list) if len(l) > 0] 184 | 185 | if len(self.availble_classes) < self.n_fg_classes: 186 | missing_classes = [i+1 for i, l in enumerate(self.contains_fg_list) if len(l) == 0] 187 | print('Warning! Some fg classes were not found in this dataset. ' 188 | 'Missing classes: {}'.format(missing_classes)) 189 | 190 | # now make the bias2_weight a probability distribution 191 | self.bias2_weights = np.array(self.bias2_weights) 192 | self.bias2_weights /= np.sum(self.bias2_weights) 193 | 194 | print('') 195 | 196 | sleep(1) 197 | 198 | def _maybe_clean_stored_data(self): 199 | # delte stuff we stored in RAM 200 | # first for the full volumes 201 | if hasattr(self, 'data'): 202 | for tpl in self.data: 203 | for arr in tpl: 204 | del arr 205 | del tpl 206 | del self.data 207 | 208 | # now for the bias coordinates 209 | if hasattr(self, 'coords_list'): 210 | for coord in self.coords_list: 211 | for crds in coord: 212 | del crds 213 | del coord 214 | del self.coords_list 215 | 216 | def change_folders_and_keys(self, new_folders, new_keys): 217 | # for progressive training, we might change the folder of image and label data during 218 | # training if we've stored the rescaled volumes on the hard drive. 219 | print('Dataloader: chaning keys and folders') 220 | print('new keys: ', *new_keys) 221 | print('new folders: ', *new_folders) 222 | print() 223 | self.vol_ds.change_folders_and_keys(new_folders, new_keys) 224 | self._maybe_store_data_in_ram() 225 | 226 | def _get_volume_tuple(self, ind=None): 227 | 228 | if ind is None: 229 | ind = np.random.randint(self.n_volumes) 230 | 231 | load_from_ram = hasattr(self, 'data') 232 | if load_from_ram: 233 | load_from_ram = ind < len(self.data) 234 | 235 | if load_from_ram: 236 | volumes = self.data[ind] 237 | else: 238 | path_dict = self.vol_ds.path_dicts[ind] 239 | im = np.load(path_dict[self.image_key], 'r') 240 | labels = np.load(path_dict[self.label_key], 'r') 241 | volumes = [im, labels] 242 | 243 | # maybe add an additional axis 244 | volumes = [maybe_add_channel_dim(vol) for vol in volumes] 245 | 246 | return volumes 247 | 248 | def _get_random_volume_ind(self, bias): 249 | 250 | if bias == 0: 251 | return np.random.randint(self.n_volumes), -1 252 | 253 | elif bias == 1: 254 | # when we do biased sampling we have to make sure that the 255 | # volume we're sampling actually has fg 256 | if len(self.availble_classes) > 0: 257 | cl = np.random.choice(self.availble_classes) 258 | return np.random.choice(self.contains_fg_list[cl]), cl 259 | else: 260 | return np.random.randint(self.n_volumes), -1 261 | 262 | else: 263 | return np.random.choice(list(range(self.n_volumes)), p=self.bias2_weights), -1 264 | 265 | def __len__(self): 266 | return self.epoch_len * self.batch_size 267 | 268 | def __getitem__(self, index): 269 | 270 | 271 | rel_indx = index % self.batch_size 272 | if rel_indx < self.n_bias1: 273 | bias = 1 274 | elif rel_indx < self.n_bias1 + self.n_bias2: 275 | bias = 2 276 | else: 277 | bias = 0 278 | 279 | ind, cl = self._get_random_volume_ind(bias) 280 | volumes = self._get_volume_tuple(ind) 281 | shape = np.array(volumes[0].shape)[1:] 282 | 283 | if bias == 1 and cl >= 0: 284 | # let's get the list of bias coordinates 285 | # loading from RAM 286 | coords = self.coords_list[ind][0][cl] 287 | 288 | # pick a random item from the list and compute the upper left corner of the patch 289 | n_coords = coords.shape[1] 290 | coord = coords[:, np.random.randint(n_coords)] - self.patch_size//2 291 | elif bias == 2 and (self.coords_list[ind][1]).shape[1] > 0: 292 | # loading from RAM 293 | coords = self.coords_list[ind][1] 294 | 295 | # pick a random item from the list and compute the upper left corner of the patch 296 | n_coords = coords.shape[1] 297 | coord = coords[:, np.random.randint(n_coords)] - self.patch_size//2 298 | else: 299 | # random coordinate uniform from the whole volume 300 | coord = np.random.randint(np.maximum(shape - self.patch_size+1, 1)) 301 | coord = np.minimum(np.maximum(coord, 0), shape - self.patch_size) 302 | # now get the cropped and padded sample 303 | 304 | volume = np.concatenate([crop_and_pad_image(vol, 305 | coord, 306 | self.patch_size, 307 | self.padded_patch_size) for vol in volumes]) 308 | 309 | if self.twoD_patches: 310 | # remove z axis 311 | volume = volume[:, 0] 312 | 313 | if self.augmentation is not None: 314 | # in augmentation we need batch style arrays with an additional 315 | # axes for the batch size 316 | # the label augmentation expects integer valued predictions as input 317 | volume = self.augmentation(volume[np.newaxis])[0] 318 | 319 | return volume.astype(self.dtype) 320 | 321 | 322 | 323 | def SegmentationDoubleBiasDataloader(vol_ds, patch_size, batch_size, num_workers=None, 324 | pin_memory=True, epoch_len=250, *args, **kwargs): 325 | dataset = SegmentationDoubleBiasBatchDataset(vol_ds=vol_ds, 326 | patch_size=patch_size, 327 | batch_size=batch_size, 328 | epoch_len=epoch_len, *args, **kwargs) 329 | if num_workers is None: 330 | num_workers = 0 if os.name == 'nt' else 5 331 | worker_init_fn = lambda _: np.random.seed() 332 | sampler = torch.utils.data.SequentialSampler(range(batch_size * epoch_len)) 333 | return torch.utils.data.DataLoader(dataset, 334 | sampler=sampler, 335 | batch_size=batch_size, 336 | pin_memory=pin_memory, 337 | num_workers=num_workers, 338 | worker_init_fn=worker_init_fn) 339 | --------------------------------------------------------------------------------