├── longiseg ├── __init__.py ├── imageio │ ├── __init__.py │ ├── readme.md │ ├── natural_image_reader_writer.py │ ├── reader_writer_registry.py │ └── tif_reader_writer.py ├── run │ ├── __init__.py │ └── load_pretrained_weights.py ├── ensembling │ └── __init__.py ├── evaluation │ ├── __init__.py │ ├── metrics │ │ ├── __init__.py │ │ ├── volumetric_metrics.py │ │ ├── distance_metrics.py │ │ └── detection_metrics.py │ └── accumulate_cv_results.py ├── inference │ ├── __init__.py │ ├── sliding_window_prediction.py │ └── export_utils.py ├── training │ ├── __init__.py │ ├── logging │ │ ├── __init__.py │ │ └── nnunet_logger.py │ ├── loss │ │ ├── __init__.py │ │ ├── robust_ce_loss.py │ │ └── deep_supervision.py │ ├── dataloading │ │ ├── __init__.py │ │ └── utils.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ └── polylr.py │ ├── LongiSegTrainer │ │ ├── __init__.py │ │ └── variants │ │ │ ├── __init__.py │ │ │ ├── loss │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerCELoss.py │ │ │ ├── nnUNetTrainerDiceLoss.py │ │ │ └── nnUNetTrainerTopkLoss.py │ │ │ ├── sampling │ │ │ ├── __init__.py │ │ │ └── nnUNetTrainer_probabilisticOversampling.py │ │ │ ├── benchmarking │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerBenchmark_5epochs_noDataLoading.py │ │ │ └── nnUNetTrainerBenchmark_5epochs.py │ │ │ ├── competitions │ │ │ ├── __init__.py │ │ │ └── aortaseg24.py │ │ │ ├── longitudinal │ │ │ ├── __init__.py │ │ │ └── LongiSegTrainerDiffWeighting.py │ │ │ ├── lr_schedule │ │ │ ├── __init__.py │ │ │ └── nnUNetTrainerCosAnneal.py │ │ │ ├── optimizer │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerAdam.py │ │ │ └── nnUNetTrainerAdan.py │ │ │ ├── data_augmentation │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainer_noDummy2DDA.py │ │ │ └── nnUNetTrainerNoDA.py │ │ │ ├── training_length │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainer_Xepochs_NoMirroring.py │ │ │ └── nnUNetTrainer_Xepochs.py │ │ │ └── network_architecture │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerNoDeepSupervision.py │ │ │ └── nnUNetTrainerBN.py │ └── data_augmentation │ │ ├── __init__.py │ │ ├── custom_transforms │ │ ├── __init__.py │ │ ├── masking.py │ │ ├── region_based_training.py │ │ ├── transforms_for_dummy_2d.py │ │ ├── deep_supervision_donwsampling.py │ │ └── longi_transforms.py │ │ └── compute_initial_patch_size.py ├── utilities │ ├── __init__.py │ ├── label_handling │ │ └── __init__.py │ ├── plans_handling │ │ └── __init__.py │ ├── network_initialization.py │ ├── helpers.py │ ├── find_class_by_name.py │ ├── collate_outputs.py │ ├── crossval_split.py │ ├── ddp_allgather.py │ ├── default_n_proc_DA.py │ ├── json_export.py │ ├── get_network_from_plans.py │ ├── dataset_name_id_conversion.py │ ├── utils.py │ └── file_path_utilities.py ├── model_sharing │ ├── __init__.py │ ├── model_import.py │ ├── model_download.py │ └── entry_points.py ├── postprocessing │ └── __init__.py ├── preprocessing │ ├── __init__.py │ ├── cropping │ │ ├── __init__.py │ │ └── cropping.py │ ├── normalization │ │ ├── __init__.py │ │ ├── readme.md │ │ ├── map_channel_name_to_normalization.py │ │ └── default_normalization_schemes.py │ ├── preprocessors │ │ └── __init__.py │ └── resampling │ │ ├── __init__.py │ │ └── utils.py ├── dataset_conversion │ ├── __init__.py │ └── generate_dataset_json.py ├── experiment_planning │ ├── __init__.py │ ├── dataset_fingerprint │ │ └── __init__.py │ ├── experiment_planners │ │ ├── __init__.py │ │ ├── resampling │ │ │ └── __init__.py │ │ ├── residual_unets │ │ │ └── __init__.py │ │ └── network_topology.py │ └── plans_for_pretraining │ │ ├── __init__.py │ │ └── move_plans_between_datasets.py ├── configuration.py └── paths.py ├── documentation ├── __init__.py ├── assets │ ├── HI_Logo.png │ ├── LongiSeg.jpg │ ├── dkfz_logo.png │ ├── time_series.jpg │ ├── nnU-Net_overview.png │ ├── regions_vs_labels.png │ ├── scribble_example.png │ ├── amos2022_sparseseg10.png │ ├── HIDSS4Health_Logo_RGB.png │ ├── sparse_annotation_amos.png │ └── amos2022_sparseseg10_2d.png ├── run_inference_with_pretrained_models.md ├── setting_up_paths.md ├── dataset_format_inference.md ├── manual_data_splits.md ├── explanation_normalization.md ├── set_environment_variables.md ├── extending_nnunet.md ├── pretraining_and_finetuning.md ├── how_to_use_longiseg.md ├── region_based_training.md ├── installation_instructions.md └── ignore_label.md ├── setup.py ├── .github └── workflows │ └── codespell.yml ├── .gitignore └── pyproject.toml /longiseg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /documentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/imageio/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/run/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/ensembling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/model_sharing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/postprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/logging/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/dataset_conversion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/experiment_planning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/preprocessing/cropping/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/dataloading/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/preprocessing/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/preprocessing/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/preprocessing/resampling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/data_augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/utilities/label_handling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/utilities/plans_handling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/experiment_planning/dataset_fingerprint/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/experiment_planning/experiment_planners/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/experiment_planning/plans_for_pretraining/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/benchmarking/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/competitions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/longitudinal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/lr_schedule/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/data_augmentation/custom_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/experiment_planning/experiment_planners/resampling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/data_augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/training_length/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/experiment_planning/experiment_planners/residual_unets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/network_architecture/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | if __name__ == "__main__": 4 | setuptools.setup() 5 | -------------------------------------------------------------------------------- /documentation/assets/HI_Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/HI_Logo.png -------------------------------------------------------------------------------- /documentation/assets/LongiSeg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/LongiSeg.jpg -------------------------------------------------------------------------------- /documentation/assets/dkfz_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/dkfz_logo.png -------------------------------------------------------------------------------- /documentation/assets/time_series.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/time_series.jpg -------------------------------------------------------------------------------- /documentation/assets/nnU-Net_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/nnU-Net_overview.png -------------------------------------------------------------------------------- /documentation/assets/regions_vs_labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/regions_vs_labels.png -------------------------------------------------------------------------------- /documentation/assets/scribble_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/scribble_example.png -------------------------------------------------------------------------------- /documentation/assets/amos2022_sparseseg10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/amos2022_sparseseg10.png -------------------------------------------------------------------------------- /documentation/assets/HIDSS4Health_Logo_RGB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/HIDSS4Health_Logo_RGB.png -------------------------------------------------------------------------------- /documentation/assets/sparse_annotation_amos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/sparse_annotation_amos.png -------------------------------------------------------------------------------- /documentation/assets/amos2022_sparseseg10_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/amos2022_sparseseg10_2d.png -------------------------------------------------------------------------------- /longiseg/model_sharing/model_import.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | from longiseg.paths import LongiSeg_results 4 | 5 | 6 | def install_model_from_zip_file(zip_file: str): 7 | with zipfile.ZipFile(zip_file, 'r') as zip_ref: 8 | zip_ref.extractall(LongiSeg_results) -------------------------------------------------------------------------------- /longiseg/imageio/readme.md: -------------------------------------------------------------------------------- 1 | - Derive your adapter from `BaseReaderWriter`. 2 | - Reimplement all abstractmethods. 3 | - make sure to support 2d and 3d input images (or raise some error). 4 | - place it in this folder or nnU-Net won't find it! 5 | - add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py` 6 | 7 | Bam, you're done! -------------------------------------------------------------------------------- /longiseg/preprocessing/normalization/readme.md: -------------------------------------------------------------------------------- 1 | The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different 2 | then you can just 3 | - create a new subclass of ImageNormalization 4 | - map your custom channel identifier to that subclass in channel_name_to_normalization_mapping 5 | - run plan and preprocess again with your custom normlaization scheme -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/competitions/aortaseg24.py: -------------------------------------------------------------------------------- 1 | from longiseg.training.LongiSegTrainer.variants.data_augmentation.nnUNetTrainerNoMirroring import nnUNetTrainer_onlyMirror01 2 | from longiseg.training.LongiSegTrainer.variants.data_augmentation.nnUNetTrainerDA5 import nnUNetTrainerDA5 3 | 4 | class nnUNetTrainer_onlyMirror01_DA5(nnUNetTrainer_onlyMirror01, nnUNetTrainerDA5): 5 | pass 6 | -------------------------------------------------------------------------------- /longiseg/configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from longiseg.utilities.default_n_proc_DA import get_allowed_n_proc_DA 4 | 5 | default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc']) 6 | 7 | ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low 8 | # resolution axis must be 3x as large as the next largest spacing) 9 | 10 | default_n_proc_DA = get_allowed_n_proc_DA() 11 | -------------------------------------------------------------------------------- /.github/workflows/codespell.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Codespell 3 | 4 | on: 5 | push: 6 | branches: [master] 7 | pull_request: 8 | branches: [master] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | codespell: 15 | name: Check for spelling errors 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | - name: Codespell 22 | uses: codespell-project/actions-codespell@v2 23 | -------------------------------------------------------------------------------- /documentation/run_inference_with_pretrained_models.md: -------------------------------------------------------------------------------- 1 | # How to run inference with pretrained models 2 | **Important:** Pretrained weights from nnU-Net v1 are NOT compatible with V2. You will need to retrain with the new 3 | version. But honestly, you already have a fully trained model with which you can run inference (in v1), so 4 | just continue using that! 5 | 6 | Not yet available for V2 :-( 7 | If you wish to run inference with pretrained models, check out the old nnU-Net for now. We are working on this full steam! 8 | -------------------------------------------------------------------------------- /longiseg/utilities/network_initialization.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class InitWeights_He(object): 5 | def __init__(self, neg_slope=1e-2): 6 | self.neg_slope = neg_slope 7 | 8 | def __call__(self, module): 9 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 10 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) 11 | if module.bias is not None: 12 | module.bias = nn.init.constant_(module.bias, 0) 13 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py: -------------------------------------------------------------------------------- 1 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 2 | import torch 3 | 4 | 5 | class nnUNetTrainerNoDeepSupervision(nnUNetTrainerNoLongi): 6 | def __init__( 7 | self, 8 | plans: dict, 9 | configuration: str, 10 | fold: int, 11 | dataset_json: dict, 12 | device: torch.device = torch.device("cuda"), 13 | ): 14 | super().__init__(plans, configuration, fold, dataset_json, device) 15 | self.enable_deep_supervision = False 16 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/longitudinal/LongiSegTrainerDiffWeighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from longiseg.training.LongiSegTrainer.LongiSegTrainer import LongiSegTrainer 4 | 5 | 6 | class LongiSegTrainerDiffWeighting(LongiSegTrainer): 7 | architecture_class_name = "LongiUNetDiffWeighting" 8 | 9 | 10 | class LongiSegTrainerDiffWeightingRP(LongiSegTrainerDiffWeighting): 11 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 12 | device: torch.device = torch.device('cuda')): 13 | super().__init__(plans, configuration, fold, dataset_json, device) 14 | self.random_prior = True -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import CosineAnnealingLR 3 | 4 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 5 | 6 | 7 | class nnUNetTrainerCosAnneal(nnUNetTrainerNoLongi): 8 | def configure_optimizers(self): 9 | optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 10 | momentum=0.99, nesterov=True) 11 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) 12 | return optimizer, lr_scheduler 13 | 14 | -------------------------------------------------------------------------------- /longiseg/utilities/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor: 5 | return torch.softmax(x, 0) 6 | 7 | 8 | def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor: 9 | return torch.softmax(x, 1) 10 | 11 | 12 | def empty_cache(device: torch.device): 13 | if device.type == 'cuda': 14 | torch.cuda.empty_cache() 15 | elif device.type == 'mps': 16 | from torch import mps 17 | mps.empty_cache() 18 | else: 19 | pass 20 | 21 | 22 | class dummy_context(object): 23 | def __enter__(self): 24 | pass 25 | 26 | def __exit__(self, exc_type, exc_val, exc_tb): 27 | pass 28 | -------------------------------------------------------------------------------- /longiseg/preprocessing/resampling/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import longiseg 4 | from batchgenerators.utilities.file_and_folder_operations import join 5 | from longiseg.utilities.find_class_by_name import recursive_find_python_class 6 | 7 | 8 | def recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable: 9 | ret = recursive_find_python_class(join(longiseg.__path__[0], "preprocessing", "resampling"), resampling_fn, 10 | 'longiseg.preprocessing.resampling') 11 | if ret is None: 12 | raise RuntimeError("Unable to find resampling function named '%s'. Please make sure this fn is located in the " 13 | "longiseg.preprocessing.resampling module." % resampling_fn) 14 | else: 15 | return ret 16 | -------------------------------------------------------------------------------- /longiseg/training/lr_scheduler/polylr.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyLRScheduler(_LRScheduler): 5 | def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None): 6 | self.optimizer = optimizer 7 | self.initial_lr = initial_lr 8 | self.max_steps = max_steps 9 | self.exponent = exponent 10 | self.ctr = 0 11 | super().__init__(optimizer, current_step if current_step is not None else -1) 12 | 13 | def step(self, current_step=None): 14 | if current_step is None or current_step == -1: 15 | current_step = self.ctr 16 | self.ctr += 1 17 | 18 | new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent 19 | for param_group in self.optimizer.param_groups: 20 | param_group['lr'] = new_lr 21 | -------------------------------------------------------------------------------- /longiseg/utilities/find_class_by_name.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import join 5 | 6 | 7 | def recursive_find_python_class(folder: str, class_name: str, current_module: str): 8 | tr = None 9 | for importer, modname, ispkg in pkgutil.iter_modules([folder]): 10 | # print(modname, ispkg) 11 | if not ispkg: 12 | m = importlib.import_module(current_module + "." + modname) 13 | if hasattr(m, class_name): 14 | tr = getattr(m, class_name) 15 | break 16 | 17 | if tr is None: 18 | for importer, modname, ispkg in pkgutil.iter_modules([folder]): 19 | if ispkg: 20 | next_current_module = current_module + "." + modname 21 | tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module) 22 | if tr is not None: 23 | break 24 | return tr 25 | -------------------------------------------------------------------------------- /longiseg/utilities/collate_outputs.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | 6 | def collate_outputs(outputs: List[dict]): 7 | """ 8 | used to collate default train_step and validation_step outputs. If you want something different then you gotta 9 | extend this 10 | 11 | we expect outputs to be a list of dictionaries where each of the dict has the same set of keys 12 | """ 13 | collated = {} 14 | for k in outputs[0].keys(): 15 | if np.isscalar(outputs[0][k]): 16 | collated[k] = [o[k] for o in outputs] 17 | elif isinstance(outputs[0][k], np.ndarray): 18 | collated[k] = np.vstack([o[k][None] for o in outputs]) 19 | elif isinstance(outputs[0][k], list): 20 | collated[k] = [item for o in outputs for item in o[k]] 21 | else: 22 | raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. ' 23 | f'Modify collate_outputs to add this functionality') 24 | return collated -------------------------------------------------------------------------------- /longiseg/evaluation/metrics/volumetric_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None): 5 | if ignore_mask is None: 6 | use_mask = np.ones_like(mask_ref, dtype=bool) 7 | else: 8 | use_mask = ~ignore_mask 9 | tp = np.sum((mask_ref & mask_pred) & use_mask) 10 | fp = np.sum(((~mask_ref) & mask_pred) & use_mask) 11 | fn = np.sum((mask_ref & (~mask_pred)) & use_mask) 12 | tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask) 13 | return tp, fp, fn, tn 14 | 15 | 16 | def compute_volumetric_metrics(mask_ref, mask_pred, ignore_mask): 17 | tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask) 18 | dice = 2 * tp / (2 * tp + fp + fn) if tp + fp + fn > 0 else 1 19 | iou = tp / (tp + fp + fn) if tp + fp + fn > 0 else 1 20 | recall = tp / (tp + fn) if tp + fn > 0 else 1 21 | precision = tp / (tp + fp) if tp + fp > 0 else 1 22 | return dice, iou, recall, precision, tp, fp, fn, tn -------------------------------------------------------------------------------- /longiseg/training/data_augmentation/custom_transforms/masking.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 4 | 5 | 6 | class MaskTransform(AbstractTransform): 7 | def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0, 8 | data_key: str = "data", seg_key: str = "seg"): 9 | """ 10 | Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!! 11 | """ 12 | self.apply_to_channels = apply_to_channels 13 | self.seg_key = seg_key 14 | self.data_key = data_key 15 | self.set_outside_to = set_outside_to 16 | self.mask_idx_in_seg = mask_idx_in_seg 17 | 18 | def __call__(self, **data_dict): 19 | mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0 20 | for c in self.apply_to_channels: 21 | data_dict[self.data_key][:, c][mask] = self.set_outside_to 22 | return data_dict 23 | -------------------------------------------------------------------------------- /longiseg/preprocessing/normalization/map_channel_name_to_normalization.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from longiseg.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \ 4 | ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization 5 | 6 | channel_name_to_normalization_mapping = { 7 | 'ct': CTNormalization, 8 | 'nonorm': NoNormalization, 9 | 'zscore': ZScoreNormalization, 10 | 'rescale_to_0_1': RescaleTo01Normalization, 11 | 'rgb_to_0_1': RGBTo01Normalization 12 | } 13 | 14 | 15 | def get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]: 16 | """ 17 | If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is 18 | not found, use the default (ZScoreNormalization) 19 | """ 20 | norm_scheme = channel_name_to_normalization_mapping.get(channel_name.casefold()) 21 | if norm_scheme is None: 22 | norm_scheme = ZScoreNormalization 23 | # print('Using %s for image normalization' % norm_scheme.__name__) 24 | return norm_scheme 25 | -------------------------------------------------------------------------------- /longiseg/training/data_augmentation/compute_initial_patch_size.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): 5 | if isinstance(rot_x, (tuple, list)): 6 | rot_x = max(np.abs(rot_x)) 7 | if isinstance(rot_y, (tuple, list)): 8 | rot_y = max(np.abs(rot_y)) 9 | if isinstance(rot_z, (tuple, list)): 10 | rot_z = max(np.abs(rot_z)) 11 | rot_x = min(90 / 360 * 2. * np.pi, rot_x) 12 | rot_y = min(90 / 360 * 2. * np.pi, rot_y) 13 | rot_z = min(90 / 360 * 2. * np.pi, rot_z) 14 | from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d 15 | coords = np.array(final_patch_size) 16 | final_shape = np.copy(coords) 17 | if len(coords) == 3: 18 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) 19 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) 20 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) 21 | elif len(coords) == 2: 22 | final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) 23 | final_shape /= min(scale_range) 24 | return final_shape.astype(int) 25 | -------------------------------------------------------------------------------- /longiseg/training/loss/robust_ce_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import numpy as np 4 | 5 | 6 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss): 7 | """ 8 | this is just a compatibility layer because my target tensor is float and has an extra dimension 9 | 10 | input must be logits, not probabilities! 11 | """ 12 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 13 | if target.ndim == input.ndim: 14 | assert target.shape[1] == 1 15 | target = target[:, 0] 16 | return super().forward(input, target.long()) 17 | 18 | 19 | class TopKLoss(RobustCrossEntropyLoss): 20 | """ 21 | input must be logits, not probabilities! 22 | """ 23 | def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0): 24 | self.k = k 25 | super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing) 26 | 27 | def forward(self, inp, target): 28 | target = target[:, 0].long() 29 | res = super(TopKLoss, self).forward(inp, target) 30 | num_voxels = np.prod(res.shape, dtype=np.int64) 31 | res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) 32 | return res.mean() 33 | -------------------------------------------------------------------------------- /longiseg/preprocessing/cropping/cropping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import binary_fill_holes 3 | from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, bounding_box_to_slice 4 | 5 | 6 | def create_nonzero_mask(data): 7 | """ 8 | 9 | :param data: 10 | :return: the mask is True where the data is nonzero 11 | """ 12 | assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)" 13 | nonzero_mask = data[0] != 0 14 | for c in range(1, data.shape[0]): 15 | nonzero_mask |= data[c] != 0 16 | return binary_fill_holes(nonzero_mask) 17 | 18 | 19 | def crop_to_nonzero(data, seg=None, bbox=None, nonzero_label=-1): 20 | """ 21 | 22 | :param data: 23 | :param seg: 24 | :param nonzero_label: this will be written into the segmentation map 25 | :return: 26 | """ 27 | nonzero_mask = create_nonzero_mask(data) 28 | if bbox is None: 29 | bbox = get_bbox_from_mask(nonzero_mask) 30 | slicer = bounding_box_to_slice(bbox) 31 | nonzero_mask = nonzero_mask[slicer][None] 32 | 33 | slicer = (slice(None), ) + slicer 34 | data = data[slicer] 35 | if seg is not None: 36 | seg = seg[slicer] 37 | seg[(seg == 0) & (~nonzero_mask)] = nonzero_label 38 | else: 39 | seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label)) 40 | return data, seg, bbox 41 | 42 | 43 | -------------------------------------------------------------------------------- /longiseg/utilities/crossval_split.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | from sklearn.model_selection import KFold, GroupKFold 5 | 6 | 7 | def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]: 8 | splits = [] 9 | kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed) 10 | for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)): 11 | train_keys = np.array(train_identifiers)[train_idx] 12 | test_keys = np.array(train_identifiers)[test_idx] 13 | splits.append({}) 14 | splits[-1]['train'] = list(train_keys) 15 | splits[-1]['val'] = list(test_keys) 16 | return splits 17 | 18 | 19 | def generate_crossval_split_longi(train_patients, seed=12345, n_splits=5): 20 | splits = [] 21 | all_keys = [] 22 | groups = [] 23 | for idx, patient in enumerate(train_patients): 24 | all_keys = all_keys + train_patients[patient] 25 | groups = groups + [idx] * len(train_patients[patient]) 26 | groups = np.array(groups) 27 | group_kfold = GroupKFold(n_splits=n_splits) 28 | 29 | for i, (train_idx, test_idx) in enumerate(group_kfold.split(all_keys, groups=groups)): 30 | train_keys = np.array(all_keys)[train_idx] 31 | test_keys = np.array(all_keys)[test_idx] 32 | splits.append({}) 33 | splits[-1]['train'] = list(train_keys) 34 | splits[-1]['val'] = list(test_keys) 35 | return splits -------------------------------------------------------------------------------- /longiseg/training/data_augmentation/custom_transforms/region_based_training.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 4 | import numpy as np 5 | 6 | 7 | class ConvertSegmentationToRegionsTransform(AbstractTransform): 8 | def __init__(self, regions: Union[List, Tuple], 9 | seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0): 10 | """ 11 | regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region, 12 | example: 13 | regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2 14 | :param regions: 15 | :param seg_key: 16 | :param output_key: 17 | """ 18 | self.seg_channel = seg_channel 19 | self.output_key = output_key 20 | self.seg_key = seg_key 21 | self.regions = regions 22 | 23 | def __call__(self, **data_dict): 24 | seg = data_dict.get(self.seg_key) 25 | if seg is not None: 26 | b, c, *shape = seg.shape 27 | region_output = np.zeros((b, len(self.regions), *shape), dtype=bool) 28 | for region_id, region_labels in enumerate(self.regions): 29 | region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels) 30 | data_dict[self.output_key] = region_output.astype(np.uint8, copy=False) 31 | return data_dict 32 | 33 | -------------------------------------------------------------------------------- /longiseg/training/loss/deep_supervision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DeepSupervisionWrapper(nn.Module): 6 | def __init__(self, loss, weight_factors=None): 7 | """ 8 | Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of 9 | inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then 10 | applied to each entry like this: 11 | l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ... 12 | If weights are None, all w will be 1. 13 | """ 14 | super(DeepSupervisionWrapper, self).__init__() 15 | assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0" 16 | self.weight_factors = tuple(weight_factors) 17 | self.loss = loss 18 | 19 | def forward(self, *args): 20 | assert all([isinstance(i, (tuple, list)) for i in args]), \ 21 | f"all args must be either tuple or list, got {[type(i) for i in args]}" 22 | # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because 23 | # this code is executed a lot of times! 24 | 25 | if self.weight_factors is None: 26 | weights = (1, ) * len(args[0]) 27 | else: 28 | weights = self.weight_factors 29 | 30 | return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) 31 | -------------------------------------------------------------------------------- /documentation/setting_up_paths.md: -------------------------------------------------------------------------------- 1 | # Setting up Paths 2 | 3 | nnU-Net relies on environment variables to know where raw data, preprocessed data and trained model weights are stored. 4 | To use the full functionality of nnU-Net, the following three environment variables must be set: 5 | 6 | 1) `LongiSeg_raw`: This is where you place the raw datasets. This folder will have one subfolder for each dataset names 7 | DatasetXXX_YYY where XXX is a 3-digit identifier (such as 001, 002, 043, 999, ...) and YYY is the (unique) 8 | dataset name. The datasets must be in nnU-Net format, see [here](dataset_format.md). 9 | 10 | Example tree structure: 11 | ``` 12 | LongiSeg_raw/Dataset001_NAME1 13 | ├── dataset.json 14 | ├── imagesTr 15 | │   ├── ... 16 | ├── imagesTs 17 | │   ├── ... 18 | └── labelsTr 19 | ├── ... 20 | LongiSeg_raw/Dataset002_NAME2 21 | ├── dataset.json 22 | ├── imagesTr 23 | │   ├── ... 24 | ├── imagesTs 25 | │   ├── ... 26 | └── labelsTr 27 | ├── ... 28 | ``` 29 | 30 | 2) `LongiSeg_preprocessed`: This is the folder where the preprocessed data will be saved. The data will also be read from 31 | this folder during training. It is important that this folder is located on a drive with low access latency and high 32 | throughput (such as a nvme SSD (PCIe gen 3 is sufficient)). 33 | 34 | 3) `LongiSeg_results`: This specifies where nnU-Net will save the model weights. If pretrained models are downloaded, this 35 | is where it will save them. 36 | 37 | ### How to set environment variables 38 | See [here](set_environment_variables.md). -------------------------------------------------------------------------------- /longiseg/evaluation/metrics/distance_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | from scipy.spatial import KDTree 4 | from scipy import ndimage 5 | 6 | 7 | def SSD(gt: np.ndarray, pred: np.ndarray, spacing: tuple[float] = (1., 1., 1.), ignore_mask: Optional[np.ndarray] = None) -> float: 8 | # slightly adapted from https://github.com/emrekavur/CHAOS-evaluation/blob/master/Python/CHAOSmetrics.py 9 | gt_mask = gt if ignore_mask is None else gt & ~ignore_mask 10 | pred_mask = pred if ignore_mask is None else pred & ~ignore_mask 11 | gt_sum = np.sum(gt_mask) 12 | pred_sum = np.sum(pred_mask) 13 | if gt_sum == 0 or pred_sum == 0: 14 | return (0, 0) if gt_sum == pred_sum else (1000, 1000) # maximum value chosen to fit the largest volumes 15 | 16 | struct = ndimage.generate_binary_structure(3, 1) 17 | spacing = np.array(spacing) 18 | 19 | gt_border = gt_mask ^ ndimage.binary_erosion(gt_mask, structure=struct, border_value=1) 20 | gt_border_voxels = np.array(np.where(gt_border)).T * spacing 21 | 22 | pred_border = pred_mask ^ ndimage.binary_erosion(pred_mask, structure=struct, border_value=1) 23 | pred_border_voxels = np.array(np.where(pred_border)).T * spacing 24 | 25 | tree_ref = KDTree(gt_border_voxels) 26 | dist_seg_to_ref, _ = tree_ref.query(pred_border_voxels) 27 | tree_seg = KDTree(pred_border_voxels) 28 | dist_ref_to_seg, _ = tree_seg.query(gt_border_voxels) 29 | 30 | assd = (dist_seg_to_ref.sum() + dist_ref_to_seg.sum()) / (len(dist_seg_to_ref) + len(dist_ref_to_seg)) 31 | hd95 = np.percentile(np.concatenate((dist_seg_to_ref, dist_ref_to_seg)), 95) 32 | return assd, hd95 -------------------------------------------------------------------------------- /documentation/dataset_format_inference.md: -------------------------------------------------------------------------------- 1 | # Data format for Inference 2 | Read the documentation on the overall [data format](dataset_format.md) first! 3 | 4 | The data format for inference must match the one used for the raw data (**specifically, the images must be in exactly 5 | the same format as in the imagesTr folder**). As before, the filenames must start with a 6 | unique identifier, followed by a 4-digit modality identifier. Here is an example for two different datasets: 7 | 8 | 1) Task005_Prostate: 9 | 10 | This task has 2 modalities, so the files in the input folder must look like this: 11 | 12 | input_folder 13 | ├── prostate_03_0000.nii.gz 14 | ├── prostate_03_0001.nii.gz 15 | ├── prostate_05_0000.nii.gz 16 | ├── prostate_05_0001.nii.gz 17 | ├── prostate_08_0000.nii.gz 18 | ├── prostate_08_0001.nii.gz 19 | ├── ... 20 | 21 | _0000 has to be the T2 image and _0001 has to be the ADC image (as specified by 'channel_names' in the 22 | dataset.json), exactly the same as was used for training. 23 | 24 | 2) Task002_Heart: 25 | 26 | imagesTs 27 | ├── la_001_0000.nii.gz 28 | ├── la_002_0000.nii.gz 29 | ├── la_006_0000.nii.gz 30 | ├── ... 31 | 32 | Task002 only has one modality, so each case only has one _0000.nii.gz file. 33 | 34 | 35 | The segmentations in the output folder will be named {CASE_IDENTIFIER}.nii.gz (omitting the modality identifier). 36 | 37 | Remember that the file format used for inference (.nii.gz in this example) must be the same as was used for training 38 | (and as was specified in 'file_ending' in the dataset.json)! 39 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/data_augmentation/nnUNetTrainer_noDummy2DDA.py: -------------------------------------------------------------------------------- 1 | from longiseg.training.data_augmentation.compute_initial_patch_size import get_patch_size 2 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 3 | import numpy as np 4 | 5 | 6 | class nnUNetTrainer_noDummy2DDA(nnUNetTrainerNoLongi): 7 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 8 | do_dummy_2d_data_aug = False 9 | 10 | patch_size = self.configuration_manager.patch_size 11 | dim = len(patch_size) 12 | if dim == 2: 13 | if max(patch_size) / min(patch_size) > 1.5: 14 | rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) 15 | else: 16 | rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) 17 | mirror_axes = (0, 1) 18 | elif dim == 3: 19 | rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) 20 | mirror_axes = (0, 1, 2) 21 | else: 22 | raise RuntimeError() 23 | 24 | # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the 25 | # old nnunet for now) 26 | initial_patch_size = get_patch_size(patch_size[-dim:], 27 | rotation_for_DA, 28 | rotation_for_DA, 29 | rotation_for_DA, 30 | (0.85, 1.25)) 31 | 32 | self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') 33 | self.inference_allowed_mirroring_axes = mirror_axes 34 | 35 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | *.memmap 92 | *.png 93 | *.zip 94 | *.npz 95 | *.npy 96 | *.jpg 97 | *.jpeg 98 | .idea 99 | *.txt 100 | .idea/* 101 | *.png 102 | *.nii.gz 103 | *.nii 104 | *.tif 105 | *.bmp 106 | *.pkl 107 | *.xml 108 | *.pkl 109 | *.pdf 110 | *.png 111 | *.jpg 112 | *.jpeg 113 | 114 | *.model 115 | 116 | !documentation/assets/* -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, List 2 | 3 | import numpy as np 4 | from batchgeneratorsv2.helpers.scalar_type import RandomScalar 5 | from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform 6 | 7 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 8 | 9 | 10 | class nnUNetTrainerNoDA(nnUNetTrainerNoLongi): 11 | @staticmethod 12 | def get_training_transforms( 13 | patch_size: Union[np.ndarray, Tuple[int]], 14 | rotation_for_DA: RandomScalar, 15 | deep_supervision_scales: Union[List, Tuple, None], 16 | mirror_axes: Tuple[int, ...], 17 | do_dummy_2d_data_aug: bool, 18 | use_mask_for_norm: List[bool] = None, 19 | is_cascaded: bool = False, 20 | foreground_labels: Union[Tuple[int, ...], List[int]] = None, 21 | regions: List[Union[List[int], Tuple[int, ...], int]] = None, 22 | ignore_label: int = None, 23 | ) -> BasicTransform: 24 | return nnUNetTrainerLongi.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels, 25 | regions, ignore_label) 26 | 27 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 28 | # we need to disable mirroring here so that no mirroring will be applied in inference! 29 | rotation_for_DA, do_dummy_2d_data_aug, _, _ = \ 30 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 31 | mirror_axes = None 32 | self.inference_allowed_mirroring_axes = None 33 | initial_patch_size = self.configuration_manager.patch_size 34 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 35 | 36 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/loss/nnUNetTrainerCELoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from longiseg.training.loss.deep_supervision import DeepSupervisionWrapper 3 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 4 | from longiseg.training.loss.robust_ce_loss import RobustCrossEntropyLoss 5 | import numpy as np 6 | 7 | 8 | class nnUNetTrainerCELoss(nnUNetTrainerNoLongi): 9 | def _build_loss(self): 10 | assert not self.label_manager.has_regions, "regions not supported by this trainer" 11 | loss = RobustCrossEntropyLoss( 12 | weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100 13 | ) 14 | 15 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 16 | # this gives higher resolution outputs more weight in the loss 17 | if self.enable_deep_supervision: 18 | deep_supervision_scales = self._get_deep_supervision_scales() 19 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) 20 | weights[-1] = 0 21 | 22 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 23 | weights = weights / weights.sum() 24 | # now wrap the loss 25 | loss = DeepSupervisionWrapper(loss, weights) 26 | return loss 27 | 28 | 29 | class nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss): 30 | def __init__( 31 | self, 32 | plans: dict, 33 | configuration: str, 34 | fold: int, 35 | dataset_json: dict, 36 | device: torch.device = torch.device("cuda"), 37 | ): 38 | """used for debugging plans etc""" 39 | super().__init__(plans, configuration, fold, dataset_json, device) 40 | self.num_epochs = 5 41 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/network_architecture/nnUNetTrainerBN.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, List 2 | from dynamic_network_architectures.building_blocks.helper import get_matching_batchnorm 3 | from torch import nn 4 | 5 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 6 | 7 | 8 | class nnUNetTrainerBN(nnUNetTrainerNoLongi): 9 | @staticmethod 10 | def build_network_architecture(architecture_class_name: str, 11 | arch_init_kwargs: dict, 12 | arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], 13 | num_input_channels: int, 14 | num_output_channels: int, 15 | enable_deep_supervision: bool = True) -> nn.Module: 16 | 17 | if 'norm_op' not in arch_init_kwargs.keys(): 18 | raise RuntimeError("'norm_op' not found in arch_init_kwargs. This does not look like an architecture " 19 | "I can hack BN into. This trainer only works with default nnU-Net architectures.") 20 | 21 | from pydoc import locate 22 | conv_op = locate(arch_init_kwargs['conv_op']) 23 | bn_class = get_matching_batchnorm(conv_op) 24 | arch_init_kwargs['norm_op'] = bn_class.__module__ + '.' + bn_class.__name__ 25 | arch_init_kwargs['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True} 26 | 27 | return nnUNetTrainerLongi.build_network_architecture(architecture_class_name, 28 | arch_init_kwargs, 29 | arch_init_kwargs_req_import, 30 | num_input_channels, 31 | num_output_channels, enable_deep_supervision) 32 | 33 | -------------------------------------------------------------------------------- /longiseg/utilities/ddp_allgather.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Optional, Tuple 15 | 16 | import torch 17 | from torch import distributed 18 | 19 | 20 | def print_if_rank0(*args): 21 | if distributed.get_rank() == 0: 22 | print(*args) 23 | 24 | 25 | class AllGatherGrad(torch.autograd.Function): 26 | # stolen from pytorch lightning 27 | @staticmethod 28 | def forward( 29 | ctx: Any, 30 | tensor: torch.Tensor, 31 | group: Optional["torch.distributed.ProcessGroup"] = None, 32 | ) -> torch.Tensor: 33 | ctx.group = group 34 | 35 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] 36 | 37 | torch.distributed.all_gather(gathered_tensor, tensor, group=group) 38 | gathered_tensor = torch.stack(gathered_tensor, dim=0) 39 | 40 | return gathered_tensor 41 | 42 | @staticmethod 43 | def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: 44 | grad_output = torch.cat(grad_output) 45 | 46 | torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) 47 | 48 | return grad_output[torch.distributed.get_rank()], None 49 | 50 | -------------------------------------------------------------------------------- /longiseg/utilities/default_n_proc_DA.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | 4 | 5 | def get_allowed_n_proc_DA(): 6 | """ 7 | This function is used to set the number of processes used on different Systems. It is specific to our cluster 8 | infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed. 9 | 10 | IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script 11 | (see first line). 12 | 13 | Interpret the output as the number of processes used for data augmentation PER GPU. 14 | 15 | The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our 16 | systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per 17 | GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use. 18 | """ 19 | 20 | if 'nnUNet_n_proc_DA' in os.environ.keys(): 21 | use_this = int(os.environ['nnUNet_n_proc_DA']) 22 | else: 23 | hostname = subprocess.getoutput(['hostname']) 24 | if hostname in ['Fabian', ]: 25 | use_this = 12 26 | elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']: 27 | use_this = 16 28 | elif hostname.startswith('e230-dgx1'): 29 | use_this = 10 30 | elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'): 31 | use_this = 16 32 | elif hostname.startswith('e230-dgx2'): 33 | use_this = 6 34 | elif hostname.startswith('e230-dgxa100-'): 35 | use_this = 28 36 | elif hostname.startswith('lsf22-gpu'): 37 | use_this = 28 38 | elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'): 39 | use_this = 12 40 | else: 41 | use_this = 12 # default value 42 | 43 | use_this = min(use_this, os.cpu_count()) 44 | return use_this 45 | -------------------------------------------------------------------------------- /longiseg/model_sharing/model_download.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import requests 4 | from batchgenerators.utilities.file_and_folder_operations import join, isfile 5 | from time import time 6 | from longiseg.model_sharing.model_import import install_model_from_zip_file 7 | from longiseg.paths import LongiSeg_results 8 | from tqdm import tqdm 9 | 10 | 11 | def download_and_install_from_url(url): 12 | assert LongiSeg_results is not None, "Cannot install model because network_training_output_dir is not " \ 13 | "set (RESULTS_FOLDER missing as environment variable, see " \ 14 | "Installation instructions)" 15 | print('Downloading pretrained model from url:', url) 16 | import http.client 17 | http.client.HTTPConnection._http_vsn = 10 18 | http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0' 19 | 20 | import os 21 | home = os.path.expanduser('~') 22 | random_number = int(time() * 1e7) 23 | tempfile = join(home, f'.nnunetdownload_{str(random_number)}') 24 | 25 | try: 26 | download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16) 27 | print("Download finished. Extracting...") 28 | install_model_from_zip_file(tempfile) 29 | print("Done") 30 | except Exception as e: 31 | raise e 32 | finally: 33 | if isfile(tempfile): 34 | os.remove(tempfile) 35 | 36 | 37 | def download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str: 38 | # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests 39 | # NOTE the stream=True parameter below 40 | with requests.get(url, stream=True, timeout=100) as r: 41 | r.raise_for_status() 42 | with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f: 43 | for chunk in r.iter_content(chunk_size=chunk_size): 44 | f.write(chunk) 45 | return local_filename 46 | 47 | 48 | -------------------------------------------------------------------------------- /documentation/manual_data_splits.md: -------------------------------------------------------------------------------- 1 | # How to generate custom splits in nnU-Net 2 | 3 | Sometimes, the default 5-fold cross-validation split by nnU-Net does not fit a project. Maybe you want to run 3-fold 4 | cross-validation instead? Or maybe your training cases cannot be split randomly and require careful stratification. 5 | Fear not, for nnU-Net has got you covered (it really can do anything <3). 6 | 7 | The splits nnU-Net uses are generated in the `do_split` function of nnUNetTrainer. This function will first look for 8 | existing splits, stored as a file, and if no split exists it will create one. So if you wish to influence the split, 9 | manually creating a split file that will then be recognized and used is the way to go! 10 | 11 | The split file is located in the `LongiSeg_preprocessed/DATASETXXX_NAME` folder. So it is best practice to first 12 | populate this folder by running `nnUNetv2_plan_and_preproccess`. 13 | 14 | Splits are stored as a .json file. They are a simple python list. The length of that list is the number of splits it 15 | contains (so it's 5 in the default nnU-Net). Each list entry is a dictionary with keys 'train' and 'val'. Values are 16 | again simply lists with the train identifiers in each set. To illustrate this, I am just messing with the Dataset002 17 | file as an example: 18 | 19 | ```commandline 20 | In [1]: from batchgenerators.utilities.file_and_folder_operations import load_json 21 | 22 | In [2]: splits = load_json('splits_final.json') 23 | 24 | In [3]: len(splits) 25 | Out[3]: 5 26 | 27 | In [4]: splits[0].keys() 28 | Out[4]: dict_keys(['train', 'val']) 29 | 30 | In [5]: len(splits[0]['train']) 31 | Out[5]: 16 32 | 33 | In [6]: len(splits[0]['val']) 34 | Out[6]: 4 35 | 36 | In [7]: print(splits[0]) 37 | {'train': ['la_003', 'la_004', 'la_005', 'la_009', 'la_010', 'la_011', 'la_014', 'la_017', 'la_018', 'la_019', 'la_020', 'la_022', 'la_023', 'la_026', 'la_029', 'la_030'], 38 | 'val': ['la_007', 'la_016', 'la_021', 'la_024']} 39 | ``` 40 | 41 | If you are still not sure what splits are supposed to look like, simply download some reference dataset from the 42 | [Medical Decathlon](http://medicaldecathlon.com/), start some training (to generate the splits) and manually inspect 43 | the .json file with your text editor of choice! 44 | 45 | In order to generate your custom splits, all you need to do is reproduce the data structure explained above and save it as 46 | `splits_final.json` in the `LongiSeg_preprocessed/DATASETXXX_NAME` folder. Then use `nnUNetv2_train` etc. as usual. -------------------------------------------------------------------------------- /longiseg/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | """ 18 | PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP 19 | """ 20 | 21 | LongiSeg_raw = os.environ.get('LongiSeg_raw') 22 | LongiSeg_preprocessed = os.environ.get('LongiSeg_preprocessed') 23 | LongiSeg_results = os.environ.get('LongiSeg_results') 24 | 25 | if LongiSeg_raw is None: 26 | print("Could not find LongiSeg_raw environment variable, falling back to nnUNet_raw") 27 | LongiSeg_raw = os.environ.get('nnUNet_raw') 28 | if LongiSeg_preprocessed is None: 29 | print("Could not find LongiSeg_preprocessed environment variable, falling back to nnUNet_preprocessed") 30 | LongiSeg_preprocessed = os.environ.get('nnUNet_preprocessed') 31 | if LongiSeg_results is None: 32 | print("Could not find LongiSeg_results environment variable, falling back to nnUNet_results") 33 | LongiSeg_results = os.environ.get('nnUNet_results') 34 | 35 | if LongiSeg_raw is None: 36 | print("LongiSeg_raw is not defined and LongiSeg can only be used on data for which preprocessed files " 37 | "are already present on your system. LongiSeg cannot be used for experiment planning and preprocessing like " 38 | "this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set " 39 | "this up properly.") 40 | 41 | if LongiSeg_preprocessed is None: 42 | print("LongiSeg_preprocessed is not defined and LongiSeg can not be used for preprocessing " 43 | "or training. If this is not intended, please read documentation/setting_up_paths.md for information on how " 44 | "to set this up.") 45 | 46 | if LongiSeg_results is None: 47 | print("LongiSeg_results is not defined and LongiSeg cannot be used for training or " 48 | "inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information " 49 | "on how to set this up.") 50 | -------------------------------------------------------------------------------- /longiseg/utilities/json_export.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def recursive_fix_for_json_export(my_dict: dict): 8 | # json is ... a very nice thing to have 9 | # 'cannot serialize object of type bool_/int64/float64'. Apart from that of course... 10 | keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys.... 11 | for k in keys: 12 | if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)): 13 | tmp = my_dict[k] 14 | del my_dict[k] 15 | my_dict[int(k)] = tmp 16 | del tmp 17 | k = int(k) 18 | 19 | if isinstance(my_dict[k], dict): 20 | recursive_fix_for_json_export(my_dict[k]) 21 | elif isinstance(my_dict[k], np.ndarray): 22 | assert my_dict[k].ndim == 1, 'only 1d arrays are supported' 23 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=list) 24 | elif isinstance(my_dict[k], (np.bool_,)): 25 | my_dict[k] = bool(my_dict[k]) 26 | elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)): 27 | my_dict[k] = int(my_dict[k]) 28 | elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)): 29 | my_dict[k] = float(my_dict[k]) 30 | elif isinstance(my_dict[k], list): 31 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k])) 32 | elif isinstance(my_dict[k], tuple): 33 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple) 34 | elif isinstance(my_dict[k], torch.device): 35 | my_dict[k] = str(my_dict[k]) 36 | else: 37 | pass # pray it can be serialized 38 | 39 | 40 | def fix_types_iterable(iterable, output_type): 41 | # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep your hands off of this. 42 | out = [] 43 | for i in iterable: 44 | if type(i) in (np.int64, np.int32, np.int8, np.uint8): 45 | out.append(int(i)) 46 | elif isinstance(i, dict): 47 | recursive_fix_for_json_export(i) 48 | out.append(i) 49 | elif type(i) in (np.float32, np.float64, np.float16): 50 | out.append(float(i)) 51 | elif type(i) in (np.bool_,): 52 | out.append(bool(i)) 53 | elif isinstance(i, str): 54 | out.append(i) 55 | elif isinstance(i, Iterable): 56 | # print('recursive call on', i, type(i)) 57 | out.append(fix_types_iterable(i, type(i))) 58 | else: 59 | out.append(i) 60 | return output_type(out) 61 | -------------------------------------------------------------------------------- /longiseg/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, List 2 | 3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 4 | 5 | 6 | class Convert3DTo2DTransform(AbstractTransform): 7 | def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): 8 | """ 9 | Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel 10 | """ 11 | self.apply_to_keys = apply_to_keys 12 | 13 | def __call__(self, **data_dict): 14 | for k in self.apply_to_keys: 15 | shp = data_dict[k].shape 16 | assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.' 17 | data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) 18 | shape_key = f'orig_shape_{k}' 19 | assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \ 20 | f'It does that using the {shape_key} key. That key is ' \ 21 | f'already taken. Bummer.' 22 | data_dict[shape_key] = shp 23 | return data_dict 24 | 25 | 26 | class Convert2DTo3DTransform(AbstractTransform): 27 | def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): 28 | """ 29 | Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z) 30 | """ 31 | self.apply_to_keys = apply_to_keys 32 | 33 | def __call__(self, **data_dict): 34 | for k in self.apply_to_keys: 35 | shape_key = f'orig_shape_{k}' 36 | assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \ 37 | f'Convert2DTo3DTransform only works in tandem with ' \ 38 | f'Convert3DTo2DTransform and you probably forgot to add ' \ 39 | f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \ 40 | f'is where the missing key is generated)' 41 | original_shape = data_dict[shape_key] 42 | current_shape = data_dict[k].shape 43 | data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2], 44 | current_shape[-2], current_shape[-1])) 45 | return data_dict 46 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from longiseg.training.LongiSegTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import ( 4 | nnUNetTrainerBenchmark_5epochs, 5 | ) 6 | from longiseg.utilities.label_handling.label_handling import determine_num_input_channels 7 | 8 | 9 | class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs): 10 | def __init__( 11 | self, 12 | plans: dict, 13 | configuration: str, 14 | fold: int, 15 | dataset_json: dict, 16 | device: torch.device = torch.device("cuda"), 17 | ): 18 | super().__init__(plans, configuration, fold, dataset_json, device) 19 | self._set_batch_size_and_oversample() 20 | num_input_channels = determine_num_input_channels( 21 | self.plans_manager, self.configuration_manager, self.dataset_json 22 | ) 23 | patch_size = self.configuration_manager.patch_size 24 | dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device) 25 | if self.enable_deep_supervision: 26 | dummy_target = [ 27 | torch.round( 28 | torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) 29 | * max(self.label_manager.all_labels) 30 | ) 31 | for k in self._get_deep_supervision_scales() 32 | ] 33 | else: 34 | raise NotImplementedError("This trainer does not support deep supervision") 35 | self.dummy_batch = {"data": dummy_data, "target": dummy_target} 36 | 37 | def get_dataloaders(self): 38 | return None, None 39 | 40 | def run_training(self): 41 | try: 42 | self.on_train_start() 43 | 44 | for epoch in range(self.current_epoch, self.num_epochs): 45 | self.on_epoch_start() 46 | 47 | self.on_train_epoch_start() 48 | train_outputs = [] 49 | for batch_id in range(self.num_iterations_per_epoch): 50 | train_outputs.append(self.train_step(self.dummy_batch)) 51 | self.on_train_epoch_end(train_outputs) 52 | 53 | with torch.no_grad(): 54 | self.on_validation_epoch_start() 55 | val_outputs = [] 56 | for batch_id in range(self.num_val_iterations_per_epoch): 57 | val_outputs.append(self.validation_step(self.dummy_batch)) 58 | self.on_validation_epoch_end(val_outputs) 59 | 60 | self.on_epoch_end() 61 | 62 | self.on_train_end() 63 | except RuntimeError: 64 | self.crashed_with_runtime_error = True 65 | -------------------------------------------------------------------------------- /longiseg/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, List 2 | 3 | from batchgenerators.augmentations.utils import resize_segmentation 4 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 5 | import numpy as np 6 | 7 | 8 | class DownsampleSegForDSTransform2(AbstractTransform): 9 | ''' 10 | data_dict['output_key'] will be a list of segmentations scaled according to ds_scales 11 | ''' 12 | def __init__(self, ds_scales: Union[List, Tuple], 13 | order: int = 0, input_key: str = "seg", 14 | output_key: str = "seg", axes: Tuple[int] = None): 15 | """ 16 | Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision 17 | output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape. 18 | ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling 19 | for each axis independently 20 | """ 21 | self.axes = axes 22 | self.output_key = output_key 23 | self.input_key = input_key 24 | self.order = order 25 | self.ds_scales = ds_scales 26 | 27 | def __call__(self, **data_dict): 28 | if self.axes is None: 29 | axes = list(range(2, data_dict[self.input_key].ndim)) 30 | else: 31 | axes = self.axes 32 | 33 | output = [] 34 | for s in self.ds_scales: 35 | if not isinstance(s, (tuple, list)): 36 | s = [s] * len(axes) 37 | else: 38 | assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \ 39 | f'for each axis) then the number of entried in that tuple (here ' \ 40 | f'{len(s)}) must be the same as the number of axes (here {len(axes)}).' 41 | 42 | if all([i == 1 for i in s]): 43 | output.append(data_dict[self.input_key]) 44 | else: 45 | new_shape = np.array(data_dict[self.input_key].shape).astype(float) 46 | for i, a in enumerate(axes): 47 | new_shape[a] *= s[i] 48 | new_shape = np.round(new_shape).astype(int) 49 | out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype) 50 | for b in range(data_dict[self.input_key].shape[0]): 51 | for c in range(data_dict[self.input_key].shape[1]): 52 | out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order) 53 | output.append(out_seg) 54 | data_dict[self.output_key] = output 55 | return data_dict 56 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/optimizer/nnUNetTrainerAdam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam, AdamW 3 | 4 | from longiseg.training.lr_scheduler.polylr import PolyLRScheduler 5 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 6 | 7 | 8 | class nnUNetTrainerAdam(nnUNetTrainerNoLongi): 9 | def configure_optimizers(self): 10 | optimizer = AdamW(self.network.parameters(), 11 | lr=self.initial_lr, 12 | weight_decay=self.weight_decay, 13 | amsgrad=True) 14 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 15 | # momentum=0.99, nesterov=True) 16 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) 17 | return optimizer, lr_scheduler 18 | 19 | 20 | class nnUNetTrainerVanillaAdam(nnUNetTrainerNoLongi): 21 | def configure_optimizers(self): 22 | optimizer = Adam(self.network.parameters(), 23 | lr=self.initial_lr, 24 | weight_decay=self.weight_decay) 25 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 26 | # momentum=0.99, nesterov=True) 27 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) 28 | return optimizer, lr_scheduler 29 | 30 | 31 | class nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam): 32 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 33 | device: torch.device = torch.device('cuda')): 34 | super().__init__(plans, configuration, fold, dataset_json, device) 35 | self.initial_lr = 1e-3 36 | 37 | 38 | class nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam): 39 | # https://twitter.com/karpathy/status/801621764144971776?lang=en 40 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 41 | device: torch.device = torch.device('cuda')): 42 | super().__init__(plans, configuration, fold, dataset_json, device) 43 | self.initial_lr = 3e-4 44 | 45 | 46 | class nnUNetTrainerAdam1en3(nnUNetTrainerAdam): 47 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 48 | device: torch.device = torch.device('cuda')): 49 | super().__init__(plans, configuration, fold, dataset_json, device) 50 | self.initial_lr = 1e-3 51 | 52 | 53 | class nnUNetTrainerAdam3en4(nnUNetTrainerAdam): 54 | # https://twitter.com/karpathy/status/801621764144971776?lang=en 55 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 56 | device: torch.device = torch.device('cuda')): 57 | super().__init__(plans, configuration, fold, dataset_json, device) 58 | self.initial_lr = 3e-4 59 | -------------------------------------------------------------------------------- /documentation/explanation_normalization.md: -------------------------------------------------------------------------------- 1 | # Intensity normalization in nnU-Net 2 | 3 | The type of intensity normalization applied in nnU-Net can be controlled via the `channel_names` (former `modalities`) 4 | entry in the dataset.json. Just like the old nnU-Net, per-channel z-scoring as well as dataset-wide z-scoring based on 5 | foreground intensities are supported. However, there have been a few additions as well. 6 | 7 | Reminder: The `channel_names` entry typically looks like this: 8 | 9 | "channel_names": { 10 | "0": "T2", 11 | "1": "ADC" 12 | }, 13 | 14 | It has as many entries as there are input channels for the given dataset. 15 | 16 | To tell you a secret, nnU-Net does not really care what your channels are called. We just use this to determine what normalization 17 | scheme will be used for the given dataset. nnU-Net requires you to specify a normalization strategy for each of your input channels! 18 | If you enter a channel name that is not in the following list, the default (`zscore`) will be used. 19 | 20 | Here is a list of currently available normalization schemes: 21 | 22 | - `CT`: Perform CT normalization. Specifically, collect intensity values from the foreground classes (all but the 23 | background and ignore) from all training cases, compute the mean, standard deviation as well as the 0.5 and 24 | 99.5 percentile of the values. Then clip to the percentiles, followed by subtraction of the mean and division with the 25 | standard deviation. The normalization that is applied is the same for each training case (for this input channel). 26 | The values used by nnU-Net for normalization are stored in the `foreground_intensity_properties_per_channel` entry in the 27 | corresponding plans file. This normalization is suitable for modalities presenting physical quantities such as CT 28 | images and ADC maps. 29 | - `noNorm` : do not perform any normalization at all 30 | - `rescale_to_0_1`: rescale the intensities to [0, 1] 31 | - `rgb_to_0_1`: assumes uint8 inputs. Divides by 255 to rescale uint8 to [0, 1] 32 | - `zscore`/anything else: perform z-scoring (subtract mean and standard deviation) separately for each train case 33 | 34 | **Important:** The nnU-Net default is to perform 'CT' normalization for CT images and 'zscore' for everything else! If 35 | you deviate from that path, make sure to benchmark whether that actually improves results! 36 | 37 | # How to implement custom normalization strategies? 38 | - Head over to longiseg/preprocessing/normalization 39 | - implement a new image normalization class by deriving from ImageNormalization 40 | - register it in longiseg/preprocessing/normalization/map_channel_name_to_normalization.py:channel_name_to_normalization_mapping. 41 | This is where you specify a channel name that should be associated with it 42 | - use it by specifying the correct channel_name 43 | 44 | Normalization can only be applied to one channel at a time. There is currently no way of implementing a normalization scheme 45 | that gets multiple channels as input to be used jointly! -------------------------------------------------------------------------------- /longiseg/inference/sliding_window_prediction.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import numpy as np 4 | import torch 5 | from typing import Union, Tuple, List 6 | from acvl_utils.cropping_and_padding.padding import pad_nd_image 7 | from scipy.ndimage import gaussian_filter 8 | 9 | 10 | @lru_cache(maxsize=2) 11 | def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8, 12 | value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \ 13 | -> torch.Tensor: 14 | tmp = np.zeros(tile_size) 15 | center_coords = [i // 2 for i in tile_size] 16 | sigmas = [i * sigma_scale for i in tile_size] 17 | tmp[tuple(center_coords)] = 1 18 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) 19 | 20 | gaussian_importance_map = torch.from_numpy(gaussian_importance_map) 21 | 22 | gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor) 23 | gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype) 24 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans! 25 | mask = gaussian_importance_map == 0 26 | gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask]) 27 | return gaussian_importance_map 28 | 29 | 30 | def compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \ 31 | List[List[int]]: 32 | assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size" 33 | assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1' 34 | 35 | # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of 36 | # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46 37 | target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size] 38 | 39 | num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)] 40 | 41 | steps = [] 42 | for dim in range(len(tile_size)): 43 | # the highest step value for this dimension is 44 | max_step_value = image_size[dim] - tile_size[dim] 45 | if num_steps[dim] > 1: 46 | actual_step_size = max_step_value / (num_steps[dim] - 1) 47 | else: 48 | actual_step_size = 99999999999 # does not matter because there is only one step at 0 49 | 50 | steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])] 51 | 52 | steps.append(steps_here) 53 | 54 | return steps 55 | 56 | 57 | if __name__ == '__main__': 58 | a = torch.rand((4, 2, 32, 23)) 59 | a_npy = a.numpy() 60 | 61 | a_padded = pad_nd_image(a, new_shape=(48, 27)) 62 | a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27)) 63 | assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))]) 64 | assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))]) 65 | assert np.all(a_padded.numpy() == a_npy_padded) 66 | -------------------------------------------------------------------------------- /longiseg/training/dataloading/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import multiprocessing 3 | import os 4 | from typing import List 5 | from pathlib import Path 6 | from warnings import warn 7 | 8 | import numpy as np 9 | from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles 10 | from longiseg.configuration import default_num_processes 11 | 12 | 13 | def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False, 14 | verify_npy: bool = False, fail_ctr: int = 0) -> None: 15 | data_npy = npz_file[:-3] + "npy" 16 | seg_npy = npz_file[:-4] + "_seg.npy" 17 | try: 18 | npz_content = None # will only be opened on demand 19 | 20 | if overwrite_existing or not isfile(data_npy): 21 | try: 22 | npz_content = np.load(npz_file) if npz_content is None else npz_content 23 | except Exception as e: 24 | print(f"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!") 25 | raise e 26 | np.save(data_npy, npz_content['data']) 27 | 28 | if unpack_segmentation and (overwrite_existing or not isfile(seg_npy)): 29 | try: 30 | npz_content = np.load(npz_file) if npz_content is None else npz_content 31 | except Exception as e: 32 | print(f"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!") 33 | raise e 34 | np.save(npz_file[:-4] + "_seg.npy", npz_content['seg']) 35 | 36 | if verify_npy: 37 | try: 38 | np.load(data_npy, mmap_mode='r') 39 | if isfile(seg_npy): 40 | np.load(seg_npy, mmap_mode='r') 41 | except ValueError: 42 | os.remove(data_npy) 43 | os.remove(seg_npy) 44 | print(f"Error when checking {data_npy} and {seg_npy}, fixing...") 45 | if fail_ctr < 2: 46 | _convert_to_npy(npz_file, unpack_segmentation, overwrite_existing, verify_npy, fail_ctr+1) 47 | else: 48 | raise RuntimeError("Unable to fix unpacking. Please check your system or rerun nnUNetv2_preprocess") 49 | 50 | except KeyboardInterrupt: 51 | if isfile(data_npy): 52 | os.remove(data_npy) 53 | if isfile(seg_npy): 54 | os.remove(seg_npy) 55 | raise KeyboardInterrupt 56 | 57 | 58 | def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_existing: bool = False, 59 | num_processes: int = default_num_processes, 60 | verify: bool = False): 61 | """ 62 | all npz files in this folder belong to the dataset, unpack them all 63 | """ 64 | with multiprocessing.get_context("spawn").Pool(num_processes) as p: 65 | npz_files = subfiles(folder, True, None, ".npz", True) 66 | p.starmap(_convert_to_npy, zip(npz_files, 67 | [unpack_segmentation] * len(npz_files), 68 | [overwrite_existing] * len(npz_files), 69 | [verify] * len(npz_files)) 70 | ) -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 4 | 5 | 6 | class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainerNoLongi): 7 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 8 | device: torch.device = torch.device('cuda')): 9 | super().__init__(plans, configuration, fold, dataset_json, device) 10 | self.num_epochs = 250 11 | 12 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 13 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 14 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 15 | mirror_axes = None 16 | self.inference_allowed_mirroring_axes = None 17 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 18 | 19 | 20 | class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainerNoLongi): 21 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 22 | device: torch.device = torch.device('cuda')): 23 | super().__init__(plans, configuration, fold, dataset_json, device) 24 | self.num_epochs = 2000 25 | 26 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 27 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 28 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 29 | mirror_axes = None 30 | self.inference_allowed_mirroring_axes = None 31 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 32 | 33 | 34 | class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainerNoLongi): 35 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 36 | device: torch.device = torch.device('cuda')): 37 | super().__init__(plans, configuration, fold, dataset_json, device) 38 | self.num_epochs = 4000 39 | 40 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 41 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 42 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 43 | mirror_axes = None 44 | self.inference_allowed_mirroring_axes = None 45 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 46 | 47 | 48 | class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainerNoLongi): 49 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 50 | device: torch.device = torch.device('cuda')): 51 | super().__init__(plans, configuration, fold, dataset_json, device) 52 | self.num_epochs = 8000 53 | 54 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 55 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 56 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 57 | mirror_axes = None 58 | self.inference_allowed_mirroring_axes = None 59 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 60 | 61 | -------------------------------------------------------------------------------- /longiseg/imageio/natural_image_reader_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center 2 | # (DKFZ), Heidelberg, Germany 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Tuple, Union, List 17 | import numpy as np 18 | from longiseg.imageio.base_reader_writer import BaseReaderWriter 19 | from skimage import io 20 | 21 | 22 | class NaturalImage2DIO(BaseReaderWriter): 23 | """ 24 | ONLY SUPPORTS 2D IMAGES!!! 25 | """ 26 | 27 | # there are surely more we could add here. Everything that can be read by skimage.io should be supported 28 | supported_file_endings = [ 29 | '.png', 30 | # '.jpg', 31 | # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps! 32 | '.bmp', 33 | '.tif' 34 | ] 35 | 36 | def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: 37 | images = [] 38 | for f in image_fnames: 39 | npy_img = io.imread(f) 40 | if npy_img.ndim == 3: 41 | # rgb image, last dimension should be the color channel and the size of that channel should be 3 42 | # (or 4 if we have alpha) 43 | assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \ 44 | "dimension must have shape 3 or 4 " \ 45 | f"(RGB or RGBA). Image shape here is {npy_img.shape}" 46 | # move RGB(A) to front, add additional dim so that we have shape (c, 1, X, Y), where c is either 3 or 4 47 | images.append(npy_img.transpose((2, 0, 1))[:, None]) 48 | elif npy_img.ndim == 2: 49 | # grayscale image 50 | images.append(npy_img[None, None]) 51 | 52 | if not self._check_all_same([i.shape for i in images]): 53 | print('ERROR! Not all input images have the same shape!') 54 | print('Shapes:') 55 | print([i.shape for i in images]) 56 | print('Image files:') 57 | print(image_fnames) 58 | raise RuntimeError() 59 | return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': (999, 1, 1)} 60 | 61 | def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: 62 | return self.read_images((seg_fname, )) 63 | 64 | def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: 65 | io.imsave(output_fname, seg[0].astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), check_contrast=False) -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import torch 4 | from batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json 5 | 6 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 7 | from torch import distributed as dist 8 | 9 | 10 | class nnUNetTrainerBenchmark_5epochs(nnUNetTrainerNoLongi): 11 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 12 | device: torch.device = torch.device('cuda')): 13 | super().__init__(plans, configuration, fold, dataset_json, device) 14 | assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results." 15 | self.disable_checkpointing = True 16 | self.num_epochs = 5 17 | assert torch.cuda.is_available(), "This only works on GPU" 18 | self.crashed_with_runtime_error = False 19 | 20 | def perform_actual_validation(self, save_probabilities: bool = False): 21 | pass 22 | 23 | def save_checkpoint(self, filename: str) -> None: 24 | # do not trust people to remember that self.disable_checkpointing must be True for this trainer 25 | pass 26 | 27 | def run_training(self): 28 | try: 29 | super().run_training() 30 | except RuntimeError: 31 | self.crashed_with_runtime_error = True 32 | self.on_train_end() 33 | 34 | def on_train_end(self): 35 | super().on_train_end() 36 | 37 | if not self.is_ddp or self.local_rank == 0: 38 | torch_version = torch.__version__ 39 | cudnn_version = torch.backends.cudnn.version() 40 | gpu_name = torch.cuda.get_device_name() 41 | if self.crashed_with_runtime_error: 42 | fastest_epoch = 'Not enough VRAM!' 43 | else: 44 | epoch_times = [i - j for i, j in zip(self.logger.my_fantastic_logging['epoch_end_timestamps'], 45 | self.logger.my_fantastic_logging['epoch_start_timestamps'])] 46 | fastest_epoch = min(epoch_times) 47 | 48 | if self.is_ddp: 49 | num_gpus = dist.get_world_size() 50 | else: 51 | num_gpus = 1 52 | 53 | benchmark_result_file = join(self.output_folder, 'benchmark_result.json') 54 | if isfile(benchmark_result_file): 55 | old_results = load_json(benchmark_result_file) 56 | else: 57 | old_results = {} 58 | # generate some unique key 59 | hostname = subprocess.getoutput('hostname') 60 | my_key = f"{hostname}__{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__num_gpus_{num_gpus}" 61 | old_results[my_key] = { 62 | 'torch_version': torch_version, 63 | 'cudnn_version': cudnn_version, 64 | 'gpu_name': gpu_name, 65 | 'fastest_epoch': fastest_epoch, 66 | 'num_gpus': num_gpus, 67 | 'hostname': hostname 68 | } 69 | save_json(old_results, 70 | join(self.output_folder, 'benchmark_result.json')) 71 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/loss/nnUNetTrainerDiceLoss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from longiseg.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss 5 | from longiseg.training.loss.deep_supervision import DeepSupervisionWrapper 6 | from longiseg.training.loss.dice import MemoryEfficientSoftDiceLoss 7 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 8 | from longiseg.utilities.helpers import softmax_helper_dim1 9 | 10 | 11 | class nnUNetTrainerDiceLoss(nnUNetTrainerNoLongi): 12 | def _build_loss(self): 13 | loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice, 14 | 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp}, 15 | apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1) 16 | 17 | if self.enable_deep_supervision: 18 | deep_supervision_scales = self._get_deep_supervision_scales() 19 | 20 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 21 | # this gives higher resolution outputs more weight in the loss 22 | weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) 23 | weights[-1] = 0 24 | 25 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 26 | weights = weights / weights.sum() 27 | # now wrap the loss 28 | loss = DeepSupervisionWrapper(loss, weights) 29 | return loss 30 | 31 | 32 | class nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainerNoLongi): 33 | def _build_loss(self): 34 | # set smooth to 0 35 | if self.label_manager.has_regions: 36 | loss = DC_and_BCE_loss({}, 37 | {'batch_dice': self.configuration_manager.batch_dice, 38 | 'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp}, 39 | use_ignore_label=self.label_manager.ignore_label is not None, 40 | dice_class=MemoryEfficientSoftDiceLoss) 41 | else: 42 | loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, 43 | 'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, 44 | ignore_label=self.label_manager.ignore_label, 45 | dice_class=MemoryEfficientSoftDiceLoss) 46 | 47 | if self.enable_deep_supervision: 48 | deep_supervision_scales = self._get_deep_supervision_scales() 49 | 50 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 51 | # this gives higher resolution outputs more weight in the loss 52 | weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) 53 | weights[-1] = 0 54 | 55 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 56 | weights = weights / weights.sum() 57 | # now wrap the loss 58 | loss = DeepSupervisionWrapper(loss, weights) 59 | return loss 60 | 61 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/optimizer/nnUNetTrainerAdan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from longiseg.training.lr_scheduler.polylr import PolyLRScheduler 4 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | try: 7 | from adan_pytorch import Adan 8 | except ImportError: 9 | Adan = None 10 | 11 | 12 | class nnUNetTrainerAdan(nnUNetTrainerNoLongi): 13 | def configure_optimizers(self): 14 | if Adan is None: 15 | raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') 16 | optimizer = Adan(self.network.parameters(), 17 | lr=self.initial_lr, 18 | # betas=(0.02, 0.08, 0.01), defaults 19 | weight_decay=self.weight_decay) 20 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 21 | # momentum=0.99, nesterov=True) 22 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) 23 | return optimizer, lr_scheduler 24 | 25 | 26 | class nnUNetTrainerAdan1en3(nnUNetTrainerAdan): 27 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 28 | device: torch.device = torch.device('cuda')): 29 | super().__init__(plans, configuration, fold, dataset_json, device) 30 | self.initial_lr = 1e-3 31 | 32 | 33 | class nnUNetTrainerAdan3en4(nnUNetTrainerAdan): 34 | # https://twitter.com/karpathy/status/801621764144971776?lang=en 35 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 36 | device: torch.device = torch.device('cuda')): 37 | super().__init__(plans, configuration, fold, dataset_json, device) 38 | self.initial_lr = 3e-4 39 | 40 | 41 | class nnUNetTrainerAdan1en1(nnUNetTrainerAdan): 42 | # this trainer makes no sense -> nan! 43 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 44 | device: torch.device = torch.device('cuda')): 45 | super().__init__(plans, configuration, fold, dataset_json, device) 46 | self.initial_lr = 1e-1 47 | 48 | 49 | class nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan): 50 | # def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 51 | # device: torch.device = torch.device('cuda')): 52 | # super().__init__(plans, configuration, fold, dataset_json, device) 53 | # self.num_epochs = 15 54 | 55 | def configure_optimizers(self): 56 | if Adan is None: 57 | raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') 58 | optimizer = Adan(self.network.parameters(), 59 | lr=self.initial_lr, 60 | # betas=(0.02, 0.08, 0.01), defaults 61 | weight_decay=self.weight_decay) 62 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 63 | # momentum=0.99, nesterov=True) 64 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) 65 | return optimizer, lr_scheduler 66 | 67 | -------------------------------------------------------------------------------- /documentation/set_environment_variables.md: -------------------------------------------------------------------------------- 1 | # How to set environment variables 2 | 3 | nnU-Net requires some environment variables so that it always knows where the raw data, preprocessed data and trained 4 | models are. Depending on the operating system, these environment variables need to be set in different ways. 5 | 6 | Variables can either be set permanently (recommended!) or you can decide to set them every time you call nnU-Net. 7 | 8 | # Linux & MacOS 9 | 10 | ## Permanent 11 | Locate the `.bashrc` file in your home folder and add the following lines to the bottom: 12 | 13 | ```bash 14 | export LongiSeg_raw="/path/to/LongiSeg_raw" 15 | export LongiSeg_preprocessed="/path/to/LongiSeg_preprocessed" 16 | export LongiSeg_results="/path/to/LongiSeg_results" 17 | ``` 18 | 19 | (Of course you need to adapt the paths to the actual folders you intend to use). 20 | If you are using a different shell, such as zsh, you will need to find the correct script for it. For zsh this is `.zshrc`. 21 | 22 | ## Temporary 23 | Just execute the following lines whenever you run nnU-Net: 24 | ```bash 25 | export LongiSeg_raw="/path/to/LongiSeg_raw" 26 | export LongiSeg_preprocessed="/path/to/LongiSeg_preprocessed" 27 | export LongiSeg_results="/path/to/LongiSeg_results" 28 | ``` 29 | (Of course you need to adapt the paths to the actual folders you intend to use). 30 | 31 | Important: These variables will be deleted if you close your terminal! They will also only apply to the current 32 | terminal window and DO NOT transfer to other terminals! 33 | 34 | Alternatively you can also just prefix them to your nnU-Net commands: 35 | 36 | `LongiSeg_results="/path/to/LongiSeg_results" LongiSeg_preprocessed="/path/to/LongiSeg_preprocessed" nnUNetv2_train[...]` 37 | 38 | ## Verify that environment parameters are set 39 | You can always execute `echo ${LongiSeg_raw}` etc to print the environment variables. This will return an empty string if 40 | they were not set. 41 | 42 | # Windows 43 | Useful links: 44 | - [https://www3.ntu.edu.sg](https://www3.ntu.edu.sg/home/ehchua/programming/howto/Environment_Variables.html#:~:text=To%20set%20(or%20change)%20a,it%20to%20an%20empty%20string.) 45 | - [https://phoenixnap.com](https://phoenixnap.com/kb/windows-set-environment-variable) 46 | 47 | ## Permanent 48 | See `Set Environment Variable in Windows via GUI` [here](https://phoenixnap.com/kb/windows-set-environment-variable). 49 | Or read about setx (command prompt). 50 | 51 | ## Temporary 52 | Just execute the following before you run nnU-Net: 53 | 54 | (PowerShell) 55 | ```PowerShell 56 | $Env:LongiSeg_raw = "C:/path/to/LongiSeg_raw" 57 | $Env:LongiSeg_preprocessed = "C:/path/to/LongiSeg_preprocessed" 58 | $Env:LongiSeg_results = "C:/path/to/LongiSeg_results" 59 | ``` 60 | 61 | (Command Prompt) 62 | ```Command Prompt 63 | set LongiSeg_raw=C:/path/to/LongiSeg_raw 64 | set LongiSeg_preprocessed=C:/path/to/LongiSeg_preprocessed 65 | set LongiSeg_results=C:/path/to/LongiSeg_results 66 | ``` 67 | 68 | (Of course you need to adapt the paths to the actual folders you intend to use). 69 | 70 | Important: These variables will be deleted if you close your session! They will also only apply to the current 71 | window and DO NOT transfer to other sessions! 72 | 73 | ## Verify that environment parameters are set 74 | Printing in Windows works differently depending on the environment you are in: 75 | 76 | PowerShell: `echo $Env:[variable_name]` 77 | 78 | Command Prompt: `echo %[variable_name]%` 79 | -------------------------------------------------------------------------------- /longiseg/utilities/get_network_from_plans.py: -------------------------------------------------------------------------------- 1 | import pydoc 2 | import warnings 3 | from typing import Union 4 | 5 | from longiseg.utilities.find_class_by_name import recursive_find_python_class 6 | from batchgenerators.utilities.file_and_folder_operations import join 7 | 8 | 9 | def get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, 10 | allow_init=True, deep_supervision: Union[bool, None] = None): 11 | network_class = arch_class_name 12 | architecture_kwargs = dict(**arch_kwargs) 13 | for ri in arch_kwargs_req_import: 14 | if architecture_kwargs[ri] is not None: 15 | architecture_kwargs[ri] = pydoc.locate(architecture_kwargs[ri]) 16 | 17 | nw_class = pydoc.locate(network_class) 18 | # sometimes things move around, this makes it so that we can at least recover some of that 19 | if nw_class is None: 20 | warnings.warn(f'Network class {network_class} not found. Attempting to locate it within ' 21 | f'dynamic_network_architectures.architectures...') 22 | import dynamic_network_architectures 23 | nw_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"), 24 | network_class.split(".")[-1], 25 | 'dynamic_network_architectures.architectures') 26 | if nw_class is not None: 27 | print(f'FOUND IT: {nw_class}') 28 | else: 29 | raise ImportError('Network class could not be found, please check/correct your plans file') 30 | 31 | if deep_supervision is not None: 32 | architecture_kwargs['deep_supervision'] = deep_supervision 33 | 34 | network = nw_class( 35 | input_channels=input_channels, 36 | num_classes=output_channels, 37 | **architecture_kwargs 38 | ) 39 | 40 | if hasattr(network, 'initialize') and allow_init: 41 | network.apply(network.initialize) 42 | 43 | return network 44 | 45 | if __name__ == "__main__": 46 | import torch 47 | 48 | model = get_network_from_plans( 49 | arch_class_name="dynamic_network_architectures.architectures.unet.ResidualEncoderUNet", 50 | arch_kwargs={ 51 | "n_stages": 7, 52 | "features_per_stage": [32, 64, 128, 256, 512, 512, 512], 53 | "conv_op": "torch.nn.modules.conv.Conv2d", 54 | "kernel_sizes": [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], 55 | "strides": [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]], 56 | "n_blocks_per_stage": [1, 3, 4, 6, 6, 6, 6], 57 | "n_conv_per_stage_decoder": [1, 1, 1, 1, 1, 1], 58 | "conv_bias": True, 59 | "norm_op": "torch.nn.modules.instancenorm.InstanceNorm2d", 60 | "norm_op_kwargs": {"eps": 1e-05, "affine": True}, 61 | "dropout_op": None, 62 | "dropout_op_kwargs": None, 63 | "nonlin": "torch.nn.LeakyReLU", 64 | "nonlin_kwargs": {"inplace": True}, 65 | }, 66 | arch_kwargs_req_import=["conv_op", "norm_op", "dropout_op", "nonlin"], 67 | input_channels=1, 68 | output_channels=4, 69 | allow_init=True, 70 | deep_supervision=True, 71 | ) 72 | data = torch.rand((8, 1, 256, 256)) 73 | target = torch.rand(size=(8, 1, 256, 256)) 74 | outputs = model(data) # this should be a list of torch.Tensor -------------------------------------------------------------------------------- /longiseg/evaluation/accumulate_cv_results.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from typing import Union, List, Tuple 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile 5 | 6 | from longiseg.configuration import default_num_processes 7 | from longiseg.evaluation.evaluate_predictions import compute_metrics_on_folder 8 | from longiseg.paths import LongiSeg_raw, LongiSeg_preprocessed 9 | from longiseg.utilities.plans_handling.plans_handler import PlansManager 10 | 11 | 12 | def accumulate_cv_results(trained_model_folder, 13 | merged_output_folder: str, 14 | folds: Union[List[int], Tuple[int, ...]], 15 | num_processes: int = default_num_processes, 16 | overwrite: bool = True): 17 | """ 18 | There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to 19 | collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files! 20 | """ 21 | 22 | if overwrite and isdir(merged_output_folder): 23 | shutil.rmtree(merged_output_folder) 24 | maybe_mkdir_p(merged_output_folder) 25 | 26 | dataset_json = load_json(join(trained_model_folder, 'dataset.json')) 27 | plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) 28 | rw = plans_manager.image_reader_writer_class() 29 | shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json')) 30 | shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json')) 31 | 32 | did_we_copy_something = False 33 | for f in folds: 34 | expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation') 35 | if not isdir(expected_validation_folder): 36 | raise RuntimeError(f"fold {f} of model {trained_model_folder} is missing. Please train it!") 37 | predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False) 38 | for pf in predicted_files: 39 | if overwrite and isfile(join(merged_output_folder, pf)): 40 | raise RuntimeError(f'More than one of your folds has a prediction for case {pf}') 41 | if overwrite or not isfile(join(merged_output_folder, pf)): 42 | shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf)) 43 | did_we_copy_something = True 44 | 45 | if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')): 46 | label_manager = plans_manager.get_label_manager(dataset_json) 47 | gt_folder = join(LongiSeg_raw, plans_manager.dataset_name, 'labelsTr') 48 | if not isdir(gt_folder): 49 | gt_folder = join(LongiSeg_preprocessed, plans_manager.dataset_name, 'gt_segmentations') 50 | compute_metrics_on_folder(gt_folder, 51 | merged_output_folder, 52 | join(merged_output_folder, 'summary.json'), 53 | rw, 54 | dataset_json['file_ending'], 55 | label_manager.foreground_regions if label_manager.has_regions else 56 | label_manager.foreground_labels, 57 | label_manager.ignore_label, 58 | num_processes) 59 | -------------------------------------------------------------------------------- /longiseg/run/load_pretrained_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._dynamo import OptimizedModule 3 | from torch.nn.parallel import DistributedDataParallel as DDP 4 | import torch.distributed as dist 5 | 6 | 7 | def load_pretrained_weights(network, fname, verbose=False): 8 | """ 9 | Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the 10 | shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps) 11 | identified by keys ending with '.seg_layers') are not transferred! 12 | 13 | If the pretrained weights were obtained with a training outside nnU-Net and DDP or torch.optimize was used, 14 | you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds 15 | '_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as 16 | nnUNetTrainer.save_checkpoint takes care of that! 17 | 18 | """ 19 | if dist.is_initialized(): 20 | saved_model = torch.load(fname, map_location=torch.device('cuda', dist.get_rank()), weights_only=False) 21 | else: 22 | saved_model = torch.load(fname, weights_only=False) 23 | pretrained_dict = saved_model['network_weights'] 24 | 25 | skip_strings_in_pretrained = [ 26 | '.seg_layers.', 27 | ] 28 | 29 | if isinstance(network, DDP): 30 | mod = network.module 31 | else: 32 | mod = network 33 | if isinstance(mod, OptimizedModule): 34 | mod = mod._orig_mod 35 | 36 | model_dict = mod.state_dict() 37 | # verify that all but the segmentation layers have the same shape 38 | for key, _ in model_dict.items(): 39 | if all([i not in key for i in skip_strings_in_pretrained]): 40 | assert key in pretrained_dict, \ 41 | f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \ 42 | f"compatible with your network." 43 | assert model_dict[key].shape == pretrained_dict[key].shape, \ 44 | f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ 45 | f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \ 46 | f"does not seem to be compatible with your network." 47 | 48 | # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained 49 | # encoders. Not supported by this function though (see assertions above) 50 | 51 | # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do' 52 | # pretrained_dict = {'module.' + k if is_ddp else k: v 53 | # for k, v in pretrained_dict.items() 54 | # if (('module.' + k if is_ddp else k) in model_dict) and 55 | # all([i not in k for i in skip_strings_in_pretrained])} 56 | 57 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 58 | if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])} 59 | 60 | model_dict.update(pretrained_dict) 61 | 62 | print("################### Loading pretrained weights from file ", fname, '###################') 63 | if verbose: 64 | print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") 65 | for key, value in pretrained_dict.items(): 66 | print(key, 'shape', value.shape) 67 | print("################### Done ###################") 68 | mod.load_state_dict(model_dict) 69 | 70 | 71 | -------------------------------------------------------------------------------- /documentation/extending_nnunet.md: -------------------------------------------------------------------------------- 1 | # Extending nnU-Net 2 | We hope that the new structure of nnU-Net v2 makes it much more intuitive on how to modify it! We cannot give an 3 | extensive tutorial on how each and every bit of it can be modified. It is better for you to search for the position 4 | in the repository where the thing you intend to change is implemented and start working your way through the code from 5 | there. Setting breakpoints and debugging into nnU-Net really helps in understanding it and thus will help you make the 6 | necessary modifications! 7 | 8 | Here are some things you might want to read before you start: 9 | - Editing nnU-Net configurations through plans files is really powerful now and allows you to change a lot of things regarding 10 | preprocessing, resampling, network topology etc. Read [this](explanation_plans_files.md)! 11 | - [Image normalization](explanation_normalization.md) and [i/o formats](dataset_format.md#supported-file-formats) are easy to extend! 12 | - Manual data splits can be defined as described [here](manual_data_splits.md) 13 | - You can chain arbitrary configurations together into cascades, see [this again](explanation_plans_files.md) 14 | - Read about our support for [region-based training](region_based_training.md) 15 | - If you intend to modify the training procedure (loss, sampling, data augmentation, lr scheduler, etc) then you need 16 | to implement your own trainer class. Best practice is to create a class that inherits from nnUNetTrainer and 17 | implements the necessary changes. Head over to our [trainer classes folder](../longiseg/training/nnUNetTrainer) for 18 | inspiration! There will be similar trainers for what you intend to change and you can take them as a guide. nnUNetTrainer 19 | are structured similarly to PyTorch lightning trainers, this should also make things easier! 20 | - Integrating new network architectures can be done in two ways: 21 | - Quick and dirty: implement a new nnUNetTrainer class and overwrite its `build_network_architecture` function. 22 | Make sure your architecture is compatible with deep supervision (if not, use `nnUNetTrainerNoDeepSupervision` 23 | as basis!) and that it can handle the patch sizes that are thrown at it! Your architecture should NOT apply any 24 | nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that! 25 | - The 'proper' (but difficult) way: Build a dynamically configurable architecture such as the `PlainConvUNet` class 26 | used by default. It needs to have some sort of GPU memory estimation method that can be used to evaluate whether 27 | certain patch sizes and 28 | topologies fit into a specified GPU memory target. Build a new `ExperimentPlanner` that can configure your new 29 | class and communicate with its memory budget estimation. Run `nnUNetv2_plan_and_preprocess` while specifying your 30 | custom `ExperimentPlanner` and a custom `plans_name`. Implement a nnUNetTrainer that can use the plans generated by 31 | your `ExperimentPlanner` to instantiate the network architecture. Specify your plans and trainer when running `nnUNetv2_train`. 32 | It always pays off to first read and understand the corresponding nnU-Net code and use it as a template for your implementation! 33 | - Remember that multi-GPU training, region-based training, ignore label and cascaded training are now simply integrated 34 | into one unified nnUNetTrainer class. No separate classes needed (remember that when implementing your own trainer 35 | classes and ensure support for all of these features! Or raise `NotImplementedError`) 36 | 37 | [//]: # (- Read about our support for [ignore label](ignore_label.md) and [region-based training](region_based_training.md)) 38 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import distributed as dist 4 | 5 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 6 | 7 | 8 | class nnUNetTrainer_probabilisticOversampling(nnUNetTrainerNoLongi): 9 | """ 10 | sampling of foreground happens randomly and not for the last 33% of samples in a batch 11 | since most trainings happen with batch size 2 and nnunet guarantees at least one fg sample, effectively this can 12 | be 50% 13 | Here we compute the actual oversampling percentage used by nnUNetTrainer in order to be as consistent as possible. 14 | If we switch to this oversampling then we can keep it at a constant 0.33 or whatever. 15 | """ 16 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 17 | device: torch.device = torch.device('cuda')): 18 | super().__init__(plans, configuration, fold, dataset_json, device) 19 | self.probabilistic_oversampling = True 20 | self.oversample_foreground_percent = float(np.mean( 21 | [not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent)) 22 | for sample_idx in range(self.configuration_manager.batch_size)])) 23 | self.print_to_log_file(f"self.oversample_foreground_percent {self.oversample_foreground_percent}") 24 | 25 | def _set_batch_size_and_oversample(self): 26 | if not self.is_ddp: 27 | # set batch size to what the plan says, leave oversample untouched 28 | self.batch_size = self.configuration_manager.batch_size 29 | else: 30 | # batch size is distributed over DDP workers and we need to change oversample_percent for each worker 31 | 32 | world_size = dist.get_world_size() 33 | my_rank = dist.get_rank() 34 | 35 | global_batch_size = self.configuration_manager.batch_size 36 | assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ 37 | 'GPUs... Duh.' 38 | 39 | batch_size_per_GPU = [global_batch_size // world_size] * world_size 40 | batch_size_per_GPU = [batch_size_per_GPU[i] + 1 41 | if (batch_size_per_GPU[i] * world_size + i) < global_batch_size 42 | else batch_size_per_GPU[i] 43 | for i in range(len(batch_size_per_GPU))] 44 | assert sum(batch_size_per_GPU) == global_batch_size 45 | print("worker", my_rank, "batch_size", batch_size_per_GPU[my_rank]) 46 | print("worker", my_rank, "oversample", self.oversample_foreground_percent) 47 | 48 | self.batch_size = batch_size_per_GPU[my_rank] 49 | 50 | 51 | class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling): 52 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 53 | device: torch.device = torch.device('cuda')): 54 | super().__init__(plans, configuration, fold, dataset_json, device) 55 | self.oversample_foreground_percent = 0.33 56 | 57 | 58 | class nnUNetTrainer_probabilisticOversampling_010(nnUNetTrainer_probabilisticOversampling): 59 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 60 | device: torch.device = torch.device('cuda')): 61 | super().__init__(plans, configuration, fold, dataset_json, device) 62 | self.oversample_foreground_percent = 0.1 63 | -------------------------------------------------------------------------------- /longiseg/model_sharing/entry_points.py: -------------------------------------------------------------------------------- 1 | from longiseg.model_sharing.model_download import download_and_install_from_url 2 | from longiseg.model_sharing.model_export import export_pretrained_model 3 | from longiseg.model_sharing.model_import import install_model_from_zip_file 4 | 5 | 6 | def print_license_warning(): 7 | print('') 8 | print('######################################################') 9 | print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!') 10 | print('######################################################') 11 | print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some " 12 | "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use " 13 | "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!") 14 | print('######################################################') 15 | print('') 16 | 17 | 18 | def download_by_url(): 19 | import argparse 20 | parser = argparse.ArgumentParser( 21 | description="Use this to download pretrained models. This script is intended to download models via url only. " 22 | "CAREFUL: This script will overwrite " 23 | "existing models (if they share the same trainer class and plans as " 24 | "the pretrained model.") 25 | parser.add_argument("url", type=str, help='URL of the pretrained model') 26 | args = parser.parse_args() 27 | url = args.url 28 | download_and_install_from_url(url) 29 | 30 | 31 | def install_from_zip_entry_point(): 32 | import argparse 33 | parser = argparse.ArgumentParser( 34 | description="Use this to install a zip file containing a pretrained model.") 35 | parser.add_argument("zip", type=str, help='zip file') 36 | args = parser.parse_args() 37 | zip = args.zip 38 | install_model_from_zip_file(zip) 39 | 40 | 41 | def export_pretrained_model_entry(): 42 | import argparse 43 | parser = argparse.ArgumentParser( 44 | description="Use this to export a trained model as a zip file.") 45 | parser.add_argument('-d', type=str, required=True, help='Dataset name or id') 46 | parser.add_argument('-o', type=str, required=True, help='Output file name') 47 | parser.add_argument('-c', nargs='+', type=str, required=False, 48 | default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'), 49 | help="List of configuration names") 50 | parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class') 51 | parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier') 52 | parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids') 53 | parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ), 54 | help='Lis tof checkpoint names to export. Default: checkpoint_final.pth') 55 | parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations') 56 | parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well') 57 | args = parser.parse_args() 58 | 59 | export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr, 60 | plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk, 61 | export_crossval_predictions=args.exp_cv_preds) 62 | -------------------------------------------------------------------------------- /longiseg/training/data_augmentation/custom_transforms/longi_transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Tuple 2 | 3 | from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform 4 | from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform 5 | import torch 6 | 7 | 8 | class MergeTransform(BasicTransform): 9 | def apply(self, data_dict, **params): 10 | if data_dict.get('image_current') is not None and data_dict.get('image_prior') is not None: 11 | data_dict['image'] = self._apply_to_tensor(data_dict['image_current'], data_dict['image_prior'], **params) 12 | del data_dict['image_current'], data_dict['image_prior'] 13 | else: 14 | raise RuntimeError("MergeTransform requires 'image_current' and 'image_prior' in data_dict") 15 | if data_dict.get('segmentation_current') is not None and data_dict.get('segmentation_prior') is not None: 16 | data_dict['segmentation'] = self._apply_to_tensor(data_dict['segmentation_current'], data_dict['segmentation_prior'], **params) 17 | del data_dict['segmentation_current'], data_dict['segmentation_prior'] 18 | else: 19 | raise RuntimeError("MergeTransform requires 'segmentation_current' and 'segmentation_prior' in data_dict") 20 | return data_dict 21 | 22 | def _apply_to_tensor(self, current: torch.Tensor, prior: torch.Tensor, **params) -> torch.Tensor: 23 | merged = torch.cat([current, prior], dim=0) 24 | return merged 25 | 26 | 27 | class SplitTransform(BasicTransform): 28 | def apply(self, data_dict, **params): 29 | if data_dict.get('image') is not None: 30 | data_dict['image_current'], data_dict['image_prior'] = self._apply_to_tensor(data_dict['image'], **params) 31 | del data_dict['image'] 32 | else: 33 | raise RuntimeError("SplitTransform requires 'image' in data_dict") 34 | if data_dict.get('segmentation') is not None: 35 | data_dict['segmentation_current'], data_dict['segmentation_prior'] = self._apply_to_tensor(data_dict['segmentation'], **params) 36 | del data_dict['segmentation'] 37 | else: 38 | raise RuntimeError("SplitTransform requires 'segmentation' in data_dict") 39 | return data_dict 40 | 41 | def _apply_to_tensor(self, tensor: torch.Tensor, **params) -> torch.Tensor: 42 | channels = tensor.shape[0] 43 | current = tensor[:channels//2] 44 | prior = tensor[channels//2:] 45 | return current, prior 46 | 47 | 48 | class ConvertSegToOneHot(BasicTransform): 49 | def __init__(self, foreground_labels: Union[Tuple[int, ...], List[int]], key: str = "segmentation_prior", 50 | dtype_key: str = "image_prior"): 51 | super().__init__() 52 | self.foreground_labels = foreground_labels 53 | self.key = key 54 | self.dtype_key = dtype_key 55 | 56 | def apply(self, data_dict, **params): 57 | seg = data_dict[self.key] 58 | seg_onehot = torch.zeros((len(self.foreground_labels), *seg.shape), dtype=data_dict[self.dtype_key].dtype) 59 | for i, l in enumerate(self.foreground_labels): 60 | seg_onehot[i][seg == l] = 1 61 | data_dict[self.key] = seg_onehot.squeeze_(1) 62 | return data_dict 63 | 64 | 65 | class DownsampleSegForDSTransformLongi(DownsampleSegForDSTransform): 66 | def __init__(self, ds_scales: Union[List, Tuple], key: str = "segmentation_current"): 67 | super().__init__(ds_scales) 68 | self.key = key 69 | 70 | def apply(self, data_dict, **params): 71 | data_dict[self.key] = self._apply_to_segmentation(data_dict[self.key], **params) 72 | return data_dict -------------------------------------------------------------------------------- /documentation/pretraining_and_finetuning.md: -------------------------------------------------------------------------------- 1 | # Pretraining with nnU-Net 2 | 3 | ## Intro 4 | 5 | So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some pretraining dataset 6 | and then use the final network weights as initialization for your target dataset. 7 | 8 | As a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a 9 | result of the automated dataset analysis and experiment planning nnU-Net is known for. So, out of the box, it is not 10 | possible to simply take the network weights from some dataset and then reuse them for another. 11 | 12 | Consequently, the plans need to be aligned between the two tasks. In this README we show how this can be achieved and 13 | how the resulting weights can then be used for initialization. 14 | 15 | ### Terminology 16 | 17 | Throughout this README we use the following terminology: 18 | 19 | - `pretraining dataset` is the dataset you intend to run the pretraining on 20 | - `finetuning dataset` is the dataset you are interested in; the one you wish to fine tune on 21 | 22 | 23 | ## Training on the pretraining dataset 24 | 25 | In order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are 26 | only interested in the finetuning dataset, we first need to run experiment planning (and preprocessing) for it: 27 | 28 | ```bash 29 | nnUNetv2_plan_and_preprocess -d FINETUNING_DATASET 30 | ``` 31 | 32 | Then we need to extract the dataset fingerprint of the pretraining dataset, if not yet available: 33 | 34 | ```bash 35 | nnUNetv2_extract_fingerprint -d PRETRAINING_DATASET 36 | ``` 37 | 38 | Now we can take the plans from the finetuning dataset and transfer it to the pretraining dataset: 39 | 40 | ```bash 41 | nnUNetv2_move_plans_between_datasets -s FINETUNING_DATASET -t PRETRAINING_DATASET -sp FINETUNING_PLANS_IDENTIFIER -tp PRETRAINING_PLANS_IDENTIFIER 42 | ``` 43 | 44 | `FINETUNING_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in 45 | nnUNetv2_plan_and_preprocess. For `PRETRAINING_PLANS_IDENTIFIER` we recommend you set something custom in order to not 46 | overwrite default plans. 47 | 48 | Note that EVERYTHING is transferred between the datasets. Not just the network topology, batch size and patch size but 49 | also the normalization scheme! Therefore, a transfer between datasets that use different normalization schemes may not 50 | work well (but it could, depending on the schemes!). 51 | 52 | Note on CT normalization: Yes, also the clip values, mean and std are transferred! 53 | 54 | Now you can run the preprocessing on the pretraining dataset: 55 | 56 | ```bash 57 | nnUNetv2_preprocess -d PRETRAINING_DATASET -plans_name PRETRAINING_PLANS_IDENTIFIER 58 | ``` 59 | 60 | And run the training as usual: 61 | 62 | ```bash 63 | nnUNetv2_train PRETRAINING_DATASET CONFIG all -p PRETRAINING_PLANS_IDENTIFIER 64 | ``` 65 | 66 | Note how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data. 67 | 68 | ## Using pretrained weights 69 | 70 | Once pretraining is completed (or you obtain compatible weights by other means) you can use them to initialize your model: 71 | 72 | ```bash 73 | nnUNetv2_train FINETUNING_DATASET CONFIG FOLD -pretrained_weights PATH_TO_CHECKPOINT 74 | ``` 75 | 76 | Specify the checkpoint in PATH_TO_CHECKPOINT. 77 | 78 | When loading pretrained weights, all layers except the segmentation layers will be used! 79 | 80 | So far there are no specific nnUNet trainers for fine tuning, so the current recommendation is to just use 81 | nnUNetTrainer. You can however easily write your own trainers with learning rate ramp up, fine-tuning of segmentation 82 | heads or shorter training time. -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/loss/nnUNetTrainerTopkLoss.py: -------------------------------------------------------------------------------- 1 | from longiseg.training.loss.compound_losses import DC_and_topk_loss 2 | from longiseg.training.loss.deep_supervision import DeepSupervisionWrapper 3 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 4 | import numpy as np 5 | from longiseg.training.loss.robust_ce_loss import TopKLoss 6 | 7 | 8 | class nnUNetTrainerTopk10Loss(nnUNetTrainerNoLongi): 9 | def _build_loss(self): 10 | assert not self.label_manager.has_regions, "regions not supported by this trainer" 11 | loss = TopKLoss( 12 | ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10 13 | ) 14 | 15 | if self.enable_deep_supervision: 16 | deep_supervision_scales = self._get_deep_supervision_scales() 17 | 18 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 19 | # this gives higher resolution outputs more weight in the loss 20 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) 21 | weights[-1] = 0 22 | 23 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 24 | weights = weights / weights.sum() 25 | # now wrap the loss 26 | loss = DeepSupervisionWrapper(loss, weights) 27 | return loss 28 | 29 | 30 | class nnUNetTrainerTopk10LossLS01(nnUNetTrainerNoLongi): 31 | def _build_loss(self): 32 | assert not self.label_manager.has_regions, "regions not supported by this trainer" 33 | loss = TopKLoss( 34 | ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, 35 | k=10, 36 | label_smoothing=0.1, 37 | ) 38 | 39 | if self.enable_deep_supervision: 40 | deep_supervision_scales = self._get_deep_supervision_scales() 41 | 42 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 43 | # this gives higher resolution outputs more weight in the loss 44 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) 45 | weights[-1] = 0 46 | 47 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 48 | weights = weights / weights.sum() 49 | # now wrap the loss 50 | loss = DeepSupervisionWrapper(loss, weights) 51 | return loss 52 | 53 | 54 | class nnUNetTrainerDiceTopK10Loss(nnUNetTrainerNoLongi): 55 | def _build_loss(self): 56 | assert not self.label_manager.has_regions, "regions not supported by this trainer" 57 | loss = DC_and_topk_loss( 58 | {"batch_dice": self.configuration_manager.batch_dice, "smooth": 1e-5, "do_bg": False, "ddp": self.is_ddp}, 59 | {"k": 10, "label_smoothing": 0.0}, 60 | weight_ce=1, 61 | weight_dice=1, 62 | ignore_label=self.label_manager.ignore_label, 63 | ) 64 | if self.enable_deep_supervision: 65 | deep_supervision_scales = self._get_deep_supervision_scales() 66 | 67 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 68 | # this gives higher resolution outputs more weight in the loss 69 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) 70 | weights[-1] = 0 71 | 72 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 73 | weights = weights / weights.sum() 74 | # now wrap the loss 75 | loss = DeepSupervisionWrapper(loss, weights) 76 | return loss 77 | -------------------------------------------------------------------------------- /documentation/how_to_use_longiseg.md: -------------------------------------------------------------------------------- 1 | # How to use LongiSeg 2 | LongiSeg inherits nnU-Net’s self-configuring capabilities, allowing it to easily adapt to new datasets. It extends nnU-Net to handle medical image time series, leveraging temporal relationships to improve segmentation accuracy. 3 | 4 | ## Dataset Format 5 | LongiSeg expects the same structured [dataset format](dataset_format.md) as nnU-Net. Datasets can either be saved in the same 6 | folders as nnU-Net or in LongiSeg's own folder (`LongiSeg_raw`, `LongiSeg_preprocessed` and `LongiSeg_results`). 7 | In contrast to nnU-Net, LongiSeg expects an additional `patientsTr.json` file in the dataset folder. This file lists patient IDs and their corresponding scans in chronological order. 8 | 9 | Dataset001_BrainTumour/ 10 | ├── dataset.json 11 | ├── patientsTr.json 12 | ├── imagesTr 13 | ├── imagesTs # optional 14 | └── labelsTr 15 | 16 | This json file should have the following structure: 17 | 18 | { 19 | "patient_1": [ 20 | "patient_1_scan_1", 21 | "patient_1_scan_2", 22 | ... 23 | ], 24 | "patient_2": [ 25 | "patient_2_scan_1", 26 | "patient_2_scan_2", 27 | ... 28 | ], 29 | ... 30 | } 31 | 32 | ## Experiment planning and preprocessing 33 | To run experiment planning and preprocessing of a new dataset, simply run 34 | ```bash 35 | LongiSeg_plan_and_preprocess -d DATASET_ID 36 | ``` 37 | or if you prefer to keep things separate, you can also use `LongiSeg_extract_fingerprint`, `LongiSeg_plan_experiment` 38 | and `LongiSeg_preprocess` (in that order). We refer to the [nnU-Net documentation](how_to_use_nnunet.md#experiment-planning-and-preprocessing) for additional details on experiment planning and preprocessing. 39 | 40 | ## Training 41 | To train a model using LongiSeg, simply run 42 | ```bash 43 | LongiSeg_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD 44 | ``` 45 | 46 | By default, LongiSeg uses the `LongiSegTrainer`, which integrates the temporal dimension by concatenating multi-timepoint images as additional input channels. Other trainers are available with the -tr option: 47 | 48 | - `nnUNetTrainerNoLongi`: A modified `nnUNetTrainer` where training data is split **at the patient level** instead of per scan. (*non-longitudinal baseline*) 49 | - `LongiSegTrainerDiffWeighting`: A longitudinal trainer that incorporates the **Difference Weighting Block** for temporal feature fusion. 50 | - `LongiSegTrainerRP`, `LongiSegTrainerDiffWeightingRP`: longitudinal trainer with randomly sampled instead of fixed prior scan of the same patient (c.f. [longi_dataset](../longiseg/training/dataloading/longi_dataset.py#L149-L153)) 51 | 52 | Other options for training are available as well (`LongiSeg_train -h`). 53 | 54 | ## Inference 55 | Inference with LongiSeg works in a similar way to the [nnU-Net inference](how_to_use_nnunet.md#run-inference), with the added requirement of specifying a patient file (-pat) in either the `LongiSeg_predict` or `LongiSeg_predict_from_modelfolder` commands. The patient file needs to detail the patient structure in the same way as during training. Only cases present in both the input folder **and** `patients.json` will be processed during inference! 56 | ```bash 57 | LongiSeg_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -path /path/to/patients.json -d DATASET_ID 58 | ``` 59 | 60 | ## Evaluation 61 | By default LongiSeg performs nnU-Net's standard evaluation on the 5-fold cross validation. This, however, does not account for individual patients and only calculates a handful of metrics. LongiSeg extends nnU-Net’s evaluation by incorporating additional metrics that provide a more comprehensive assessment of segmentation performance, including volumetric, surface-based, distance-based, and detection metrics. 62 | To run evaluation with LongiSeg, use 63 | ```bash 64 | LongiSeg_evaluate_folder GT_FOLDER PRED_FOLDER -djfile /path/to/dataset.json -pfile /path/to/plans.json -patfile /path/to/patients.json 65 | ``` 66 | 67 | If no patient file is provided, LongiSeg will default to standard nnU-Net evaluation while incorporating the additional metrics. -------------------------------------------------------------------------------- /longiseg/imageio/reader_writer_registry.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from typing import Type 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import join 5 | 6 | import longiseg 7 | from longiseg.imageio.natural_image_reader_writer import NaturalImage2DIO 8 | from longiseg.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient 9 | from longiseg.imageio.simpleitk_reader_writer import SimpleITKIO, SimpleITKIOWithReorient 10 | from longiseg.imageio.tif_reader_writer import Tiff3DIO 11 | from longiseg.imageio.base_reader_writer import BaseReaderWriter 12 | from longiseg.utilities.find_class_by_name import recursive_find_python_class 13 | 14 | LIST_OF_IO_CLASSES = [ 15 | NaturalImage2DIO, 16 | SimpleITKIOWithReorient, 17 | SimpleITKIO, 18 | Tiff3DIO, 19 | NibabelIO, 20 | NibabelIOWithReorient 21 | ] 22 | 23 | 24 | def determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None, 25 | allow_nonmatching_filename: bool = False, verbose: bool = True 26 | ) -> Type[BaseReaderWriter]: 27 | if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \ 28 | dataset_json_content['overwrite_image_reader_writer'] != 'None': 29 | ioclass_name = dataset_json_content['overwrite_image_reader_writer'] 30 | # trying to find that class in the longiseg.imageio module 31 | try: 32 | ret = recursive_find_reader_writer_by_name(ioclass_name) 33 | if verbose: print(f'Using {ret} reader/writer') 34 | return ret 35 | except RuntimeError: 36 | if verbose: print(f'Warning: Unable to find ioclass specified in dataset.json: {ioclass_name}') 37 | if verbose: print('Trying to automatically determine desired class') 38 | return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file, 39 | allow_nonmatching_filename, verbose) 40 | 41 | 42 | def determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False, 43 | verbose: bool = True): 44 | for rw in LIST_OF_IO_CLASSES: 45 | if file_ending.lower() in rw.supported_file_endings: 46 | if example_file is not None: 47 | # if an example file is provided, try if we can actually read it. If not move on to the next reader 48 | try: 49 | tmp = rw() 50 | _ = tmp.read_images((example_file,)) 51 | if verbose: print(f'Using {rw} as reader/writer') 52 | return rw 53 | except: 54 | if verbose: print(f'Failed to open file {example_file} with reader {rw}:') 55 | traceback.print_exc() 56 | pass 57 | else: 58 | if verbose: print(f'Using {rw} as reader/writer') 59 | return rw 60 | else: 61 | if allow_nonmatching_filename and example_file is not None: 62 | try: 63 | tmp = rw() 64 | _ = tmp.read_images((example_file,)) 65 | if verbose: print(f'Using {rw} as reader/writer') 66 | return rw 67 | except: 68 | if verbose: print(f'Failed to open file {example_file} with reader {rw}:') 69 | if verbose: traceback.print_exc() 70 | pass 71 | raise RuntimeError(f"Unable to determine a reader for file ending {file_ending} and file {example_file} (file None means no file provided).") 72 | 73 | 74 | def recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]: 75 | ret = recursive_find_python_class(join(longiseg.__path__[0], "imageio"), rw_class_name, 'longiseg.imageio') 76 | if ret is None: 77 | raise RuntimeError("Unable to find reader writer class '%s'. Please make sure this class is located in the " 78 | "longiseg.imageio module." % rw_class_name) 79 | else: 80 | return ret 81 | -------------------------------------------------------------------------------- /longiseg/utilities/dataset_name_id_conversion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Union 15 | 16 | import os 17 | import numpy as np 18 | 19 | from longiseg.paths import LongiSeg_preprocessed, LongiSeg_raw, LongiSeg_results 20 | from batchgenerators.utilities.file_and_folder_operations import subdirs, isdir 21 | 22 | 23 | def find_candidate_datasets(dataset_id: int): 24 | startswith = "Dataset%03.0d" % dataset_id 25 | if LongiSeg_preprocessed is not None and isdir(LongiSeg_preprocessed): 26 | candidates_preprocessed = subdirs(LongiSeg_preprocessed, prefix=startswith, join=False) 27 | else: 28 | candidates_preprocessed = [] 29 | 30 | if LongiSeg_raw is not None and isdir(LongiSeg_raw): 31 | candidates_raw = subdirs(LongiSeg_raw, prefix=startswith, join=False) 32 | else: 33 | candidates_raw = [] 34 | 35 | candidates_trained_models = [] 36 | if LongiSeg_results is not None and isdir(LongiSeg_results): 37 | candidates_trained_models += subdirs(LongiSeg_results, prefix=startswith, join=False) 38 | 39 | all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models 40 | unique_candidates = np.unique(all_candidates) 41 | return unique_candidates 42 | 43 | 44 | def convert_id_to_dataset_name(dataset_id: int): 45 | unique_candidates = find_candidate_datasets(dataset_id) 46 | if len(unique_candidates) > 1: 47 | raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the " 48 | "following folders:\n%s\n%s\n%s" % (dataset_id, LongiSeg_raw, LongiSeg_preprocessed, LongiSeg_results)) 49 | if len(unique_candidates) == 0: 50 | raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID " 51 | f"exists and that nnU-Net knows where raw and preprocessed data are located " 52 | f"(see Documentation - Installation). Here are your currently defined folders:\n" 53 | f"LongiSeg_preprocessed={os.environ.get('LongiSeg_preprocessed') if os.environ.get('LongiSeg_preprocessed') is not None else 'None'}\n" 54 | f"LongiSeg_results={os.environ.get('LongiSeg_results') if os.environ.get('LongiSeg_results') is not None else 'None'}\n" 55 | f"LongiSeg_raw={os.environ.get('LongiSeg_raw') if os.environ.get('LongiSeg_raw') is not None else 'None'}\n" 56 | f"If something is not right, adapt your environment variables.") 57 | return unique_candidates[0] 58 | 59 | 60 | def convert_dataset_name_to_id(dataset_name: str): 61 | assert dataset_name.startswith("Dataset") 62 | dataset_id = int(dataset_name[7:10]) 63 | return dataset_id 64 | 65 | 66 | def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str: 67 | if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"): 68 | return dataset_name_or_id 69 | if isinstance(dataset_name_or_id, str): 70 | try: 71 | dataset_name_or_id = int(dataset_name_or_id) 72 | except ValueError: 73 | raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to " 74 | "convert it to a dataset ID (int). That failed, however. Please give an integer number " 75 | "('1', '2', etc) or a correct dataset name. Your input: %s" % dataset_name_or_id) 76 | return convert_id_to_dataset_name(dataset_name_or_id) 77 | -------------------------------------------------------------------------------- /longiseg/training/LongiSegTrainer/variants/training_length/nnUNetTrainer_Xepochs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi 4 | 5 | 6 | class nnUNetTrainer_5epochs(nnUNetTrainerNoLongi): 7 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 8 | device: torch.device = torch.device('cuda')): 9 | """used for debugging plans etc""" 10 | super().__init__(plans, configuration, fold, dataset_json, device) 11 | self.num_epochs = 5 12 | 13 | 14 | class nnUNetTrainer_1epoch(nnUNetTrainerNoLongi): 15 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 16 | device: torch.device = torch.device('cuda')): 17 | """used for debugging plans etc""" 18 | super().__init__(plans, configuration, fold, dataset_json, device) 19 | self.num_epochs = 1 20 | 21 | 22 | class nnUNetTrainer_10epochs(nnUNetTrainerNoLongi): 23 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 24 | device: torch.device = torch.device('cuda')): 25 | """used for debugging plans etc""" 26 | super().__init__(plans, configuration, fold, dataset_json, device) 27 | self.num_epochs = 10 28 | 29 | 30 | class nnUNetTrainer_20epochs(nnUNetTrainerNoLongi): 31 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 32 | device: torch.device = torch.device('cuda')): 33 | super().__init__(plans, configuration, fold, dataset_json, device) 34 | self.num_epochs = 20 35 | 36 | 37 | class nnUNetTrainer_50epochs(nnUNetTrainerNoLongi): 38 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 39 | device: torch.device = torch.device('cuda')): 40 | super().__init__(plans, configuration, fold, dataset_json, device) 41 | self.num_epochs = 50 42 | 43 | 44 | class nnUNetTrainer_100epochs(nnUNetTrainerNoLongi): 45 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 46 | device: torch.device = torch.device('cuda')): 47 | super().__init__(plans, configuration, fold, dataset_json, device) 48 | self.num_epochs = 100 49 | 50 | 51 | class nnUNetTrainer_250epochs(nnUNetTrainerNoLongi): 52 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 53 | device: torch.device = torch.device('cuda')): 54 | super().__init__(plans, configuration, fold, dataset_json, device) 55 | self.num_epochs = 250 56 | 57 | 58 | class nnUNetTrainer_500epochs(nnUNetTrainerNoLongi): 59 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 60 | device: torch.device = torch.device('cuda')): 61 | super().__init__(plans, configuration, fold, dataset_json, device) 62 | self.num_epochs = 500 63 | 64 | 65 | class nnUNetTrainer_750epochs(nnUNetTrainerNoLongi): 66 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 67 | device: torch.device = torch.device('cuda')): 68 | super().__init__(plans, configuration, fold, dataset_json, device) 69 | self.num_epochs = 750 70 | 71 | 72 | class nnUNetTrainer_2000epochs(nnUNetTrainerNoLongi): 73 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 74 | device: torch.device = torch.device('cuda')): 75 | super().__init__(plans, configuration, fold, dataset_json, device) 76 | self.num_epochs = 2000 77 | 78 | 79 | class nnUNetTrainer_4000epochs(nnUNetTrainerNoLongi): 80 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 81 | device: torch.device = torch.device('cuda')): 82 | super().__init__(plans, configuration, fold, dataset_json, device) 83 | self.num_epochs = 4000 84 | 85 | 86 | class nnUNetTrainer_8000epochs(nnUNetTrainerNoLongi): 87 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, 88 | device: torch.device = torch.device('cuda')): 89 | super().__init__(plans, configuration, fold, dataset_json, device) 90 | self.num_epochs = 8000 91 | -------------------------------------------------------------------------------- /longiseg/evaluation/metrics/detection_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | import cc3d 4 | from skimage.morphology import dilation 5 | from scipy import sparse 6 | 7 | 8 | def get_instances(gt: np.ndarray, pred: np.ndarray, footprint: int = 0, spacing: tuple[float] = (1., 1., 1.)): 9 | if footprint: 10 | x, y, z = np.ceil(np.divide(footprint, spacing)).astype(int) 11 | struct = np.ones((x, y, z), dtype=np.uint8) 12 | dilated_gt = dilation(gt.astype(np.uint8), struct) 13 | dilated_pred = dilation(pred.astype(np.uint8), struct) 14 | else: 15 | dilated_gt = gt.astype(np.uint8) 16 | dilated_pred = pred.astype(np.uint8) 17 | gt_inst, gt_inst_num = cc3d.connected_components(dilated_gt, return_N=True) 18 | pred_inst, pred_inst_num = cc3d.connected_components(dilated_pred, return_N=True) 19 | gt_inst[(gt!=1)] = 0 20 | pred_inst[(pred!=1)] = 0 21 | return gt_inst, gt_inst_num, pred_inst, pred_inst_num 22 | 23 | 24 | def get_inst_TPFPFN(gt_inst, gt_inst_num, pred_inst, pred_inst_num): 25 | M, N = gt_inst_num, pred_inst_num 26 | if M == 0 or N == 0: 27 | return 0, 0, M, N 28 | 29 | gt_inst_flat = gt_inst.flatten() 30 | pred_inst_flat = pred_inst.flatten() 31 | 32 | gt_masks = gt_inst_flat==np.arange(1, M+1)[:, None] 33 | pred_masks = pred_inst_flat==np.arange(1, N+1)[:, None] 34 | 35 | gt_size = np.bincount(gt_inst_flat) 36 | pred_size = np.bincount(pred_inst_flat) 37 | 38 | gt_masks_sparse = [sparse.coo_matrix(mask) for mask in gt_masks] 39 | pred_masks_sparse = [sparse.coo_matrix(mask) for mask in pred_masks] 40 | 41 | iou_data = [] 42 | rows, cols = [], [] 43 | for i in range(M): 44 | for j in range(N): 45 | # Efficient intersection computation for sparse matrices 46 | gt_mask, pred_mask = gt_masks_sparse[i], pred_masks_sparse[j] 47 | intersection = len(set(gt_mask.col) & set(pred_mask.col)) 48 | if intersection > 0: 49 | union = gt_size[i + 1] + pred_size[j + 1] - intersection 50 | iou_data.append(intersection / union) 51 | rows.append(i) 52 | cols.append(j) 53 | 54 | iou_data_gt = np.array(iou_data) 55 | rows_gt = np.array(rows) 56 | cols_gt = np.array(cols) 57 | iou_data_pred = np.array(iou_data) 58 | rows_pred = np.array(rows) 59 | cols_pred = np.array(cols) 60 | 61 | TP_gt, FN = 0, 0 62 | for i in range(M): 63 | if not i in rows_gt: 64 | FN += 1 65 | continue 66 | iou_i = iou_data_gt[rows_gt == i] 67 | col_i = cols_gt[rows_gt == i] 68 | argmax_iou = np.argmax(iou_i) 69 | max_iou = iou_i[argmax_iou] 70 | if max_iou > 0.1: 71 | TP_gt += 1 72 | iou_data_gt[cols_gt==col_i[argmax_iou]] = 0 73 | else: 74 | FN += 1 75 | 76 | TP_pred, FP = 0, 0 77 | for j in range(N): 78 | if not j in cols_pred: 79 | FP += 1 80 | continue 81 | iou_j = iou_data_pred[cols_pred == j] 82 | row_j = rows_pred[cols_pred == j] 83 | argmax_iou = np.argmax(iou_j) 84 | max_iou = iou_j[argmax_iou] 85 | if max_iou > 0.1: 86 | TP_pred += 1 87 | iou_data_pred[rows_pred==row_j[argmax_iou]] = 0 88 | else: 89 | FP += 1 90 | 91 | return TP_gt, TP_pred, FN, FP 92 | 93 | 94 | def compute_detection_metrics(mask_ref: np.ndarray, mask_pred: np.ndarray, footprint: int = 0, 95 | spacing: tuple[float] = (1., 1., 1.), ignore_mask: Optional[np.ndarray] = None): 96 | gt_mask = mask_ref if ignore_mask is None else mask_ref & ~ignore_mask 97 | pred_mask = mask_pred if ignore_mask is None else mask_pred & ~ignore_mask 98 | 99 | gt_inst, gt_inst_num, pred_inst, pred_inst_num = get_instances(gt_mask, pred_mask, footprint=footprint, spacing=spacing) 100 | TP_gt, TP_pred, FN, FP = get_inst_TPFPFN(gt_inst, gt_inst_num, pred_inst, pred_inst_num) 101 | recall = TP_gt / (TP_gt + FN) if TP_gt + FN > 0 else 1 102 | precision = TP_pred / (TP_pred + FP) if TP_pred + FP > 0 else 1 103 | F1 = (2 * recall * precision) / (recall + precision) if recall + precision > 0 else 0 104 | return F1, recall, precision, TP_gt, TP_pred, FP, FN -------------------------------------------------------------------------------- /documentation/region_based_training.md: -------------------------------------------------------------------------------- 1 | # Region-based training 2 | 3 | ## What is this about? 4 | In some segmentation tasks, most prominently the 5 | [Brain Tumor Segmentation Challenge](http://braintumorsegmentation.org/), the target areas (based on which the metric 6 | will be computed) are different from the labels provided in the training data. This is the case because for some 7 | clinical applications, it is more relevant to detect the whole tumor, tumor core and enhancing tumor instead of the 8 | individual labels (edema, necrosis and non-enhancing tumor, enhancing tumor). 9 | 10 | 11 | 12 | The figure shows an example BraTS case along with label-based representation of the task (top) and region-based 13 | representation (bottom). The challenge evaluation is done on the regions. As we have shown in our 14 | [BraTS 2018 contribution](https://arxiv.org/abs/1809.10483), directly optimizing those 15 | overlapping areas over the individual labels yields better scoring models! 16 | 17 | ## What can nnU-Net do? 18 | nnU-Net's region-based training allows you to learn areas that are constructed by merging individual labels. For 19 | some segmentation tasks this provides a benefit, as this shifts the importance allocated to different labels during training. 20 | Most prominently, this feature can be used to represent **hierarchical classes**, for example when organs + 21 | substructures are to be segmented. Imagine a liver segmentation problem, where vessels and tumors are also to be 22 | segmented. The first target region could thus be the entire liver (including the substructures), while the remaining 23 | targets are the individual substructues. 24 | 25 | Important: nnU-Net still requires integer label maps as input and will produce integer label maps as output! 26 | Region-based training can be used to learn overlapping labels, but there must be a way to model these overlaps 27 | for nnU-Net to work (see below how this is done). 28 | 29 | ## How do you use it? 30 | 31 | When declaring the labels in the `dataset.json` file, BraTS would typically look like this: 32 | 33 | ```python 34 | ... 35 | "labels": { 36 | "background": 0, 37 | "edema": 1, 38 | "non_enhancing_and_necrosis": 2, 39 | "enhancing_tumor": 3 40 | }, 41 | ... 42 | ``` 43 | (we use different int values than the challenge because nnU-Net needs consecutive integers!) 44 | 45 | This representation corresponds to the upper row in the figure above. 46 | 47 | For region-based training, the labels need to be changed to the following: 48 | 49 | ```python 50 | ... 51 | "labels": { 52 | "background": 0, 53 | "whole_tumor": [1, 2, 3], 54 | "tumor_core": [2, 3], 55 | "enhancing_tumor": 3 # or [3] 56 | }, 57 | "regions_class_order": [1, 2, 3], 58 | ... 59 | ``` 60 | This corresponds to the bottom row in the figure above. Note how an additional entry in the dataset.json is 61 | required: `regions_class_order`. This tells nnU-Net how to convert the region representations back to an integer map. 62 | It essentially just tells nnU-Net what labels to place for which region in what order. The length of the 63 | list here needs to be the same as the number of regions (excl background). Each element in the list corresponds 64 | to the label that is placed instead of the region into the final segmentation. Later entries will overwrite earlier ones! 65 | Concretely, for the example given here, nnU-Net 66 | will firstly place the label 1 (edema) where the 'whole_tumor' region was predicted, then place the label 2 67 | (non-enhancing tumor and necrosis) where the "tumor_core" was predicted and finally place the label 3 in the 68 | predicted 'enhancing_tumor' area. With each step, part of the previously set pixels 69 | will be overwritten with the new label! So when setting your `regions_class_order`, place encompassing regions 70 | (like whole tumor etc) first, followed by substructures. 71 | 72 | **IMPORTANT** Because the conversion back to a segmentation map is sensitive to the order in which the regions are 73 | declared ("place label X in the first region") you need to make sure that this order is not perturbed! When 74 | automatically generating the dataset.json, make sure the dictionary keys do not get sorted alphabetically! Set 75 | `sort_keys=False` in `json.dump()`!!! 76 | 77 | nnU-Net will perform the evaluation + model selection also on the regions, not the individual labels! 78 | 79 | That's all. Easy, huh? -------------------------------------------------------------------------------- /longiseg/inference/export_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | import itertools 3 | import torch 4 | import numpy as np 5 | 6 | from acvl_utils.cropping_and_padding.bounding_boxes import int_bbox 7 | 8 | 9 | # adatped from https://github.com/MIC-DKFZ/acvl_utils/blob/master/acvl_utils/cropping_and_padding/bounding_boxes.py#L357 10 | # to work with None in bboxx 11 | def insert_crop_into_image( 12 | image: Union[torch.Tensor, np.ndarray], 13 | crop: Union[torch.Tensor, np.ndarray], 14 | bbox: List[List[int]] 15 | ) -> Union[torch.Tensor, np.ndarray]: 16 | """ 17 | Inserts a cropped patch back into the original image at the position specified by bbox. 18 | If the bounding box extends beyond the image boundaries, only the valid portions are inserted. 19 | If the bounding box lies entirely outside the image, the original image is returned. 20 | 21 | Parameters: 22 | - image: Original N-dimensional torch.Tensor or np.ndarray to which the crop will be inserted. 23 | - crop: Cropped patch of the image to be reinserted. May have additional dimensions compared to bbox. 24 | - bbox: List of [[dim_min, dim_max], ...] defining the bounding box for the last dimensions of the crop in the original image. 25 | 26 | Returns: 27 | - image: The original image with the crop reinserted at the specified location (modified in-place). 28 | """ 29 | # If the bounding box is None and shapes of image and crop are the same, return the crop directly 30 | if all([b is None for b in itertools.chain(*bbox)]) and image.shape==crop.shape: 31 | return crop 32 | 33 | # make sure bounding boxes are int and not uint. Otherwise we may get underflow 34 | bbox = int_bbox(bbox) 35 | 36 | # Ensure that bbox only applies to the last len(bbox) dimensions of crop and image 37 | num_dims = len(image.shape) 38 | crop_dims = len(crop.shape) 39 | bbox_dims = len(bbox) 40 | 41 | if crop_dims < bbox_dims: 42 | raise ValueError("Bounding box dimensions cannot exceed crop dimensions.") 43 | 44 | # Validate that non-cropped leading dimensions match between image and crop 45 | leading_dims = num_dims - bbox_dims 46 | if image.shape[:leading_dims] != crop.shape[:leading_dims]: 47 | raise ValueError("Leading dimensions of crop and image must match.") 48 | 49 | # Check if the bounding box lies completely outside the image bounds for each cropped dimension 50 | for i in range(bbox_dims): 51 | min_val, max_val = bbox[i] 52 | dim_idx = leading_dims + i # Corresponding dimension in the image 53 | 54 | if max_val <= 0 or min_val >= image.shape[dim_idx]: 55 | # If completely out of bounds in any dimension, return the original image 56 | return image 57 | 58 | # Prepare slices for inserting the crop into the original image 59 | image_slices = [] 60 | crop_slices = [] 61 | 62 | # Iterate over all dimensions, applying bbox only to the last len(bbox) dimensions 63 | for i in range(num_dims): 64 | if i < leading_dims: 65 | # For leading dimensions, use entire dimension (slice(None)) and validate shape 66 | image_slices.append(slice(None)) 67 | crop_slices.append(slice(None)) 68 | else: 69 | # For dimensions specified by bbox, calculate the intersection with image bounds 70 | dim_idx = i - leading_dims 71 | min_val, max_val = bbox[dim_idx] 72 | 73 | crop_start = max(0, -min_val) # Start of the crop within the valid area 74 | image_start = max(0, min_val) # Start of the image where the crop will be inserted 75 | image_end = min(max_val, image.shape[i]) # Exclude upper bound by using max_val directly 76 | 77 | # Adjusted range for insertion 78 | crop_end = crop_start + (image_end - image_start) 79 | 80 | # Append slices for both image and crop insertion ranges 81 | image_slices.append(slice(image_start, image_end)) 82 | crop_slices.append(slice(crop_start, crop_end)) 83 | 84 | # Insert the valid part of the crop back into the original image 85 | if isinstance(image, torch.Tensor): 86 | image[tuple(image_slices)] = crop[tuple(crop_slices)] 87 | elif isinstance(image, np.ndarray): 88 | image[tuple(image_slices)] = crop[tuple(crop_slices)] 89 | else: 90 | raise ValueError(f"Unsupported image type {type(image)}") 91 | 92 | return image -------------------------------------------------------------------------------- /longiseg/experiment_planning/experiment_planners/network_topology.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | 4 | 5 | def get_shape_must_be_divisible_by(net_numpool_per_axis): 6 | return 2 ** np.array(net_numpool_per_axis) 7 | 8 | 9 | def pad_shape(shape, must_be_divisible_by): 10 | """ 11 | pads shape so that it is divisible by must_be_divisible_by 12 | :param shape: 13 | :param must_be_divisible_by: 14 | :return: 15 | """ 16 | if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)): 17 | must_be_divisible_by = [must_be_divisible_by] * len(shape) 18 | else: 19 | assert len(must_be_divisible_by) == len(shape) 20 | 21 | new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))] 22 | 23 | for i in range(len(shape)): 24 | if shape[i] % must_be_divisible_by[i] == 0: 25 | new_shp[i] -= must_be_divisible_by[i] 26 | new_shp = np.array(new_shp).astype(int) 27 | return new_shp 28 | 29 | 30 | def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool): 31 | """ 32 | this is the same as get_pool_and_conv_props_v2 from old nnunet 33 | 34 | :param spacing: 35 | :param patch_size: 36 | :param min_feature_map_size: min edge length of feature maps in bottleneck 37 | :param max_numpool: 38 | :return: 39 | """ 40 | # todo review this code 41 | dim = len(spacing) 42 | 43 | current_spacing = deepcopy(list(spacing)) 44 | current_size = deepcopy(list(patch_size)) 45 | 46 | pool_op_kernel_sizes = [[1] * len(spacing)] 47 | conv_kernel_sizes = [] 48 | 49 | num_pool_per_axis = [0] * dim 50 | kernel_size = [1] * dim 51 | 52 | while True: 53 | # exclude axes that we cannot pool further because of min_feature_map_size constraint 54 | valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size] 55 | if len(valid_axes_for_pool) < 1: 56 | break 57 | 58 | spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool] 59 | 60 | # find axis that are within factor of 2 within smallest spacing 61 | min_spacing_of_valid = min(spacings_of_axes) 62 | valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2] 63 | 64 | # max_numpool constraint 65 | valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] 66 | 67 | if len(valid_axes_for_pool) == 1: 68 | if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size: 69 | pass 70 | else: 71 | break 72 | if len(valid_axes_for_pool) < 1: 73 | break 74 | 75 | # now we need to find kernel sizes 76 | # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within 77 | # factor 2 of min_spacing. Once they are 3 they remain 3 78 | for d in range(dim): 79 | if kernel_size[d] == 3: 80 | continue 81 | else: 82 | if current_spacing[d] / min(current_spacing) < 2: 83 | kernel_size[d] = 3 84 | 85 | other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] 86 | 87 | pool_kernel_sizes = [0] * dim 88 | for v in valid_axes_for_pool: 89 | pool_kernel_sizes[v] = 2 90 | num_pool_per_axis[v] += 1 91 | current_spacing[v] *= 2 92 | current_size[v] = np.ceil(current_size[v] / 2) 93 | for nv in other_axes: 94 | pool_kernel_sizes[nv] = 1 95 | 96 | pool_op_kernel_sizes.append(pool_kernel_sizes) 97 | conv_kernel_sizes.append(deepcopy(kernel_size)) 98 | #print(conv_kernel_sizes) 99 | 100 | must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) 101 | patch_size = pad_shape(patch_size, must_be_divisible_by) 102 | 103 | def _to_tuple(lst): 104 | return tuple(_to_tuple(i) if isinstance(i, list) else i for i in lst) 105 | 106 | # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here 107 | conv_kernel_sizes.append([3]*dim) 108 | return num_pool_per_axis, _to_tuple(pool_op_kernel_sizes), _to_tuple(conv_kernel_sizes), tuple(patch_size), must_be_divisible_by 109 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "longiseg" 3 | version = "1.0.0" 4 | requires-python = ">=3.10" 5 | description = "LongiSeg is a framework for longitudinal medical image segmentation built on nnU-Net." 6 | readme = "readme.md" 7 | license = { file = "LICENSE" } 8 | authors = [ 9 | { name = "Yannick Kirchhoff", email = "yannick.kirchhoff@dkfz-heidelberg.de"} 10 | ] 11 | classifiers = [ 12 | "Development Status :: 5 - Production/Stable", 13 | "Intended Audience :: Developers", 14 | "Intended Audience :: Science/Research", 15 | "Intended Audience :: Healthcare Industry", 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: Apache Software License", 18 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 19 | "Topic :: Scientific/Engineering :: Image Recognition", 20 | "Topic :: Scientific/Engineering :: Medical Science Apps.", 21 | ] 22 | keywords = [ 23 | 'longitudinal medical image segmentation', 24 | 'deep learning', 25 | 'image segmentation', 26 | 'semantic segmentation', 27 | 'medical image analysis', 28 | 'medical image segmentation', 29 | 'nnU-Net', 30 | 'nnunet' 31 | ] 32 | dependencies = [ 33 | "torch>=2.1.2", 34 | "acvl-utils>=0.2.3,<0.3", # 0.3 may bring breaking changes. Careful! 35 | "dynamic-network-architectures>=0.3.1,<0.4", # 0.3.1 and lower are supported, 0.4 may have breaking changes. Let's be careful here 36 | "tqdm", 37 | "dicom2nifti", 38 | "scipy", 39 | "batchgenerators>=0.25.1", 40 | "numpy>=1.24", 41 | "scikit-learn", 42 | "scikit-image>=0.19.3", 43 | "SimpleITK>=2.2.1", 44 | "pandas", 45 | "graphviz", 46 | 'tifffile', 47 | 'requests', 48 | "nibabel", 49 | "matplotlib", 50 | "seaborn", 51 | "imagecodecs", 52 | "yacs", 53 | "batchgeneratorsv2>=0.2", 54 | "einops", 55 | "blosc2>=3.0.0b1", 56 | "difference-weighting @ git+https://github.com/MIC-DKFZ/Longitudinal-Difference-Weighting.git" 57 | ] 58 | 59 | [project.urls] 60 | homepage = "https://github.com/MIC-DKFZ/LongiSeg" 61 | repository = "https://github.com/MIC-DKFZ/LongiSeg" 62 | 63 | [project.scripts] 64 | LongiSeg_find_best_configuration = "longiseg.evaluation.find_best_configuration:find_best_configuration_entry_point" 65 | LongiSeg_determine_postprocessing = "longiseg.postprocessing.remove_connected_components:entry_point_determine_postprocessing_folder" 66 | LongiSeg_apply_postprocessing = "longiseg.postprocessing.remove_connected_components:entry_point_apply_postprocessing" 67 | LongiSeg_ensemble = "longiseg.ensembling.ensemble:entry_point_ensemble_folders" 68 | LongiSeg_accumulate_crossval_results = "longiseg.evaluation.find_best_configuration:accumulate_crossval_results_entry_point" 69 | LongiSeg_plot_overlay_pngs = "longiseg.utilities.overlay_plots:entry_point_generate_overlay" 70 | LongiSeg_download_pretrained_model_by_url = "longiseg.model_sharing.entry_points:download_by_url" 71 | LongiSeg_install_pretrained_model_from_zip = "longiseg.model_sharing.entry_points:install_from_zip_entry_point" 72 | LongiSeg_export_model_to_zip = "longiseg.model_sharing.entry_points:export_pretrained_model_entry" 73 | LongiSeg_move_plans_between_datasets = "longiseg.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets" 74 | LongiSeg_plan_and_preprocess = "longiseg.experiment_planning.plan_and_preprocess_longi_entrypoints:plan_and_preprocess_longi_entry" 75 | LongiSeg_extract_fingerprint = "longiseg.experiment_planning.plan_and_preprocess_longi_entrypoints:extract_fingerprint_longi_entry" 76 | LongiSeg_plan_experiment = "longiseg.experiment_planning.plan_and_preprocess_longi_entrypoints:plan_experiment_longi_entry" 77 | LongiSeg_preprocess = "longiseg.experiment_planning.plan_and_preprocess_longi_entrypoints:preprocess_longi_entry" 78 | LongiSeg_train = "longiseg.run.run_training:run_training_longi_entry" 79 | LongiSeg_predict_from_modelfolder = "longiseg.inference.predict_from_raw_data_longi:predict_longi_entry_point_modelfolder" 80 | LongiSeg_predict = "longiseg.inference.predict_from_raw_data_longi:predict_longi_entry_point" 81 | LongiSeg_evaluate_folder = "longiseg.evaluation.evaluate_predictions_longi:evaluate_longi_folder_entry_point" 82 | 83 | [project.optional-dependencies] 84 | dev = [ 85 | "black", 86 | "ruff", 87 | "pre-commit" 88 | ] 89 | 90 | [build-system] 91 | requires = ["setuptools>=67.8.0"] 92 | build-backend = "setuptools.build_meta" 93 | 94 | [tool.codespell] 95 | skip = '.git,*.pdf,*.svg' 96 | # 97 | # ignore-words-list = '' 98 | -------------------------------------------------------------------------------- /longiseg/preprocessing/normalization/default_normalization_schemes.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Type 3 | 4 | import numpy as np 5 | from numpy import number 6 | 7 | 8 | class ImageNormalization(ABC): 9 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None 10 | 11 | def __init__(self, use_mask_for_norm: bool = None, intensityproperties: dict = None, 12 | target_dtype: Type[number] = np.float32): 13 | assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool) 14 | self.use_mask_for_norm = use_mask_for_norm 15 | assert isinstance(intensityproperties, dict) 16 | self.intensityproperties = intensityproperties 17 | self.target_dtype = target_dtype 18 | 19 | @abstractmethod 20 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: 21 | """ 22 | Image and seg must have the same shape. Seg is not always used 23 | """ 24 | pass 25 | 26 | 27 | class ZScoreNormalization(ImageNormalization): 28 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = True 29 | 30 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: 31 | """ 32 | here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by 33 | default. 34 | """ 35 | image = image.astype(self.target_dtype, copy=False) 36 | if self.use_mask_for_norm is not None and self.use_mask_for_norm: 37 | # negative values in the segmentation encode the 'outside' region (think zero values around the brain as 38 | # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image. 39 | # The default nnU-net sets use_mask_for_norm to True if cropping to the nonzero region substantially 40 | # reduced the image size. 41 | mask = seg >= 0 42 | mean = image[mask].mean() 43 | std = image[mask].std() 44 | image[mask] = (image[mask] - mean) / (max(std, 1e-8)) 45 | else: 46 | mean = image.mean() 47 | std = image.std() 48 | image -= mean 49 | image /= (max(std, 1e-8)) 50 | return image 51 | 52 | 53 | class CTNormalization(ImageNormalization): 54 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False 55 | 56 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: 57 | assert self.intensityproperties is not None, "CTNormalization requires intensity properties" 58 | mean_intensity = self.intensityproperties['mean'] 59 | std_intensity = self.intensityproperties['std'] 60 | lower_bound = self.intensityproperties['percentile_00_5'] 61 | upper_bound = self.intensityproperties['percentile_99_5'] 62 | 63 | image = image.astype(self.target_dtype, copy=False) 64 | np.clip(image, lower_bound, upper_bound, out=image) 65 | image -= mean_intensity 66 | image /= max(std_intensity, 1e-8) 67 | return image 68 | 69 | 70 | class NoNormalization(ImageNormalization): 71 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False 72 | 73 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: 74 | return image.astype(self.target_dtype, copy=False) 75 | 76 | 77 | class RescaleTo01Normalization(ImageNormalization): 78 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False 79 | 80 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: 81 | image = image.astype(self.target_dtype, copy=False) 82 | image -= image.min() 83 | image /= np.clip(image.max(), a_min=1e-8, a_max=None) 84 | return image 85 | 86 | 87 | class RGBTo01Normalization(ImageNormalization): 88 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False 89 | 90 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: 91 | assert image.min() >= 0, "RGB images are uint 8, for whatever reason I found pixel values smaller than 0. " \ 92 | "Your images do not seem to be RGB images" 93 | assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \ 94 | ". Your images do not seem to be RGB images" 95 | image = image.astype(self.target_dtype, copy=False) 96 | image /= 255. 97 | return image 98 | 99 | -------------------------------------------------------------------------------- /longiseg/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Union 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, subfiles, save_json 5 | 6 | from longiseg.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json 7 | from longiseg.paths import LongiSeg_preprocessed, LongiSeg_raw 8 | from longiseg.utilities.file_path_utilities import maybe_convert_to_dataset_name 9 | from longiseg.utilities.plans_handling.plans_handler import PlansManager 10 | from longiseg.utilities.utils import get_filenames_of_train_images_and_targets 11 | 12 | 13 | def move_plans_between_datasets( 14 | source_dataset_name_or_id: Union[int, str], 15 | target_dataset_name_or_id: Union[int, str], 16 | source_plans_identifier: str, 17 | target_plans_identifier: str = None): 18 | source_dataset_name = maybe_convert_to_dataset_name(source_dataset_name_or_id) 19 | target_dataset_name = maybe_convert_to_dataset_name(target_dataset_name_or_id) 20 | 21 | if target_plans_identifier is None: 22 | target_plans_identifier = source_plans_identifier 23 | 24 | source_folder = join(LongiSeg_preprocessed, source_dataset_name) 25 | assert isdir(source_folder), f"Cannot move plans because preprocessed directory of source dataset is missing. " \ 26 | f"Run nnUNetv2_plan_and_preprocess for source dataset first!" 27 | 28 | source_plans_file = join(source_folder, source_plans_identifier + '.json') 29 | assert isfile(source_plans_file), f"Source plans are missing. Run the corresponding experiment planning first! " \ 30 | f"Expected file: {source_plans_file}" 31 | 32 | source_plans = load_json(source_plans_file) 33 | source_plans['dataset_name'] = target_dataset_name 34 | 35 | # we need to change data_identifier to use target_plans_identifier 36 | if target_plans_identifier != source_plans_identifier: 37 | for c in source_plans['configurations'].keys(): 38 | if 'data_identifier' in source_plans['configurations'][c].keys(): 39 | old_identifier = source_plans['configurations'][c]["data_identifier"] 40 | if old_identifier.startswith(source_plans_identifier): 41 | new_identifier = target_plans_identifier + old_identifier[len(source_plans_identifier):] 42 | else: 43 | new_identifier = target_plans_identifier + '_' + old_identifier 44 | source_plans['configurations'][c]["data_identifier"] = new_identifier 45 | 46 | # we need to change the reader writer class! 47 | target_raw_data_dir = join(LongiSeg_raw, target_dataset_name) 48 | target_dataset_json = load_json(join(target_raw_data_dir, 'dataset.json')) 49 | 50 | # we may need to change the reader/writer 51 | # pick any file from the source dataset 52 | dataset = get_filenames_of_train_images_and_targets(target_raw_data_dir, target_dataset_json) 53 | example_image = dataset[dataset.keys().__iter__().__next__()]['images'][0] 54 | rw = determine_reader_writer_from_dataset_json(target_dataset_json, example_image, allow_nonmatching_filename=True, 55 | verbose=False) 56 | 57 | source_plans["image_reader_writer"] = rw.__name__ 58 | if target_plans_identifier is not None: 59 | source_plans["plans_name"] = target_plans_identifier 60 | 61 | save_json(source_plans, join(LongiSeg_preprocessed, target_dataset_name, target_plans_identifier + '.json'), 62 | sort_keys=False) 63 | 64 | 65 | def entry_point_move_plans_between_datasets(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('-s', type=str, required=True, 68 | help='Source dataset name or id') 69 | parser.add_argument('-t', type=str, required=True, 70 | help='Target dataset name or id') 71 | parser.add_argument('-sp', type=str, required=True, 72 | help='Source plans identifier. If your plans are named "nnUNetPlans.json" then the ' 73 | 'identifier would be nnUNetPlans') 74 | parser.add_argument('-tp', type=str, required=False, default=None, 75 | help='Target plans identifier. Default is None meaning the source plans identifier will ' 76 | 'be kept. Not recommended if the source plans identifier is a default nnU-Net identifier ' 77 | 'such as nnUNetPlans!!!') 78 | args = parser.parse_args() 79 | move_plans_between_datasets(args.s, args.t, args.sp, args.tp) 80 | 81 | 82 | if __name__ == '__main__': 83 | move_plans_between_datasets(2, 4, 'nnUNetPlans', 'nnUNetPlansFrom2') 84 | -------------------------------------------------------------------------------- /longiseg/dataset_conversion/generate_dataset_json.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, List 2 | 3 | from batchgenerators.utilities.file_and_folder_operations import save_json, join 4 | 5 | 6 | def generate_dataset_json(output_folder: str, 7 | channel_names: dict, 8 | labels: dict, 9 | num_training_cases: int, 10 | file_ending: str, 11 | citation: Union[List[str], str] = None, 12 | regions_class_order: Tuple[int, ...] = None, 13 | dataset_name: str = None, 14 | reference: str = None, 15 | release: str = None, 16 | description: str = None, 17 | overwrite_image_reader_writer: str = None, 18 | license: str = 'Whoever converted this dataset was lazy and didn\'t look it up!', 19 | converted_by: str = "Please enter your name, especially when sharing datasets with others in a common infrastructure!", 20 | **kwargs): 21 | """ 22 | Generates a dataset.json file in the output folder 23 | 24 | channel_names: 25 | Channel names must map the index to the name of the channel, example: 26 | { 27 | 0: 'T1', 28 | 1: 'CT' 29 | } 30 | Note that the channel names may influence the normalization scheme!! Learn more in the documentation. 31 | 32 | labels: 33 | This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not. 34 | Example regular labels: 35 | { 36 | 'background': 0, 37 | 'left atrium': 1, 38 | 'some other label': 2 39 | } 40 | Example region-based training: 41 | { 42 | 'background': 0, 43 | 'whole tumor': (1, 2, 3), 44 | 'tumor core': (2, 3), 45 | 'enhancing tumor': 3 46 | } 47 | 48 | Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background! 49 | 50 | num_training_cases: is used to double check all cases are there! 51 | 52 | file_ending: needed for finding the files correctly. IMPORTANT! File endings must match between images and 53 | segmentations! 54 | 55 | dataset_name, reference, release, license, description: self-explanatory and not used by nnU-Net. Just for 56 | completeness and as a reminder that these would be great! 57 | 58 | overwrite_image_reader_writer: If you need a special IO class for your dataset you can derive it from 59 | BaseReaderWriter, place it into nnunet.imageio and reference it here by name 60 | 61 | kwargs: whatever you put here will be placed in the dataset.json as well 62 | 63 | """ 64 | has_regions: bool = any([isinstance(i, (tuple, list)) and len(i) > 1 for i in labels.values()]) 65 | if has_regions: 66 | assert regions_class_order is not None, f"You have defined regions but regions_class_order is not set. " \ 67 | f"You need that." 68 | # channel names need strings as keys 69 | keys = list(channel_names.keys()) 70 | for k in keys: 71 | if not isinstance(k, str): 72 | channel_names[str(k)] = channel_names[k] 73 | del channel_names[k] 74 | 75 | # labels need ints as values 76 | for l in labels.keys(): 77 | value = labels[l] 78 | if isinstance(value, (tuple, list)): 79 | value = tuple([int(i) for i in value]) 80 | labels[l] = value 81 | else: 82 | labels[l] = int(labels[l]) 83 | 84 | dataset_json = { 85 | 'channel_names': channel_names, # previously this was called 'modality'. I didn't like this so this is 86 | # channel_names now. Live with it. 87 | 'labels': labels, 88 | 'numTraining': num_training_cases, 89 | 'file_ending': file_ending, 90 | 'licence': license, 91 | 'converted_by': converted_by 92 | } 93 | 94 | if dataset_name is not None: 95 | dataset_json['name'] = dataset_name 96 | if reference is not None: 97 | dataset_json['reference'] = reference 98 | if release is not None: 99 | dataset_json['release'] = release 100 | if citation is not None: 101 | dataset_json['citation'] = release 102 | if description is not None: 103 | dataset_json['description'] = description 104 | if overwrite_image_reader_writer is not None: 105 | dataset_json['overwrite_image_reader_writer'] = overwrite_image_reader_writer 106 | if regions_class_order is not None: 107 | dataset_json['regions_class_order'] = regions_class_order 108 | 109 | dataset_json.update(kwargs) 110 | 111 | save_json(dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) 112 | -------------------------------------------------------------------------------- /documentation/installation_instructions.md: -------------------------------------------------------------------------------- 1 | # System requirements 2 | 3 | ## Operating System 4 | nnU-Net has been tested on Linux (Ubuntu 18.04, 20.04, 22.04; centOS, RHEL), Windows and MacOS! It should work out of the box! 5 | 6 | ## Hardware requirements 7 | We support GPU (recommended), CPU and Apple M1/M2 as devices (currently Apple mps does not implement 3D 8 | convolutions, so you might have to use the CPU on those devices). 9 | 10 | ### Hardware requirements for Training 11 | We recommend you use a GPU for training as this will take a really long time on CPU or MPS (Apple M1/M2). 12 | For training a GPU with at least 10 GB (popular non-datacenter options are the RTX 2080ti, RTX 3080/3090 or RTX 4080/4090) is 13 | required. We also recommend a strong CPU to go along with the GPU. 6 cores (12 threads) 14 | are the bare minimum! CPU requirements are mostly related to data augmentation and scale with the number of 15 | input channels and target structures. Plus, the faster the GPU, the better the CPU should be! 16 | 17 | ### Hardware Requirements for inference 18 | Again we recommend a GPU to make predictions as this will be substantially faster than the other options. However, 19 | inference times are typically still manageable on CPU and MPS (Apple M1/M2). If using a GPU, it should have at least 20 | 4 GB of available (unused) VRAM. 21 | 22 | ### Example hardware configurations 23 | Example workstation configurations for training: 24 | - CPU: Ryzen 5800X - 5900X or 7900X would be even better! We have not yet tested Intel Alder/Raptor lake but they will likely work as well. 25 | - GPU: RTX 3090 or RTX 4090 26 | - RAM: 64GB 27 | - Storage: SSD (M.2 PCIe Gen 3 or better!) 28 | 29 | Example Server configuration for training: 30 | - CPU: 2x AMD EPYC7763 for a total of 128C/256T. 16C/GPU are highly recommended for fast GPUs such as the A100! 31 | - GPU: 8xA100 PCIe (price/performance superior to SXM variant + they use less power) 32 | - RAM: 1 TB 33 | - Storage: local SSD storage (PCIe Gen 3 or better) or ultra fast network storage 34 | 35 | (nnU-net by default uses one GPU per training. The server configuration can run up to 8 model trainings simultaneously) 36 | 37 | ### Setting the correct number of Workers for data augmentation (training only) 38 | Note that you will need to manually set the number of processes nnU-Net uses for data augmentation according to your 39 | CPU/GPU ratio. For the server above (256 threads for 8 GPUs), a good value would be 24-30. You can do this by 40 | setting the `nnUNet_n_proc_DA` environment variable (`export nnUNet_n_proc_DA=XX`). 41 | Recommended values (assuming a recent CPU with good IPC) are 10-12 for RTX 2080 ti, 12 for a RTX 3090, 16-18 for 42 | RTX 4090, 28-32 for A100. Optimal values may vary depending on the number of input channels/modalities and number of classes. 43 | 44 | # Installation instructions 45 | We strongly recommend that you install nnU-Net in a virtual environment! Pip or anaconda are both fine. If you choose to 46 | compile PyTorch from source (see below), you will need to use conda instead of pip. 47 | 48 | Use a recent version of Python! 3.9 or newer is guaranteed to work! 49 | 50 | **nnU-Net v2 can coexist with nnU-Net v1! Both can be installed at the same time.** 51 | 52 | 1) Install [PyTorch](https://pytorch.org/get-started/locally/) as described on their website (conda/pip). Please 53 | install the latest version with support for your hardware (cuda, mps, cpu). 54 | **DO NOT JUST `pip install nnunetv2` WITHOUT PROPERLY INSTALLING PYTORCH FIRST**. For maximum speed, consider 55 | [compiling pytorch yourself](https://github.com/pytorch/pytorch#from-source) (experienced users only!). 56 | 2) Install nnU-Net depending on your use case: 57 | 1) For use as **standardized baseline**, **out-of-the-box segmentation algorithm** or for running 58 | **inference with pretrained models**: 59 | 60 | ```pip install nnunetv2``` 61 | 62 | 2) For use as integrative **framework** (this will create a copy of the nnU-Net code on your computer so that you 63 | can modify it as needed): 64 | ```bash 65 | git clone https://github.com/MIC-DKFZ/nnUNet.git 66 | cd nnUNet 67 | pip install -e . 68 | ``` 69 | 3) nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to 70 | set a few environment variables. Please follow the instructions [here](setting_up_paths.md). 71 | 72 | Installing nnU-Net will add several new commands to your terminal. These commands are used to run the entire nnU-Net 73 | pipeline. You can execute them from any location on your system. All nnU-Net commands have the prefix `nnUNetv2_` for 74 | easy identification. 75 | 76 | Note that these commands simply execute python scripts. If you installed nnU-Net in a virtual environment, this 77 | environment must be activated when executing the commands. You can see what scripts/functions are executed by 78 | checking the project.scripts in the [pyproject.toml](../pyproject.toml) file. 79 | 80 | All nnU-Net commands have a `-h` option which gives information on how to use them. 81 | -------------------------------------------------------------------------------- /longiseg/utilities/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center 2 | # (DKFZ), Heidelberg, Germany 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from typing import List 16 | 17 | import os.path 18 | from functools import lru_cache 19 | from typing import Union 20 | 21 | from batchgenerators.utilities.file_and_folder_operations import subfiles, subdirs, isdir, join, load_json 22 | import numpy as np 23 | import re 24 | 25 | from longiseg.paths import LongiSeg_raw 26 | from multiprocessing import Pool 27 | 28 | 29 | def get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str): 30 | files = subfiles(folder, suffix=file_ending, join=False) 31 | # all files have a 4 digit channel index (_XXXX) 32 | crop = len(file_ending) + 5 33 | files = [i[:-crop] for i in files] 34 | # only unique image ids 35 | files = np.unique(files) 36 | return files 37 | 38 | 39 | def create_paths_fn(folder, files, file_ending, f): 40 | p = re.compile(re.escape(f) + r"_\d\d\d\d" + re.escape(file_ending)) 41 | return [join(folder, i) for i in files if p.fullmatch(i)] 42 | 43 | 44 | def create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None, num_processes: int = 12) -> List[ 45 | List[str]]: 46 | """ 47 | does not rely on dataset.json 48 | """ 49 | if identifiers is None: 50 | identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending) 51 | files = subfiles(folder, suffix=file_ending, join=False, sort=True) 52 | list_of_lists = [] 53 | 54 | params_list = [(folder, files, file_ending, f) for f in identifiers] 55 | with Pool(processes=num_processes) as pool: 56 | list_of_lists = pool.starmap(create_paths_fn, params_list) 57 | 58 | return list_of_lists 59 | 60 | 61 | def create_lists_from_splitted_dataset_folder_and_patients_json(folder: str, file_ending: str, patients_json: dict, identifiers: List[str] = None, 62 | num_processes: int = 12) -> List[List[str]]: 63 | """ 64 | uses patients_json to filter for scans with patient information 65 | """ 66 | if identifiers is None: 67 | identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending) 68 | files = subfiles(folder, suffix=file_ending, join=False, sort=True) 69 | 70 | # filter patient_json with identifiers 71 | filtered_patients_json = {k: [i for i in v if i in identifiers] for k, v in patients_json.items()} 72 | 73 | # for now must work without multiprocessing, but the current implementation is not optimal anyways, so that's it 74 | list_of_lists = [] 75 | for k, v in filtered_patients_json.items(): 76 | if len(v) == 0: continue 77 | patient_files = [create_paths_fn(folder, files, file_ending, f) for f in v] 78 | list_of_lists.extend([[c, p] for c, p in zip(patient_files, patient_files[:1] + patient_files[:-1])]) 79 | 80 | return list_of_lists 81 | 82 | 83 | def get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None): 84 | if dataset_json is None: 85 | dataset_json = load_json(join(raw_dataset_folder, 'dataset.json')) 86 | 87 | if 'dataset' in dataset_json.keys(): 88 | dataset = dataset_json['dataset'] 89 | for k in dataset.keys(): 90 | expanded_label_file = os.path.expandvars(dataset[k]['label']) 91 | dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, expanded_label_file)) if not os.path.isabs(expanded_label_file) else expanded_label_file 92 | dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, os.path.expandvars(i))) if not os.path.isabs(os.path.expandvars(i)) else os.path.expandvars(i) for i in dataset[k]['images']] 93 | else: 94 | identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending']) 95 | images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers) 96 | segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers] 97 | dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)} 98 | return dataset 99 | 100 | 101 | if __name__ == '__main__': 102 | print(get_filenames_of_train_images_and_targets(join(LongiSeg_raw, 'Dataset002_Heart'))) 103 | -------------------------------------------------------------------------------- /longiseg/training/logging/nnunet_logger.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | from batchgenerators.utilities.file_and_folder_operations import join 3 | 4 | matplotlib.use('agg') 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | class nnUNetLogger(object): 10 | """ 11 | This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems 12 | arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a 13 | little 14 | 15 | YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP 16 | """ 17 | def __init__(self, verbose: bool = False): 18 | self.my_fantastic_logging = { 19 | 'mean_fg_dice': list(), 20 | 'ema_fg_dice': list(), 21 | 'dice_per_class_or_region': list(), 22 | 'train_losses': list(), 23 | 'val_losses': list(), 24 | 'lrs': list(), 25 | 'epoch_start_timestamps': list(), 26 | 'epoch_end_timestamps': list() 27 | } 28 | self.verbose = verbose 29 | # shut up, this logging is great 30 | 31 | def log(self, key, value, epoch: int): 32 | """ 33 | sometimes shit gets messed up. We try to catch that here 34 | """ 35 | assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \ 36 | 'This function is only intended to log stuff to lists and to have one entry per epoch' 37 | 38 | if self.verbose: print(f'logging {key}: {value} for epoch {epoch}') 39 | 40 | if len(self.my_fantastic_logging[key]) < (epoch + 1): 41 | self.my_fantastic_logging[key].append(value) 42 | else: 43 | assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \ 44 | 'lists length is off by more than 1' 45 | print(f'maybe some logging issue!? logging {key} and {value}') 46 | self.my_fantastic_logging[key][epoch] = value 47 | 48 | # handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice 49 | if key == 'mean_fg_dice': 50 | new_ema_pseudo_dice = self.my_fantastic_logging['ema_fg_dice'][epoch - 1] * 0.9 + 0.1 * value \ 51 | if len(self.my_fantastic_logging['ema_fg_dice']) > 0 else value 52 | self.log('ema_fg_dice', new_ema_pseudo_dice, epoch) 53 | 54 | def plot_progress_png(self, output_folder): 55 | # we infer the epoch form our internal logging 56 | epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1 # lists of epoch 0 have len 1 57 | sns.set(font_scale=2.5) 58 | fig, ax_all = plt.subplots(3, 1, figsize=(30, 54)) 59 | # regular progress.png as we are used to from previous nnU-Net versions 60 | ax = ax_all[0] 61 | ax2 = ax.twinx() 62 | x_values = list(range(epoch + 1)) 63 | ax.plot(x_values, self.my_fantastic_logging['train_losses'][:epoch + 1], color='b', ls='-', label="loss_tr", linewidth=4) 64 | ax.plot(x_values, self.my_fantastic_logging['val_losses'][:epoch + 1], color='r', ls='-', label="loss_val", linewidth=4) 65 | ax2.plot(x_values, self.my_fantastic_logging['mean_fg_dice'][:epoch + 1], color='g', ls='dotted', label="pseudo dice", 66 | linewidth=3) 67 | ax2.plot(x_values, self.my_fantastic_logging['ema_fg_dice'][:epoch + 1], color='g', ls='-', label="pseudo dice (mov. avg.)", 68 | linewidth=4) 69 | ax.set_xlabel("epoch") 70 | ax.set_ylabel("loss") 71 | ax2.set_ylabel("pseudo dice") 72 | ax.legend(loc=(0, 1)) 73 | ax2.legend(loc=(0.2, 1)) 74 | 75 | # epoch times to see whether the training speed is consistent (inconsistent means there are other jobs 76 | # clogging up the system) 77 | ax = ax_all[1] 78 | ax.plot(x_values, [i - j for i, j in zip(self.my_fantastic_logging['epoch_end_timestamps'][:epoch + 1], 79 | self.my_fantastic_logging['epoch_start_timestamps'])][:epoch + 1], color='b', 80 | ls='-', label="epoch duration", linewidth=4) 81 | ylim = [0] + [ax.get_ylim()[1]] 82 | ax.set(ylim=ylim) 83 | ax.set_xlabel("epoch") 84 | ax.set_ylabel("time [s]") 85 | ax.legend(loc=(0, 1)) 86 | 87 | # learning rate 88 | ax = ax_all[2] 89 | ax.plot(x_values, self.my_fantastic_logging['lrs'][:epoch + 1], color='b', ls='-', label="learning rate", linewidth=4) 90 | ax.set_xlabel("epoch") 91 | ax.set_ylabel("learning rate") 92 | ax.legend(loc=(0, 1)) 93 | 94 | plt.tight_layout() 95 | 96 | fig.savefig(join(output_folder, "progress.png")) 97 | plt.close() 98 | 99 | def get_checkpoint(self): 100 | return self.my_fantastic_logging 101 | 102 | def load_checkpoint(self, checkpoint: dict): 103 | self.my_fantastic_logging = checkpoint 104 | -------------------------------------------------------------------------------- /longiseg/imageio/tif_reader_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center 2 | # (DKFZ), Heidelberg, Germany 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os.path 16 | from typing import Tuple, Union, List 17 | import numpy as np 18 | from longiseg.imageio.base_reader_writer import BaseReaderWriter 19 | import tifffile 20 | from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_json, split_path, join 21 | 22 | 23 | class Tiff3DIO(BaseReaderWriter): 24 | """ 25 | reads and writes 3D tif(f) images. Uses tifffile package. Ignores metadata (for now)! 26 | 27 | If you have 2D tiffs, use NaturalImage2DIO 28 | 29 | Supports the use of auxiliary files for spacing information. If used, the auxiliary files are expected to end 30 | with .json and omit the channel identifier. So, for example, the corresponding of image image1_0000.tif is 31 | expected to be image1.json)! 32 | """ 33 | supported_file_endings = [ 34 | '.tif', 35 | '.tiff', 36 | ] 37 | 38 | def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: 39 | # figure out file ending used here 40 | ending = '.' + image_fnames[0].split('.')[-1] 41 | assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}' 42 | ending_length = len(ending) 43 | truncate_length = ending_length + 5 # 5 comes from len(_0000) 44 | 45 | images = [] 46 | for f in image_fnames: 47 | image = tifffile.imread(f) 48 | if image.ndim != 3: 49 | raise RuntimeError(f"Only 3D images are supported! File: {f}") 50 | images.append(image[None]) 51 | 52 | # see if aux file can be found 53 | expected_aux_file = image_fnames[0][:-truncate_length] + '.json' 54 | if isfile(expected_aux_file): 55 | spacing = load_json(expected_aux_file)['spacing'] 56 | assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}' 57 | else: 58 | print(f'WARNING no spacing file found for images {image_fnames}\nAssuming spacing (1, 1, 1).') 59 | spacing = (1, 1, 1) 60 | 61 | if not self._check_all_same([i.shape for i in images]): 62 | print('ERROR! Not all input images have the same shape!') 63 | print('Shapes:') 64 | print([i.shape for i in images]) 65 | print('Image files:') 66 | print(image_fnames) 67 | raise RuntimeError() 68 | 69 | return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': spacing} 70 | 71 | def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: 72 | # not ideal but I really have no clue how to set spacing/resolution information properly in tif files haha 73 | tifffile.imwrite(output_fname, data=seg.astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), compression='zlib') 74 | file = os.path.basename(output_fname) 75 | out_dir = os.path.dirname(output_fname) 76 | ending = file.split('.')[-1] 77 | save_json({'spacing': properties['spacing']}, join(out_dir, file[:-(len(ending) + 1)] + '.json')) 78 | 79 | def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: 80 | # figure out file ending used here 81 | ending = '.' + seg_fname.split('.')[-1] 82 | assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}' 83 | ending_length = len(ending) 84 | 85 | seg = tifffile.imread(seg_fname) 86 | if seg.ndim != 3: 87 | raise RuntimeError(f"Only 3D images are supported! File: {seg_fname}") 88 | seg = seg[None] 89 | 90 | # see if aux file can be found 91 | expected_aux_file = seg_fname[:-ending_length] + '.json' 92 | if isfile(expected_aux_file): 93 | spacing = load_json(expected_aux_file)['spacing'] 94 | assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}' 95 | assert all([i > 0 for i in spacing]), f"Spacing must be > 0, spacing: {spacing}" 96 | else: 97 | print(f'WARNING no spacing file found for segmentation {seg_fname}\nAssuming spacing (1, 1, 1).') 98 | spacing = (1, 1, 1) 99 | 100 | return seg.astype(np.float32, copy=False), {'spacing': spacing} 101 | -------------------------------------------------------------------------------- /longiseg/utilities/file_path_utilities.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from multiprocessing.pool import Pool 4 | from typing import Union, Tuple 5 | import numpy as np 6 | import os 7 | from batchgenerators.utilities.file_and_folder_operations import split_path, join 8 | 9 | from longiseg.configuration import default_num_processes 10 | from longiseg.paths import LongiSeg_results 11 | from longiseg.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name 12 | 13 | 14 | def convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration): 15 | return f'{trainer_name}__{plans_identifier}__{configuration}' 16 | 17 | 18 | def convert_identifier_to_trainer_plans_config(identifier: str): 19 | return os.path.basename(identifier).split('__') 20 | 21 | 22 | def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer', 23 | plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres', 24 | fold: Union[str, int] = None) -> str: 25 | tmp = join(LongiSeg_results, maybe_convert_to_dataset_name(dataset_name_or_id), 26 | convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration)) 27 | if fold is not None: 28 | tmp = join(tmp, f'fold_{fold}') 29 | return tmp 30 | 31 | 32 | def parse_dataset_trainer_plans_configuration_from_path(path: str): 33 | folders = split_path(path) 34 | # this here can be a little tricky because we are making assumptions. Let's hope this never fails lol 35 | 36 | # safer to make this depend on two conditions, the fold_x and the DatasetXXX 37 | # first let's see if some fold_X is present 38 | fold_x_present = [i.startswith('fold_') for i in folders] 39 | if any(fold_x_present): 40 | idx = fold_x_present.index(True) 41 | # OK now two entries before that there should be DatasetXXX 42 | assert len(folders[:idx]) >= 2, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ 43 | 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' 44 | if folders[idx - 2].startswith('Dataset'): 45 | split = folders[idx - 1].split('__') 46 | assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ 47 | 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' 48 | return folders[idx - 2], *split 49 | else: 50 | # we can only check for dataset followed by a string that is separable into three strings by splitting with '__' 51 | # look for DatasetXXX 52 | dataset_folder = [i.startswith('Dataset') for i in folders] 53 | if any(dataset_folder): 54 | idx = dataset_folder.index(True) 55 | assert len(folders) >= (idx + 1), 'Bad path, cannot extract what I need. Your path needs to be at least ' \ 56 | 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' 57 | split = folders[idx + 1].split('__') 58 | assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ 59 | 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' 60 | return folders[idx], *split 61 | 62 | 63 | def get_ensemble_name(model1_folder, model2_folder, folds: Tuple[int, ...]): 64 | identifier = 'ensemble___' + os.path.basename(model1_folder) + '___' + \ 65 | os.path.basename(model2_folder) + '___' + folds_tuple_to_string(folds) 66 | return identifier 67 | 68 | 69 | def get_ensemble_name_from_d_tr_c(dataset, tr1, p1, c1, tr2, p2, c2, folds: Tuple[int, ...]): 70 | model1_folder = get_output_folder(dataset, tr1, p1, c1) 71 | model2_folder = get_output_folder(dataset, tr2, p2, c2) 72 | 73 | get_ensemble_name(model1_folder, model2_folder, folds) 74 | 75 | 76 | def convert_ensemble_folder_to_model_identifiers_and_folds(ensemble_folder: str): 77 | prefix, *models, folds = os.path.basename(ensemble_folder).split('___') 78 | return models, folds 79 | 80 | 81 | def folds_tuple_to_string(folds: Union[List[int], Tuple[int, ...]]): 82 | s = str(folds[0]) 83 | for f in folds[1:]: 84 | s += f"_{f}" 85 | return s 86 | 87 | 88 | def folds_string_to_tuple(folds_string: str): 89 | folds = folds_string.split('_') 90 | res = [] 91 | for f in folds: 92 | try: 93 | res.append(int(f)) 94 | except ValueError: 95 | res.append(f) 96 | return res 97 | 98 | 99 | def check_workers_alive_and_busy(export_pool: Pool, worker_list: List, results_list: List, allowed_num_queued: int = 0): 100 | """ 101 | 102 | returns True if the number of results that are not ready is greater than the number of available workers + allowed_num_queued 103 | """ 104 | alive = [i.is_alive() for i in worker_list] 105 | if not all(alive): 106 | raise RuntimeError('Some background workers are no longer alive') 107 | 108 | not_ready = [not i.ready() for i in results_list] 109 | if sum(not_ready) >= (len(export_pool._pool) + allowed_num_queued): 110 | return True 111 | return False -------------------------------------------------------------------------------- /documentation/ignore_label.md: -------------------------------------------------------------------------------- 1 | # Ignore Label 2 | 3 | The _ignore label_ can be used to mark regions that should be ignored by nnU-Net. This can be used to 4 | learn from images where only sparse annotations are available, for example in the form of scribbles or a limited 5 | amount of annotated slices. Internally, this is accomplished by using partial losses, i.e. losses that are only 6 | computed on annotated pixels while ignoring the rest. Take a look at our 7 | [`DC_and_BCE_loss` loss](../longiseg/training/loss/compound_losses.py) to see how this is done. 8 | During inference (validation and prediction), nnU-Net will always predict dense segmentations. Metric computation in 9 | validation is of course only done on annotated pixels. 10 | 11 | Using sparse annotations can be used to train a model for application to new, unseen images or to autocomplete the 12 | provided training cases given the sparse labels. 13 | 14 | (See our [paper](https://arxiv.org/abs/2403.12834) for more information) 15 | 16 | Typical use-cases for the ignore label are: 17 | - Save annotation time through sparse annotation schemes 18 | - Annotation of all or a subset of slices with scribbles (Scribble Supervision) 19 | - Dense annotation of a subset of slices 20 | - Dense annotation of chosen patches/cubes within an image 21 | - Coarsly masking out faulty segmentations in the reference segmentations 22 | - Masking areas for other reasons 23 | 24 | If you are using nnU-Net's ignore label, please cite the following paper in addition to the original nnU-net paper: 25 | 26 | ``` 27 | Gotkowski, K., Lüth, C., Jäger, P. F., Ziegler, S., Krämer, L., Denner, S., Xiao, S., Disch, N., H., K., & Isensee, F. 28 | (2024). Embarrassingly Simple Scribble Supervision for 3D Medical Segmentation. ArXiv. /abs/2403.12834 29 | ``` 30 | 31 | ## Usecases 32 | 33 | ### Scribble Supervision 34 | 35 | Scribbles are free-form drawings to coarsly annotate an image. As we have demonstrated in our recent [paper](https://arxiv.org/abs/2403.12834), nnU-Net's partial loss implementation enables state-of-the-art learning from partially annotated data and even surpasses many purpose-built methods for learning from scribbles. As a starting point, for each image slice and each class (including background), an interior and a border scribble should be generated: 36 | 37 | - Interior Scribble: A scribble placed randomly within the class interior of a class instance 38 | - Border Scribble: A scribble roughly delineating a small part of the class border of a class instance 39 | 40 | An example of such scribble annotations is depicted in Figure 1 and an animation in Animation 1. 41 | Depending on the availability of data and their variability it is also possible to only annotated a subset of selected slices. 42 | 43 |

44 | 45 |

Figure 1: Examples of segmentation types with (A) depicting a dense segmentation and (B) a scribble segmentation.
46 | 47 |

48 | 49 |

50 | 51 | 52 |

Animation 1: Depiction of a dense segmentation and a scribble annotation. Background scribbles have been excluded for better visualization.
53 |

54 | 55 | ### Dense annotation of a subset of slices 56 | 57 | Another form of sparse annotation is the dense annotation of a subset of slices. These slices should be selected by the user either randomly, based on visual class variation between slices or in an active learning setting. An example with only 10% of slices annotated is depicted in Figure 2. 58 | 59 |

60 | 61 | 62 |

Figure 2: Examples of a dense annotation of a subset of slices. The ignored areas are shown in red.
63 | 64 |

65 | 66 | 67 | ## Usage within nnU-Net 68 | 69 | Usage of the ignore label in nnU-Net is straightforward and only requires the definition of an _ignore_ label in the _dataset.json_. 70 | This ignore label MUST be the highest integer label value in the segmentation. Exemplary, given the classes background and two foreground classes, then the ignore label must have the integer 3. The ignore label must be named _ignore_ in the _dataset.json_. Given the BraTS dataset as an example the labels dict of the _dataset.json_ must look like this: 71 | 72 | ```python 73 | ... 74 | "labels": { 75 | "background": 0, 76 | "edema": 1, 77 | "non_enhancing_and_necrosis": 2, 78 | "enhancing_tumor": 3, 79 | "ignore": 4 80 | }, 81 | ... 82 | ``` 83 | 84 | Of course, the ignore label is compatible with [region-based training](region_based_training.md): 85 | 86 | ```python 87 | ... 88 | "labels": { 89 | "background": 0, 90 | "whole_tumor": (1, 2, 3), 91 | "tumor_core": (2, 3), 92 | "enhancing_tumor": 3, # or (3, ) 93 | "ignore": 4 94 | }, 95 | "regions_class_order": (1, 2, 3), # don't declare ignore label here! It is not predicted 96 | ... 97 | ``` 98 | 99 | Then use the dataset as you would any other. 100 | 101 | Remember that nnU-Net runs a cross-validation. Thus, it will also evaluate on your partially annotated data. This 102 | will of course work! If you wish to compare different sparse annotation strategies (through simulations for example), 103 | we recommend evaluating on densely annotated images by running inference and then using `nnUNetv2_evaluate_folder` or 104 | `nnUNetv2_evaluate_simple`. --------------------------------------------------------------------------------