├── longiseg
├── __init__.py
├── imageio
│ ├── __init__.py
│ ├── readme.md
│ ├── natural_image_reader_writer.py
│ ├── reader_writer_registry.py
│ └── tif_reader_writer.py
├── run
│ ├── __init__.py
│ └── load_pretrained_weights.py
├── ensembling
│ └── __init__.py
├── evaluation
│ ├── __init__.py
│ ├── metrics
│ │ ├── __init__.py
│ │ ├── volumetric_metrics.py
│ │ ├── distance_metrics.py
│ │ └── detection_metrics.py
│ └── accumulate_cv_results.py
├── inference
│ ├── __init__.py
│ ├── sliding_window_prediction.py
│ └── export_utils.py
├── training
│ ├── __init__.py
│ ├── logging
│ │ ├── __init__.py
│ │ └── nnunet_logger.py
│ ├── loss
│ │ ├── __init__.py
│ │ ├── robust_ce_loss.py
│ │ └── deep_supervision.py
│ ├── dataloading
│ │ ├── __init__.py
│ │ └── utils.py
│ ├── lr_scheduler
│ │ ├── __init__.py
│ │ └── polylr.py
│ ├── LongiSegTrainer
│ │ ├── __init__.py
│ │ └── variants
│ │ │ ├── __init__.py
│ │ │ ├── loss
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerCELoss.py
│ │ │ ├── nnUNetTrainerDiceLoss.py
│ │ │ └── nnUNetTrainerTopkLoss.py
│ │ │ ├── sampling
│ │ │ ├── __init__.py
│ │ │ └── nnUNetTrainer_probabilisticOversampling.py
│ │ │ ├── benchmarking
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerBenchmark_5epochs_noDataLoading.py
│ │ │ └── nnUNetTrainerBenchmark_5epochs.py
│ │ │ ├── competitions
│ │ │ ├── __init__.py
│ │ │ └── aortaseg24.py
│ │ │ ├── longitudinal
│ │ │ ├── __init__.py
│ │ │ └── LongiSegTrainerDiffWeighting.py
│ │ │ ├── lr_schedule
│ │ │ ├── __init__.py
│ │ │ └── nnUNetTrainerCosAnneal.py
│ │ │ ├── optimizer
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerAdam.py
│ │ │ └── nnUNetTrainerAdan.py
│ │ │ ├── data_augmentation
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainer_noDummy2DDA.py
│ │ │ └── nnUNetTrainerNoDA.py
│ │ │ ├── training_length
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainer_Xepochs_NoMirroring.py
│ │ │ └── nnUNetTrainer_Xepochs.py
│ │ │ └── network_architecture
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerNoDeepSupervision.py
│ │ │ └── nnUNetTrainerBN.py
│ └── data_augmentation
│ │ ├── __init__.py
│ │ ├── custom_transforms
│ │ ├── __init__.py
│ │ ├── masking.py
│ │ ├── region_based_training.py
│ │ ├── transforms_for_dummy_2d.py
│ │ ├── deep_supervision_donwsampling.py
│ │ └── longi_transforms.py
│ │ └── compute_initial_patch_size.py
├── utilities
│ ├── __init__.py
│ ├── label_handling
│ │ └── __init__.py
│ ├── plans_handling
│ │ └── __init__.py
│ ├── network_initialization.py
│ ├── helpers.py
│ ├── find_class_by_name.py
│ ├── collate_outputs.py
│ ├── crossval_split.py
│ ├── ddp_allgather.py
│ ├── default_n_proc_DA.py
│ ├── json_export.py
│ ├── get_network_from_plans.py
│ ├── dataset_name_id_conversion.py
│ ├── utils.py
│ └── file_path_utilities.py
├── model_sharing
│ ├── __init__.py
│ ├── model_import.py
│ ├── model_download.py
│ └── entry_points.py
├── postprocessing
│ └── __init__.py
├── preprocessing
│ ├── __init__.py
│ ├── cropping
│ │ ├── __init__.py
│ │ └── cropping.py
│ ├── normalization
│ │ ├── __init__.py
│ │ ├── readme.md
│ │ ├── map_channel_name_to_normalization.py
│ │ └── default_normalization_schemes.py
│ ├── preprocessors
│ │ └── __init__.py
│ └── resampling
│ │ ├── __init__.py
│ │ └── utils.py
├── dataset_conversion
│ ├── __init__.py
│ └── generate_dataset_json.py
├── experiment_planning
│ ├── __init__.py
│ ├── dataset_fingerprint
│ │ └── __init__.py
│ ├── experiment_planners
│ │ ├── __init__.py
│ │ ├── resampling
│ │ │ └── __init__.py
│ │ ├── residual_unets
│ │ │ └── __init__.py
│ │ └── network_topology.py
│ └── plans_for_pretraining
│ │ ├── __init__.py
│ │ └── move_plans_between_datasets.py
├── configuration.py
└── paths.py
├── documentation
├── __init__.py
├── assets
│ ├── HI_Logo.png
│ ├── LongiSeg.jpg
│ ├── dkfz_logo.png
│ ├── time_series.jpg
│ ├── nnU-Net_overview.png
│ ├── regions_vs_labels.png
│ ├── scribble_example.png
│ ├── amos2022_sparseseg10.png
│ ├── HIDSS4Health_Logo_RGB.png
│ ├── sparse_annotation_amos.png
│ └── amos2022_sparseseg10_2d.png
├── run_inference_with_pretrained_models.md
├── setting_up_paths.md
├── dataset_format_inference.md
├── manual_data_splits.md
├── explanation_normalization.md
├── set_environment_variables.md
├── extending_nnunet.md
├── pretraining_and_finetuning.md
├── how_to_use_longiseg.md
├── region_based_training.md
├── installation_instructions.md
└── ignore_label.md
├── setup.py
├── .github
└── workflows
│ └── codespell.yml
├── .gitignore
└── pyproject.toml
/longiseg/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/documentation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/imageio/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/run/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/ensembling/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/utilities/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/model_sharing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/postprocessing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/logging/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/loss/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/dataset_conversion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/evaluation/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/experiment_planning/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/cropping/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/dataloading/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/lr_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/normalization/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/preprocessors/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/resampling/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/data_augmentation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/utilities/label_handling/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/utilities/plans_handling/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/experiment_planning/dataset_fingerprint/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/experiment_planning/experiment_planners/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/experiment_planning/plans_for_pretraining/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/loss/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/sampling/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/benchmarking/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/competitions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/longitudinal/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/lr_schedule/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/data_augmentation/custom_transforms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/experiment_planning/experiment_planners/resampling/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/data_augmentation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/training_length/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/experiment_planning/experiment_planners/residual_unets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/network_architecture/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | if __name__ == "__main__":
4 | setuptools.setup()
5 |
--------------------------------------------------------------------------------
/documentation/assets/HI_Logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/HI_Logo.png
--------------------------------------------------------------------------------
/documentation/assets/LongiSeg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/LongiSeg.jpg
--------------------------------------------------------------------------------
/documentation/assets/dkfz_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/dkfz_logo.png
--------------------------------------------------------------------------------
/documentation/assets/time_series.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/time_series.jpg
--------------------------------------------------------------------------------
/documentation/assets/nnU-Net_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/nnU-Net_overview.png
--------------------------------------------------------------------------------
/documentation/assets/regions_vs_labels.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/regions_vs_labels.png
--------------------------------------------------------------------------------
/documentation/assets/scribble_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/scribble_example.png
--------------------------------------------------------------------------------
/documentation/assets/amos2022_sparseseg10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/amos2022_sparseseg10.png
--------------------------------------------------------------------------------
/documentation/assets/HIDSS4Health_Logo_RGB.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/HIDSS4Health_Logo_RGB.png
--------------------------------------------------------------------------------
/documentation/assets/sparse_annotation_amos.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/sparse_annotation_amos.png
--------------------------------------------------------------------------------
/documentation/assets/amos2022_sparseseg10_2d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/LongiSeg/HEAD/documentation/assets/amos2022_sparseseg10_2d.png
--------------------------------------------------------------------------------
/longiseg/model_sharing/model_import.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 |
3 | from longiseg.paths import LongiSeg_results
4 |
5 |
6 | def install_model_from_zip_file(zip_file: str):
7 | with zipfile.ZipFile(zip_file, 'r') as zip_ref:
8 | zip_ref.extractall(LongiSeg_results)
--------------------------------------------------------------------------------
/longiseg/imageio/readme.md:
--------------------------------------------------------------------------------
1 | - Derive your adapter from `BaseReaderWriter`.
2 | - Reimplement all abstractmethods.
3 | - make sure to support 2d and 3d input images (or raise some error).
4 | - place it in this folder or nnU-Net won't find it!
5 | - add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py`
6 |
7 | Bam, you're done!
--------------------------------------------------------------------------------
/longiseg/preprocessing/normalization/readme.md:
--------------------------------------------------------------------------------
1 | The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different
2 | then you can just
3 | - create a new subclass of ImageNormalization
4 | - map your custom channel identifier to that subclass in channel_name_to_normalization_mapping
5 | - run plan and preprocess again with your custom normlaization scheme
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/competitions/aortaseg24.py:
--------------------------------------------------------------------------------
1 | from longiseg.training.LongiSegTrainer.variants.data_augmentation.nnUNetTrainerNoMirroring import nnUNetTrainer_onlyMirror01
2 | from longiseg.training.LongiSegTrainer.variants.data_augmentation.nnUNetTrainerDA5 import nnUNetTrainerDA5
3 |
4 | class nnUNetTrainer_onlyMirror01_DA5(nnUNetTrainer_onlyMirror01, nnUNetTrainerDA5):
5 | pass
6 |
--------------------------------------------------------------------------------
/longiseg/configuration.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from longiseg.utilities.default_n_proc_DA import get_allowed_n_proc_DA
4 |
5 | default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc'])
6 |
7 | ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
8 | # resolution axis must be 3x as large as the next largest spacing)
9 |
10 | default_n_proc_DA = get_allowed_n_proc_DA()
11 |
--------------------------------------------------------------------------------
/.github/workflows/codespell.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Codespell
3 |
4 | on:
5 | push:
6 | branches: [master]
7 | pull_request:
8 | branches: [master]
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | codespell:
15 | name: Check for spelling errors
16 | runs-on: ubuntu-latest
17 |
18 | steps:
19 | - name: Checkout
20 | uses: actions/checkout@v3
21 | - name: Codespell
22 | uses: codespell-project/actions-codespell@v2
23 |
--------------------------------------------------------------------------------
/documentation/run_inference_with_pretrained_models.md:
--------------------------------------------------------------------------------
1 | # How to run inference with pretrained models
2 | **Important:** Pretrained weights from nnU-Net v1 are NOT compatible with V2. You will need to retrain with the new
3 | version. But honestly, you already have a fully trained model with which you can run inference (in v1), so
4 | just continue using that!
5 |
6 | Not yet available for V2 :-(
7 | If you wish to run inference with pretrained models, check out the old nnU-Net for now. We are working on this full steam!
8 |
--------------------------------------------------------------------------------
/longiseg/utilities/network_initialization.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class InitWeights_He(object):
5 | def __init__(self, neg_slope=1e-2):
6 | self.neg_slope = neg_slope
7 |
8 | def __call__(self, module):
9 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
10 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
11 | if module.bias is not None:
12 | module.bias = nn.init.constant_(module.bias, 0)
13 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py:
--------------------------------------------------------------------------------
1 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
2 | import torch
3 |
4 |
5 | class nnUNetTrainerNoDeepSupervision(nnUNetTrainerNoLongi):
6 | def __init__(
7 | self,
8 | plans: dict,
9 | configuration: str,
10 | fold: int,
11 | dataset_json: dict,
12 | device: torch.device = torch.device("cuda"),
13 | ):
14 | super().__init__(plans, configuration, fold, dataset_json, device)
15 | self.enable_deep_supervision = False
16 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/longitudinal/LongiSegTrainerDiffWeighting.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from longiseg.training.LongiSegTrainer.LongiSegTrainer import LongiSegTrainer
4 |
5 |
6 | class LongiSegTrainerDiffWeighting(LongiSegTrainer):
7 | architecture_class_name = "LongiUNetDiffWeighting"
8 |
9 |
10 | class LongiSegTrainerDiffWeightingRP(LongiSegTrainerDiffWeighting):
11 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
12 | device: torch.device = torch.device('cuda')):
13 | super().__init__(plans, configuration, fold, dataset_json, device)
14 | self.random_prior = True
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import CosineAnnealingLR
3 |
4 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
5 |
6 |
7 | class nnUNetTrainerCosAnneal(nnUNetTrainerNoLongi):
8 | def configure_optimizers(self):
9 | optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
10 | momentum=0.99, nesterov=True)
11 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
12 | return optimizer, lr_scheduler
13 |
14 |
--------------------------------------------------------------------------------
/longiseg/utilities/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor:
5 | return torch.softmax(x, 0)
6 |
7 |
8 | def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor:
9 | return torch.softmax(x, 1)
10 |
11 |
12 | def empty_cache(device: torch.device):
13 | if device.type == 'cuda':
14 | torch.cuda.empty_cache()
15 | elif device.type == 'mps':
16 | from torch import mps
17 | mps.empty_cache()
18 | else:
19 | pass
20 |
21 |
22 | class dummy_context(object):
23 | def __enter__(self):
24 | pass
25 |
26 | def __exit__(self, exc_type, exc_val, exc_tb):
27 | pass
28 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/resampling/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import longiseg
4 | from batchgenerators.utilities.file_and_folder_operations import join
5 | from longiseg.utilities.find_class_by_name import recursive_find_python_class
6 |
7 |
8 | def recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable:
9 | ret = recursive_find_python_class(join(longiseg.__path__[0], "preprocessing", "resampling"), resampling_fn,
10 | 'longiseg.preprocessing.resampling')
11 | if ret is None:
12 | raise RuntimeError("Unable to find resampling function named '%s'. Please make sure this fn is located in the "
13 | "longiseg.preprocessing.resampling module." % resampling_fn)
14 | else:
15 | return ret
16 |
--------------------------------------------------------------------------------
/longiseg/training/lr_scheduler/polylr.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 |
4 | class PolyLRScheduler(_LRScheduler):
5 | def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None):
6 | self.optimizer = optimizer
7 | self.initial_lr = initial_lr
8 | self.max_steps = max_steps
9 | self.exponent = exponent
10 | self.ctr = 0
11 | super().__init__(optimizer, current_step if current_step is not None else -1)
12 |
13 | def step(self, current_step=None):
14 | if current_step is None or current_step == -1:
15 | current_step = self.ctr
16 | self.ctr += 1
17 |
18 | new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent
19 | for param_group in self.optimizer.param_groups:
20 | param_group['lr'] = new_lr
21 |
--------------------------------------------------------------------------------
/longiseg/utilities/find_class_by_name.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import pkgutil
3 |
4 | from batchgenerators.utilities.file_and_folder_operations import join
5 |
6 |
7 | def recursive_find_python_class(folder: str, class_name: str, current_module: str):
8 | tr = None
9 | for importer, modname, ispkg in pkgutil.iter_modules([folder]):
10 | # print(modname, ispkg)
11 | if not ispkg:
12 | m = importlib.import_module(current_module + "." + modname)
13 | if hasattr(m, class_name):
14 | tr = getattr(m, class_name)
15 | break
16 |
17 | if tr is None:
18 | for importer, modname, ispkg in pkgutil.iter_modules([folder]):
19 | if ispkg:
20 | next_current_module = current_module + "." + modname
21 | tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module)
22 | if tr is not None:
23 | break
24 | return tr
25 |
--------------------------------------------------------------------------------
/longiseg/utilities/collate_outputs.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 |
5 |
6 | def collate_outputs(outputs: List[dict]):
7 | """
8 | used to collate default train_step and validation_step outputs. If you want something different then you gotta
9 | extend this
10 |
11 | we expect outputs to be a list of dictionaries where each of the dict has the same set of keys
12 | """
13 | collated = {}
14 | for k in outputs[0].keys():
15 | if np.isscalar(outputs[0][k]):
16 | collated[k] = [o[k] for o in outputs]
17 | elif isinstance(outputs[0][k], np.ndarray):
18 | collated[k] = np.vstack([o[k][None] for o in outputs])
19 | elif isinstance(outputs[0][k], list):
20 | collated[k] = [item for o in outputs for item in o[k]]
21 | else:
22 | raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. '
23 | f'Modify collate_outputs to add this functionality')
24 | return collated
--------------------------------------------------------------------------------
/longiseg/evaluation/metrics/volumetric_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None):
5 | if ignore_mask is None:
6 | use_mask = np.ones_like(mask_ref, dtype=bool)
7 | else:
8 | use_mask = ~ignore_mask
9 | tp = np.sum((mask_ref & mask_pred) & use_mask)
10 | fp = np.sum(((~mask_ref) & mask_pred) & use_mask)
11 | fn = np.sum((mask_ref & (~mask_pred)) & use_mask)
12 | tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask)
13 | return tp, fp, fn, tn
14 |
15 |
16 | def compute_volumetric_metrics(mask_ref, mask_pred, ignore_mask):
17 | tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask)
18 | dice = 2 * tp / (2 * tp + fp + fn) if tp + fp + fn > 0 else 1
19 | iou = tp / (tp + fp + fn) if tp + fp + fn > 0 else 1
20 | recall = tp / (tp + fn) if tp + fn > 0 else 1
21 | precision = tp / (tp + fp) if tp + fp > 0 else 1
22 | return dice, iou, recall, precision, tp, fp, fn, tn
--------------------------------------------------------------------------------
/longiseg/training/data_augmentation/custom_transforms/masking.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
4 |
5 |
6 | class MaskTransform(AbstractTransform):
7 | def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0,
8 | data_key: str = "data", seg_key: str = "seg"):
9 | """
10 | Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!!
11 | """
12 | self.apply_to_channels = apply_to_channels
13 | self.seg_key = seg_key
14 | self.data_key = data_key
15 | self.set_outside_to = set_outside_to
16 | self.mask_idx_in_seg = mask_idx_in_seg
17 |
18 | def __call__(self, **data_dict):
19 | mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0
20 | for c in self.apply_to_channels:
21 | data_dict[self.data_key][:, c][mask] = self.set_outside_to
22 | return data_dict
23 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/normalization/map_channel_name_to_normalization.py:
--------------------------------------------------------------------------------
1 | from typing import Type
2 |
3 | from longiseg.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \
4 | ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization
5 |
6 | channel_name_to_normalization_mapping = {
7 | 'ct': CTNormalization,
8 | 'nonorm': NoNormalization,
9 | 'zscore': ZScoreNormalization,
10 | 'rescale_to_0_1': RescaleTo01Normalization,
11 | 'rgb_to_0_1': RGBTo01Normalization
12 | }
13 |
14 |
15 | def get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]:
16 | """
17 | If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is
18 | not found, use the default (ZScoreNormalization)
19 | """
20 | norm_scheme = channel_name_to_normalization_mapping.get(channel_name.casefold())
21 | if norm_scheme is None:
22 | norm_scheme = ZScoreNormalization
23 | # print('Using %s for image normalization' % norm_scheme.__name__)
24 | return norm_scheme
25 |
--------------------------------------------------------------------------------
/longiseg/training/data_augmentation/compute_initial_patch_size.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range):
5 | if isinstance(rot_x, (tuple, list)):
6 | rot_x = max(np.abs(rot_x))
7 | if isinstance(rot_y, (tuple, list)):
8 | rot_y = max(np.abs(rot_y))
9 | if isinstance(rot_z, (tuple, list)):
10 | rot_z = max(np.abs(rot_z))
11 | rot_x = min(90 / 360 * 2. * np.pi, rot_x)
12 | rot_y = min(90 / 360 * 2. * np.pi, rot_y)
13 | rot_z = min(90 / 360 * 2. * np.pi, rot_z)
14 | from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
15 | coords = np.array(final_patch_size)
16 | final_shape = np.copy(coords)
17 | if len(coords) == 3:
18 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
19 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
20 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
21 | elif len(coords) == 2:
22 | final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
23 | final_shape /= min(scale_range)
24 | return final_shape.astype(int)
25 |
--------------------------------------------------------------------------------
/longiseg/training/loss/robust_ce_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | import numpy as np
4 |
5 |
6 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
7 | """
8 | this is just a compatibility layer because my target tensor is float and has an extra dimension
9 |
10 | input must be logits, not probabilities!
11 | """
12 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
13 | if target.ndim == input.ndim:
14 | assert target.shape[1] == 1
15 | target = target[:, 0]
16 | return super().forward(input, target.long())
17 |
18 |
19 | class TopKLoss(RobustCrossEntropyLoss):
20 | """
21 | input must be logits, not probabilities!
22 | """
23 | def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0):
24 | self.k = k
25 | super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing)
26 |
27 | def forward(self, inp, target):
28 | target = target[:, 0].long()
29 | res = super(TopKLoss, self).forward(inp, target)
30 | num_voxels = np.prod(res.shape, dtype=np.int64)
31 | res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
32 | return res.mean()
33 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/cropping/cropping.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage import binary_fill_holes
3 | from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, bounding_box_to_slice
4 |
5 |
6 | def create_nonzero_mask(data):
7 | """
8 |
9 | :param data:
10 | :return: the mask is True where the data is nonzero
11 | """
12 | assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
13 | nonzero_mask = data[0] != 0
14 | for c in range(1, data.shape[0]):
15 | nonzero_mask |= data[c] != 0
16 | return binary_fill_holes(nonzero_mask)
17 |
18 |
19 | def crop_to_nonzero(data, seg=None, bbox=None, nonzero_label=-1):
20 | """
21 |
22 | :param data:
23 | :param seg:
24 | :param nonzero_label: this will be written into the segmentation map
25 | :return:
26 | """
27 | nonzero_mask = create_nonzero_mask(data)
28 | if bbox is None:
29 | bbox = get_bbox_from_mask(nonzero_mask)
30 | slicer = bounding_box_to_slice(bbox)
31 | nonzero_mask = nonzero_mask[slicer][None]
32 |
33 | slicer = (slice(None), ) + slicer
34 | data = data[slicer]
35 | if seg is not None:
36 | seg = seg[slicer]
37 | seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
38 | else:
39 | seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))
40 | return data, seg, bbox
41 |
42 |
43 |
--------------------------------------------------------------------------------
/longiseg/utilities/crossval_split.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | from sklearn.model_selection import KFold, GroupKFold
5 |
6 |
7 | def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]:
8 | splits = []
9 | kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
10 | for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)):
11 | train_keys = np.array(train_identifiers)[train_idx]
12 | test_keys = np.array(train_identifiers)[test_idx]
13 | splits.append({})
14 | splits[-1]['train'] = list(train_keys)
15 | splits[-1]['val'] = list(test_keys)
16 | return splits
17 |
18 |
19 | def generate_crossval_split_longi(train_patients, seed=12345, n_splits=5):
20 | splits = []
21 | all_keys = []
22 | groups = []
23 | for idx, patient in enumerate(train_patients):
24 | all_keys = all_keys + train_patients[patient]
25 | groups = groups + [idx] * len(train_patients[patient])
26 | groups = np.array(groups)
27 | group_kfold = GroupKFold(n_splits=n_splits)
28 |
29 | for i, (train_idx, test_idx) in enumerate(group_kfold.split(all_keys, groups=groups)):
30 | train_keys = np.array(all_keys)[train_idx]
31 | test_keys = np.array(all_keys)[test_idx]
32 | splits.append({})
33 | splits[-1]['train'] = list(train_keys)
34 | splits[-1]['val'] = list(test_keys)
35 | return splits
--------------------------------------------------------------------------------
/longiseg/training/data_augmentation/custom_transforms/region_based_training.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Union
2 |
3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
4 | import numpy as np
5 |
6 |
7 | class ConvertSegmentationToRegionsTransform(AbstractTransform):
8 | def __init__(self, regions: Union[List, Tuple],
9 | seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0):
10 | """
11 | regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region,
12 | example:
13 | regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2
14 | :param regions:
15 | :param seg_key:
16 | :param output_key:
17 | """
18 | self.seg_channel = seg_channel
19 | self.output_key = output_key
20 | self.seg_key = seg_key
21 | self.regions = regions
22 |
23 | def __call__(self, **data_dict):
24 | seg = data_dict.get(self.seg_key)
25 | if seg is not None:
26 | b, c, *shape = seg.shape
27 | region_output = np.zeros((b, len(self.regions), *shape), dtype=bool)
28 | for region_id, region_labels in enumerate(self.regions):
29 | region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
30 | data_dict[self.output_key] = region_output.astype(np.uint8, copy=False)
31 | return data_dict
32 |
33 |
--------------------------------------------------------------------------------
/longiseg/training/loss/deep_supervision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class DeepSupervisionWrapper(nn.Module):
6 | def __init__(self, loss, weight_factors=None):
7 | """
8 | Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of
9 | inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then
10 | applied to each entry like this:
11 | l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ...
12 | If weights are None, all w will be 1.
13 | """
14 | super(DeepSupervisionWrapper, self).__init__()
15 | assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0"
16 | self.weight_factors = tuple(weight_factors)
17 | self.loss = loss
18 |
19 | def forward(self, *args):
20 | assert all([isinstance(i, (tuple, list)) for i in args]), \
21 | f"all args must be either tuple or list, got {[type(i) for i in args]}"
22 | # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because
23 | # this code is executed a lot of times!
24 |
25 | if self.weight_factors is None:
26 | weights = (1, ) * len(args[0])
27 | else:
28 | weights = self.weight_factors
29 |
30 | return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0])
31 |
--------------------------------------------------------------------------------
/documentation/setting_up_paths.md:
--------------------------------------------------------------------------------
1 | # Setting up Paths
2 |
3 | nnU-Net relies on environment variables to know where raw data, preprocessed data and trained model weights are stored.
4 | To use the full functionality of nnU-Net, the following three environment variables must be set:
5 |
6 | 1) `LongiSeg_raw`: This is where you place the raw datasets. This folder will have one subfolder for each dataset names
7 | DatasetXXX_YYY where XXX is a 3-digit identifier (such as 001, 002, 043, 999, ...) and YYY is the (unique)
8 | dataset name. The datasets must be in nnU-Net format, see [here](dataset_format.md).
9 |
10 | Example tree structure:
11 | ```
12 | LongiSeg_raw/Dataset001_NAME1
13 | ├── dataset.json
14 | ├── imagesTr
15 | │ ├── ...
16 | ├── imagesTs
17 | │ ├── ...
18 | └── labelsTr
19 | ├── ...
20 | LongiSeg_raw/Dataset002_NAME2
21 | ├── dataset.json
22 | ├── imagesTr
23 | │ ├── ...
24 | ├── imagesTs
25 | │ ├── ...
26 | └── labelsTr
27 | ├── ...
28 | ```
29 |
30 | 2) `LongiSeg_preprocessed`: This is the folder where the preprocessed data will be saved. The data will also be read from
31 | this folder during training. It is important that this folder is located on a drive with low access latency and high
32 | throughput (such as a nvme SSD (PCIe gen 3 is sufficient)).
33 |
34 | 3) `LongiSeg_results`: This specifies where nnU-Net will save the model weights. If pretrained models are downloaded, this
35 | is where it will save them.
36 |
37 | ### How to set environment variables
38 | See [here](set_environment_variables.md).
--------------------------------------------------------------------------------
/longiseg/evaluation/metrics/distance_metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import numpy as np
3 | from scipy.spatial import KDTree
4 | from scipy import ndimage
5 |
6 |
7 | def SSD(gt: np.ndarray, pred: np.ndarray, spacing: tuple[float] = (1., 1., 1.), ignore_mask: Optional[np.ndarray] = None) -> float:
8 | # slightly adapted from https://github.com/emrekavur/CHAOS-evaluation/blob/master/Python/CHAOSmetrics.py
9 | gt_mask = gt if ignore_mask is None else gt & ~ignore_mask
10 | pred_mask = pred if ignore_mask is None else pred & ~ignore_mask
11 | gt_sum = np.sum(gt_mask)
12 | pred_sum = np.sum(pred_mask)
13 | if gt_sum == 0 or pred_sum == 0:
14 | return (0, 0) if gt_sum == pred_sum else (1000, 1000) # maximum value chosen to fit the largest volumes
15 |
16 | struct = ndimage.generate_binary_structure(3, 1)
17 | spacing = np.array(spacing)
18 |
19 | gt_border = gt_mask ^ ndimage.binary_erosion(gt_mask, structure=struct, border_value=1)
20 | gt_border_voxels = np.array(np.where(gt_border)).T * spacing
21 |
22 | pred_border = pred_mask ^ ndimage.binary_erosion(pred_mask, structure=struct, border_value=1)
23 | pred_border_voxels = np.array(np.where(pred_border)).T * spacing
24 |
25 | tree_ref = KDTree(gt_border_voxels)
26 | dist_seg_to_ref, _ = tree_ref.query(pred_border_voxels)
27 | tree_seg = KDTree(pred_border_voxels)
28 | dist_ref_to_seg, _ = tree_seg.query(gt_border_voxels)
29 |
30 | assd = (dist_seg_to_ref.sum() + dist_ref_to_seg.sum()) / (len(dist_seg_to_ref) + len(dist_ref_to_seg))
31 | hd95 = np.percentile(np.concatenate((dist_seg_to_ref, dist_ref_to_seg)), 95)
32 | return assd, hd95
--------------------------------------------------------------------------------
/documentation/dataset_format_inference.md:
--------------------------------------------------------------------------------
1 | # Data format for Inference
2 | Read the documentation on the overall [data format](dataset_format.md) first!
3 |
4 | The data format for inference must match the one used for the raw data (**specifically, the images must be in exactly
5 | the same format as in the imagesTr folder**). As before, the filenames must start with a
6 | unique identifier, followed by a 4-digit modality identifier. Here is an example for two different datasets:
7 |
8 | 1) Task005_Prostate:
9 |
10 | This task has 2 modalities, so the files in the input folder must look like this:
11 |
12 | input_folder
13 | ├── prostate_03_0000.nii.gz
14 | ├── prostate_03_0001.nii.gz
15 | ├── prostate_05_0000.nii.gz
16 | ├── prostate_05_0001.nii.gz
17 | ├── prostate_08_0000.nii.gz
18 | ├── prostate_08_0001.nii.gz
19 | ├── ...
20 |
21 | _0000 has to be the T2 image and _0001 has to be the ADC image (as specified by 'channel_names' in the
22 | dataset.json), exactly the same as was used for training.
23 |
24 | 2) Task002_Heart:
25 |
26 | imagesTs
27 | ├── la_001_0000.nii.gz
28 | ├── la_002_0000.nii.gz
29 | ├── la_006_0000.nii.gz
30 | ├── ...
31 |
32 | Task002 only has one modality, so each case only has one _0000.nii.gz file.
33 |
34 |
35 | The segmentations in the output folder will be named {CASE_IDENTIFIER}.nii.gz (omitting the modality identifier).
36 |
37 | Remember that the file format used for inference (.nii.gz in this example) must be the same as was used for training
38 | (and as was specified in 'file_ending' in the dataset.json)!
39 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/data_augmentation/nnUNetTrainer_noDummy2DDA.py:
--------------------------------------------------------------------------------
1 | from longiseg.training.data_augmentation.compute_initial_patch_size import get_patch_size
2 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
3 | import numpy as np
4 |
5 |
6 | class nnUNetTrainer_noDummy2DDA(nnUNetTrainerNoLongi):
7 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
8 | do_dummy_2d_data_aug = False
9 |
10 | patch_size = self.configuration_manager.patch_size
11 | dim = len(patch_size)
12 | if dim == 2:
13 | if max(patch_size) / min(patch_size) > 1.5:
14 | rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
15 | else:
16 | rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
17 | mirror_axes = (0, 1)
18 | elif dim == 3:
19 | rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
20 | mirror_axes = (0, 1, 2)
21 | else:
22 | raise RuntimeError()
23 |
24 | # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
25 | # old nnunet for now)
26 | initial_patch_size = get_patch_size(patch_size[-dim:],
27 | rotation_for_DA,
28 | rotation_for_DA,
29 | rotation_for_DA,
30 | (0.85, 1.25))
31 |
32 | self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
33 | self.inference_allowed_mirroring_axes = mirror_axes
34 |
35 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | *.memmap
92 | *.png
93 | *.zip
94 | *.npz
95 | *.npy
96 | *.jpg
97 | *.jpeg
98 | .idea
99 | *.txt
100 | .idea/*
101 | *.png
102 | *.nii.gz
103 | *.nii
104 | *.tif
105 | *.bmp
106 | *.pkl
107 | *.xml
108 | *.pkl
109 | *.pdf
110 | *.png
111 | *.jpg
112 | *.jpeg
113 |
114 | *.model
115 |
116 | !documentation/assets/*
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Tuple, List
2 |
3 | import numpy as np
4 | from batchgeneratorsv2.helpers.scalar_type import RandomScalar
5 | from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
6 |
7 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
8 |
9 |
10 | class nnUNetTrainerNoDA(nnUNetTrainerNoLongi):
11 | @staticmethod
12 | def get_training_transforms(
13 | patch_size: Union[np.ndarray, Tuple[int]],
14 | rotation_for_DA: RandomScalar,
15 | deep_supervision_scales: Union[List, Tuple, None],
16 | mirror_axes: Tuple[int, ...],
17 | do_dummy_2d_data_aug: bool,
18 | use_mask_for_norm: List[bool] = None,
19 | is_cascaded: bool = False,
20 | foreground_labels: Union[Tuple[int, ...], List[int]] = None,
21 | regions: List[Union[List[int], Tuple[int, ...], int]] = None,
22 | ignore_label: int = None,
23 | ) -> BasicTransform:
24 | return nnUNetTrainerLongi.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels,
25 | regions, ignore_label)
26 |
27 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
28 | # we need to disable mirroring here so that no mirroring will be applied in inference!
29 | rotation_for_DA, do_dummy_2d_data_aug, _, _ = \
30 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
31 | mirror_axes = None
32 | self.inference_allowed_mirroring_axes = None
33 | initial_patch_size = self.configuration_manager.patch_size
34 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
35 |
36 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/loss/nnUNetTrainerCELoss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from longiseg.training.loss.deep_supervision import DeepSupervisionWrapper
3 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
4 | from longiseg.training.loss.robust_ce_loss import RobustCrossEntropyLoss
5 | import numpy as np
6 |
7 |
8 | class nnUNetTrainerCELoss(nnUNetTrainerNoLongi):
9 | def _build_loss(self):
10 | assert not self.label_manager.has_regions, "regions not supported by this trainer"
11 | loss = RobustCrossEntropyLoss(
12 | weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100
13 | )
14 |
15 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
16 | # this gives higher resolution outputs more weight in the loss
17 | if self.enable_deep_supervision:
18 | deep_supervision_scales = self._get_deep_supervision_scales()
19 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])
20 | weights[-1] = 0
21 |
22 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
23 | weights = weights / weights.sum()
24 | # now wrap the loss
25 | loss = DeepSupervisionWrapper(loss, weights)
26 | return loss
27 |
28 |
29 | class nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss):
30 | def __init__(
31 | self,
32 | plans: dict,
33 | configuration: str,
34 | fold: int,
35 | dataset_json: dict,
36 | device: torch.device = torch.device("cuda"),
37 | ):
38 | """used for debugging plans etc"""
39 | super().__init__(plans, configuration, fold, dataset_json, device)
40 | self.num_epochs = 5
41 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/network_architecture/nnUNetTrainerBN.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Tuple, List
2 | from dynamic_network_architectures.building_blocks.helper import get_matching_batchnorm
3 | from torch import nn
4 |
5 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
6 |
7 |
8 | class nnUNetTrainerBN(nnUNetTrainerNoLongi):
9 | @staticmethod
10 | def build_network_architecture(architecture_class_name: str,
11 | arch_init_kwargs: dict,
12 | arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
13 | num_input_channels: int,
14 | num_output_channels: int,
15 | enable_deep_supervision: bool = True) -> nn.Module:
16 |
17 | if 'norm_op' not in arch_init_kwargs.keys():
18 | raise RuntimeError("'norm_op' not found in arch_init_kwargs. This does not look like an architecture "
19 | "I can hack BN into. This trainer only works with default nnU-Net architectures.")
20 |
21 | from pydoc import locate
22 | conv_op = locate(arch_init_kwargs['conv_op'])
23 | bn_class = get_matching_batchnorm(conv_op)
24 | arch_init_kwargs['norm_op'] = bn_class.__module__ + '.' + bn_class.__name__
25 | arch_init_kwargs['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True}
26 |
27 | return nnUNetTrainerLongi.build_network_architecture(architecture_class_name,
28 | arch_init_kwargs,
29 | arch_init_kwargs_req_import,
30 | num_input_channels,
31 | num_output_channels, enable_deep_supervision)
32 |
33 |
--------------------------------------------------------------------------------
/longiseg/utilities/ddp_allgather.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Any, Optional, Tuple
15 |
16 | import torch
17 | from torch import distributed
18 |
19 |
20 | def print_if_rank0(*args):
21 | if distributed.get_rank() == 0:
22 | print(*args)
23 |
24 |
25 | class AllGatherGrad(torch.autograd.Function):
26 | # stolen from pytorch lightning
27 | @staticmethod
28 | def forward(
29 | ctx: Any,
30 | tensor: torch.Tensor,
31 | group: Optional["torch.distributed.ProcessGroup"] = None,
32 | ) -> torch.Tensor:
33 | ctx.group = group
34 |
35 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
36 |
37 | torch.distributed.all_gather(gathered_tensor, tensor, group=group)
38 | gathered_tensor = torch.stack(gathered_tensor, dim=0)
39 |
40 | return gathered_tensor
41 |
42 | @staticmethod
43 | def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
44 | grad_output = torch.cat(grad_output)
45 |
46 | torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
47 |
48 | return grad_output[torch.distributed.get_rank()], None
49 |
50 |
--------------------------------------------------------------------------------
/longiseg/utilities/default_n_proc_DA.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 |
4 |
5 | def get_allowed_n_proc_DA():
6 | """
7 | This function is used to set the number of processes used on different Systems. It is specific to our cluster
8 | infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed.
9 |
10 | IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script
11 | (see first line).
12 |
13 | Interpret the output as the number of processes used for data augmentation PER GPU.
14 |
15 | The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our
16 | systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per
17 | GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use.
18 | """
19 |
20 | if 'nnUNet_n_proc_DA' in os.environ.keys():
21 | use_this = int(os.environ['nnUNet_n_proc_DA'])
22 | else:
23 | hostname = subprocess.getoutput(['hostname'])
24 | if hostname in ['Fabian', ]:
25 | use_this = 12
26 | elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']:
27 | use_this = 16
28 | elif hostname.startswith('e230-dgx1'):
29 | use_this = 10
30 | elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'):
31 | use_this = 16
32 | elif hostname.startswith('e230-dgx2'):
33 | use_this = 6
34 | elif hostname.startswith('e230-dgxa100-'):
35 | use_this = 28
36 | elif hostname.startswith('lsf22-gpu'):
37 | use_this = 28
38 | elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'):
39 | use_this = 12
40 | else:
41 | use_this = 12 # default value
42 |
43 | use_this = min(use_this, os.cpu_count())
44 | return use_this
45 |
--------------------------------------------------------------------------------
/longiseg/model_sharing/model_download.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import requests
4 | from batchgenerators.utilities.file_and_folder_operations import join, isfile
5 | from time import time
6 | from longiseg.model_sharing.model_import import install_model_from_zip_file
7 | from longiseg.paths import LongiSeg_results
8 | from tqdm import tqdm
9 |
10 |
11 | def download_and_install_from_url(url):
12 | assert LongiSeg_results is not None, "Cannot install model because network_training_output_dir is not " \
13 | "set (RESULTS_FOLDER missing as environment variable, see " \
14 | "Installation instructions)"
15 | print('Downloading pretrained model from url:', url)
16 | import http.client
17 | http.client.HTTPConnection._http_vsn = 10
18 | http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0'
19 |
20 | import os
21 | home = os.path.expanduser('~')
22 | random_number = int(time() * 1e7)
23 | tempfile = join(home, f'.nnunetdownload_{str(random_number)}')
24 |
25 | try:
26 | download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16)
27 | print("Download finished. Extracting...")
28 | install_model_from_zip_file(tempfile)
29 | print("Done")
30 | except Exception as e:
31 | raise e
32 | finally:
33 | if isfile(tempfile):
34 | os.remove(tempfile)
35 |
36 |
37 | def download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str:
38 | # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests
39 | # NOTE the stream=True parameter below
40 | with requests.get(url, stream=True, timeout=100) as r:
41 | r.raise_for_status()
42 | with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f:
43 | for chunk in r.iter_content(chunk_size=chunk_size):
44 | f.write(chunk)
45 | return local_filename
46 |
47 |
48 |
--------------------------------------------------------------------------------
/documentation/manual_data_splits.md:
--------------------------------------------------------------------------------
1 | # How to generate custom splits in nnU-Net
2 |
3 | Sometimes, the default 5-fold cross-validation split by nnU-Net does not fit a project. Maybe you want to run 3-fold
4 | cross-validation instead? Or maybe your training cases cannot be split randomly and require careful stratification.
5 | Fear not, for nnU-Net has got you covered (it really can do anything <3).
6 |
7 | The splits nnU-Net uses are generated in the `do_split` function of nnUNetTrainer. This function will first look for
8 | existing splits, stored as a file, and if no split exists it will create one. So if you wish to influence the split,
9 | manually creating a split file that will then be recognized and used is the way to go!
10 |
11 | The split file is located in the `LongiSeg_preprocessed/DATASETXXX_NAME` folder. So it is best practice to first
12 | populate this folder by running `nnUNetv2_plan_and_preproccess`.
13 |
14 | Splits are stored as a .json file. They are a simple python list. The length of that list is the number of splits it
15 | contains (so it's 5 in the default nnU-Net). Each list entry is a dictionary with keys 'train' and 'val'. Values are
16 | again simply lists with the train identifiers in each set. To illustrate this, I am just messing with the Dataset002
17 | file as an example:
18 |
19 | ```commandline
20 | In [1]: from batchgenerators.utilities.file_and_folder_operations import load_json
21 |
22 | In [2]: splits = load_json('splits_final.json')
23 |
24 | In [3]: len(splits)
25 | Out[3]: 5
26 |
27 | In [4]: splits[0].keys()
28 | Out[4]: dict_keys(['train', 'val'])
29 |
30 | In [5]: len(splits[0]['train'])
31 | Out[5]: 16
32 |
33 | In [6]: len(splits[0]['val'])
34 | Out[6]: 4
35 |
36 | In [7]: print(splits[0])
37 | {'train': ['la_003', 'la_004', 'la_005', 'la_009', 'la_010', 'la_011', 'la_014', 'la_017', 'la_018', 'la_019', 'la_020', 'la_022', 'la_023', 'la_026', 'la_029', 'la_030'],
38 | 'val': ['la_007', 'la_016', 'la_021', 'la_024']}
39 | ```
40 |
41 | If you are still not sure what splits are supposed to look like, simply download some reference dataset from the
42 | [Medical Decathlon](http://medicaldecathlon.com/), start some training (to generate the splits) and manually inspect
43 | the .json file with your text editor of choice!
44 |
45 | In order to generate your custom splits, all you need to do is reproduce the data structure explained above and save it as
46 | `splits_final.json` in the `LongiSeg_preprocessed/DATASETXXX_NAME` folder. Then use `nnUNetv2_train` etc. as usual.
--------------------------------------------------------------------------------
/longiseg/paths.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | """
18 | PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP
19 | """
20 |
21 | LongiSeg_raw = os.environ.get('LongiSeg_raw')
22 | LongiSeg_preprocessed = os.environ.get('LongiSeg_preprocessed')
23 | LongiSeg_results = os.environ.get('LongiSeg_results')
24 |
25 | if LongiSeg_raw is None:
26 | print("Could not find LongiSeg_raw environment variable, falling back to nnUNet_raw")
27 | LongiSeg_raw = os.environ.get('nnUNet_raw')
28 | if LongiSeg_preprocessed is None:
29 | print("Could not find LongiSeg_preprocessed environment variable, falling back to nnUNet_preprocessed")
30 | LongiSeg_preprocessed = os.environ.get('nnUNet_preprocessed')
31 | if LongiSeg_results is None:
32 | print("Could not find LongiSeg_results environment variable, falling back to nnUNet_results")
33 | LongiSeg_results = os.environ.get('nnUNet_results')
34 |
35 | if LongiSeg_raw is None:
36 | print("LongiSeg_raw is not defined and LongiSeg can only be used on data for which preprocessed files "
37 | "are already present on your system. LongiSeg cannot be used for experiment planning and preprocessing like "
38 | "this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set "
39 | "this up properly.")
40 |
41 | if LongiSeg_preprocessed is None:
42 | print("LongiSeg_preprocessed is not defined and LongiSeg can not be used for preprocessing "
43 | "or training. If this is not intended, please read documentation/setting_up_paths.md for information on how "
44 | "to set this up.")
45 |
46 | if LongiSeg_results is None:
47 | print("LongiSeg_results is not defined and LongiSeg cannot be used for training or "
48 | "inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information "
49 | "on how to set this up.")
50 |
--------------------------------------------------------------------------------
/longiseg/utilities/json_export.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def recursive_fix_for_json_export(my_dict: dict):
8 | # json is ... a very nice thing to have
9 | # 'cannot serialize object of type bool_/int64/float64'. Apart from that of course...
10 | keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys....
11 | for k in keys:
12 | if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)):
13 | tmp = my_dict[k]
14 | del my_dict[k]
15 | my_dict[int(k)] = tmp
16 | del tmp
17 | k = int(k)
18 |
19 | if isinstance(my_dict[k], dict):
20 | recursive_fix_for_json_export(my_dict[k])
21 | elif isinstance(my_dict[k], np.ndarray):
22 | assert my_dict[k].ndim == 1, 'only 1d arrays are supported'
23 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=list)
24 | elif isinstance(my_dict[k], (np.bool_,)):
25 | my_dict[k] = bool(my_dict[k])
26 | elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)):
27 | my_dict[k] = int(my_dict[k])
28 | elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)):
29 | my_dict[k] = float(my_dict[k])
30 | elif isinstance(my_dict[k], list):
31 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k]))
32 | elif isinstance(my_dict[k], tuple):
33 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple)
34 | elif isinstance(my_dict[k], torch.device):
35 | my_dict[k] = str(my_dict[k])
36 | else:
37 | pass # pray it can be serialized
38 |
39 |
40 | def fix_types_iterable(iterable, output_type):
41 | # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep your hands off of this.
42 | out = []
43 | for i in iterable:
44 | if type(i) in (np.int64, np.int32, np.int8, np.uint8):
45 | out.append(int(i))
46 | elif isinstance(i, dict):
47 | recursive_fix_for_json_export(i)
48 | out.append(i)
49 | elif type(i) in (np.float32, np.float64, np.float16):
50 | out.append(float(i))
51 | elif type(i) in (np.bool_,):
52 | out.append(bool(i))
53 | elif isinstance(i, str):
54 | out.append(i)
55 | elif isinstance(i, Iterable):
56 | # print('recursive call on', i, type(i))
57 | out.append(fix_types_iterable(i, type(i)))
58 | else:
59 | out.append(i)
60 | return output_type(out)
61 |
--------------------------------------------------------------------------------
/longiseg/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union, List
2 |
3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
4 |
5 |
6 | class Convert3DTo2DTransform(AbstractTransform):
7 | def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')):
8 | """
9 | Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel
10 | """
11 | self.apply_to_keys = apply_to_keys
12 |
13 | def __call__(self, **data_dict):
14 | for k in self.apply_to_keys:
15 | shp = data_dict[k].shape
16 | assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.'
17 | data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
18 | shape_key = f'orig_shape_{k}'
19 | assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \
20 | f'It does that using the {shape_key} key. That key is ' \
21 | f'already taken. Bummer.'
22 | data_dict[shape_key] = shp
23 | return data_dict
24 |
25 |
26 | class Convert2DTo3DTransform(AbstractTransform):
27 | def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')):
28 | """
29 | Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z)
30 | """
31 | self.apply_to_keys = apply_to_keys
32 |
33 | def __call__(self, **data_dict):
34 | for k in self.apply_to_keys:
35 | shape_key = f'orig_shape_{k}'
36 | assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \
37 | f'Convert2DTo3DTransform only works in tandem with ' \
38 | f'Convert3DTo2DTransform and you probably forgot to add ' \
39 | f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \
40 | f'is where the missing key is generated)'
41 | original_shape = data_dict[shape_key]
42 | current_shape = data_dict[k].shape
43 | data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2],
44 | current_shape[-2], current_shape[-1]))
45 | return data_dict
46 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from longiseg.training.LongiSegTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import (
4 | nnUNetTrainerBenchmark_5epochs,
5 | )
6 | from longiseg.utilities.label_handling.label_handling import determine_num_input_channels
7 |
8 |
9 | class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs):
10 | def __init__(
11 | self,
12 | plans: dict,
13 | configuration: str,
14 | fold: int,
15 | dataset_json: dict,
16 | device: torch.device = torch.device("cuda"),
17 | ):
18 | super().__init__(plans, configuration, fold, dataset_json, device)
19 | self._set_batch_size_and_oversample()
20 | num_input_channels = determine_num_input_channels(
21 | self.plans_manager, self.configuration_manager, self.dataset_json
22 | )
23 | patch_size = self.configuration_manager.patch_size
24 | dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device)
25 | if self.enable_deep_supervision:
26 | dummy_target = [
27 | torch.round(
28 | torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device)
29 | * max(self.label_manager.all_labels)
30 | )
31 | for k in self._get_deep_supervision_scales()
32 | ]
33 | else:
34 | raise NotImplementedError("This trainer does not support deep supervision")
35 | self.dummy_batch = {"data": dummy_data, "target": dummy_target}
36 |
37 | def get_dataloaders(self):
38 | return None, None
39 |
40 | def run_training(self):
41 | try:
42 | self.on_train_start()
43 |
44 | for epoch in range(self.current_epoch, self.num_epochs):
45 | self.on_epoch_start()
46 |
47 | self.on_train_epoch_start()
48 | train_outputs = []
49 | for batch_id in range(self.num_iterations_per_epoch):
50 | train_outputs.append(self.train_step(self.dummy_batch))
51 | self.on_train_epoch_end(train_outputs)
52 |
53 | with torch.no_grad():
54 | self.on_validation_epoch_start()
55 | val_outputs = []
56 | for batch_id in range(self.num_val_iterations_per_epoch):
57 | val_outputs.append(self.validation_step(self.dummy_batch))
58 | self.on_validation_epoch_end(val_outputs)
59 |
60 | self.on_epoch_end()
61 |
62 | self.on_train_end()
63 | except RuntimeError:
64 | self.crashed_with_runtime_error = True
65 |
--------------------------------------------------------------------------------
/longiseg/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union, List
2 |
3 | from batchgenerators.augmentations.utils import resize_segmentation
4 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
5 | import numpy as np
6 |
7 |
8 | class DownsampleSegForDSTransform2(AbstractTransform):
9 | '''
10 | data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
11 | '''
12 | def __init__(self, ds_scales: Union[List, Tuple],
13 | order: int = 0, input_key: str = "seg",
14 | output_key: str = "seg", axes: Tuple[int] = None):
15 | """
16 | Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision
17 | output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.
18 | ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling
19 | for each axis independently
20 | """
21 | self.axes = axes
22 | self.output_key = output_key
23 | self.input_key = input_key
24 | self.order = order
25 | self.ds_scales = ds_scales
26 |
27 | def __call__(self, **data_dict):
28 | if self.axes is None:
29 | axes = list(range(2, data_dict[self.input_key].ndim))
30 | else:
31 | axes = self.axes
32 |
33 | output = []
34 | for s in self.ds_scales:
35 | if not isinstance(s, (tuple, list)):
36 | s = [s] * len(axes)
37 | else:
38 | assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \
39 | f'for each axis) then the number of entried in that tuple (here ' \
40 | f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'
41 |
42 | if all([i == 1 for i in s]):
43 | output.append(data_dict[self.input_key])
44 | else:
45 | new_shape = np.array(data_dict[self.input_key].shape).astype(float)
46 | for i, a in enumerate(axes):
47 | new_shape[a] *= s[i]
48 | new_shape = np.round(new_shape).astype(int)
49 | out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)
50 | for b in range(data_dict[self.input_key].shape[0]):
51 | for c in range(data_dict[self.input_key].shape[1]):
52 | out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)
53 | output.append(out_seg)
54 | data_dict[self.output_key] = output
55 | return data_dict
56 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/optimizer/nnUNetTrainerAdam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import Adam, AdamW
3 |
4 | from longiseg.training.lr_scheduler.polylr import PolyLRScheduler
5 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
6 |
7 |
8 | class nnUNetTrainerAdam(nnUNetTrainerNoLongi):
9 | def configure_optimizers(self):
10 | optimizer = AdamW(self.network.parameters(),
11 | lr=self.initial_lr,
12 | weight_decay=self.weight_decay,
13 | amsgrad=True)
14 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
15 | # momentum=0.99, nesterov=True)
16 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
17 | return optimizer, lr_scheduler
18 |
19 |
20 | class nnUNetTrainerVanillaAdam(nnUNetTrainerNoLongi):
21 | def configure_optimizers(self):
22 | optimizer = Adam(self.network.parameters(),
23 | lr=self.initial_lr,
24 | weight_decay=self.weight_decay)
25 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
26 | # momentum=0.99, nesterov=True)
27 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
28 | return optimizer, lr_scheduler
29 |
30 |
31 | class nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam):
32 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
33 | device: torch.device = torch.device('cuda')):
34 | super().__init__(plans, configuration, fold, dataset_json, device)
35 | self.initial_lr = 1e-3
36 |
37 |
38 | class nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam):
39 | # https://twitter.com/karpathy/status/801621764144971776?lang=en
40 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
41 | device: torch.device = torch.device('cuda')):
42 | super().__init__(plans, configuration, fold, dataset_json, device)
43 | self.initial_lr = 3e-4
44 |
45 |
46 | class nnUNetTrainerAdam1en3(nnUNetTrainerAdam):
47 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
48 | device: torch.device = torch.device('cuda')):
49 | super().__init__(plans, configuration, fold, dataset_json, device)
50 | self.initial_lr = 1e-3
51 |
52 |
53 | class nnUNetTrainerAdam3en4(nnUNetTrainerAdam):
54 | # https://twitter.com/karpathy/status/801621764144971776?lang=en
55 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
56 | device: torch.device = torch.device('cuda')):
57 | super().__init__(plans, configuration, fold, dataset_json, device)
58 | self.initial_lr = 3e-4
59 |
--------------------------------------------------------------------------------
/documentation/explanation_normalization.md:
--------------------------------------------------------------------------------
1 | # Intensity normalization in nnU-Net
2 |
3 | The type of intensity normalization applied in nnU-Net can be controlled via the `channel_names` (former `modalities`)
4 | entry in the dataset.json. Just like the old nnU-Net, per-channel z-scoring as well as dataset-wide z-scoring based on
5 | foreground intensities are supported. However, there have been a few additions as well.
6 |
7 | Reminder: The `channel_names` entry typically looks like this:
8 |
9 | "channel_names": {
10 | "0": "T2",
11 | "1": "ADC"
12 | },
13 |
14 | It has as many entries as there are input channels for the given dataset.
15 |
16 | To tell you a secret, nnU-Net does not really care what your channels are called. We just use this to determine what normalization
17 | scheme will be used for the given dataset. nnU-Net requires you to specify a normalization strategy for each of your input channels!
18 | If you enter a channel name that is not in the following list, the default (`zscore`) will be used.
19 |
20 | Here is a list of currently available normalization schemes:
21 |
22 | - `CT`: Perform CT normalization. Specifically, collect intensity values from the foreground classes (all but the
23 | background and ignore) from all training cases, compute the mean, standard deviation as well as the 0.5 and
24 | 99.5 percentile of the values. Then clip to the percentiles, followed by subtraction of the mean and division with the
25 | standard deviation. The normalization that is applied is the same for each training case (for this input channel).
26 | The values used by nnU-Net for normalization are stored in the `foreground_intensity_properties_per_channel` entry in the
27 | corresponding plans file. This normalization is suitable for modalities presenting physical quantities such as CT
28 | images and ADC maps.
29 | - `noNorm` : do not perform any normalization at all
30 | - `rescale_to_0_1`: rescale the intensities to [0, 1]
31 | - `rgb_to_0_1`: assumes uint8 inputs. Divides by 255 to rescale uint8 to [0, 1]
32 | - `zscore`/anything else: perform z-scoring (subtract mean and standard deviation) separately for each train case
33 |
34 | **Important:** The nnU-Net default is to perform 'CT' normalization for CT images and 'zscore' for everything else! If
35 | you deviate from that path, make sure to benchmark whether that actually improves results!
36 |
37 | # How to implement custom normalization strategies?
38 | - Head over to longiseg/preprocessing/normalization
39 | - implement a new image normalization class by deriving from ImageNormalization
40 | - register it in longiseg/preprocessing/normalization/map_channel_name_to_normalization.py:channel_name_to_normalization_mapping.
41 | This is where you specify a channel name that should be associated with it
42 | - use it by specifying the correct channel_name
43 |
44 | Normalization can only be applied to one channel at a time. There is currently no way of implementing a normalization scheme
45 | that gets multiple channels as input to be used jointly!
--------------------------------------------------------------------------------
/longiseg/inference/sliding_window_prediction.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | import numpy as np
4 | import torch
5 | from typing import Union, Tuple, List
6 | from acvl_utils.cropping_and_padding.padding import pad_nd_image
7 | from scipy.ndimage import gaussian_filter
8 |
9 |
10 | @lru_cache(maxsize=2)
11 | def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8,
12 | value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \
13 | -> torch.Tensor:
14 | tmp = np.zeros(tile_size)
15 | center_coords = [i // 2 for i in tile_size]
16 | sigmas = [i * sigma_scale for i in tile_size]
17 | tmp[tuple(center_coords)] = 1
18 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
19 |
20 | gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
21 |
22 | gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor)
23 | gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype)
24 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
25 | mask = gaussian_importance_map == 0
26 | gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask])
27 | return gaussian_importance_map
28 |
29 |
30 | def compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \
31 | List[List[int]]:
32 | assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size"
33 | assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'
34 |
35 | # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
36 | # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46
37 | target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size]
38 |
39 | num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)]
40 |
41 | steps = []
42 | for dim in range(len(tile_size)):
43 | # the highest step value for this dimension is
44 | max_step_value = image_size[dim] - tile_size[dim]
45 | if num_steps[dim] > 1:
46 | actual_step_size = max_step_value / (num_steps[dim] - 1)
47 | else:
48 | actual_step_size = 99999999999 # does not matter because there is only one step at 0
49 |
50 | steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]
51 |
52 | steps.append(steps_here)
53 |
54 | return steps
55 |
56 |
57 | if __name__ == '__main__':
58 | a = torch.rand((4, 2, 32, 23))
59 | a_npy = a.numpy()
60 |
61 | a_padded = pad_nd_image(a, new_shape=(48, 27))
62 | a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27))
63 | assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))])
64 | assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))])
65 | assert np.all(a_padded.numpy() == a_npy_padded)
66 |
--------------------------------------------------------------------------------
/longiseg/training/dataloading/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import multiprocessing
3 | import os
4 | from typing import List
5 | from pathlib import Path
6 | from warnings import warn
7 |
8 | import numpy as np
9 | from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles
10 | from longiseg.configuration import default_num_processes
11 |
12 |
13 | def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False,
14 | verify_npy: bool = False, fail_ctr: int = 0) -> None:
15 | data_npy = npz_file[:-3] + "npy"
16 | seg_npy = npz_file[:-4] + "_seg.npy"
17 | try:
18 | npz_content = None # will only be opened on demand
19 |
20 | if overwrite_existing or not isfile(data_npy):
21 | try:
22 | npz_content = np.load(npz_file) if npz_content is None else npz_content
23 | except Exception as e:
24 | print(f"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!")
25 | raise e
26 | np.save(data_npy, npz_content['data'])
27 |
28 | if unpack_segmentation and (overwrite_existing or not isfile(seg_npy)):
29 | try:
30 | npz_content = np.load(npz_file) if npz_content is None else npz_content
31 | except Exception as e:
32 | print(f"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!")
33 | raise e
34 | np.save(npz_file[:-4] + "_seg.npy", npz_content['seg'])
35 |
36 | if verify_npy:
37 | try:
38 | np.load(data_npy, mmap_mode='r')
39 | if isfile(seg_npy):
40 | np.load(seg_npy, mmap_mode='r')
41 | except ValueError:
42 | os.remove(data_npy)
43 | os.remove(seg_npy)
44 | print(f"Error when checking {data_npy} and {seg_npy}, fixing...")
45 | if fail_ctr < 2:
46 | _convert_to_npy(npz_file, unpack_segmentation, overwrite_existing, verify_npy, fail_ctr+1)
47 | else:
48 | raise RuntimeError("Unable to fix unpacking. Please check your system or rerun nnUNetv2_preprocess")
49 |
50 | except KeyboardInterrupt:
51 | if isfile(data_npy):
52 | os.remove(data_npy)
53 | if isfile(seg_npy):
54 | os.remove(seg_npy)
55 | raise KeyboardInterrupt
56 |
57 |
58 | def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_existing: bool = False,
59 | num_processes: int = default_num_processes,
60 | verify: bool = False):
61 | """
62 | all npz files in this folder belong to the dataset, unpack them all
63 | """
64 | with multiprocessing.get_context("spawn").Pool(num_processes) as p:
65 | npz_files = subfiles(folder, True, None, ".npz", True)
66 | p.starmap(_convert_to_npy, zip(npz_files,
67 | [unpack_segmentation] * len(npz_files),
68 | [overwrite_existing] * len(npz_files),
69 | [verify] * len(npz_files))
70 | )
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
4 |
5 |
6 | class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainerNoLongi):
7 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
8 | device: torch.device = torch.device('cuda')):
9 | super().__init__(plans, configuration, fold, dataset_json, device)
10 | self.num_epochs = 250
11 |
12 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
13 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
14 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
15 | mirror_axes = None
16 | self.inference_allowed_mirroring_axes = None
17 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
18 |
19 |
20 | class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainerNoLongi):
21 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
22 | device: torch.device = torch.device('cuda')):
23 | super().__init__(plans, configuration, fold, dataset_json, device)
24 | self.num_epochs = 2000
25 |
26 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
27 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
28 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
29 | mirror_axes = None
30 | self.inference_allowed_mirroring_axes = None
31 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
32 |
33 |
34 | class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainerNoLongi):
35 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
36 | device: torch.device = torch.device('cuda')):
37 | super().__init__(plans, configuration, fold, dataset_json, device)
38 | self.num_epochs = 4000
39 |
40 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
41 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
42 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
43 | mirror_axes = None
44 | self.inference_allowed_mirroring_axes = None
45 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
46 |
47 |
48 | class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainerNoLongi):
49 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
50 | device: torch.device = torch.device('cuda')):
51 | super().__init__(plans, configuration, fold, dataset_json, device)
52 | self.num_epochs = 8000
53 |
54 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
55 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
56 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
57 | mirror_axes = None
58 | self.inference_allowed_mirroring_axes = None
59 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
60 |
61 |
--------------------------------------------------------------------------------
/longiseg/imageio/natural_image_reader_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
2 | # (DKFZ), Heidelberg, Germany
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Tuple, Union, List
17 | import numpy as np
18 | from longiseg.imageio.base_reader_writer import BaseReaderWriter
19 | from skimage import io
20 |
21 |
22 | class NaturalImage2DIO(BaseReaderWriter):
23 | """
24 | ONLY SUPPORTS 2D IMAGES!!!
25 | """
26 |
27 | # there are surely more we could add here. Everything that can be read by skimage.io should be supported
28 | supported_file_endings = [
29 | '.png',
30 | # '.jpg',
31 | # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps!
32 | '.bmp',
33 | '.tif'
34 | ]
35 |
36 | def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
37 | images = []
38 | for f in image_fnames:
39 | npy_img = io.imread(f)
40 | if npy_img.ndim == 3:
41 | # rgb image, last dimension should be the color channel and the size of that channel should be 3
42 | # (or 4 if we have alpha)
43 | assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \
44 | "dimension must have shape 3 or 4 " \
45 | f"(RGB or RGBA). Image shape here is {npy_img.shape}"
46 | # move RGB(A) to front, add additional dim so that we have shape (c, 1, X, Y), where c is either 3 or 4
47 | images.append(npy_img.transpose((2, 0, 1))[:, None])
48 | elif npy_img.ndim == 2:
49 | # grayscale image
50 | images.append(npy_img[None, None])
51 |
52 | if not self._check_all_same([i.shape for i in images]):
53 | print('ERROR! Not all input images have the same shape!')
54 | print('Shapes:')
55 | print([i.shape for i in images])
56 | print('Image files:')
57 | print(image_fnames)
58 | raise RuntimeError()
59 | return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': (999, 1, 1)}
60 |
61 | def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
62 | return self.read_images((seg_fname, ))
63 |
64 | def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
65 | io.imsave(output_fname, seg[0].astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), check_contrast=False)
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import torch
4 | from batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json
5 |
6 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
7 | from torch import distributed as dist
8 |
9 |
10 | class nnUNetTrainerBenchmark_5epochs(nnUNetTrainerNoLongi):
11 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
12 | device: torch.device = torch.device('cuda')):
13 | super().__init__(plans, configuration, fold, dataset_json, device)
14 | assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results."
15 | self.disable_checkpointing = True
16 | self.num_epochs = 5
17 | assert torch.cuda.is_available(), "This only works on GPU"
18 | self.crashed_with_runtime_error = False
19 |
20 | def perform_actual_validation(self, save_probabilities: bool = False):
21 | pass
22 |
23 | def save_checkpoint(self, filename: str) -> None:
24 | # do not trust people to remember that self.disable_checkpointing must be True for this trainer
25 | pass
26 |
27 | def run_training(self):
28 | try:
29 | super().run_training()
30 | except RuntimeError:
31 | self.crashed_with_runtime_error = True
32 | self.on_train_end()
33 |
34 | def on_train_end(self):
35 | super().on_train_end()
36 |
37 | if not self.is_ddp or self.local_rank == 0:
38 | torch_version = torch.__version__
39 | cudnn_version = torch.backends.cudnn.version()
40 | gpu_name = torch.cuda.get_device_name()
41 | if self.crashed_with_runtime_error:
42 | fastest_epoch = 'Not enough VRAM!'
43 | else:
44 | epoch_times = [i - j for i, j in zip(self.logger.my_fantastic_logging['epoch_end_timestamps'],
45 | self.logger.my_fantastic_logging['epoch_start_timestamps'])]
46 | fastest_epoch = min(epoch_times)
47 |
48 | if self.is_ddp:
49 | num_gpus = dist.get_world_size()
50 | else:
51 | num_gpus = 1
52 |
53 | benchmark_result_file = join(self.output_folder, 'benchmark_result.json')
54 | if isfile(benchmark_result_file):
55 | old_results = load_json(benchmark_result_file)
56 | else:
57 | old_results = {}
58 | # generate some unique key
59 | hostname = subprocess.getoutput('hostname')
60 | my_key = f"{hostname}__{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__num_gpus_{num_gpus}"
61 | old_results[my_key] = {
62 | 'torch_version': torch_version,
63 | 'cudnn_version': cudnn_version,
64 | 'gpu_name': gpu_name,
65 | 'fastest_epoch': fastest_epoch,
66 | 'num_gpus': num_gpus,
67 | 'hostname': hostname
68 | }
69 | save_json(old_results,
70 | join(self.output_folder, 'benchmark_result.json'))
71 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/loss/nnUNetTrainerDiceLoss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from longiseg.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
5 | from longiseg.training.loss.deep_supervision import DeepSupervisionWrapper
6 | from longiseg.training.loss.dice import MemoryEfficientSoftDiceLoss
7 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
8 | from longiseg.utilities.helpers import softmax_helper_dim1
9 |
10 |
11 | class nnUNetTrainerDiceLoss(nnUNetTrainerNoLongi):
12 | def _build_loss(self):
13 | loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice,
14 | 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp},
15 | apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1)
16 |
17 | if self.enable_deep_supervision:
18 | deep_supervision_scales = self._get_deep_supervision_scales()
19 |
20 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
21 | # this gives higher resolution outputs more weight in the loss
22 | weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
23 | weights[-1] = 0
24 |
25 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
26 | weights = weights / weights.sum()
27 | # now wrap the loss
28 | loss = DeepSupervisionWrapper(loss, weights)
29 | return loss
30 |
31 |
32 | class nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainerNoLongi):
33 | def _build_loss(self):
34 | # set smooth to 0
35 | if self.label_manager.has_regions:
36 | loss = DC_and_BCE_loss({},
37 | {'batch_dice': self.configuration_manager.batch_dice,
38 | 'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp},
39 | use_ignore_label=self.label_manager.ignore_label is not None,
40 | dice_class=MemoryEfficientSoftDiceLoss)
41 | else:
42 | loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
43 | 'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
44 | ignore_label=self.label_manager.ignore_label,
45 | dice_class=MemoryEfficientSoftDiceLoss)
46 |
47 | if self.enable_deep_supervision:
48 | deep_supervision_scales = self._get_deep_supervision_scales()
49 |
50 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
51 | # this gives higher resolution outputs more weight in the loss
52 | weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
53 | weights[-1] = 0
54 |
55 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
56 | weights = weights / weights.sum()
57 | # now wrap the loss
58 | loss = DeepSupervisionWrapper(loss, weights)
59 | return loss
60 |
61 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/optimizer/nnUNetTrainerAdan.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from longiseg.training.lr_scheduler.polylr import PolyLRScheduler
4 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
5 | from torch.optim.lr_scheduler import CosineAnnealingLR
6 | try:
7 | from adan_pytorch import Adan
8 | except ImportError:
9 | Adan = None
10 |
11 |
12 | class nnUNetTrainerAdan(nnUNetTrainerNoLongi):
13 | def configure_optimizers(self):
14 | if Adan is None:
15 | raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"')
16 | optimizer = Adan(self.network.parameters(),
17 | lr=self.initial_lr,
18 | # betas=(0.02, 0.08, 0.01), defaults
19 | weight_decay=self.weight_decay)
20 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
21 | # momentum=0.99, nesterov=True)
22 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
23 | return optimizer, lr_scheduler
24 |
25 |
26 | class nnUNetTrainerAdan1en3(nnUNetTrainerAdan):
27 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
28 | device: torch.device = torch.device('cuda')):
29 | super().__init__(plans, configuration, fold, dataset_json, device)
30 | self.initial_lr = 1e-3
31 |
32 |
33 | class nnUNetTrainerAdan3en4(nnUNetTrainerAdan):
34 | # https://twitter.com/karpathy/status/801621764144971776?lang=en
35 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
36 | device: torch.device = torch.device('cuda')):
37 | super().__init__(plans, configuration, fold, dataset_json, device)
38 | self.initial_lr = 3e-4
39 |
40 |
41 | class nnUNetTrainerAdan1en1(nnUNetTrainerAdan):
42 | # this trainer makes no sense -> nan!
43 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
44 | device: torch.device = torch.device('cuda')):
45 | super().__init__(plans, configuration, fold, dataset_json, device)
46 | self.initial_lr = 1e-1
47 |
48 |
49 | class nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan):
50 | # def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
51 | # device: torch.device = torch.device('cuda')):
52 | # super().__init__(plans, configuration, fold, dataset_json, device)
53 | # self.num_epochs = 15
54 |
55 | def configure_optimizers(self):
56 | if Adan is None:
57 | raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"')
58 | optimizer = Adan(self.network.parameters(),
59 | lr=self.initial_lr,
60 | # betas=(0.02, 0.08, 0.01), defaults
61 | weight_decay=self.weight_decay)
62 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
63 | # momentum=0.99, nesterov=True)
64 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
65 | return optimizer, lr_scheduler
66 |
67 |
--------------------------------------------------------------------------------
/documentation/set_environment_variables.md:
--------------------------------------------------------------------------------
1 | # How to set environment variables
2 |
3 | nnU-Net requires some environment variables so that it always knows where the raw data, preprocessed data and trained
4 | models are. Depending on the operating system, these environment variables need to be set in different ways.
5 |
6 | Variables can either be set permanently (recommended!) or you can decide to set them every time you call nnU-Net.
7 |
8 | # Linux & MacOS
9 |
10 | ## Permanent
11 | Locate the `.bashrc` file in your home folder and add the following lines to the bottom:
12 |
13 | ```bash
14 | export LongiSeg_raw="/path/to/LongiSeg_raw"
15 | export LongiSeg_preprocessed="/path/to/LongiSeg_preprocessed"
16 | export LongiSeg_results="/path/to/LongiSeg_results"
17 | ```
18 |
19 | (Of course you need to adapt the paths to the actual folders you intend to use).
20 | If you are using a different shell, such as zsh, you will need to find the correct script for it. For zsh this is `.zshrc`.
21 |
22 | ## Temporary
23 | Just execute the following lines whenever you run nnU-Net:
24 | ```bash
25 | export LongiSeg_raw="/path/to/LongiSeg_raw"
26 | export LongiSeg_preprocessed="/path/to/LongiSeg_preprocessed"
27 | export LongiSeg_results="/path/to/LongiSeg_results"
28 | ```
29 | (Of course you need to adapt the paths to the actual folders you intend to use).
30 |
31 | Important: These variables will be deleted if you close your terminal! They will also only apply to the current
32 | terminal window and DO NOT transfer to other terminals!
33 |
34 | Alternatively you can also just prefix them to your nnU-Net commands:
35 |
36 | `LongiSeg_results="/path/to/LongiSeg_results" LongiSeg_preprocessed="/path/to/LongiSeg_preprocessed" nnUNetv2_train[...]`
37 |
38 | ## Verify that environment parameters are set
39 | You can always execute `echo ${LongiSeg_raw}` etc to print the environment variables. This will return an empty string if
40 | they were not set.
41 |
42 | # Windows
43 | Useful links:
44 | - [https://www3.ntu.edu.sg](https://www3.ntu.edu.sg/home/ehchua/programming/howto/Environment_Variables.html#:~:text=To%20set%20(or%20change)%20a,it%20to%20an%20empty%20string.)
45 | - [https://phoenixnap.com](https://phoenixnap.com/kb/windows-set-environment-variable)
46 |
47 | ## Permanent
48 | See `Set Environment Variable in Windows via GUI` [here](https://phoenixnap.com/kb/windows-set-environment-variable).
49 | Or read about setx (command prompt).
50 |
51 | ## Temporary
52 | Just execute the following before you run nnU-Net:
53 |
54 | (PowerShell)
55 | ```PowerShell
56 | $Env:LongiSeg_raw = "C:/path/to/LongiSeg_raw"
57 | $Env:LongiSeg_preprocessed = "C:/path/to/LongiSeg_preprocessed"
58 | $Env:LongiSeg_results = "C:/path/to/LongiSeg_results"
59 | ```
60 |
61 | (Command Prompt)
62 | ```Command Prompt
63 | set LongiSeg_raw=C:/path/to/LongiSeg_raw
64 | set LongiSeg_preprocessed=C:/path/to/LongiSeg_preprocessed
65 | set LongiSeg_results=C:/path/to/LongiSeg_results
66 | ```
67 |
68 | (Of course you need to adapt the paths to the actual folders you intend to use).
69 |
70 | Important: These variables will be deleted if you close your session! They will also only apply to the current
71 | window and DO NOT transfer to other sessions!
72 |
73 | ## Verify that environment parameters are set
74 | Printing in Windows works differently depending on the environment you are in:
75 |
76 | PowerShell: `echo $Env:[variable_name]`
77 |
78 | Command Prompt: `echo %[variable_name]%`
79 |
--------------------------------------------------------------------------------
/longiseg/utilities/get_network_from_plans.py:
--------------------------------------------------------------------------------
1 | import pydoc
2 | import warnings
3 | from typing import Union
4 |
5 | from longiseg.utilities.find_class_by_name import recursive_find_python_class
6 | from batchgenerators.utilities.file_and_folder_operations import join
7 |
8 |
9 | def get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels,
10 | allow_init=True, deep_supervision: Union[bool, None] = None):
11 | network_class = arch_class_name
12 | architecture_kwargs = dict(**arch_kwargs)
13 | for ri in arch_kwargs_req_import:
14 | if architecture_kwargs[ri] is not None:
15 | architecture_kwargs[ri] = pydoc.locate(architecture_kwargs[ri])
16 |
17 | nw_class = pydoc.locate(network_class)
18 | # sometimes things move around, this makes it so that we can at least recover some of that
19 | if nw_class is None:
20 | warnings.warn(f'Network class {network_class} not found. Attempting to locate it within '
21 | f'dynamic_network_architectures.architectures...')
22 | import dynamic_network_architectures
23 | nw_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"),
24 | network_class.split(".")[-1],
25 | 'dynamic_network_architectures.architectures')
26 | if nw_class is not None:
27 | print(f'FOUND IT: {nw_class}')
28 | else:
29 | raise ImportError('Network class could not be found, please check/correct your plans file')
30 |
31 | if deep_supervision is not None:
32 | architecture_kwargs['deep_supervision'] = deep_supervision
33 |
34 | network = nw_class(
35 | input_channels=input_channels,
36 | num_classes=output_channels,
37 | **architecture_kwargs
38 | )
39 |
40 | if hasattr(network, 'initialize') and allow_init:
41 | network.apply(network.initialize)
42 |
43 | return network
44 |
45 | if __name__ == "__main__":
46 | import torch
47 |
48 | model = get_network_from_plans(
49 | arch_class_name="dynamic_network_architectures.architectures.unet.ResidualEncoderUNet",
50 | arch_kwargs={
51 | "n_stages": 7,
52 | "features_per_stage": [32, 64, 128, 256, 512, 512, 512],
53 | "conv_op": "torch.nn.modules.conv.Conv2d",
54 | "kernel_sizes": [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
55 | "strides": [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],
56 | "n_blocks_per_stage": [1, 3, 4, 6, 6, 6, 6],
57 | "n_conv_per_stage_decoder": [1, 1, 1, 1, 1, 1],
58 | "conv_bias": True,
59 | "norm_op": "torch.nn.modules.instancenorm.InstanceNorm2d",
60 | "norm_op_kwargs": {"eps": 1e-05, "affine": True},
61 | "dropout_op": None,
62 | "dropout_op_kwargs": None,
63 | "nonlin": "torch.nn.LeakyReLU",
64 | "nonlin_kwargs": {"inplace": True},
65 | },
66 | arch_kwargs_req_import=["conv_op", "norm_op", "dropout_op", "nonlin"],
67 | input_channels=1,
68 | output_channels=4,
69 | allow_init=True,
70 | deep_supervision=True,
71 | )
72 | data = torch.rand((8, 1, 256, 256))
73 | target = torch.rand(size=(8, 1, 256, 256))
74 | outputs = model(data) # this should be a list of torch.Tensor
--------------------------------------------------------------------------------
/longiseg/evaluation/accumulate_cv_results.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from typing import Union, List, Tuple
3 |
4 | from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile
5 |
6 | from longiseg.configuration import default_num_processes
7 | from longiseg.evaluation.evaluate_predictions import compute_metrics_on_folder
8 | from longiseg.paths import LongiSeg_raw, LongiSeg_preprocessed
9 | from longiseg.utilities.plans_handling.plans_handler import PlansManager
10 |
11 |
12 | def accumulate_cv_results(trained_model_folder,
13 | merged_output_folder: str,
14 | folds: Union[List[int], Tuple[int, ...]],
15 | num_processes: int = default_num_processes,
16 | overwrite: bool = True):
17 | """
18 | There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to
19 | collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files!
20 | """
21 |
22 | if overwrite and isdir(merged_output_folder):
23 | shutil.rmtree(merged_output_folder)
24 | maybe_mkdir_p(merged_output_folder)
25 |
26 | dataset_json = load_json(join(trained_model_folder, 'dataset.json'))
27 | plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
28 | rw = plans_manager.image_reader_writer_class()
29 | shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json'))
30 | shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json'))
31 |
32 | did_we_copy_something = False
33 | for f in folds:
34 | expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation')
35 | if not isdir(expected_validation_folder):
36 | raise RuntimeError(f"fold {f} of model {trained_model_folder} is missing. Please train it!")
37 | predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False)
38 | for pf in predicted_files:
39 | if overwrite and isfile(join(merged_output_folder, pf)):
40 | raise RuntimeError(f'More than one of your folds has a prediction for case {pf}')
41 | if overwrite or not isfile(join(merged_output_folder, pf)):
42 | shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf))
43 | did_we_copy_something = True
44 |
45 | if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')):
46 | label_manager = plans_manager.get_label_manager(dataset_json)
47 | gt_folder = join(LongiSeg_raw, plans_manager.dataset_name, 'labelsTr')
48 | if not isdir(gt_folder):
49 | gt_folder = join(LongiSeg_preprocessed, plans_manager.dataset_name, 'gt_segmentations')
50 | compute_metrics_on_folder(gt_folder,
51 | merged_output_folder,
52 | join(merged_output_folder, 'summary.json'),
53 | rw,
54 | dataset_json['file_ending'],
55 | label_manager.foreground_regions if label_manager.has_regions else
56 | label_manager.foreground_labels,
57 | label_manager.ignore_label,
58 | num_processes)
59 |
--------------------------------------------------------------------------------
/longiseg/run/load_pretrained_weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch._dynamo import OptimizedModule
3 | from torch.nn.parallel import DistributedDataParallel as DDP
4 | import torch.distributed as dist
5 |
6 |
7 | def load_pretrained_weights(network, fname, verbose=False):
8 | """
9 | Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the
10 | shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps)
11 | identified by keys ending with '.seg_layers') are not transferred!
12 |
13 | If the pretrained weights were obtained with a training outside nnU-Net and DDP or torch.optimize was used,
14 | you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds
15 | '_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as
16 | nnUNetTrainer.save_checkpoint takes care of that!
17 |
18 | """
19 | if dist.is_initialized():
20 | saved_model = torch.load(fname, map_location=torch.device('cuda', dist.get_rank()), weights_only=False)
21 | else:
22 | saved_model = torch.load(fname, weights_only=False)
23 | pretrained_dict = saved_model['network_weights']
24 |
25 | skip_strings_in_pretrained = [
26 | '.seg_layers.',
27 | ]
28 |
29 | if isinstance(network, DDP):
30 | mod = network.module
31 | else:
32 | mod = network
33 | if isinstance(mod, OptimizedModule):
34 | mod = mod._orig_mod
35 |
36 | model_dict = mod.state_dict()
37 | # verify that all but the segmentation layers have the same shape
38 | for key, _ in model_dict.items():
39 | if all([i not in key for i in skip_strings_in_pretrained]):
40 | assert key in pretrained_dict, \
41 | f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \
42 | f"compatible with your network."
43 | assert model_dict[key].shape == pretrained_dict[key].shape, \
44 | f"The shape of the parameters of key {key} is not the same. Pretrained model: " \
45 | f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \
46 | f"does not seem to be compatible with your network."
47 |
48 | # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained
49 | # encoders. Not supported by this function though (see assertions above)
50 |
51 | # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do'
52 | # pretrained_dict = {'module.' + k if is_ddp else k: v
53 | # for k, v in pretrained_dict.items()
54 | # if (('module.' + k if is_ddp else k) in model_dict) and
55 | # all([i not in k for i in skip_strings_in_pretrained])}
56 |
57 | pretrained_dict = {k: v for k, v in pretrained_dict.items()
58 | if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])}
59 |
60 | model_dict.update(pretrained_dict)
61 |
62 | print("################### Loading pretrained weights from file ", fname, '###################')
63 | if verbose:
64 | print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:")
65 | for key, value in pretrained_dict.items():
66 | print(key, 'shape', value.shape)
67 | print("################### Done ###################")
68 | mod.load_state_dict(model_dict)
69 |
70 |
71 |
--------------------------------------------------------------------------------
/documentation/extending_nnunet.md:
--------------------------------------------------------------------------------
1 | # Extending nnU-Net
2 | We hope that the new structure of nnU-Net v2 makes it much more intuitive on how to modify it! We cannot give an
3 | extensive tutorial on how each and every bit of it can be modified. It is better for you to search for the position
4 | in the repository where the thing you intend to change is implemented and start working your way through the code from
5 | there. Setting breakpoints and debugging into nnU-Net really helps in understanding it and thus will help you make the
6 | necessary modifications!
7 |
8 | Here are some things you might want to read before you start:
9 | - Editing nnU-Net configurations through plans files is really powerful now and allows you to change a lot of things regarding
10 | preprocessing, resampling, network topology etc. Read [this](explanation_plans_files.md)!
11 | - [Image normalization](explanation_normalization.md) and [i/o formats](dataset_format.md#supported-file-formats) are easy to extend!
12 | - Manual data splits can be defined as described [here](manual_data_splits.md)
13 | - You can chain arbitrary configurations together into cascades, see [this again](explanation_plans_files.md)
14 | - Read about our support for [region-based training](region_based_training.md)
15 | - If you intend to modify the training procedure (loss, sampling, data augmentation, lr scheduler, etc) then you need
16 | to implement your own trainer class. Best practice is to create a class that inherits from nnUNetTrainer and
17 | implements the necessary changes. Head over to our [trainer classes folder](../longiseg/training/nnUNetTrainer) for
18 | inspiration! There will be similar trainers for what you intend to change and you can take them as a guide. nnUNetTrainer
19 | are structured similarly to PyTorch lightning trainers, this should also make things easier!
20 | - Integrating new network architectures can be done in two ways:
21 | - Quick and dirty: implement a new nnUNetTrainer class and overwrite its `build_network_architecture` function.
22 | Make sure your architecture is compatible with deep supervision (if not, use `nnUNetTrainerNoDeepSupervision`
23 | as basis!) and that it can handle the patch sizes that are thrown at it! Your architecture should NOT apply any
24 | nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that!
25 | - The 'proper' (but difficult) way: Build a dynamically configurable architecture such as the `PlainConvUNet` class
26 | used by default. It needs to have some sort of GPU memory estimation method that can be used to evaluate whether
27 | certain patch sizes and
28 | topologies fit into a specified GPU memory target. Build a new `ExperimentPlanner` that can configure your new
29 | class and communicate with its memory budget estimation. Run `nnUNetv2_plan_and_preprocess` while specifying your
30 | custom `ExperimentPlanner` and a custom `plans_name`. Implement a nnUNetTrainer that can use the plans generated by
31 | your `ExperimentPlanner` to instantiate the network architecture. Specify your plans and trainer when running `nnUNetv2_train`.
32 | It always pays off to first read and understand the corresponding nnU-Net code and use it as a template for your implementation!
33 | - Remember that multi-GPU training, region-based training, ignore label and cascaded training are now simply integrated
34 | into one unified nnUNetTrainer class. No separate classes needed (remember that when implementing your own trainer
35 | classes and ensure support for all of these features! Or raise `NotImplementedError`)
36 |
37 | [//]: # (- Read about our support for [ignore label](ignore_label.md) and [region-based training](region_based_training.md))
38 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import distributed as dist
4 |
5 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
6 |
7 |
8 | class nnUNetTrainer_probabilisticOversampling(nnUNetTrainerNoLongi):
9 | """
10 | sampling of foreground happens randomly and not for the last 33% of samples in a batch
11 | since most trainings happen with batch size 2 and nnunet guarantees at least one fg sample, effectively this can
12 | be 50%
13 | Here we compute the actual oversampling percentage used by nnUNetTrainer in order to be as consistent as possible.
14 | If we switch to this oversampling then we can keep it at a constant 0.33 or whatever.
15 | """
16 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
17 | device: torch.device = torch.device('cuda')):
18 | super().__init__(plans, configuration, fold, dataset_json, device)
19 | self.probabilistic_oversampling = True
20 | self.oversample_foreground_percent = float(np.mean(
21 | [not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent))
22 | for sample_idx in range(self.configuration_manager.batch_size)]))
23 | self.print_to_log_file(f"self.oversample_foreground_percent {self.oversample_foreground_percent}")
24 |
25 | def _set_batch_size_and_oversample(self):
26 | if not self.is_ddp:
27 | # set batch size to what the plan says, leave oversample untouched
28 | self.batch_size = self.configuration_manager.batch_size
29 | else:
30 | # batch size is distributed over DDP workers and we need to change oversample_percent for each worker
31 |
32 | world_size = dist.get_world_size()
33 | my_rank = dist.get_rank()
34 |
35 | global_batch_size = self.configuration_manager.batch_size
36 | assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \
37 | 'GPUs... Duh.'
38 |
39 | batch_size_per_GPU = [global_batch_size // world_size] * world_size
40 | batch_size_per_GPU = [batch_size_per_GPU[i] + 1
41 | if (batch_size_per_GPU[i] * world_size + i) < global_batch_size
42 | else batch_size_per_GPU[i]
43 | for i in range(len(batch_size_per_GPU))]
44 | assert sum(batch_size_per_GPU) == global_batch_size
45 | print("worker", my_rank, "batch_size", batch_size_per_GPU[my_rank])
46 | print("worker", my_rank, "oversample", self.oversample_foreground_percent)
47 |
48 | self.batch_size = batch_size_per_GPU[my_rank]
49 |
50 |
51 | class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling):
52 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
53 | device: torch.device = torch.device('cuda')):
54 | super().__init__(plans, configuration, fold, dataset_json, device)
55 | self.oversample_foreground_percent = 0.33
56 |
57 |
58 | class nnUNetTrainer_probabilisticOversampling_010(nnUNetTrainer_probabilisticOversampling):
59 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
60 | device: torch.device = torch.device('cuda')):
61 | super().__init__(plans, configuration, fold, dataset_json, device)
62 | self.oversample_foreground_percent = 0.1
63 |
--------------------------------------------------------------------------------
/longiseg/model_sharing/entry_points.py:
--------------------------------------------------------------------------------
1 | from longiseg.model_sharing.model_download import download_and_install_from_url
2 | from longiseg.model_sharing.model_export import export_pretrained_model
3 | from longiseg.model_sharing.model_import import install_model_from_zip_file
4 |
5 |
6 | def print_license_warning():
7 | print('')
8 | print('######################################################')
9 | print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!')
10 | print('######################################################')
11 | print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some "
12 | "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use "
13 | "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!")
14 | print('######################################################')
15 | print('')
16 |
17 |
18 | def download_by_url():
19 | import argparse
20 | parser = argparse.ArgumentParser(
21 | description="Use this to download pretrained models. This script is intended to download models via url only. "
22 | "CAREFUL: This script will overwrite "
23 | "existing models (if they share the same trainer class and plans as "
24 | "the pretrained model.")
25 | parser.add_argument("url", type=str, help='URL of the pretrained model')
26 | args = parser.parse_args()
27 | url = args.url
28 | download_and_install_from_url(url)
29 |
30 |
31 | def install_from_zip_entry_point():
32 | import argparse
33 | parser = argparse.ArgumentParser(
34 | description="Use this to install a zip file containing a pretrained model.")
35 | parser.add_argument("zip", type=str, help='zip file')
36 | args = parser.parse_args()
37 | zip = args.zip
38 | install_model_from_zip_file(zip)
39 |
40 |
41 | def export_pretrained_model_entry():
42 | import argparse
43 | parser = argparse.ArgumentParser(
44 | description="Use this to export a trained model as a zip file.")
45 | parser.add_argument('-d', type=str, required=True, help='Dataset name or id')
46 | parser.add_argument('-o', type=str, required=True, help='Output file name')
47 | parser.add_argument('-c', nargs='+', type=str, required=False,
48 | default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'),
49 | help="List of configuration names")
50 | parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class')
51 | parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier')
52 | parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids')
53 | parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ),
54 | help='Lis tof checkpoint names to export. Default: checkpoint_final.pth')
55 | parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations')
56 | parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well')
57 | args = parser.parse_args()
58 |
59 | export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr,
60 | plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk,
61 | export_crossval_predictions=args.exp_cv_preds)
62 |
--------------------------------------------------------------------------------
/longiseg/training/data_augmentation/custom_transforms/longi_transforms.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Tuple
2 |
3 | from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
4 | from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
5 | import torch
6 |
7 |
8 | class MergeTransform(BasicTransform):
9 | def apply(self, data_dict, **params):
10 | if data_dict.get('image_current') is not None and data_dict.get('image_prior') is not None:
11 | data_dict['image'] = self._apply_to_tensor(data_dict['image_current'], data_dict['image_prior'], **params)
12 | del data_dict['image_current'], data_dict['image_prior']
13 | else:
14 | raise RuntimeError("MergeTransform requires 'image_current' and 'image_prior' in data_dict")
15 | if data_dict.get('segmentation_current') is not None and data_dict.get('segmentation_prior') is not None:
16 | data_dict['segmentation'] = self._apply_to_tensor(data_dict['segmentation_current'], data_dict['segmentation_prior'], **params)
17 | del data_dict['segmentation_current'], data_dict['segmentation_prior']
18 | else:
19 | raise RuntimeError("MergeTransform requires 'segmentation_current' and 'segmentation_prior' in data_dict")
20 | return data_dict
21 |
22 | def _apply_to_tensor(self, current: torch.Tensor, prior: torch.Tensor, **params) -> torch.Tensor:
23 | merged = torch.cat([current, prior], dim=0)
24 | return merged
25 |
26 |
27 | class SplitTransform(BasicTransform):
28 | def apply(self, data_dict, **params):
29 | if data_dict.get('image') is not None:
30 | data_dict['image_current'], data_dict['image_prior'] = self._apply_to_tensor(data_dict['image'], **params)
31 | del data_dict['image']
32 | else:
33 | raise RuntimeError("SplitTransform requires 'image' in data_dict")
34 | if data_dict.get('segmentation') is not None:
35 | data_dict['segmentation_current'], data_dict['segmentation_prior'] = self._apply_to_tensor(data_dict['segmentation'], **params)
36 | del data_dict['segmentation']
37 | else:
38 | raise RuntimeError("SplitTransform requires 'segmentation' in data_dict")
39 | return data_dict
40 |
41 | def _apply_to_tensor(self, tensor: torch.Tensor, **params) -> torch.Tensor:
42 | channels = tensor.shape[0]
43 | current = tensor[:channels//2]
44 | prior = tensor[channels//2:]
45 | return current, prior
46 |
47 |
48 | class ConvertSegToOneHot(BasicTransform):
49 | def __init__(self, foreground_labels: Union[Tuple[int, ...], List[int]], key: str = "segmentation_prior",
50 | dtype_key: str = "image_prior"):
51 | super().__init__()
52 | self.foreground_labels = foreground_labels
53 | self.key = key
54 | self.dtype_key = dtype_key
55 |
56 | def apply(self, data_dict, **params):
57 | seg = data_dict[self.key]
58 | seg_onehot = torch.zeros((len(self.foreground_labels), *seg.shape), dtype=data_dict[self.dtype_key].dtype)
59 | for i, l in enumerate(self.foreground_labels):
60 | seg_onehot[i][seg == l] = 1
61 | data_dict[self.key] = seg_onehot.squeeze_(1)
62 | return data_dict
63 |
64 |
65 | class DownsampleSegForDSTransformLongi(DownsampleSegForDSTransform):
66 | def __init__(self, ds_scales: Union[List, Tuple], key: str = "segmentation_current"):
67 | super().__init__(ds_scales)
68 | self.key = key
69 |
70 | def apply(self, data_dict, **params):
71 | data_dict[self.key] = self._apply_to_segmentation(data_dict[self.key], **params)
72 | return data_dict
--------------------------------------------------------------------------------
/documentation/pretraining_and_finetuning.md:
--------------------------------------------------------------------------------
1 | # Pretraining with nnU-Net
2 |
3 | ## Intro
4 |
5 | So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some pretraining dataset
6 | and then use the final network weights as initialization for your target dataset.
7 |
8 | As a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a
9 | result of the automated dataset analysis and experiment planning nnU-Net is known for. So, out of the box, it is not
10 | possible to simply take the network weights from some dataset and then reuse them for another.
11 |
12 | Consequently, the plans need to be aligned between the two tasks. In this README we show how this can be achieved and
13 | how the resulting weights can then be used for initialization.
14 |
15 | ### Terminology
16 |
17 | Throughout this README we use the following terminology:
18 |
19 | - `pretraining dataset` is the dataset you intend to run the pretraining on
20 | - `finetuning dataset` is the dataset you are interested in; the one you wish to fine tune on
21 |
22 |
23 | ## Training on the pretraining dataset
24 |
25 | In order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are
26 | only interested in the finetuning dataset, we first need to run experiment planning (and preprocessing) for it:
27 |
28 | ```bash
29 | nnUNetv2_plan_and_preprocess -d FINETUNING_DATASET
30 | ```
31 |
32 | Then we need to extract the dataset fingerprint of the pretraining dataset, if not yet available:
33 |
34 | ```bash
35 | nnUNetv2_extract_fingerprint -d PRETRAINING_DATASET
36 | ```
37 |
38 | Now we can take the plans from the finetuning dataset and transfer it to the pretraining dataset:
39 |
40 | ```bash
41 | nnUNetv2_move_plans_between_datasets -s FINETUNING_DATASET -t PRETRAINING_DATASET -sp FINETUNING_PLANS_IDENTIFIER -tp PRETRAINING_PLANS_IDENTIFIER
42 | ```
43 |
44 | `FINETUNING_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in
45 | nnUNetv2_plan_and_preprocess. For `PRETRAINING_PLANS_IDENTIFIER` we recommend you set something custom in order to not
46 | overwrite default plans.
47 |
48 | Note that EVERYTHING is transferred between the datasets. Not just the network topology, batch size and patch size but
49 | also the normalization scheme! Therefore, a transfer between datasets that use different normalization schemes may not
50 | work well (but it could, depending on the schemes!).
51 |
52 | Note on CT normalization: Yes, also the clip values, mean and std are transferred!
53 |
54 | Now you can run the preprocessing on the pretraining dataset:
55 |
56 | ```bash
57 | nnUNetv2_preprocess -d PRETRAINING_DATASET -plans_name PRETRAINING_PLANS_IDENTIFIER
58 | ```
59 |
60 | And run the training as usual:
61 |
62 | ```bash
63 | nnUNetv2_train PRETRAINING_DATASET CONFIG all -p PRETRAINING_PLANS_IDENTIFIER
64 | ```
65 |
66 | Note how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data.
67 |
68 | ## Using pretrained weights
69 |
70 | Once pretraining is completed (or you obtain compatible weights by other means) you can use them to initialize your model:
71 |
72 | ```bash
73 | nnUNetv2_train FINETUNING_DATASET CONFIG FOLD -pretrained_weights PATH_TO_CHECKPOINT
74 | ```
75 |
76 | Specify the checkpoint in PATH_TO_CHECKPOINT.
77 |
78 | When loading pretrained weights, all layers except the segmentation layers will be used!
79 |
80 | So far there are no specific nnUNet trainers for fine tuning, so the current recommendation is to just use
81 | nnUNetTrainer. You can however easily write your own trainers with learning rate ramp up, fine-tuning of segmentation
82 | heads or shorter training time.
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/loss/nnUNetTrainerTopkLoss.py:
--------------------------------------------------------------------------------
1 | from longiseg.training.loss.compound_losses import DC_and_topk_loss
2 | from longiseg.training.loss.deep_supervision import DeepSupervisionWrapper
3 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
4 | import numpy as np
5 | from longiseg.training.loss.robust_ce_loss import TopKLoss
6 |
7 |
8 | class nnUNetTrainerTopk10Loss(nnUNetTrainerNoLongi):
9 | def _build_loss(self):
10 | assert not self.label_manager.has_regions, "regions not supported by this trainer"
11 | loss = TopKLoss(
12 | ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10
13 | )
14 |
15 | if self.enable_deep_supervision:
16 | deep_supervision_scales = self._get_deep_supervision_scales()
17 |
18 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
19 | # this gives higher resolution outputs more weight in the loss
20 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])
21 | weights[-1] = 0
22 |
23 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
24 | weights = weights / weights.sum()
25 | # now wrap the loss
26 | loss = DeepSupervisionWrapper(loss, weights)
27 | return loss
28 |
29 |
30 | class nnUNetTrainerTopk10LossLS01(nnUNetTrainerNoLongi):
31 | def _build_loss(self):
32 | assert not self.label_manager.has_regions, "regions not supported by this trainer"
33 | loss = TopKLoss(
34 | ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100,
35 | k=10,
36 | label_smoothing=0.1,
37 | )
38 |
39 | if self.enable_deep_supervision:
40 | deep_supervision_scales = self._get_deep_supervision_scales()
41 |
42 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
43 | # this gives higher resolution outputs more weight in the loss
44 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])
45 | weights[-1] = 0
46 |
47 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
48 | weights = weights / weights.sum()
49 | # now wrap the loss
50 | loss = DeepSupervisionWrapper(loss, weights)
51 | return loss
52 |
53 |
54 | class nnUNetTrainerDiceTopK10Loss(nnUNetTrainerNoLongi):
55 | def _build_loss(self):
56 | assert not self.label_manager.has_regions, "regions not supported by this trainer"
57 | loss = DC_and_topk_loss(
58 | {"batch_dice": self.configuration_manager.batch_dice, "smooth": 1e-5, "do_bg": False, "ddp": self.is_ddp},
59 | {"k": 10, "label_smoothing": 0.0},
60 | weight_ce=1,
61 | weight_dice=1,
62 | ignore_label=self.label_manager.ignore_label,
63 | )
64 | if self.enable_deep_supervision:
65 | deep_supervision_scales = self._get_deep_supervision_scales()
66 |
67 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
68 | # this gives higher resolution outputs more weight in the loss
69 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])
70 | weights[-1] = 0
71 |
72 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
73 | weights = weights / weights.sum()
74 | # now wrap the loss
75 | loss = DeepSupervisionWrapper(loss, weights)
76 | return loss
77 |
--------------------------------------------------------------------------------
/documentation/how_to_use_longiseg.md:
--------------------------------------------------------------------------------
1 | # How to use LongiSeg
2 | LongiSeg inherits nnU-Net’s self-configuring capabilities, allowing it to easily adapt to new datasets. It extends nnU-Net to handle medical image time series, leveraging temporal relationships to improve segmentation accuracy.
3 |
4 | ## Dataset Format
5 | LongiSeg expects the same structured [dataset format](dataset_format.md) as nnU-Net. Datasets can either be saved in the same
6 | folders as nnU-Net or in LongiSeg's own folder (`LongiSeg_raw`, `LongiSeg_preprocessed` and `LongiSeg_results`).
7 | In contrast to nnU-Net, LongiSeg expects an additional `patientsTr.json` file in the dataset folder. This file lists patient IDs and their corresponding scans in chronological order.
8 |
9 | Dataset001_BrainTumour/
10 | ├── dataset.json
11 | ├── patientsTr.json
12 | ├── imagesTr
13 | ├── imagesTs # optional
14 | └── labelsTr
15 |
16 | This json file should have the following structure:
17 |
18 | {
19 | "patient_1": [
20 | "patient_1_scan_1",
21 | "patient_1_scan_2",
22 | ...
23 | ],
24 | "patient_2": [
25 | "patient_2_scan_1",
26 | "patient_2_scan_2",
27 | ...
28 | ],
29 | ...
30 | }
31 |
32 | ## Experiment planning and preprocessing
33 | To run experiment planning and preprocessing of a new dataset, simply run
34 | ```bash
35 | LongiSeg_plan_and_preprocess -d DATASET_ID
36 | ```
37 | or if you prefer to keep things separate, you can also use `LongiSeg_extract_fingerprint`, `LongiSeg_plan_experiment`
38 | and `LongiSeg_preprocess` (in that order). We refer to the [nnU-Net documentation](how_to_use_nnunet.md#experiment-planning-and-preprocessing) for additional details on experiment planning and preprocessing.
39 |
40 | ## Training
41 | To train a model using LongiSeg, simply run
42 | ```bash
43 | LongiSeg_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD
44 | ```
45 |
46 | By default, LongiSeg uses the `LongiSegTrainer`, which integrates the temporal dimension by concatenating multi-timepoint images as additional input channels. Other trainers are available with the -tr option:
47 |
48 | - `nnUNetTrainerNoLongi`: A modified `nnUNetTrainer` where training data is split **at the patient level** instead of per scan. (*non-longitudinal baseline*)
49 | - `LongiSegTrainerDiffWeighting`: A longitudinal trainer that incorporates the **Difference Weighting Block** for temporal feature fusion.
50 | - `LongiSegTrainerRP`, `LongiSegTrainerDiffWeightingRP`: longitudinal trainer with randomly sampled instead of fixed prior scan of the same patient (c.f. [longi_dataset](../longiseg/training/dataloading/longi_dataset.py#L149-L153))
51 |
52 | Other options for training are available as well (`LongiSeg_train -h`).
53 |
54 | ## Inference
55 | Inference with LongiSeg works in a similar way to the [nnU-Net inference](how_to_use_nnunet.md#run-inference), with the added requirement of specifying a patient file (-pat) in either the `LongiSeg_predict` or `LongiSeg_predict_from_modelfolder` commands. The patient file needs to detail the patient structure in the same way as during training. Only cases present in both the input folder **and** `patients.json` will be processed during inference!
56 | ```bash
57 | LongiSeg_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -path /path/to/patients.json -d DATASET_ID
58 | ```
59 |
60 | ## Evaluation
61 | By default LongiSeg performs nnU-Net's standard evaluation on the 5-fold cross validation. This, however, does not account for individual patients and only calculates a handful of metrics. LongiSeg extends nnU-Net’s evaluation by incorporating additional metrics that provide a more comprehensive assessment of segmentation performance, including volumetric, surface-based, distance-based, and detection metrics.
62 | To run evaluation with LongiSeg, use
63 | ```bash
64 | LongiSeg_evaluate_folder GT_FOLDER PRED_FOLDER -djfile /path/to/dataset.json -pfile /path/to/plans.json -patfile /path/to/patients.json
65 | ```
66 |
67 | If no patient file is provided, LongiSeg will default to standard nnU-Net evaluation while incorporating the additional metrics.
--------------------------------------------------------------------------------
/longiseg/imageio/reader_writer_registry.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | from typing import Type
3 |
4 | from batchgenerators.utilities.file_and_folder_operations import join
5 |
6 | import longiseg
7 | from longiseg.imageio.natural_image_reader_writer import NaturalImage2DIO
8 | from longiseg.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient
9 | from longiseg.imageio.simpleitk_reader_writer import SimpleITKIO, SimpleITKIOWithReorient
10 | from longiseg.imageio.tif_reader_writer import Tiff3DIO
11 | from longiseg.imageio.base_reader_writer import BaseReaderWriter
12 | from longiseg.utilities.find_class_by_name import recursive_find_python_class
13 |
14 | LIST_OF_IO_CLASSES = [
15 | NaturalImage2DIO,
16 | SimpleITKIOWithReorient,
17 | SimpleITKIO,
18 | Tiff3DIO,
19 | NibabelIO,
20 | NibabelIOWithReorient
21 | ]
22 |
23 |
24 | def determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None,
25 | allow_nonmatching_filename: bool = False, verbose: bool = True
26 | ) -> Type[BaseReaderWriter]:
27 | if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \
28 | dataset_json_content['overwrite_image_reader_writer'] != 'None':
29 | ioclass_name = dataset_json_content['overwrite_image_reader_writer']
30 | # trying to find that class in the longiseg.imageio module
31 | try:
32 | ret = recursive_find_reader_writer_by_name(ioclass_name)
33 | if verbose: print(f'Using {ret} reader/writer')
34 | return ret
35 | except RuntimeError:
36 | if verbose: print(f'Warning: Unable to find ioclass specified in dataset.json: {ioclass_name}')
37 | if verbose: print('Trying to automatically determine desired class')
38 | return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file,
39 | allow_nonmatching_filename, verbose)
40 |
41 |
42 | def determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False,
43 | verbose: bool = True):
44 | for rw in LIST_OF_IO_CLASSES:
45 | if file_ending.lower() in rw.supported_file_endings:
46 | if example_file is not None:
47 | # if an example file is provided, try if we can actually read it. If not move on to the next reader
48 | try:
49 | tmp = rw()
50 | _ = tmp.read_images((example_file,))
51 | if verbose: print(f'Using {rw} as reader/writer')
52 | return rw
53 | except:
54 | if verbose: print(f'Failed to open file {example_file} with reader {rw}:')
55 | traceback.print_exc()
56 | pass
57 | else:
58 | if verbose: print(f'Using {rw} as reader/writer')
59 | return rw
60 | else:
61 | if allow_nonmatching_filename and example_file is not None:
62 | try:
63 | tmp = rw()
64 | _ = tmp.read_images((example_file,))
65 | if verbose: print(f'Using {rw} as reader/writer')
66 | return rw
67 | except:
68 | if verbose: print(f'Failed to open file {example_file} with reader {rw}:')
69 | if verbose: traceback.print_exc()
70 | pass
71 | raise RuntimeError(f"Unable to determine a reader for file ending {file_ending} and file {example_file} (file None means no file provided).")
72 |
73 |
74 | def recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]:
75 | ret = recursive_find_python_class(join(longiseg.__path__[0], "imageio"), rw_class_name, 'longiseg.imageio')
76 | if ret is None:
77 | raise RuntimeError("Unable to find reader writer class '%s'. Please make sure this class is located in the "
78 | "longiseg.imageio module." % rw_class_name)
79 | else:
80 | return ret
81 |
--------------------------------------------------------------------------------
/longiseg/utilities/dataset_name_id_conversion.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Union
15 |
16 | import os
17 | import numpy as np
18 |
19 | from longiseg.paths import LongiSeg_preprocessed, LongiSeg_raw, LongiSeg_results
20 | from batchgenerators.utilities.file_and_folder_operations import subdirs, isdir
21 |
22 |
23 | def find_candidate_datasets(dataset_id: int):
24 | startswith = "Dataset%03.0d" % dataset_id
25 | if LongiSeg_preprocessed is not None and isdir(LongiSeg_preprocessed):
26 | candidates_preprocessed = subdirs(LongiSeg_preprocessed, prefix=startswith, join=False)
27 | else:
28 | candidates_preprocessed = []
29 |
30 | if LongiSeg_raw is not None and isdir(LongiSeg_raw):
31 | candidates_raw = subdirs(LongiSeg_raw, prefix=startswith, join=False)
32 | else:
33 | candidates_raw = []
34 |
35 | candidates_trained_models = []
36 | if LongiSeg_results is not None and isdir(LongiSeg_results):
37 | candidates_trained_models += subdirs(LongiSeg_results, prefix=startswith, join=False)
38 |
39 | all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models
40 | unique_candidates = np.unique(all_candidates)
41 | return unique_candidates
42 |
43 |
44 | def convert_id_to_dataset_name(dataset_id: int):
45 | unique_candidates = find_candidate_datasets(dataset_id)
46 | if len(unique_candidates) > 1:
47 | raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the "
48 | "following folders:\n%s\n%s\n%s" % (dataset_id, LongiSeg_raw, LongiSeg_preprocessed, LongiSeg_results))
49 | if len(unique_candidates) == 0:
50 | raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID "
51 | f"exists and that nnU-Net knows where raw and preprocessed data are located "
52 | f"(see Documentation - Installation). Here are your currently defined folders:\n"
53 | f"LongiSeg_preprocessed={os.environ.get('LongiSeg_preprocessed') if os.environ.get('LongiSeg_preprocessed') is not None else 'None'}\n"
54 | f"LongiSeg_results={os.environ.get('LongiSeg_results') if os.environ.get('LongiSeg_results') is not None else 'None'}\n"
55 | f"LongiSeg_raw={os.environ.get('LongiSeg_raw') if os.environ.get('LongiSeg_raw') is not None else 'None'}\n"
56 | f"If something is not right, adapt your environment variables.")
57 | return unique_candidates[0]
58 |
59 |
60 | def convert_dataset_name_to_id(dataset_name: str):
61 | assert dataset_name.startswith("Dataset")
62 | dataset_id = int(dataset_name[7:10])
63 | return dataset_id
64 |
65 |
66 | def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str:
67 | if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"):
68 | return dataset_name_or_id
69 | if isinstance(dataset_name_or_id, str):
70 | try:
71 | dataset_name_or_id = int(dataset_name_or_id)
72 | except ValueError:
73 | raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to "
74 | "convert it to a dataset ID (int). That failed, however. Please give an integer number "
75 | "('1', '2', etc) or a correct dataset name. Your input: %s" % dataset_name_or_id)
76 | return convert_id_to_dataset_name(dataset_name_or_id)
77 |
--------------------------------------------------------------------------------
/longiseg/training/LongiSegTrainer/variants/training_length/nnUNetTrainer_Xepochs.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from longiseg.training.LongiSegTrainer.nnUNetTrainerLongi import nnUNetTrainerNoLongi
4 |
5 |
6 | class nnUNetTrainer_5epochs(nnUNetTrainerNoLongi):
7 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
8 | device: torch.device = torch.device('cuda')):
9 | """used for debugging plans etc"""
10 | super().__init__(plans, configuration, fold, dataset_json, device)
11 | self.num_epochs = 5
12 |
13 |
14 | class nnUNetTrainer_1epoch(nnUNetTrainerNoLongi):
15 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
16 | device: torch.device = torch.device('cuda')):
17 | """used for debugging plans etc"""
18 | super().__init__(plans, configuration, fold, dataset_json, device)
19 | self.num_epochs = 1
20 |
21 |
22 | class nnUNetTrainer_10epochs(nnUNetTrainerNoLongi):
23 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
24 | device: torch.device = torch.device('cuda')):
25 | """used for debugging plans etc"""
26 | super().__init__(plans, configuration, fold, dataset_json, device)
27 | self.num_epochs = 10
28 |
29 |
30 | class nnUNetTrainer_20epochs(nnUNetTrainerNoLongi):
31 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
32 | device: torch.device = torch.device('cuda')):
33 | super().__init__(plans, configuration, fold, dataset_json, device)
34 | self.num_epochs = 20
35 |
36 |
37 | class nnUNetTrainer_50epochs(nnUNetTrainerNoLongi):
38 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
39 | device: torch.device = torch.device('cuda')):
40 | super().__init__(plans, configuration, fold, dataset_json, device)
41 | self.num_epochs = 50
42 |
43 |
44 | class nnUNetTrainer_100epochs(nnUNetTrainerNoLongi):
45 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
46 | device: torch.device = torch.device('cuda')):
47 | super().__init__(plans, configuration, fold, dataset_json, device)
48 | self.num_epochs = 100
49 |
50 |
51 | class nnUNetTrainer_250epochs(nnUNetTrainerNoLongi):
52 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
53 | device: torch.device = torch.device('cuda')):
54 | super().__init__(plans, configuration, fold, dataset_json, device)
55 | self.num_epochs = 250
56 |
57 |
58 | class nnUNetTrainer_500epochs(nnUNetTrainerNoLongi):
59 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
60 | device: torch.device = torch.device('cuda')):
61 | super().__init__(plans, configuration, fold, dataset_json, device)
62 | self.num_epochs = 500
63 |
64 |
65 | class nnUNetTrainer_750epochs(nnUNetTrainerNoLongi):
66 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
67 | device: torch.device = torch.device('cuda')):
68 | super().__init__(plans, configuration, fold, dataset_json, device)
69 | self.num_epochs = 750
70 |
71 |
72 | class nnUNetTrainer_2000epochs(nnUNetTrainerNoLongi):
73 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
74 | device: torch.device = torch.device('cuda')):
75 | super().__init__(plans, configuration, fold, dataset_json, device)
76 | self.num_epochs = 2000
77 |
78 |
79 | class nnUNetTrainer_4000epochs(nnUNetTrainerNoLongi):
80 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
81 | device: torch.device = torch.device('cuda')):
82 | super().__init__(plans, configuration, fold, dataset_json, device)
83 | self.num_epochs = 4000
84 |
85 |
86 | class nnUNetTrainer_8000epochs(nnUNetTrainerNoLongi):
87 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
88 | device: torch.device = torch.device('cuda')):
89 | super().__init__(plans, configuration, fold, dataset_json, device)
90 | self.num_epochs = 8000
91 |
--------------------------------------------------------------------------------
/longiseg/evaluation/metrics/detection_metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import numpy as np
3 | import cc3d
4 | from skimage.morphology import dilation
5 | from scipy import sparse
6 |
7 |
8 | def get_instances(gt: np.ndarray, pred: np.ndarray, footprint: int = 0, spacing: tuple[float] = (1., 1., 1.)):
9 | if footprint:
10 | x, y, z = np.ceil(np.divide(footprint, spacing)).astype(int)
11 | struct = np.ones((x, y, z), dtype=np.uint8)
12 | dilated_gt = dilation(gt.astype(np.uint8), struct)
13 | dilated_pred = dilation(pred.astype(np.uint8), struct)
14 | else:
15 | dilated_gt = gt.astype(np.uint8)
16 | dilated_pred = pred.astype(np.uint8)
17 | gt_inst, gt_inst_num = cc3d.connected_components(dilated_gt, return_N=True)
18 | pred_inst, pred_inst_num = cc3d.connected_components(dilated_pred, return_N=True)
19 | gt_inst[(gt!=1)] = 0
20 | pred_inst[(pred!=1)] = 0
21 | return gt_inst, gt_inst_num, pred_inst, pred_inst_num
22 |
23 |
24 | def get_inst_TPFPFN(gt_inst, gt_inst_num, pred_inst, pred_inst_num):
25 | M, N = gt_inst_num, pred_inst_num
26 | if M == 0 or N == 0:
27 | return 0, 0, M, N
28 |
29 | gt_inst_flat = gt_inst.flatten()
30 | pred_inst_flat = pred_inst.flatten()
31 |
32 | gt_masks = gt_inst_flat==np.arange(1, M+1)[:, None]
33 | pred_masks = pred_inst_flat==np.arange(1, N+1)[:, None]
34 |
35 | gt_size = np.bincount(gt_inst_flat)
36 | pred_size = np.bincount(pred_inst_flat)
37 |
38 | gt_masks_sparse = [sparse.coo_matrix(mask) for mask in gt_masks]
39 | pred_masks_sparse = [sparse.coo_matrix(mask) for mask in pred_masks]
40 |
41 | iou_data = []
42 | rows, cols = [], []
43 | for i in range(M):
44 | for j in range(N):
45 | # Efficient intersection computation for sparse matrices
46 | gt_mask, pred_mask = gt_masks_sparse[i], pred_masks_sparse[j]
47 | intersection = len(set(gt_mask.col) & set(pred_mask.col))
48 | if intersection > 0:
49 | union = gt_size[i + 1] + pred_size[j + 1] - intersection
50 | iou_data.append(intersection / union)
51 | rows.append(i)
52 | cols.append(j)
53 |
54 | iou_data_gt = np.array(iou_data)
55 | rows_gt = np.array(rows)
56 | cols_gt = np.array(cols)
57 | iou_data_pred = np.array(iou_data)
58 | rows_pred = np.array(rows)
59 | cols_pred = np.array(cols)
60 |
61 | TP_gt, FN = 0, 0
62 | for i in range(M):
63 | if not i in rows_gt:
64 | FN += 1
65 | continue
66 | iou_i = iou_data_gt[rows_gt == i]
67 | col_i = cols_gt[rows_gt == i]
68 | argmax_iou = np.argmax(iou_i)
69 | max_iou = iou_i[argmax_iou]
70 | if max_iou > 0.1:
71 | TP_gt += 1
72 | iou_data_gt[cols_gt==col_i[argmax_iou]] = 0
73 | else:
74 | FN += 1
75 |
76 | TP_pred, FP = 0, 0
77 | for j in range(N):
78 | if not j in cols_pred:
79 | FP += 1
80 | continue
81 | iou_j = iou_data_pred[cols_pred == j]
82 | row_j = rows_pred[cols_pred == j]
83 | argmax_iou = np.argmax(iou_j)
84 | max_iou = iou_j[argmax_iou]
85 | if max_iou > 0.1:
86 | TP_pred += 1
87 | iou_data_pred[rows_pred==row_j[argmax_iou]] = 0
88 | else:
89 | FP += 1
90 |
91 | return TP_gt, TP_pred, FN, FP
92 |
93 |
94 | def compute_detection_metrics(mask_ref: np.ndarray, mask_pred: np.ndarray, footprint: int = 0,
95 | spacing: tuple[float] = (1., 1., 1.), ignore_mask: Optional[np.ndarray] = None):
96 | gt_mask = mask_ref if ignore_mask is None else mask_ref & ~ignore_mask
97 | pred_mask = mask_pred if ignore_mask is None else mask_pred & ~ignore_mask
98 |
99 | gt_inst, gt_inst_num, pred_inst, pred_inst_num = get_instances(gt_mask, pred_mask, footprint=footprint, spacing=spacing)
100 | TP_gt, TP_pred, FN, FP = get_inst_TPFPFN(gt_inst, gt_inst_num, pred_inst, pred_inst_num)
101 | recall = TP_gt / (TP_gt + FN) if TP_gt + FN > 0 else 1
102 | precision = TP_pred / (TP_pred + FP) if TP_pred + FP > 0 else 1
103 | F1 = (2 * recall * precision) / (recall + precision) if recall + precision > 0 else 0
104 | return F1, recall, precision, TP_gt, TP_pred, FP, FN
--------------------------------------------------------------------------------
/documentation/region_based_training.md:
--------------------------------------------------------------------------------
1 | # Region-based training
2 |
3 | ## What is this about?
4 | In some segmentation tasks, most prominently the
5 | [Brain Tumor Segmentation Challenge](http://braintumorsegmentation.org/), the target areas (based on which the metric
6 | will be computed) are different from the labels provided in the training data. This is the case because for some
7 | clinical applications, it is more relevant to detect the whole tumor, tumor core and enhancing tumor instead of the
8 | individual labels (edema, necrosis and non-enhancing tumor, enhancing tumor).
9 |
10 |
11 |
12 | The figure shows an example BraTS case along with label-based representation of the task (top) and region-based
13 | representation (bottom). The challenge evaluation is done on the regions. As we have shown in our
14 | [BraTS 2018 contribution](https://arxiv.org/abs/1809.10483), directly optimizing those
15 | overlapping areas over the individual labels yields better scoring models!
16 |
17 | ## What can nnU-Net do?
18 | nnU-Net's region-based training allows you to learn areas that are constructed by merging individual labels. For
19 | some segmentation tasks this provides a benefit, as this shifts the importance allocated to different labels during training.
20 | Most prominently, this feature can be used to represent **hierarchical classes**, for example when organs +
21 | substructures are to be segmented. Imagine a liver segmentation problem, where vessels and tumors are also to be
22 | segmented. The first target region could thus be the entire liver (including the substructures), while the remaining
23 | targets are the individual substructues.
24 |
25 | Important: nnU-Net still requires integer label maps as input and will produce integer label maps as output!
26 | Region-based training can be used to learn overlapping labels, but there must be a way to model these overlaps
27 | for nnU-Net to work (see below how this is done).
28 |
29 | ## How do you use it?
30 |
31 | When declaring the labels in the `dataset.json` file, BraTS would typically look like this:
32 |
33 | ```python
34 | ...
35 | "labels": {
36 | "background": 0,
37 | "edema": 1,
38 | "non_enhancing_and_necrosis": 2,
39 | "enhancing_tumor": 3
40 | },
41 | ...
42 | ```
43 | (we use different int values than the challenge because nnU-Net needs consecutive integers!)
44 |
45 | This representation corresponds to the upper row in the figure above.
46 |
47 | For region-based training, the labels need to be changed to the following:
48 |
49 | ```python
50 | ...
51 | "labels": {
52 | "background": 0,
53 | "whole_tumor": [1, 2, 3],
54 | "tumor_core": [2, 3],
55 | "enhancing_tumor": 3 # or [3]
56 | },
57 | "regions_class_order": [1, 2, 3],
58 | ...
59 | ```
60 | This corresponds to the bottom row in the figure above. Note how an additional entry in the dataset.json is
61 | required: `regions_class_order`. This tells nnU-Net how to convert the region representations back to an integer map.
62 | It essentially just tells nnU-Net what labels to place for which region in what order. The length of the
63 | list here needs to be the same as the number of regions (excl background). Each element in the list corresponds
64 | to the label that is placed instead of the region into the final segmentation. Later entries will overwrite earlier ones!
65 | Concretely, for the example given here, nnU-Net
66 | will firstly place the label 1 (edema) where the 'whole_tumor' region was predicted, then place the label 2
67 | (non-enhancing tumor and necrosis) where the "tumor_core" was predicted and finally place the label 3 in the
68 | predicted 'enhancing_tumor' area. With each step, part of the previously set pixels
69 | will be overwritten with the new label! So when setting your `regions_class_order`, place encompassing regions
70 | (like whole tumor etc) first, followed by substructures.
71 |
72 | **IMPORTANT** Because the conversion back to a segmentation map is sensitive to the order in which the regions are
73 | declared ("place label X in the first region") you need to make sure that this order is not perturbed! When
74 | automatically generating the dataset.json, make sure the dictionary keys do not get sorted alphabetically! Set
75 | `sort_keys=False` in `json.dump()`!!!
76 |
77 | nnU-Net will perform the evaluation + model selection also on the regions, not the individual labels!
78 |
79 | That's all. Easy, huh?
--------------------------------------------------------------------------------
/longiseg/inference/export_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 | import itertools
3 | import torch
4 | import numpy as np
5 |
6 | from acvl_utils.cropping_and_padding.bounding_boxes import int_bbox
7 |
8 |
9 | # adatped from https://github.com/MIC-DKFZ/acvl_utils/blob/master/acvl_utils/cropping_and_padding/bounding_boxes.py#L357
10 | # to work with None in bboxx
11 | def insert_crop_into_image(
12 | image: Union[torch.Tensor, np.ndarray],
13 | crop: Union[torch.Tensor, np.ndarray],
14 | bbox: List[List[int]]
15 | ) -> Union[torch.Tensor, np.ndarray]:
16 | """
17 | Inserts a cropped patch back into the original image at the position specified by bbox.
18 | If the bounding box extends beyond the image boundaries, only the valid portions are inserted.
19 | If the bounding box lies entirely outside the image, the original image is returned.
20 |
21 | Parameters:
22 | - image: Original N-dimensional torch.Tensor or np.ndarray to which the crop will be inserted.
23 | - crop: Cropped patch of the image to be reinserted. May have additional dimensions compared to bbox.
24 | - bbox: List of [[dim_min, dim_max], ...] defining the bounding box for the last dimensions of the crop in the original image.
25 |
26 | Returns:
27 | - image: The original image with the crop reinserted at the specified location (modified in-place).
28 | """
29 | # If the bounding box is None and shapes of image and crop are the same, return the crop directly
30 | if all([b is None for b in itertools.chain(*bbox)]) and image.shape==crop.shape:
31 | return crop
32 |
33 | # make sure bounding boxes are int and not uint. Otherwise we may get underflow
34 | bbox = int_bbox(bbox)
35 |
36 | # Ensure that bbox only applies to the last len(bbox) dimensions of crop and image
37 | num_dims = len(image.shape)
38 | crop_dims = len(crop.shape)
39 | bbox_dims = len(bbox)
40 |
41 | if crop_dims < bbox_dims:
42 | raise ValueError("Bounding box dimensions cannot exceed crop dimensions.")
43 |
44 | # Validate that non-cropped leading dimensions match between image and crop
45 | leading_dims = num_dims - bbox_dims
46 | if image.shape[:leading_dims] != crop.shape[:leading_dims]:
47 | raise ValueError("Leading dimensions of crop and image must match.")
48 |
49 | # Check if the bounding box lies completely outside the image bounds for each cropped dimension
50 | for i in range(bbox_dims):
51 | min_val, max_val = bbox[i]
52 | dim_idx = leading_dims + i # Corresponding dimension in the image
53 |
54 | if max_val <= 0 or min_val >= image.shape[dim_idx]:
55 | # If completely out of bounds in any dimension, return the original image
56 | return image
57 |
58 | # Prepare slices for inserting the crop into the original image
59 | image_slices = []
60 | crop_slices = []
61 |
62 | # Iterate over all dimensions, applying bbox only to the last len(bbox) dimensions
63 | for i in range(num_dims):
64 | if i < leading_dims:
65 | # For leading dimensions, use entire dimension (slice(None)) and validate shape
66 | image_slices.append(slice(None))
67 | crop_slices.append(slice(None))
68 | else:
69 | # For dimensions specified by bbox, calculate the intersection with image bounds
70 | dim_idx = i - leading_dims
71 | min_val, max_val = bbox[dim_idx]
72 |
73 | crop_start = max(0, -min_val) # Start of the crop within the valid area
74 | image_start = max(0, min_val) # Start of the image where the crop will be inserted
75 | image_end = min(max_val, image.shape[i]) # Exclude upper bound by using max_val directly
76 |
77 | # Adjusted range for insertion
78 | crop_end = crop_start + (image_end - image_start)
79 |
80 | # Append slices for both image and crop insertion ranges
81 | image_slices.append(slice(image_start, image_end))
82 | crop_slices.append(slice(crop_start, crop_end))
83 |
84 | # Insert the valid part of the crop back into the original image
85 | if isinstance(image, torch.Tensor):
86 | image[tuple(image_slices)] = crop[tuple(crop_slices)]
87 | elif isinstance(image, np.ndarray):
88 | image[tuple(image_slices)] = crop[tuple(crop_slices)]
89 | else:
90 | raise ValueError(f"Unsupported image type {type(image)}")
91 |
92 | return image
--------------------------------------------------------------------------------
/longiseg/experiment_planning/experiment_planners/network_topology.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import numpy as np
3 |
4 |
5 | def get_shape_must_be_divisible_by(net_numpool_per_axis):
6 | return 2 ** np.array(net_numpool_per_axis)
7 |
8 |
9 | def pad_shape(shape, must_be_divisible_by):
10 | """
11 | pads shape so that it is divisible by must_be_divisible_by
12 | :param shape:
13 | :param must_be_divisible_by:
14 | :return:
15 | """
16 | if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):
17 | must_be_divisible_by = [must_be_divisible_by] * len(shape)
18 | else:
19 | assert len(must_be_divisible_by) == len(shape)
20 |
21 | new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))]
22 |
23 | for i in range(len(shape)):
24 | if shape[i] % must_be_divisible_by[i] == 0:
25 | new_shp[i] -= must_be_divisible_by[i]
26 | new_shp = np.array(new_shp).astype(int)
27 | return new_shp
28 |
29 |
30 | def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool):
31 | """
32 | this is the same as get_pool_and_conv_props_v2 from old nnunet
33 |
34 | :param spacing:
35 | :param patch_size:
36 | :param min_feature_map_size: min edge length of feature maps in bottleneck
37 | :param max_numpool:
38 | :return:
39 | """
40 | # todo review this code
41 | dim = len(spacing)
42 |
43 | current_spacing = deepcopy(list(spacing))
44 | current_size = deepcopy(list(patch_size))
45 |
46 | pool_op_kernel_sizes = [[1] * len(spacing)]
47 | conv_kernel_sizes = []
48 |
49 | num_pool_per_axis = [0] * dim
50 | kernel_size = [1] * dim
51 |
52 | while True:
53 | # exclude axes that we cannot pool further because of min_feature_map_size constraint
54 | valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size]
55 | if len(valid_axes_for_pool) < 1:
56 | break
57 |
58 | spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool]
59 |
60 | # find axis that are within factor of 2 within smallest spacing
61 | min_spacing_of_valid = min(spacings_of_axes)
62 | valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2]
63 |
64 | # max_numpool constraint
65 | valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
66 |
67 | if len(valid_axes_for_pool) == 1:
68 | if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size:
69 | pass
70 | else:
71 | break
72 | if len(valid_axes_for_pool) < 1:
73 | break
74 |
75 | # now we need to find kernel sizes
76 | # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within
77 | # factor 2 of min_spacing. Once they are 3 they remain 3
78 | for d in range(dim):
79 | if kernel_size[d] == 3:
80 | continue
81 | else:
82 | if current_spacing[d] / min(current_spacing) < 2:
83 | kernel_size[d] = 3
84 |
85 | other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
86 |
87 | pool_kernel_sizes = [0] * dim
88 | for v in valid_axes_for_pool:
89 | pool_kernel_sizes[v] = 2
90 | num_pool_per_axis[v] += 1
91 | current_spacing[v] *= 2
92 | current_size[v] = np.ceil(current_size[v] / 2)
93 | for nv in other_axes:
94 | pool_kernel_sizes[nv] = 1
95 |
96 | pool_op_kernel_sizes.append(pool_kernel_sizes)
97 | conv_kernel_sizes.append(deepcopy(kernel_size))
98 | #print(conv_kernel_sizes)
99 |
100 | must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
101 | patch_size = pad_shape(patch_size, must_be_divisible_by)
102 |
103 | def _to_tuple(lst):
104 | return tuple(_to_tuple(i) if isinstance(i, list) else i for i in lst)
105 |
106 | # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
107 | conv_kernel_sizes.append([3]*dim)
108 | return num_pool_per_axis, _to_tuple(pool_op_kernel_sizes), _to_tuple(conv_kernel_sizes), tuple(patch_size), must_be_divisible_by
109 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "longiseg"
3 | version = "1.0.0"
4 | requires-python = ">=3.10"
5 | description = "LongiSeg is a framework for longitudinal medical image segmentation built on nnU-Net."
6 | readme = "readme.md"
7 | license = { file = "LICENSE" }
8 | authors = [
9 | { name = "Yannick Kirchhoff", email = "yannick.kirchhoff@dkfz-heidelberg.de"}
10 | ]
11 | classifiers = [
12 | "Development Status :: 5 - Production/Stable",
13 | "Intended Audience :: Developers",
14 | "Intended Audience :: Science/Research",
15 | "Intended Audience :: Healthcare Industry",
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: Apache Software License",
18 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
19 | "Topic :: Scientific/Engineering :: Image Recognition",
20 | "Topic :: Scientific/Engineering :: Medical Science Apps.",
21 | ]
22 | keywords = [
23 | 'longitudinal medical image segmentation',
24 | 'deep learning',
25 | 'image segmentation',
26 | 'semantic segmentation',
27 | 'medical image analysis',
28 | 'medical image segmentation',
29 | 'nnU-Net',
30 | 'nnunet'
31 | ]
32 | dependencies = [
33 | "torch>=2.1.2",
34 | "acvl-utils>=0.2.3,<0.3", # 0.3 may bring breaking changes. Careful!
35 | "dynamic-network-architectures>=0.3.1,<0.4", # 0.3.1 and lower are supported, 0.4 may have breaking changes. Let's be careful here
36 | "tqdm",
37 | "dicom2nifti",
38 | "scipy",
39 | "batchgenerators>=0.25.1",
40 | "numpy>=1.24",
41 | "scikit-learn",
42 | "scikit-image>=0.19.3",
43 | "SimpleITK>=2.2.1",
44 | "pandas",
45 | "graphviz",
46 | 'tifffile',
47 | 'requests',
48 | "nibabel",
49 | "matplotlib",
50 | "seaborn",
51 | "imagecodecs",
52 | "yacs",
53 | "batchgeneratorsv2>=0.2",
54 | "einops",
55 | "blosc2>=3.0.0b1",
56 | "difference-weighting @ git+https://github.com/MIC-DKFZ/Longitudinal-Difference-Weighting.git"
57 | ]
58 |
59 | [project.urls]
60 | homepage = "https://github.com/MIC-DKFZ/LongiSeg"
61 | repository = "https://github.com/MIC-DKFZ/LongiSeg"
62 |
63 | [project.scripts]
64 | LongiSeg_find_best_configuration = "longiseg.evaluation.find_best_configuration:find_best_configuration_entry_point"
65 | LongiSeg_determine_postprocessing = "longiseg.postprocessing.remove_connected_components:entry_point_determine_postprocessing_folder"
66 | LongiSeg_apply_postprocessing = "longiseg.postprocessing.remove_connected_components:entry_point_apply_postprocessing"
67 | LongiSeg_ensemble = "longiseg.ensembling.ensemble:entry_point_ensemble_folders"
68 | LongiSeg_accumulate_crossval_results = "longiseg.evaluation.find_best_configuration:accumulate_crossval_results_entry_point"
69 | LongiSeg_plot_overlay_pngs = "longiseg.utilities.overlay_plots:entry_point_generate_overlay"
70 | LongiSeg_download_pretrained_model_by_url = "longiseg.model_sharing.entry_points:download_by_url"
71 | LongiSeg_install_pretrained_model_from_zip = "longiseg.model_sharing.entry_points:install_from_zip_entry_point"
72 | LongiSeg_export_model_to_zip = "longiseg.model_sharing.entry_points:export_pretrained_model_entry"
73 | LongiSeg_move_plans_between_datasets = "longiseg.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets"
74 | LongiSeg_plan_and_preprocess = "longiseg.experiment_planning.plan_and_preprocess_longi_entrypoints:plan_and_preprocess_longi_entry"
75 | LongiSeg_extract_fingerprint = "longiseg.experiment_planning.plan_and_preprocess_longi_entrypoints:extract_fingerprint_longi_entry"
76 | LongiSeg_plan_experiment = "longiseg.experiment_planning.plan_and_preprocess_longi_entrypoints:plan_experiment_longi_entry"
77 | LongiSeg_preprocess = "longiseg.experiment_planning.plan_and_preprocess_longi_entrypoints:preprocess_longi_entry"
78 | LongiSeg_train = "longiseg.run.run_training:run_training_longi_entry"
79 | LongiSeg_predict_from_modelfolder = "longiseg.inference.predict_from_raw_data_longi:predict_longi_entry_point_modelfolder"
80 | LongiSeg_predict = "longiseg.inference.predict_from_raw_data_longi:predict_longi_entry_point"
81 | LongiSeg_evaluate_folder = "longiseg.evaluation.evaluate_predictions_longi:evaluate_longi_folder_entry_point"
82 |
83 | [project.optional-dependencies]
84 | dev = [
85 | "black",
86 | "ruff",
87 | "pre-commit"
88 | ]
89 |
90 | [build-system]
91 | requires = ["setuptools>=67.8.0"]
92 | build-backend = "setuptools.build_meta"
93 |
94 | [tool.codespell]
95 | skip = '.git,*.pdf,*.svg'
96 | #
97 | # ignore-words-list = ''
98 |
--------------------------------------------------------------------------------
/longiseg/preprocessing/normalization/default_normalization_schemes.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Type
3 |
4 | import numpy as np
5 | from numpy import number
6 |
7 |
8 | class ImageNormalization(ABC):
9 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None
10 |
11 | def __init__(self, use_mask_for_norm: bool = None, intensityproperties: dict = None,
12 | target_dtype: Type[number] = np.float32):
13 | assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool)
14 | self.use_mask_for_norm = use_mask_for_norm
15 | assert isinstance(intensityproperties, dict)
16 | self.intensityproperties = intensityproperties
17 | self.target_dtype = target_dtype
18 |
19 | @abstractmethod
20 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
21 | """
22 | Image and seg must have the same shape. Seg is not always used
23 | """
24 | pass
25 |
26 |
27 | class ZScoreNormalization(ImageNormalization):
28 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = True
29 |
30 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
31 | """
32 | here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by
33 | default.
34 | """
35 | image = image.astype(self.target_dtype, copy=False)
36 | if self.use_mask_for_norm is not None and self.use_mask_for_norm:
37 | # negative values in the segmentation encode the 'outside' region (think zero values around the brain as
38 | # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image.
39 | # The default nnU-net sets use_mask_for_norm to True if cropping to the nonzero region substantially
40 | # reduced the image size.
41 | mask = seg >= 0
42 | mean = image[mask].mean()
43 | std = image[mask].std()
44 | image[mask] = (image[mask] - mean) / (max(std, 1e-8))
45 | else:
46 | mean = image.mean()
47 | std = image.std()
48 | image -= mean
49 | image /= (max(std, 1e-8))
50 | return image
51 |
52 |
53 | class CTNormalization(ImageNormalization):
54 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
55 |
56 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
57 | assert self.intensityproperties is not None, "CTNormalization requires intensity properties"
58 | mean_intensity = self.intensityproperties['mean']
59 | std_intensity = self.intensityproperties['std']
60 | lower_bound = self.intensityproperties['percentile_00_5']
61 | upper_bound = self.intensityproperties['percentile_99_5']
62 |
63 | image = image.astype(self.target_dtype, copy=False)
64 | np.clip(image, lower_bound, upper_bound, out=image)
65 | image -= mean_intensity
66 | image /= max(std_intensity, 1e-8)
67 | return image
68 |
69 |
70 | class NoNormalization(ImageNormalization):
71 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
72 |
73 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
74 | return image.astype(self.target_dtype, copy=False)
75 |
76 |
77 | class RescaleTo01Normalization(ImageNormalization):
78 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
79 |
80 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
81 | image = image.astype(self.target_dtype, copy=False)
82 | image -= image.min()
83 | image /= np.clip(image.max(), a_min=1e-8, a_max=None)
84 | return image
85 |
86 |
87 | class RGBTo01Normalization(ImageNormalization):
88 | leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
89 |
90 | def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
91 | assert image.min() >= 0, "RGB images are uint 8, for whatever reason I found pixel values smaller than 0. " \
92 | "Your images do not seem to be RGB images"
93 | assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \
94 | ". Your images do not seem to be RGB images"
95 | image = image.astype(self.target_dtype, copy=False)
96 | image /= 255.
97 | return image
98 |
99 |
--------------------------------------------------------------------------------
/longiseg/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Union
3 |
4 | from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, subfiles, save_json
5 |
6 | from longiseg.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
7 | from longiseg.paths import LongiSeg_preprocessed, LongiSeg_raw
8 | from longiseg.utilities.file_path_utilities import maybe_convert_to_dataset_name
9 | from longiseg.utilities.plans_handling.plans_handler import PlansManager
10 | from longiseg.utilities.utils import get_filenames_of_train_images_and_targets
11 |
12 |
13 | def move_plans_between_datasets(
14 | source_dataset_name_or_id: Union[int, str],
15 | target_dataset_name_or_id: Union[int, str],
16 | source_plans_identifier: str,
17 | target_plans_identifier: str = None):
18 | source_dataset_name = maybe_convert_to_dataset_name(source_dataset_name_or_id)
19 | target_dataset_name = maybe_convert_to_dataset_name(target_dataset_name_or_id)
20 |
21 | if target_plans_identifier is None:
22 | target_plans_identifier = source_plans_identifier
23 |
24 | source_folder = join(LongiSeg_preprocessed, source_dataset_name)
25 | assert isdir(source_folder), f"Cannot move plans because preprocessed directory of source dataset is missing. " \
26 | f"Run nnUNetv2_plan_and_preprocess for source dataset first!"
27 |
28 | source_plans_file = join(source_folder, source_plans_identifier + '.json')
29 | assert isfile(source_plans_file), f"Source plans are missing. Run the corresponding experiment planning first! " \
30 | f"Expected file: {source_plans_file}"
31 |
32 | source_plans = load_json(source_plans_file)
33 | source_plans['dataset_name'] = target_dataset_name
34 |
35 | # we need to change data_identifier to use target_plans_identifier
36 | if target_plans_identifier != source_plans_identifier:
37 | for c in source_plans['configurations'].keys():
38 | if 'data_identifier' in source_plans['configurations'][c].keys():
39 | old_identifier = source_plans['configurations'][c]["data_identifier"]
40 | if old_identifier.startswith(source_plans_identifier):
41 | new_identifier = target_plans_identifier + old_identifier[len(source_plans_identifier):]
42 | else:
43 | new_identifier = target_plans_identifier + '_' + old_identifier
44 | source_plans['configurations'][c]["data_identifier"] = new_identifier
45 |
46 | # we need to change the reader writer class!
47 | target_raw_data_dir = join(LongiSeg_raw, target_dataset_name)
48 | target_dataset_json = load_json(join(target_raw_data_dir, 'dataset.json'))
49 |
50 | # we may need to change the reader/writer
51 | # pick any file from the source dataset
52 | dataset = get_filenames_of_train_images_and_targets(target_raw_data_dir, target_dataset_json)
53 | example_image = dataset[dataset.keys().__iter__().__next__()]['images'][0]
54 | rw = determine_reader_writer_from_dataset_json(target_dataset_json, example_image, allow_nonmatching_filename=True,
55 | verbose=False)
56 |
57 | source_plans["image_reader_writer"] = rw.__name__
58 | if target_plans_identifier is not None:
59 | source_plans["plans_name"] = target_plans_identifier
60 |
61 | save_json(source_plans, join(LongiSeg_preprocessed, target_dataset_name, target_plans_identifier + '.json'),
62 | sort_keys=False)
63 |
64 |
65 | def entry_point_move_plans_between_datasets():
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('-s', type=str, required=True,
68 | help='Source dataset name or id')
69 | parser.add_argument('-t', type=str, required=True,
70 | help='Target dataset name or id')
71 | parser.add_argument('-sp', type=str, required=True,
72 | help='Source plans identifier. If your plans are named "nnUNetPlans.json" then the '
73 | 'identifier would be nnUNetPlans')
74 | parser.add_argument('-tp', type=str, required=False, default=None,
75 | help='Target plans identifier. Default is None meaning the source plans identifier will '
76 | 'be kept. Not recommended if the source plans identifier is a default nnU-Net identifier '
77 | 'such as nnUNetPlans!!!')
78 | args = parser.parse_args()
79 | move_plans_between_datasets(args.s, args.t, args.sp, args.tp)
80 |
81 |
82 | if __name__ == '__main__':
83 | move_plans_between_datasets(2, 4, 'nnUNetPlans', 'nnUNetPlansFrom2')
84 |
--------------------------------------------------------------------------------
/longiseg/dataset_conversion/generate_dataset_json.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union, List
2 |
3 | from batchgenerators.utilities.file_and_folder_operations import save_json, join
4 |
5 |
6 | def generate_dataset_json(output_folder: str,
7 | channel_names: dict,
8 | labels: dict,
9 | num_training_cases: int,
10 | file_ending: str,
11 | citation: Union[List[str], str] = None,
12 | regions_class_order: Tuple[int, ...] = None,
13 | dataset_name: str = None,
14 | reference: str = None,
15 | release: str = None,
16 | description: str = None,
17 | overwrite_image_reader_writer: str = None,
18 | license: str = 'Whoever converted this dataset was lazy and didn\'t look it up!',
19 | converted_by: str = "Please enter your name, especially when sharing datasets with others in a common infrastructure!",
20 | **kwargs):
21 | """
22 | Generates a dataset.json file in the output folder
23 |
24 | channel_names:
25 | Channel names must map the index to the name of the channel, example:
26 | {
27 | 0: 'T1',
28 | 1: 'CT'
29 | }
30 | Note that the channel names may influence the normalization scheme!! Learn more in the documentation.
31 |
32 | labels:
33 | This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not.
34 | Example regular labels:
35 | {
36 | 'background': 0,
37 | 'left atrium': 1,
38 | 'some other label': 2
39 | }
40 | Example region-based training:
41 | {
42 | 'background': 0,
43 | 'whole tumor': (1, 2, 3),
44 | 'tumor core': (2, 3),
45 | 'enhancing tumor': 3
46 | }
47 |
48 | Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background!
49 |
50 | num_training_cases: is used to double check all cases are there!
51 |
52 | file_ending: needed for finding the files correctly. IMPORTANT! File endings must match between images and
53 | segmentations!
54 |
55 | dataset_name, reference, release, license, description: self-explanatory and not used by nnU-Net. Just for
56 | completeness and as a reminder that these would be great!
57 |
58 | overwrite_image_reader_writer: If you need a special IO class for your dataset you can derive it from
59 | BaseReaderWriter, place it into nnunet.imageio and reference it here by name
60 |
61 | kwargs: whatever you put here will be placed in the dataset.json as well
62 |
63 | """
64 | has_regions: bool = any([isinstance(i, (tuple, list)) and len(i) > 1 for i in labels.values()])
65 | if has_regions:
66 | assert regions_class_order is not None, f"You have defined regions but regions_class_order is not set. " \
67 | f"You need that."
68 | # channel names need strings as keys
69 | keys = list(channel_names.keys())
70 | for k in keys:
71 | if not isinstance(k, str):
72 | channel_names[str(k)] = channel_names[k]
73 | del channel_names[k]
74 |
75 | # labels need ints as values
76 | for l in labels.keys():
77 | value = labels[l]
78 | if isinstance(value, (tuple, list)):
79 | value = tuple([int(i) for i in value])
80 | labels[l] = value
81 | else:
82 | labels[l] = int(labels[l])
83 |
84 | dataset_json = {
85 | 'channel_names': channel_names, # previously this was called 'modality'. I didn't like this so this is
86 | # channel_names now. Live with it.
87 | 'labels': labels,
88 | 'numTraining': num_training_cases,
89 | 'file_ending': file_ending,
90 | 'licence': license,
91 | 'converted_by': converted_by
92 | }
93 |
94 | if dataset_name is not None:
95 | dataset_json['name'] = dataset_name
96 | if reference is not None:
97 | dataset_json['reference'] = reference
98 | if release is not None:
99 | dataset_json['release'] = release
100 | if citation is not None:
101 | dataset_json['citation'] = release
102 | if description is not None:
103 | dataset_json['description'] = description
104 | if overwrite_image_reader_writer is not None:
105 | dataset_json['overwrite_image_reader_writer'] = overwrite_image_reader_writer
106 | if regions_class_order is not None:
107 | dataset_json['regions_class_order'] = regions_class_order
108 |
109 | dataset_json.update(kwargs)
110 |
111 | save_json(dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
112 |
--------------------------------------------------------------------------------
/documentation/installation_instructions.md:
--------------------------------------------------------------------------------
1 | # System requirements
2 |
3 | ## Operating System
4 | nnU-Net has been tested on Linux (Ubuntu 18.04, 20.04, 22.04; centOS, RHEL), Windows and MacOS! It should work out of the box!
5 |
6 | ## Hardware requirements
7 | We support GPU (recommended), CPU and Apple M1/M2 as devices (currently Apple mps does not implement 3D
8 | convolutions, so you might have to use the CPU on those devices).
9 |
10 | ### Hardware requirements for Training
11 | We recommend you use a GPU for training as this will take a really long time on CPU or MPS (Apple M1/M2).
12 | For training a GPU with at least 10 GB (popular non-datacenter options are the RTX 2080ti, RTX 3080/3090 or RTX 4080/4090) is
13 | required. We also recommend a strong CPU to go along with the GPU. 6 cores (12 threads)
14 | are the bare minimum! CPU requirements are mostly related to data augmentation and scale with the number of
15 | input channels and target structures. Plus, the faster the GPU, the better the CPU should be!
16 |
17 | ### Hardware Requirements for inference
18 | Again we recommend a GPU to make predictions as this will be substantially faster than the other options. However,
19 | inference times are typically still manageable on CPU and MPS (Apple M1/M2). If using a GPU, it should have at least
20 | 4 GB of available (unused) VRAM.
21 |
22 | ### Example hardware configurations
23 | Example workstation configurations for training:
24 | - CPU: Ryzen 5800X - 5900X or 7900X would be even better! We have not yet tested Intel Alder/Raptor lake but they will likely work as well.
25 | - GPU: RTX 3090 or RTX 4090
26 | - RAM: 64GB
27 | - Storage: SSD (M.2 PCIe Gen 3 or better!)
28 |
29 | Example Server configuration for training:
30 | - CPU: 2x AMD EPYC7763 for a total of 128C/256T. 16C/GPU are highly recommended for fast GPUs such as the A100!
31 | - GPU: 8xA100 PCIe (price/performance superior to SXM variant + they use less power)
32 | - RAM: 1 TB
33 | - Storage: local SSD storage (PCIe Gen 3 or better) or ultra fast network storage
34 |
35 | (nnU-net by default uses one GPU per training. The server configuration can run up to 8 model trainings simultaneously)
36 |
37 | ### Setting the correct number of Workers for data augmentation (training only)
38 | Note that you will need to manually set the number of processes nnU-Net uses for data augmentation according to your
39 | CPU/GPU ratio. For the server above (256 threads for 8 GPUs), a good value would be 24-30. You can do this by
40 | setting the `nnUNet_n_proc_DA` environment variable (`export nnUNet_n_proc_DA=XX`).
41 | Recommended values (assuming a recent CPU with good IPC) are 10-12 for RTX 2080 ti, 12 for a RTX 3090, 16-18 for
42 | RTX 4090, 28-32 for A100. Optimal values may vary depending on the number of input channels/modalities and number of classes.
43 |
44 | # Installation instructions
45 | We strongly recommend that you install nnU-Net in a virtual environment! Pip or anaconda are both fine. If you choose to
46 | compile PyTorch from source (see below), you will need to use conda instead of pip.
47 |
48 | Use a recent version of Python! 3.9 or newer is guaranteed to work!
49 |
50 | **nnU-Net v2 can coexist with nnU-Net v1! Both can be installed at the same time.**
51 |
52 | 1) Install [PyTorch](https://pytorch.org/get-started/locally/) as described on their website (conda/pip). Please
53 | install the latest version with support for your hardware (cuda, mps, cpu).
54 | **DO NOT JUST `pip install nnunetv2` WITHOUT PROPERLY INSTALLING PYTORCH FIRST**. For maximum speed, consider
55 | [compiling pytorch yourself](https://github.com/pytorch/pytorch#from-source) (experienced users only!).
56 | 2) Install nnU-Net depending on your use case:
57 | 1) For use as **standardized baseline**, **out-of-the-box segmentation algorithm** or for running
58 | **inference with pretrained models**:
59 |
60 | ```pip install nnunetv2```
61 |
62 | 2) For use as integrative **framework** (this will create a copy of the nnU-Net code on your computer so that you
63 | can modify it as needed):
64 | ```bash
65 | git clone https://github.com/MIC-DKFZ/nnUNet.git
66 | cd nnUNet
67 | pip install -e .
68 | ```
69 | 3) nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to
70 | set a few environment variables. Please follow the instructions [here](setting_up_paths.md).
71 |
72 | Installing nnU-Net will add several new commands to your terminal. These commands are used to run the entire nnU-Net
73 | pipeline. You can execute them from any location on your system. All nnU-Net commands have the prefix `nnUNetv2_` for
74 | easy identification.
75 |
76 | Note that these commands simply execute python scripts. If you installed nnU-Net in a virtual environment, this
77 | environment must be activated when executing the commands. You can see what scripts/functions are executed by
78 | checking the project.scripts in the [pyproject.toml](../pyproject.toml) file.
79 |
80 | All nnU-Net commands have a `-h` option which gives information on how to use them.
81 |
--------------------------------------------------------------------------------
/longiseg/utilities/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
2 | # (DKFZ), Heidelberg, Germany
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from typing import List
16 |
17 | import os.path
18 | from functools import lru_cache
19 | from typing import Union
20 |
21 | from batchgenerators.utilities.file_and_folder_operations import subfiles, subdirs, isdir, join, load_json
22 | import numpy as np
23 | import re
24 |
25 | from longiseg.paths import LongiSeg_raw
26 | from multiprocessing import Pool
27 |
28 |
29 | def get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str):
30 | files = subfiles(folder, suffix=file_ending, join=False)
31 | # all files have a 4 digit channel index (_XXXX)
32 | crop = len(file_ending) + 5
33 | files = [i[:-crop] for i in files]
34 | # only unique image ids
35 | files = np.unique(files)
36 | return files
37 |
38 |
39 | def create_paths_fn(folder, files, file_ending, f):
40 | p = re.compile(re.escape(f) + r"_\d\d\d\d" + re.escape(file_ending))
41 | return [join(folder, i) for i in files if p.fullmatch(i)]
42 |
43 |
44 | def create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None, num_processes: int = 12) -> List[
45 | List[str]]:
46 | """
47 | does not rely on dataset.json
48 | """
49 | if identifiers is None:
50 | identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending)
51 | files = subfiles(folder, suffix=file_ending, join=False, sort=True)
52 | list_of_lists = []
53 |
54 | params_list = [(folder, files, file_ending, f) for f in identifiers]
55 | with Pool(processes=num_processes) as pool:
56 | list_of_lists = pool.starmap(create_paths_fn, params_list)
57 |
58 | return list_of_lists
59 |
60 |
61 | def create_lists_from_splitted_dataset_folder_and_patients_json(folder: str, file_ending: str, patients_json: dict, identifiers: List[str] = None,
62 | num_processes: int = 12) -> List[List[str]]:
63 | """
64 | uses patients_json to filter for scans with patient information
65 | """
66 | if identifiers is None:
67 | identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending)
68 | files = subfiles(folder, suffix=file_ending, join=False, sort=True)
69 |
70 | # filter patient_json with identifiers
71 | filtered_patients_json = {k: [i for i in v if i in identifiers] for k, v in patients_json.items()}
72 |
73 | # for now must work without multiprocessing, but the current implementation is not optimal anyways, so that's it
74 | list_of_lists = []
75 | for k, v in filtered_patients_json.items():
76 | if len(v) == 0: continue
77 | patient_files = [create_paths_fn(folder, files, file_ending, f) for f in v]
78 | list_of_lists.extend([[c, p] for c, p in zip(patient_files, patient_files[:1] + patient_files[:-1])])
79 |
80 | return list_of_lists
81 |
82 |
83 | def get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None):
84 | if dataset_json is None:
85 | dataset_json = load_json(join(raw_dataset_folder, 'dataset.json'))
86 |
87 | if 'dataset' in dataset_json.keys():
88 | dataset = dataset_json['dataset']
89 | for k in dataset.keys():
90 | expanded_label_file = os.path.expandvars(dataset[k]['label'])
91 | dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, expanded_label_file)) if not os.path.isabs(expanded_label_file) else expanded_label_file
92 | dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, os.path.expandvars(i))) if not os.path.isabs(os.path.expandvars(i)) else os.path.expandvars(i) for i in dataset[k]['images']]
93 | else:
94 | identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'])
95 | images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers)
96 | segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers]
97 | dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)}
98 | return dataset
99 |
100 |
101 | if __name__ == '__main__':
102 | print(get_filenames_of_train_images_and_targets(join(LongiSeg_raw, 'Dataset002_Heart')))
103 |
--------------------------------------------------------------------------------
/longiseg/training/logging/nnunet_logger.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | from batchgenerators.utilities.file_and_folder_operations import join
3 |
4 | matplotlib.use('agg')
5 | import seaborn as sns
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | class nnUNetLogger(object):
10 | """
11 | This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems
12 | arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a
13 | little
14 |
15 | YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP
16 | """
17 | def __init__(self, verbose: bool = False):
18 | self.my_fantastic_logging = {
19 | 'mean_fg_dice': list(),
20 | 'ema_fg_dice': list(),
21 | 'dice_per_class_or_region': list(),
22 | 'train_losses': list(),
23 | 'val_losses': list(),
24 | 'lrs': list(),
25 | 'epoch_start_timestamps': list(),
26 | 'epoch_end_timestamps': list()
27 | }
28 | self.verbose = verbose
29 | # shut up, this logging is great
30 |
31 | def log(self, key, value, epoch: int):
32 | """
33 | sometimes shit gets messed up. We try to catch that here
34 | """
35 | assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \
36 | 'This function is only intended to log stuff to lists and to have one entry per epoch'
37 |
38 | if self.verbose: print(f'logging {key}: {value} for epoch {epoch}')
39 |
40 | if len(self.my_fantastic_logging[key]) < (epoch + 1):
41 | self.my_fantastic_logging[key].append(value)
42 | else:
43 | assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \
44 | 'lists length is off by more than 1'
45 | print(f'maybe some logging issue!? logging {key} and {value}')
46 | self.my_fantastic_logging[key][epoch] = value
47 |
48 | # handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice
49 | if key == 'mean_fg_dice':
50 | new_ema_pseudo_dice = self.my_fantastic_logging['ema_fg_dice'][epoch - 1] * 0.9 + 0.1 * value \
51 | if len(self.my_fantastic_logging['ema_fg_dice']) > 0 else value
52 | self.log('ema_fg_dice', new_ema_pseudo_dice, epoch)
53 |
54 | def plot_progress_png(self, output_folder):
55 | # we infer the epoch form our internal logging
56 | epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1 # lists of epoch 0 have len 1
57 | sns.set(font_scale=2.5)
58 | fig, ax_all = plt.subplots(3, 1, figsize=(30, 54))
59 | # regular progress.png as we are used to from previous nnU-Net versions
60 | ax = ax_all[0]
61 | ax2 = ax.twinx()
62 | x_values = list(range(epoch + 1))
63 | ax.plot(x_values, self.my_fantastic_logging['train_losses'][:epoch + 1], color='b', ls='-', label="loss_tr", linewidth=4)
64 | ax.plot(x_values, self.my_fantastic_logging['val_losses'][:epoch + 1], color='r', ls='-', label="loss_val", linewidth=4)
65 | ax2.plot(x_values, self.my_fantastic_logging['mean_fg_dice'][:epoch + 1], color='g', ls='dotted', label="pseudo dice",
66 | linewidth=3)
67 | ax2.plot(x_values, self.my_fantastic_logging['ema_fg_dice'][:epoch + 1], color='g', ls='-', label="pseudo dice (mov. avg.)",
68 | linewidth=4)
69 | ax.set_xlabel("epoch")
70 | ax.set_ylabel("loss")
71 | ax2.set_ylabel("pseudo dice")
72 | ax.legend(loc=(0, 1))
73 | ax2.legend(loc=(0.2, 1))
74 |
75 | # epoch times to see whether the training speed is consistent (inconsistent means there are other jobs
76 | # clogging up the system)
77 | ax = ax_all[1]
78 | ax.plot(x_values, [i - j for i, j in zip(self.my_fantastic_logging['epoch_end_timestamps'][:epoch + 1],
79 | self.my_fantastic_logging['epoch_start_timestamps'])][:epoch + 1], color='b',
80 | ls='-', label="epoch duration", linewidth=4)
81 | ylim = [0] + [ax.get_ylim()[1]]
82 | ax.set(ylim=ylim)
83 | ax.set_xlabel("epoch")
84 | ax.set_ylabel("time [s]")
85 | ax.legend(loc=(0, 1))
86 |
87 | # learning rate
88 | ax = ax_all[2]
89 | ax.plot(x_values, self.my_fantastic_logging['lrs'][:epoch + 1], color='b', ls='-', label="learning rate", linewidth=4)
90 | ax.set_xlabel("epoch")
91 | ax.set_ylabel("learning rate")
92 | ax.legend(loc=(0, 1))
93 |
94 | plt.tight_layout()
95 |
96 | fig.savefig(join(output_folder, "progress.png"))
97 | plt.close()
98 |
99 | def get_checkpoint(self):
100 | return self.my_fantastic_logging
101 |
102 | def load_checkpoint(self, checkpoint: dict):
103 | self.my_fantastic_logging = checkpoint
104 |
--------------------------------------------------------------------------------
/longiseg/imageio/tif_reader_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
2 | # (DKFZ), Heidelberg, Germany
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import os.path
16 | from typing import Tuple, Union, List
17 | import numpy as np
18 | from longiseg.imageio.base_reader_writer import BaseReaderWriter
19 | import tifffile
20 | from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_json, split_path, join
21 |
22 |
23 | class Tiff3DIO(BaseReaderWriter):
24 | """
25 | reads and writes 3D tif(f) images. Uses tifffile package. Ignores metadata (for now)!
26 |
27 | If you have 2D tiffs, use NaturalImage2DIO
28 |
29 | Supports the use of auxiliary files for spacing information. If used, the auxiliary files are expected to end
30 | with .json and omit the channel identifier. So, for example, the corresponding of image image1_0000.tif is
31 | expected to be image1.json)!
32 | """
33 | supported_file_endings = [
34 | '.tif',
35 | '.tiff',
36 | ]
37 |
38 | def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
39 | # figure out file ending used here
40 | ending = '.' + image_fnames[0].split('.')[-1]
41 | assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}'
42 | ending_length = len(ending)
43 | truncate_length = ending_length + 5 # 5 comes from len(_0000)
44 |
45 | images = []
46 | for f in image_fnames:
47 | image = tifffile.imread(f)
48 | if image.ndim != 3:
49 | raise RuntimeError(f"Only 3D images are supported! File: {f}")
50 | images.append(image[None])
51 |
52 | # see if aux file can be found
53 | expected_aux_file = image_fnames[0][:-truncate_length] + '.json'
54 | if isfile(expected_aux_file):
55 | spacing = load_json(expected_aux_file)['spacing']
56 | assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}'
57 | else:
58 | print(f'WARNING no spacing file found for images {image_fnames}\nAssuming spacing (1, 1, 1).')
59 | spacing = (1, 1, 1)
60 |
61 | if not self._check_all_same([i.shape for i in images]):
62 | print('ERROR! Not all input images have the same shape!')
63 | print('Shapes:')
64 | print([i.shape for i in images])
65 | print('Image files:')
66 | print(image_fnames)
67 | raise RuntimeError()
68 |
69 | return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': spacing}
70 |
71 | def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
72 | # not ideal but I really have no clue how to set spacing/resolution information properly in tif files haha
73 | tifffile.imwrite(output_fname, data=seg.astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), compression='zlib')
74 | file = os.path.basename(output_fname)
75 | out_dir = os.path.dirname(output_fname)
76 | ending = file.split('.')[-1]
77 | save_json({'spacing': properties['spacing']}, join(out_dir, file[:-(len(ending) + 1)] + '.json'))
78 |
79 | def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
80 | # figure out file ending used here
81 | ending = '.' + seg_fname.split('.')[-1]
82 | assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}'
83 | ending_length = len(ending)
84 |
85 | seg = tifffile.imread(seg_fname)
86 | if seg.ndim != 3:
87 | raise RuntimeError(f"Only 3D images are supported! File: {seg_fname}")
88 | seg = seg[None]
89 |
90 | # see if aux file can be found
91 | expected_aux_file = seg_fname[:-ending_length] + '.json'
92 | if isfile(expected_aux_file):
93 | spacing = load_json(expected_aux_file)['spacing']
94 | assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}'
95 | assert all([i > 0 for i in spacing]), f"Spacing must be > 0, spacing: {spacing}"
96 | else:
97 | print(f'WARNING no spacing file found for segmentation {seg_fname}\nAssuming spacing (1, 1, 1).')
98 | spacing = (1, 1, 1)
99 |
100 | return seg.astype(np.float32, copy=False), {'spacing': spacing}
101 |
--------------------------------------------------------------------------------
/longiseg/utilities/file_path_utilities.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from multiprocessing.pool import Pool
4 | from typing import Union, Tuple
5 | import numpy as np
6 | import os
7 | from batchgenerators.utilities.file_and_folder_operations import split_path, join
8 |
9 | from longiseg.configuration import default_num_processes
10 | from longiseg.paths import LongiSeg_results
11 | from longiseg.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
12 |
13 |
14 | def convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration):
15 | return f'{trainer_name}__{plans_identifier}__{configuration}'
16 |
17 |
18 | def convert_identifier_to_trainer_plans_config(identifier: str):
19 | return os.path.basename(identifier).split('__')
20 |
21 |
22 | def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer',
23 | plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres',
24 | fold: Union[str, int] = None) -> str:
25 | tmp = join(LongiSeg_results, maybe_convert_to_dataset_name(dataset_name_or_id),
26 | convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration))
27 | if fold is not None:
28 | tmp = join(tmp, f'fold_{fold}')
29 | return tmp
30 |
31 |
32 | def parse_dataset_trainer_plans_configuration_from_path(path: str):
33 | folders = split_path(path)
34 | # this here can be a little tricky because we are making assumptions. Let's hope this never fails lol
35 |
36 | # safer to make this depend on two conditions, the fold_x and the DatasetXXX
37 | # first let's see if some fold_X is present
38 | fold_x_present = [i.startswith('fold_') for i in folders]
39 | if any(fold_x_present):
40 | idx = fold_x_present.index(True)
41 | # OK now two entries before that there should be DatasetXXX
42 | assert len(folders[:idx]) >= 2, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
43 | 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
44 | if folders[idx - 2].startswith('Dataset'):
45 | split = folders[idx - 1].split('__')
46 | assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
47 | 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
48 | return folders[idx - 2], *split
49 | else:
50 | # we can only check for dataset followed by a string that is separable into three strings by splitting with '__'
51 | # look for DatasetXXX
52 | dataset_folder = [i.startswith('Dataset') for i in folders]
53 | if any(dataset_folder):
54 | idx = dataset_folder.index(True)
55 | assert len(folders) >= (idx + 1), 'Bad path, cannot extract what I need. Your path needs to be at least ' \
56 | 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
57 | split = folders[idx + 1].split('__')
58 | assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
59 | 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
60 | return folders[idx], *split
61 |
62 |
63 | def get_ensemble_name(model1_folder, model2_folder, folds: Tuple[int, ...]):
64 | identifier = 'ensemble___' + os.path.basename(model1_folder) + '___' + \
65 | os.path.basename(model2_folder) + '___' + folds_tuple_to_string(folds)
66 | return identifier
67 |
68 |
69 | def get_ensemble_name_from_d_tr_c(dataset, tr1, p1, c1, tr2, p2, c2, folds: Tuple[int, ...]):
70 | model1_folder = get_output_folder(dataset, tr1, p1, c1)
71 | model2_folder = get_output_folder(dataset, tr2, p2, c2)
72 |
73 | get_ensemble_name(model1_folder, model2_folder, folds)
74 |
75 |
76 | def convert_ensemble_folder_to_model_identifiers_and_folds(ensemble_folder: str):
77 | prefix, *models, folds = os.path.basename(ensemble_folder).split('___')
78 | return models, folds
79 |
80 |
81 | def folds_tuple_to_string(folds: Union[List[int], Tuple[int, ...]]):
82 | s = str(folds[0])
83 | for f in folds[1:]:
84 | s += f"_{f}"
85 | return s
86 |
87 |
88 | def folds_string_to_tuple(folds_string: str):
89 | folds = folds_string.split('_')
90 | res = []
91 | for f in folds:
92 | try:
93 | res.append(int(f))
94 | except ValueError:
95 | res.append(f)
96 | return res
97 |
98 |
99 | def check_workers_alive_and_busy(export_pool: Pool, worker_list: List, results_list: List, allowed_num_queued: int = 0):
100 | """
101 |
102 | returns True if the number of results that are not ready is greater than the number of available workers + allowed_num_queued
103 | """
104 | alive = [i.is_alive() for i in worker_list]
105 | if not all(alive):
106 | raise RuntimeError('Some background workers are no longer alive')
107 |
108 | not_ready = [not i.ready() for i in results_list]
109 | if sum(not_ready) >= (len(export_pool._pool) + allowed_num_queued):
110 | return True
111 | return False
--------------------------------------------------------------------------------
/documentation/ignore_label.md:
--------------------------------------------------------------------------------
1 | # Ignore Label
2 |
3 | The _ignore label_ can be used to mark regions that should be ignored by nnU-Net. This can be used to
4 | learn from images where only sparse annotations are available, for example in the form of scribbles or a limited
5 | amount of annotated slices. Internally, this is accomplished by using partial losses, i.e. losses that are only
6 | computed on annotated pixels while ignoring the rest. Take a look at our
7 | [`DC_and_BCE_loss` loss](../longiseg/training/loss/compound_losses.py) to see how this is done.
8 | During inference (validation and prediction), nnU-Net will always predict dense segmentations. Metric computation in
9 | validation is of course only done on annotated pixels.
10 |
11 | Using sparse annotations can be used to train a model for application to new, unseen images or to autocomplete the
12 | provided training cases given the sparse labels.
13 |
14 | (See our [paper](https://arxiv.org/abs/2403.12834) for more information)
15 |
16 | Typical use-cases for the ignore label are:
17 | - Save annotation time through sparse annotation schemes
18 | - Annotation of all or a subset of slices with scribbles (Scribble Supervision)
19 | - Dense annotation of a subset of slices
20 | - Dense annotation of chosen patches/cubes within an image
21 | - Coarsly masking out faulty segmentations in the reference segmentations
22 | - Masking areas for other reasons
23 |
24 | If you are using nnU-Net's ignore label, please cite the following paper in addition to the original nnU-net paper:
25 |
26 | ```
27 | Gotkowski, K., Lüth, C., Jäger, P. F., Ziegler, S., Krämer, L., Denner, S., Xiao, S., Disch, N., H., K., & Isensee, F.
28 | (2024). Embarrassingly Simple Scribble Supervision for 3D Medical Segmentation. ArXiv. /abs/2403.12834
29 | ```
30 |
31 | ## Usecases
32 |
33 | ### Scribble Supervision
34 |
35 | Scribbles are free-form drawings to coarsly annotate an image. As we have demonstrated in our recent [paper](https://arxiv.org/abs/2403.12834), nnU-Net's partial loss implementation enables state-of-the-art learning from partially annotated data and even surpasses many purpose-built methods for learning from scribbles. As a starting point, for each image slice and each class (including background), an interior and a border scribble should be generated:
36 |
37 | - Interior Scribble: A scribble placed randomly within the class interior of a class instance
38 | - Border Scribble: A scribble roughly delineating a small part of the class border of a class instance
39 |
40 | An example of such scribble annotations is depicted in Figure 1 and an animation in Animation 1.
41 | Depending on the availability of data and their variability it is also possible to only annotated a subset of selected slices.
42 |
43 |
44 |
45 |
50 |
51 |
52 |
60 |
61 |
62 |