├── LICENSE ├── README.md ├── analysis ├── bootstrap_ranking.py ├── check_training_progress.ipynb ├── paper_analysis.py └── utils.py ├── assets ├── fd_figure1.svg ├── paper_table.png └── table_big_comparison_multimetric.html ├── dataset_splits ├── Dataset500_simple_fets_corruptions │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset503_BraTS19 │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset511_MnM_VendorB_train │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset514_MVSeg23 │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset515_KiTS23 │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset520_CovidLungCT │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset521_ProstateGonzalez │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset531_DoFEOpticDiscCup_512 │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset540_RETOUCH_ood=Cirrus │ ├── domain_mapping_00.json │ └── splits_final.json ├── Dataset560_OCTA500_pathology_split │ ├── domain_mapping_00.json │ └── splits_final.json └── readme.md ├── pyproject.toml ├── src └── segmentation_failures │ ├── callbacks │ ├── batch_logging.py │ ├── confidence_map_writer.py │ ├── network_graph_tb.py │ ├── prediction_writer.py │ ├── quality_estimator_monitor.py │ └── results_writer.py │ ├── conf │ ├── analysis │ │ ├── fail_thresholds │ │ │ └── fail_thresholds_general.yaml │ │ ├── failure_detection.yaml │ │ └── failure_detection_with_thresholds.yaml │ ├── backbone │ │ ├── dynamic_resencunet.yaml │ │ ├── dynamic_resencunet_deepsup.yaml │ │ ├── dynamic_resencunet_dropout.yaml │ │ ├── dynamic_unet.yaml │ │ ├── dynamic_unet_deepsup.yaml │ │ ├── dynamic_unet_dropout.yaml │ │ ├── dynamic_wideunet.yaml │ │ ├── dynamic_wideunet_dropout.yaml │ │ ├── monai_unet.yaml │ │ └── monai_unet_dropout.yaml │ ├── callbacks │ │ ├── default.yaml │ │ ├── none.yaml │ │ ├── test │ │ │ ├── confidence_saver.yaml │ │ │ ├── dummy.yaml │ │ │ ├── ensemble_prediction_saver.yaml │ │ │ ├── prediction_saver.yaml │ │ │ └── results_saver.yaml │ │ ├── train │ │ │ ├── batch_logging.yaml │ │ │ ├── model_checkpoint.yaml │ │ │ ├── prediction_saver.yaml │ │ │ ├── quality_training_monitor.yaml │ │ │ └── save_val_predictions.yaml │ │ ├── train_seg │ │ │ ├── batch_logging.yaml │ │ │ ├── model_checkpoint.yaml │ │ │ ├── prediction_saver.yaml │ │ │ ├── quality_training_monitor.yaml │ │ │ └── save_val_predictions.yaml │ │ └── validate │ │ │ ├── batch_logging.yaml │ │ │ ├── confidence_saver.yaml │ │ │ └── prediction_saver.yaml │ ├── config.yaml │ ├── csf_aggregation │ │ ├── all_simple.yaml │ │ ├── heuristic.yaml │ │ ├── radiomics.yaml │ │ └── simple_aggs │ │ │ ├── distance_weighting.yaml │ │ │ ├── foreground.yaml │ │ │ ├── mean.yaml │ │ │ ├── only_non_boundary.yaml │ │ │ ├── pairwise_gen_dice.yaml │ │ │ ├── pairwise_mean_dice.yaml │ │ │ └── patch_based.yaml │ ├── csf_image │ │ ├── mahalanobis.yaml │ │ ├── mahalanobis_gonzalez.yaml │ │ ├── quality_regression.yaml │ │ ├── vae_image_and_mask.yaml │ │ ├── vae_image_only.yaml │ │ ├── vae_iterative_surrogate.yaml │ │ └── vae_mask_only.yaml │ ├── csf_pixel │ │ ├── baseline.yaml │ │ ├── deep_ensemble.yaml │ │ └── mcdropout.yaml │ ├── datamodule │ │ ├── acdc_nnunet.yaml │ │ ├── brats19_lhgg_nnunet.yaml │ │ ├── covid_nnunet.yaml │ │ ├── dummy.yaml │ │ ├── heuristic_radiomics.yaml │ │ ├── kits23_nnunet.yaml │ │ ├── mnms_nnunet.yaml │ │ ├── mvseg23_nnunet.yaml │ │ ├── nnunet.yaml │ │ ├── octa500_nnunet.yaml │ │ ├── prostate_nnunet.yaml │ │ ├── quality_regression.yaml │ │ ├── retina_nnunet.yaml │ │ ├── retouch_cirrus_nnunet.yaml │ │ ├── retouch_spectralis_nnunet.yaml │ │ ├── retouch_topcon_nnunet.yaml │ │ ├── simple_agg.yaml │ │ ├── simple_fets22_corrupted.yaml │ │ └── vae.yaml │ ├── dataset │ │ ├── abstract.yaml │ │ ├── acdc.yaml │ │ ├── brats19_lhgg.yaml │ │ ├── covid_gonzalez.yaml │ │ ├── kits23.yaml │ │ ├── mnms.yaml │ │ ├── mvseg23.yaml │ │ ├── octa500.yaml │ │ ├── prostate_gonzalez.yaml │ │ ├── retina.yaml │ │ ├── retouch_cirrus.yaml │ │ ├── retouch_spectralis.yaml │ │ ├── retouch_topcon.yaml │ │ └── simple_fets22_corrupted.yaml │ ├── debug │ │ ├── default.yaml │ │ ├── limit_batches.yaml │ │ ├── lrfind.yaml │ │ ├── overfit.yaml │ │ └── step.yaml │ ├── hydra │ │ ├── cluster.yaml │ │ ├── debug.yaml │ │ └── local.yaml │ ├── logger │ │ ├── csv.yaml │ │ ├── default.yaml │ │ └── tensorboard.yaml │ ├── paths │ │ └── default.yaml │ ├── segmentation │ │ ├── baseline.yaml │ │ └── dynunet.yaml │ └── trainer │ │ ├── cpu.yaml │ │ └── single_gpu.yaml │ ├── data │ ├── __init__.py │ ├── corruptions │ │ ├── corrupt_data_torchio.py │ │ └── image_corruptions_tio.py │ ├── datamodules │ │ ├── __init__.py │ │ ├── additional_readers.py │ │ ├── dummy_modules.py │ │ ├── monai_modules.py │ │ ├── nnunet_module.py │ │ ├── nnunet_utils.py │ │ ├── quality_regression.py │ │ ├── simple_agg.py │ │ └── vae.py │ └── dataset_conversion │ │ ├── nnunet_acdc.py │ │ ├── nnunet_brats19.py │ │ ├── nnunet_covid.py │ │ ├── nnunet_fets22.py │ │ ├── nnunet_kits23.py │ │ ├── nnunet_mnm.py │ │ ├── nnunet_mnm2.py │ │ ├── nnunet_mvseg23.py │ │ ├── nnunet_octa500.py │ │ ├── nnunet_prostate.py │ │ ├── nnunet_retina.py │ │ ├── nnunet_retouch.py │ │ └── nnunet_simple_fets_corruptions.py │ ├── evaluation │ ├── __init__.py │ ├── experiment_data.py │ ├── failure_detection │ │ ├── __init__.py │ │ ├── fd_analysis.py │ │ └── metrics.py │ ├── ood_detection │ │ ├── __init__.py │ │ ├── metrics.py │ │ └── ood_analysis.py │ └── segmentation │ │ ├── __init__.py │ │ ├── compute_seg_metrics.py │ │ ├── custom_metrics │ │ ├── hausdorff.py │ │ └── surface_distance.py │ │ ├── distance_thresholds.py │ │ └── segmentation_metrics.py │ ├── experiments │ ├── __init__.py │ ├── cluster.py │ ├── experiment.py │ ├── experiments_paper.sh │ ├── experiments_revision_arch.sh │ ├── experiments_revision_ds_size.sh │ ├── experiments_revision_newdata.sh │ ├── launcher.py │ ├── nnunet_cluster.py │ └── prepare_auxdata.py │ ├── models │ ├── __init__.py │ ├── confidence_aggregation │ │ ├── __init__.py │ │ ├── base.py │ │ ├── heuristic.py │ │ ├── radiomics.py │ │ └── simple_agg.py │ ├── image_confidence │ │ ├── __init__.py │ │ ├── mahalanobis.py │ │ ├── regression_network.py │ │ └── vae_estimator.py │ ├── pixel_confidence │ │ ├── __init__.py │ │ ├── ensemble.py │ │ ├── posthoc.py │ │ └── scores.py │ └── segmentation │ │ ├── __init__.py │ │ ├── dynunet_module.py │ │ └── monai_segmenter.py │ ├── networks │ ├── __init__.py │ ├── dynunet.py │ ├── nnunet │ │ └── __init__.py │ └── vae │ │ ├── __init__.py │ │ ├── encoder_decoder.py │ │ ├── utils.py │ │ └── vae.py │ ├── scripts │ ├── check_dataset_splits.py │ ├── evaluate_experiment.py │ ├── prepare_data_quality_regression.py │ ├── test_fd.py │ ├── test_pixel_csf.py │ ├── train_image_csf.py │ ├── train_seg.py │ └── validate_pixel_csf.py │ └── utils │ ├── __init__.py │ ├── checkpointing.py │ ├── config_handling.py │ ├── data.py │ ├── dice_bce_loss.py │ ├── feature_extraction.py │ ├── io.py │ ├── label_handling.py │ ├── network.py │ ├── view_images_napari.py │ └── visualization.py └── tests ├── callbacks └── test_results_writer.py ├── evaluation ├── test_ood_metrics.py └── test_seg_metrics.py └── models └── confidence_scoring ├── test_aggregation.py ├── test_mahalanobis.py └── test_posthoc.py /analysis/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple 3 | 4 | import pandas as pd 5 | from loguru import logger 6 | from omegaconf import DictConfig 7 | 8 | from segmentation_failures.evaluation.experiment_data import ExperimentData 9 | from segmentation_failures.utils.io import load_expt_config 10 | 11 | 12 | def load_raw_results(expt_dir): 13 | expt_data = ExperimentData.from_experiment_dir(Path(expt_dir)) 14 | expt_config = expt_data.config 15 | expt_df = expt_data.to_dataframe() 16 | expt_df["expt_name"] = expt_data.config.expt_name 17 | if "hparams" in expt_config.datamodule: 18 | expt_df["fold"] = expt_config.datamodule.hparams.fold 19 | else: 20 | expt_df["fold"] = expt_config.datamodule.fold 21 | expt_df["seed"] = expt_config.seed 22 | return expt_df 23 | 24 | 25 | def load_fd_results_hydra( 26 | expt_root_dir: Path, 27 | csv_name="fd_metrics.csv", 28 | legacy_structure=False, 29 | ) -> Tuple[pd.DataFrame, DictConfig]: 30 | # assumes expt_root_dir structure: $expt_root_dir/experiment_name/test_fd/ 31 | # TODO maybe select the csv file with highest count (e.g. fd_metrics_1.csv, fd_metrics_2.csv, ...) 32 | # because this is how it is saved when rerunning the evaluation 33 | all_results = [] 34 | all_configs = {} 35 | expt_id = 0 36 | for expt_dir in expt_root_dir.iterdir(): 37 | test_runs_root = expt_dir / "test_fd" 38 | if legacy_structure: 39 | test_runs_root = expt_dir 40 | if not test_runs_root.exists(): 41 | continue 42 | for run_version_dir in test_runs_root.iterdir(): 43 | logger.debug(run_version_dir) 44 | try: 45 | results, config = load_single_fd_result_hydra(run_version_dir, csv_name) 46 | except FileNotFoundError: 47 | logger.warning( 48 | f"Could not find results file {csv_name} in {run_version_dir}. Ignoring this!" 49 | ) 50 | else: 51 | results["expt_id"] = expt_id 52 | all_results.append(results) 53 | all_configs[expt_id] = config 54 | expt_id += 1 55 | # it could happen that we don't find the csv for this run, but we still want to count it, because there might be other csvs 56 | if len(all_results) > 0: 57 | return pd.concat(all_results, ignore_index=True), all_configs 58 | else: 59 | logger.warning(f"Couldn't find experiment runs here: {expt_root_dir}") 60 | return pd.DataFrame(), DictConfig({}) 61 | 62 | 63 | def load_single_fd_result_hydra(expt_dir: Path, csv_name: str): 64 | if not csv_name.lower().endswith(".csv"): 65 | csv_name += ".csv" 66 | # load run configuration 67 | config = load_expt_config(expt_dir, resolve=False) 68 | # I set the output dir manually here because we already know it (don't need hydra). 69 | # Otherwise, there can be issues when running on cluster/analysing on workstation 70 | config.paths.output_dir = str(expt_dir) 71 | if "analysis_dir" not in config.paths: 72 | raise RuntimeError( 73 | "No analysis_dir entry in config.paths found. Seems like the directory follows another storing convention." 74 | ) 75 | results_csv = Path(config.paths.analysis_dir) / csv_name 76 | if not results_csv.exists(): 77 | raise FileNotFoundError 78 | results = pd.read_csv(results_csv, index_col=0) 79 | results["root_dir"] = str(expt_dir.absolute()) 80 | return results, config 81 | -------------------------------------------------------------------------------- /assets/paper_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/segmentation_failures_benchmark/a1af98be0f93c2bdc30ffe5bb6eda531c485d87d/assets/paper_table.png -------------------------------------------------------------------------------- /dataset_splits/Dataset503_BraTS19/domain_mapping_00.json: -------------------------------------------------------------------------------- 1 | { 2 | "BraTS19_CBICA_ALU_1": "HGG", 3 | "BraTS19_CBICA_ANP_1": "HGG", 4 | "BraTS19_CBICA_ANZ_1": "HGG", 5 | "BraTS19_CBICA_AOH_1": "HGG", 6 | "BraTS19_CBICA_APY_1": "HGG", 7 | "BraTS19_CBICA_AQZ_1": "HGG", 8 | "BraTS19_CBICA_ARF_1": "HGG", 9 | "BraTS19_CBICA_ARZ_1": "HGG", 10 | "BraTS19_CBICA_ASW_1": "HGG", 11 | "BraTS19_CBICA_ATD_1": "HGG", 12 | "BraTS19_CBICA_ATP_1": "HGG", 13 | "BraTS19_CBICA_ATV_1": "HGG", 14 | "BraTS19_CBICA_AUQ_1": "HGG", 15 | "BraTS19_CBICA_AWH_1": "HGG", 16 | "BraTS19_CBICA_AWI_1": "HGG", 17 | "BraTS19_CBICA_AXL_1": "HGG", 18 | "BraTS19_CBICA_AXO_1": "HGG", 19 | "BraTS19_CBICA_BFB_1": "HGG", 20 | "BraTS19_CBICA_AOS_1": "HGG", 21 | "BraTS19_CBICA_ASR_1": "HGG", 22 | "BraTS19_CBICA_ATN_1": "HGG", 23 | "BraTS19_CBICA_AUX_1": "HGG", 24 | "BraTS19_CBICA_AVB_1": "HGG", 25 | "BraTS19_CBICA_BGO_1": "HGG", 26 | "BraTS19_CBICA_BGT_1": "HGG", 27 | "BraTS19_2013_14_1": "HGG", 28 | "BraTS19_2013_23_1": "HGG", 29 | "BraTS19_TMC_06643_1": "HGG", 30 | "BraTS19_TMC_27374_1": "HGG", 31 | "BraTS19_TCIA01_201_1": "HGG", 32 | "BraTS19_TCIA01_378_1": "HGG", 33 | "BraTS19_TCIA01_401_1": "HGG", 34 | "BraTS19_TCIA01_235_1": "HGG", 35 | "BraTS19_TCIA01_131_1": "HGG", 36 | "BraTS19_TCIA01_203_1": "HGG", 37 | "BraTS19_TCIA01_221_1": "HGG", 38 | "BraTS19_TCIA01_411_1": "HGG", 39 | "BraTS19_TCIA02_491_1": "HGG", 40 | "BraTS19_TCIA02_331_1": "HGG", 41 | "BraTS19_TCIA02_321_1": "HGG", 42 | "BraTS19_TCIA02_430_1": "HGG", 43 | "BraTS19_TCIA02_374_1": "HGG", 44 | "BraTS19_TCIA02_370_1": "HGG", 45 | "BraTS19_TCIA02_135_1": "HGG", 46 | "BraTS19_TCIA03_121_1": "HGG", 47 | "BraTS19_TCIA03_498_1": "HGG", 48 | "BraTS19_TCIA03_133_1": "HGG", 49 | "BraTS19_TCIA05_396_1": "HGG", 50 | "BraTS19_TCIA06_165_1": "HGG", 51 | "BraTS19_TCIA06_332_1": "HGG", 52 | "BraTS19_2013_0_1": "LGG", 53 | "BraTS19_2013_6_1": "LGG", 54 | "BraTS19_2013_16_1": "LGG", 55 | "BraTS19_2013_15_1": "LGG", 56 | "BraTS19_2013_8_1": "LGG", 57 | "BraTS19_2013_24_1": "LGG", 58 | "BraTS19_2013_29_1": "LGG", 59 | "BraTS19_2013_9_1": "LGG", 60 | "BraTS19_TMC_09043_1": "LGG", 61 | "BraTS19_TCIA09_451_1": "LGG", 62 | "BraTS19_TCIA09_254_1": "LGG", 63 | "BraTS19_TCIA09_255_1": "LGG", 64 | "BraTS19_TCIA09_141_1": "LGG", 65 | "BraTS19_TCIA09_312_1": "LGG", 66 | "BraTS19_TCIA09_620_1": "LGG", 67 | "BraTS19_TCIA09_462_1": "LGG", 68 | "BraTS19_TCIA10_152_1": "LGG", 69 | "BraTS19_TCIA10_299_1": "LGG", 70 | "BraTS19_TCIA10_393_1": "LGG", 71 | "BraTS19_TCIA10_266_1": "LGG", 72 | "BraTS19_TCIA10_449_1": "LGG", 73 | "BraTS19_TCIA10_408_1": "LGG", 74 | "BraTS19_TCIA10_109_1": "LGG", 75 | "BraTS19_TCIA10_330_1": "LGG", 76 | "BraTS19_TCIA10_307_1": "LGG", 77 | "BraTS19_TCIA10_130_1": "LGG", 78 | "BraTS19_TCIA10_413_1": "LGG", 79 | "BraTS19_TCIA10_276_1": "LGG", 80 | "BraTS19_TCIA10_282_1": "LGG", 81 | "BraTS19_TCIA10_442_1": "LGG", 82 | "BraTS19_TCIA10_420_1": "LGG", 83 | "BraTS19_TCIA10_310_1": "LGG", 84 | "BraTS19_TCIA10_325_1": "LGG", 85 | "BraTS19_TCIA10_261_1": "LGG", 86 | "BraTS19_TCIA10_640_1": "LGG", 87 | "BraTS19_TCIA10_639_1": "LGG", 88 | "BraTS19_TCIA10_629_1": "LGG", 89 | "BraTS19_TCIA12_470_1": "LGG", 90 | "BraTS19_TCIA12_101_1": "LGG", 91 | "BraTS19_TCIA12_249_1": "LGG", 92 | "BraTS19_TCIA13_623_1": "LGG", 93 | "BraTS19_TCIA13_650_1": "LGG", 94 | "BraTS19_TCIA13_654_1": "LGG", 95 | "BraTS19_TCIA13_633_1": "LGG", 96 | "BraTS19_TCIA13_645_1": "LGG", 97 | "BraTS19_TCIA13_630_1": "LGG", 98 | "BraTS19_TCIA13_621_1": "LGG", 99 | "BraTS19_TCIA13_642_1": "LGG", 100 | "BraTS19_TCIA13_624_1": "LGG", 101 | "BraTS19_TCIA13_634_1": "LGG" 102 | } 103 | -------------------------------------------------------------------------------- /dataset_splits/Dataset514_MVSeg23/domain_mapping_00.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_044": "ID", 3 | "train_079": "ID", 4 | "train_053": "ID", 5 | "train_088": "ID", 6 | "val_014": "ID", 7 | "train_004": "ID", 8 | "train_039": "ID", 9 | "train_025": "ID", 10 | "val_023": "ID", 11 | "val_021": "ID", 12 | "train_064": "ID", 13 | "train_082": "ID", 14 | "train_032": "ID", 15 | "train_078": "ID", 16 | "train_048": "ID", 17 | "train_066": "ID", 18 | "train_034": "ID", 19 | "train_045": "ID", 20 | "train_086": "ID", 21 | "train_063": "ID", 22 | "train_060": "ID", 23 | "train_099": "ID", 24 | "val_012": "ID", 25 | "train_089": "ID", 26 | "train_014": "ID", 27 | "train_084": "ID", 28 | "train_100": "ID", 29 | "train_040": "ID", 30 | "val_030": "ID", 31 | "val_020": "ID", 32 | "train_012": "ID", 33 | "train_027": "ID", 34 | "train_057": "ID", 35 | "train_098": "ID", 36 | "train_065": "ID", 37 | "train_067": "ID", 38 | "train_035": "ID", 39 | "train_073": "ID", 40 | "train_020": "ID", 41 | "train_041": "ID", 42 | "val_007": "ID", 43 | "val_008": "ID", 44 | "train_030": "ID", 45 | "val_016": "ID", 46 | "train_093": "ID", 47 | "val_015": "ID", 48 | "train_015": "ID", 49 | "train_092": "ID", 50 | "train_102": "ID", 51 | "train_031": "ID", 52 | "train_026": "ID", 53 | "train_008": "ID", 54 | "train_077": "ID", 55 | "train_070": "ID", 56 | "val_009": "ID" 57 | } 58 | -------------------------------------------------------------------------------- /dataset_splits/Dataset515_KiTS23/domain_mapping_00.json: -------------------------------------------------------------------------------- 1 | { 2 | "case_00566": "ID", 3 | "case_00190": "ID", 4 | "case_00413": "ID", 5 | "case_00137": "ID", 6 | "case_00565": "ID", 7 | "case_00174": "ID", 8 | "case_00063": "ID", 9 | "case_00255": "ID", 10 | "case_00125": "ID", 11 | "case_00197": "ID", 12 | "case_00163": "ID", 13 | "case_00170": "ID", 14 | "case_00537": "ID", 15 | "case_00153": "ID", 16 | "case_00184": "ID", 17 | "case_00497": "ID", 18 | "case_00525": "ID", 19 | "case_00447": "ID", 20 | "case_00058": "ID", 21 | "case_00576": "ID", 22 | "case_00270": "ID", 23 | "case_00448": "ID", 24 | "case_00264": "ID", 25 | "case_00512": "ID", 26 | "case_00483": "ID", 27 | "case_00403": "ID", 28 | "case_00051": "ID", 29 | "case_00194": "ID", 30 | "case_00457": "ID", 31 | "case_00424": "ID", 32 | "case_00579": "ID", 33 | "case_00456": "ID", 34 | "case_00498": "ID", 35 | "case_00446": "ID", 36 | "case_00294": "ID", 37 | "case_00031": "ID", 38 | "case_00441": "ID", 39 | "case_00425": "ID", 40 | "case_00097": "ID", 41 | "case_00575": "ID", 42 | "case_00234": "ID", 43 | "case_00087": "ID", 44 | "case_00064": "ID", 45 | "case_00204": "ID", 46 | "case_00104": "ID", 47 | "case_00009": "ID", 48 | "case_00117": "ID", 49 | "case_00442": "ID", 50 | "case_00521": "ID", 51 | "case_00121": "ID", 52 | "case_00287": "ID", 53 | "case_00196": "ID", 54 | "case_00459": "ID", 55 | "case_00211": "ID", 56 | "case_00266": "ID", 57 | "case_00022": "ID", 58 | "case_00573": "ID", 59 | "case_00143": "ID", 60 | "case_00469": "ID", 61 | "case_00148": "ID", 62 | "case_00527": "ID", 63 | "case_00154": "ID", 64 | "case_00405": "ID", 65 | "case_00123": "ID", 66 | "case_00508": "ID", 67 | "case_00438": "ID", 68 | "case_00076": "ID", 69 | "case_00096": "ID", 70 | "case_00219": "ID", 71 | "case_00149": "ID", 72 | "case_00428": "ID", 73 | "case_00437": "ID", 74 | "case_00001": "ID", 75 | "case_00550": "ID", 76 | "case_00260": "ID", 77 | "case_00012": "ID", 78 | "case_00435": "ID", 79 | "case_00247": "ID", 80 | "case_00103": "ID", 81 | "case_00275": "ID", 82 | "case_00066": "ID", 83 | "case_00056": "ID", 84 | "case_00155": "ID", 85 | "case_00499": "ID", 86 | "case_00161": "ID", 87 | "case_00280": "ID", 88 | "case_00585": "ID", 89 | "case_00167": "ID", 90 | "case_00233": "ID", 91 | "case_00408": "ID", 92 | "case_00470": "ID", 93 | "case_00250": "ID", 94 | "case_00244": "ID", 95 | "case_00182": "ID", 96 | "case_00549": "ID", 97 | "case_00222": "ID", 98 | "case_00050": "ID", 99 | "case_00151": "ID", 100 | "case_00000": "ID", 101 | "case_00491": "ID", 102 | "case_00246": "ID", 103 | "case_00503": "ID", 104 | "case_00502": "ID", 105 | "case_00493": "ID", 106 | "case_00518": "ID", 107 | "case_00559": "ID", 108 | "case_00431": "ID", 109 | "case_00242": "ID", 110 | "case_00540": "ID", 111 | "case_00042": "ID", 112 | "case_00047": "ID", 113 | "case_00513": "ID", 114 | "case_00003": "ID", 115 | "case_00093": "ID", 116 | "case_00078": "ID", 117 | "case_00173": "ID", 118 | "case_00422": "ID", 119 | "case_00231": "ID", 120 | "case_00400": "ID", 121 | "case_00463": "ID", 122 | "case_00189": "ID", 123 | "case_00243": "ID" 124 | } 125 | -------------------------------------------------------------------------------- /dataset_splits/Dataset520_CovidLungCT/domain_mapping_00.json: -------------------------------------------------------------------------------- 1 | { 2 | "challenge_volume-covid19-A-0329": "ID", 3 | "challenge_volume-covid19-A-0247": "ID", 4 | "challenge_volume-covid19-A-0498": "ID", 5 | "challenge_volume-covid19-A-0576": "ID", 6 | "coronacases_007": "radiopaedia", 7 | "mosmed_study_0256": "mosmed", 8 | "coronacases_008": "radiopaedia", 9 | "mosmed_study_0277": "mosmed", 10 | "mosmed_study_0275": "mosmed", 11 | "mosmed_study_0263": "mosmed", 12 | "mosmed_study_0281": "mosmed", 13 | "mosmed_study_0259": "mosmed", 14 | "challenge_volume-covid19-A-0199": "ID", 15 | "challenge_volume-covid19-A-0599": "ID", 16 | "mosmed_study_0267": "mosmed", 17 | "radiopaedia_40_86625_0": "radiopaedia", 18 | "mosmed_study_0270": "mosmed", 19 | "mosmed_study_0303": "mosmed", 20 | "mosmed_study_0288": "mosmed", 21 | "radiopaedia_7_85703_0": "radiopaedia", 22 | "challenge_volume-covid19-A-0585": "ID", 23 | "mosmed_study_0287": "mosmed", 24 | "coronacases_003": "radiopaedia", 25 | "challenge_volume-covid19-A-0388": "ID", 26 | "challenge_volume-covid19-A-0665": "ID", 27 | "mosmed_study_0289": "mosmed", 28 | "mosmed_study_0273": "mosmed", 29 | "coronacases_006": "radiopaedia", 30 | "mosmed_study_0300": "mosmed", 31 | "mosmed_study_0278": "mosmed", 32 | "challenge_volume-covid19-A-0522": "ID", 33 | "radiopaedia_4_85506_1": "radiopaedia", 34 | "challenge_volume-covid19-A-0263": "ID", 35 | "mosmed_study_0291": "mosmed", 36 | "radiopaedia_29_86491_1": "radiopaedia", 37 | "mosmed_study_0302": "mosmed", 38 | "challenge_volume-covid19-A-0473": "ID", 39 | "mosmed_study_0274": "mosmed", 40 | "challenge_volume-covid19-A-0636": "ID", 41 | "coronacases_009": "radiopaedia", 42 | "mosmed_study_0285": "mosmed", 43 | "challenge_volume-covid19-A-0096": "ID", 44 | "coronacases_002": "radiopaedia", 45 | "radiopaedia_29_86490_1": "radiopaedia", 46 | "mosmed_study_0297": "mosmed", 47 | "challenge_volume-covid19-A-0034": "ID", 48 | "mosmed_study_0266": "mosmed", 49 | "radiopaedia_36_86526_0": "radiopaedia", 50 | "mosmed_study_0279": "mosmed", 51 | "mosmed_study_0276": "mosmed", 52 | "challenge_volume-covid19-A-0307": "ID", 53 | "mosmed_study_0271": "mosmed", 54 | "mosmed_study_0255": "mosmed", 55 | "mosmed_study_0292": "mosmed", 56 | "coronacases_001": "radiopaedia", 57 | "mosmed_study_0264": "mosmed", 58 | "mosmed_study_0261": "mosmed", 59 | "challenge_volume-covid19-A-0164": "ID", 60 | "challenge_volume-covid19-A-0167_1": "ID", 61 | "radiopaedia_10_85902_3": "radiopaedia", 62 | "challenge_volume-covid19-A-0355": "ID", 63 | "mosmed_study_0301": "mosmed", 64 | "challenge_volume-covid19-A-0187": "ID", 65 | "challenge_volume-covid19-A-0360": "ID", 66 | "coronacases_004": "radiopaedia", 67 | "challenge_volume-covid19-A-0626": "ID", 68 | "challenge_volume-covid19-A-0031": "ID", 69 | "coronacases_010": "radiopaedia", 70 | "challenge_volume-covid19-A-0070": "ID", 71 | "coronacases_005": "radiopaedia", 72 | "mosmed_study_0280": "mosmed", 73 | "mosmed_study_0257": "mosmed", 74 | "challenge_volume-covid19-A-0112": "ID", 75 | "mosmed_study_0284": "mosmed", 76 | "mosmed_study_0262": "mosmed", 77 | "challenge_volume-covid19-A-0267": "ID", 78 | "challenge_volume-covid19-A-0627": "ID", 79 | "radiopaedia_14_85914_0": "radiopaedia", 80 | "mosmed_study_0294": "mosmed", 81 | "challenge_volume-covid19-A-0463": "ID", 82 | "radiopaedia_10_85902_1": "radiopaedia", 83 | "mosmed_study_0286": "mosmed", 84 | "mosmed_study_0260": "mosmed", 85 | "challenge_volume-covid19-A-0114": "ID", 86 | "mosmed_study_0299": "mosmed", 87 | "mosmed_study_0268": "mosmed", 88 | "mosmed_study_0296": "mosmed", 89 | "mosmed_study_0293": "mosmed", 90 | "challenge_volume-covid19-A-0623": "ID", 91 | "challenge_volume-covid19-A-0319": "ID", 92 | "mosmed_study_0269": "mosmed", 93 | "challenge_volume-covid19-A-0251": "ID", 94 | "challenge_volume-covid19-A-0567": "ID", 95 | "radiopaedia_27_86410_0": "radiopaedia", 96 | "mosmed_study_0304": "mosmed", 97 | "challenge_volume-covid19-A-0421": "ID", 98 | "challenge_volume-covid19-A-0320": "ID", 99 | "mosmed_study_0265": "mosmed", 100 | "mosmed_study_0272": "mosmed", 101 | "mosmed_study_0295": "mosmed", 102 | "mosmed_study_0282": "mosmed", 103 | "challenge_volume-covid19-A-0046": "ID", 104 | "mosmed_study_0298": "mosmed", 105 | "mosmed_study_0290": "mosmed", 106 | "mosmed_study_0283": "mosmed", 107 | "challenge_volume-covid19-A-0537": "ID", 108 | "challenge_volume-covid19-A-0237": "ID", 109 | "challenge_volume-covid19-A-0366": "ID", 110 | "mosmed_study_0258": "mosmed" 111 | } 112 | -------------------------------------------------------------------------------- /dataset_splits/Dataset521_ProstateGonzalez/domain_mapping_00.json: -------------------------------------------------------------------------------- 1 | { 2 | "I2CVB_Case14": "I2CVB", 3 | "BMC_Case05": "BMC", 4 | "UCL_Case28": "UCL", 5 | "I2CVB_Case12": "I2CVB", 6 | "HK_Case46": "HK", 7 | "BMC_Case01": "BMC", 8 | "ID_prostate_16": "ID", 9 | "I2CVB_Case17": "I2CVB", 10 | "UCL_Case26": "UCL", 11 | "BMC_Case28": "BMC", 12 | "I2CVB_Case13": "I2CVB", 13 | "HK_Case48": "HK", 14 | "ID_prostate_07": "ID", 15 | "I2CVB_Case02": "I2CVB", 16 | "BIDMC_Case09": "BIDMC", 17 | "I2CVB_Case18": "I2CVB", 18 | "BIDMC_Case05": "BIDMC", 19 | "I2CVB_Case00": "I2CVB", 20 | "UCL_Case01": "UCL", 21 | "BIDMC_Case11": "BIDMC", 22 | "BMC_Case13": "BMC", 23 | "BIDMC_Case08": "BIDMC", 24 | "BMC_Case24": "BMC", 25 | "UCL_Case35": "UCL", 26 | "ID_prostate_40": "ID", 27 | "UCL_Case37": "UCL", 28 | "HK_Case44": "HK", 29 | "HK_Case45": "HK", 30 | "HK_Case47": "HK", 31 | "HK_Case49": "HK", 32 | "ID_prostate_02": "ID", 33 | "I2CVB_Case04": "I2CVB", 34 | "UCL_Case30": "UCL", 35 | "BIDMC_Case06": "BIDMC", 36 | "HK_Case42": "HK", 37 | "I2CVB_Case07": "I2CVB", 38 | "BIDMC_Case04": "BIDMC", 39 | "UCL_Case36": "UCL", 40 | "BMC_Case23": "BMC", 41 | "BMC_Case20": "BMC", 42 | "BMC_Case00": "BMC", 43 | "BIDMC_Case07": "BIDMC", 44 | "BMC_Case17": "BMC", 45 | "BMC_Case10": "BMC", 46 | "BMC_Case29": "BMC", 47 | "BMC_Case14": "BMC", 48 | "UCL_Case33": "UCL", 49 | "BMC_Case02": "BMC", 50 | "HK_Case38": "HK", 51 | "BMC_Case22": "BMC", 52 | "HK_Case40": "HK", 53 | "I2CVB_Case11": "I2CVB", 54 | "BMC_Case08": "BMC", 55 | "BMC_Case21": "BMC", 56 | "BIDMC_Case03": "BIDMC", 57 | "HK_Case39": "HK", 58 | "BIDMC_Case12": "BIDMC", 59 | "UCL_Case27": "UCL", 60 | "HK_Case41": "HK", 61 | "I2CVB_Case01": "I2CVB", 62 | "BMC_Case27": "BMC", 63 | "BMC_Case26": "BMC", 64 | "BIDMC_Case02": "BIDMC", 65 | "UCL_Case31": "UCL", 66 | "BMC_Case04": "BMC", 67 | "BIDMC_Case00": "BIDMC", 68 | "I2CVB_Case06": "I2CVB", 69 | "I2CVB_Case05": "I2CVB", 70 | "I2CVB_Case09": "I2CVB", 71 | "BMC_Case18": "BMC", 72 | "I2CVB_Case16": "I2CVB", 73 | "BIDMC_Case10": "BIDMC", 74 | "I2CVB_Case15": "I2CVB", 75 | "ID_prostate_35": "ID", 76 | "BMC_Case11": "BMC", 77 | "BMC_Case03": "BMC", 78 | "UCL_Case32": "UCL", 79 | "UCL_Case34": "UCL", 80 | "BMC_Case25": "BMC", 81 | "HK_Case43": "HK", 82 | "I2CVB_Case03": "I2CVB", 83 | "I2CVB_Case08": "I2CVB", 84 | "BMC_Case06": "BMC", 85 | "BMC_Case12": "BMC", 86 | "ID_prostate_00": "ID", 87 | "BMC_Case16": "BMC", 88 | "BMC_Case09": "BMC", 89 | "I2CVB_Case10": "I2CVB", 90 | "BMC_Case15": "BMC", 91 | "BMC_Case19": "BMC", 92 | "UCL_Case29": "UCL", 93 | "BMC_Case07": "BMC" 94 | } 95 | -------------------------------------------------------------------------------- /dataset_splits/Dataset521_ProstateGonzalez/splits_final.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "train": [ 4 | "prostate_01", 5 | "prostate_04", 6 | "prostate_06", 7 | "prostate_13", 8 | "prostate_14", 9 | "prostate_17", 10 | "prostate_18", 11 | "prostate_21", 12 | "prostate_24", 13 | "prostate_25", 14 | "prostate_28", 15 | "prostate_31", 16 | "prostate_32", 17 | "prostate_37", 18 | "prostate_38", 19 | "prostate_39", 20 | "prostate_42", 21 | "prostate_44", 22 | "prostate_46", 23 | "prostate_47" 24 | ], 25 | "val": [ 26 | "prostate_10", 27 | "prostate_20", 28 | "prostate_29", 29 | "prostate_34", 30 | "prostate_41", 31 | "prostate_43" 32 | ] 33 | }, 34 | { 35 | "train": [ 36 | "prostate_01", 37 | "prostate_04", 38 | "prostate_06", 39 | "prostate_10", 40 | "prostate_13", 41 | "prostate_14", 42 | "prostate_17", 43 | "prostate_18", 44 | "prostate_20", 45 | "prostate_21", 46 | "prostate_25", 47 | "prostate_29", 48 | "prostate_31", 49 | "prostate_34", 50 | "prostate_37", 51 | "prostate_41", 52 | "prostate_42", 53 | "prostate_43", 54 | "prostate_44", 55 | "prostate_46", 56 | "prostate_47" 57 | ], 58 | "val": [ 59 | "prostate_24", 60 | "prostate_28", 61 | "prostate_32", 62 | "prostate_38", 63 | "prostate_39" 64 | ] 65 | }, 66 | { 67 | "train": [ 68 | "prostate_04", 69 | "prostate_06", 70 | "prostate_10", 71 | "prostate_13", 72 | "prostate_14", 73 | "prostate_20", 74 | "prostate_21", 75 | "prostate_24", 76 | "prostate_28", 77 | "prostate_29", 78 | "prostate_31", 79 | "prostate_32", 80 | "prostate_34", 81 | "prostate_37", 82 | "prostate_38", 83 | "prostate_39", 84 | "prostate_41", 85 | "prostate_43", 86 | "prostate_44", 87 | "prostate_46", 88 | "prostate_47" 89 | ], 90 | "val": [ 91 | "prostate_01", 92 | "prostate_17", 93 | "prostate_18", 94 | "prostate_25", 95 | "prostate_42" 96 | ] 97 | }, 98 | { 99 | "train": [ 100 | "prostate_01", 101 | "prostate_04", 102 | "prostate_06", 103 | "prostate_10", 104 | "prostate_13", 105 | "prostate_14", 106 | "prostate_17", 107 | "prostate_18", 108 | "prostate_20", 109 | "prostate_21", 110 | "prostate_24", 111 | "prostate_25", 112 | "prostate_28", 113 | "prostate_29", 114 | "prostate_32", 115 | "prostate_34", 116 | "prostate_38", 117 | "prostate_39", 118 | "prostate_41", 119 | "prostate_42", 120 | "prostate_43" 121 | ], 122 | "val": [ 123 | "prostate_31", 124 | "prostate_37", 125 | "prostate_44", 126 | "prostate_46", 127 | "prostate_47" 128 | ] 129 | }, 130 | { 131 | "train": [ 132 | "prostate_01", 133 | "prostate_10", 134 | "prostate_17", 135 | "prostate_18", 136 | "prostate_20", 137 | "prostate_24", 138 | "prostate_25", 139 | "prostate_28", 140 | "prostate_29", 141 | "prostate_31", 142 | "prostate_32", 143 | "prostate_34", 144 | "prostate_37", 145 | "prostate_38", 146 | "prostate_39", 147 | "prostate_41", 148 | "prostate_42", 149 | "prostate_43", 150 | "prostate_44", 151 | "prostate_46", 152 | "prostate_47" 153 | ], 154 | "val": [ 155 | "prostate_04", 156 | "prostate_06", 157 | "prostate_13", 158 | "prostate_14", 159 | "prostate_21" 160 | ] 161 | } 162 | ] 163 | -------------------------------------------------------------------------------- /dataset_splits/Dataset540_RETOUCH_ood=Cirrus/domain_mapping_00.json: -------------------------------------------------------------------------------- 1 | { 2 | "TRAIN014": "Cirrus", 3 | "TRAIN015": "Cirrus", 4 | "TRAIN005": "Cirrus", 5 | "TRAIN002": "Cirrus", 6 | "TRAIN010": "Cirrus", 7 | "TRAIN001": "Cirrus", 8 | "TRAIN022": "Cirrus", 9 | "TRAIN016": "Cirrus", 10 | "TRAIN017": "Cirrus", 11 | "TRAIN021": "Cirrus", 12 | "TRAIN018": "Cirrus", 13 | "TRAIN012": "Cirrus", 14 | "TRAIN009": "Cirrus", 15 | "TRAIN003": "Cirrus", 16 | "TRAIN008": "Cirrus", 17 | "TRAIN006": "Cirrus", 18 | "TRAIN019": "Cirrus", 19 | "TRAIN007": "Cirrus", 20 | "TRAIN023": "Cirrus", 21 | "TRAIN004": "Cirrus", 22 | "TRAIN011": "Cirrus", 23 | "TRAIN020": "Cirrus", 24 | "TRAIN024": "Cirrus", 25 | "TRAIN013": "Cirrus", 26 | "TRAIN045": "Spectralis", 27 | "TRAIN043": "Spectralis", 28 | "TRAIN032": "Spectralis", 29 | "TRAIN035": "Spectralis", 30 | "TRAIN040": "Spectralis", 31 | "TRAIN039": "Spectralis", 32 | "TRAIN063": "Topcon", 33 | "TRAIN049": "Topcon", 34 | "TRAIN068": "Topcon", 35 | "TRAIN056": "Topcon", 36 | "TRAIN058": "Topcon", 37 | "TRAIN051": "Topcon" 38 | } 39 | -------------------------------------------------------------------------------- /dataset_splits/Dataset540_RETOUCH_ood=Cirrus/splits_final.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "train": [ 4 | "TRAIN026", 5 | "TRAIN027", 6 | "TRAIN029", 7 | "TRAIN030", 8 | "TRAIN031", 9 | "TRAIN033", 10 | "TRAIN034", 11 | "TRAIN036", 12 | "TRAIN037", 13 | "TRAIN038", 14 | "TRAIN041", 15 | "TRAIN042", 16 | "TRAIN044", 17 | "TRAIN047", 18 | "TRAIN048", 19 | "TRAIN050", 20 | "TRAIN052", 21 | "TRAIN053", 22 | "TRAIN054", 23 | "TRAIN055", 24 | "TRAIN057", 25 | "TRAIN061", 26 | "TRAIN062", 27 | "TRAIN064", 28 | "TRAIN065", 29 | "TRAIN066", 30 | "TRAIN069" 31 | ], 32 | "val": [ 33 | "TRAIN025", 34 | "TRAIN028", 35 | "TRAIN046", 36 | "TRAIN059", 37 | "TRAIN060", 38 | "TRAIN067", 39 | "TRAIN070" 40 | ] 41 | }, 42 | { 43 | "train": [ 44 | "TRAIN025", 45 | "TRAIN026", 46 | "TRAIN027", 47 | "TRAIN028", 48 | "TRAIN029", 49 | "TRAIN030", 50 | "TRAIN036", 51 | "TRAIN037", 52 | "TRAIN038", 53 | "TRAIN042", 54 | "TRAIN044", 55 | "TRAIN046", 56 | "TRAIN047", 57 | "TRAIN048", 58 | "TRAIN050", 59 | "TRAIN055", 60 | "TRAIN057", 61 | "TRAIN059", 62 | "TRAIN060", 63 | "TRAIN061", 64 | "TRAIN062", 65 | "TRAIN064", 66 | "TRAIN065", 67 | "TRAIN066", 68 | "TRAIN067", 69 | "TRAIN069", 70 | "TRAIN070" 71 | ], 72 | "val": [ 73 | "TRAIN031", 74 | "TRAIN033", 75 | "TRAIN034", 76 | "TRAIN041", 77 | "TRAIN052", 78 | "TRAIN053", 79 | "TRAIN054" 80 | ] 81 | }, 82 | { 83 | "train": [ 84 | "TRAIN025", 85 | "TRAIN026", 86 | "TRAIN027", 87 | "TRAIN028", 88 | "TRAIN029", 89 | "TRAIN030", 90 | "TRAIN031", 91 | "TRAIN033", 92 | "TRAIN034", 93 | "TRAIN036", 94 | "TRAIN041", 95 | "TRAIN044", 96 | "TRAIN046", 97 | "TRAIN047", 98 | "TRAIN048", 99 | "TRAIN052", 100 | "TRAIN053", 101 | "TRAIN054", 102 | "TRAIN055", 103 | "TRAIN059", 104 | "TRAIN060", 105 | "TRAIN061", 106 | "TRAIN065", 107 | "TRAIN066", 108 | "TRAIN067", 109 | "TRAIN069", 110 | "TRAIN070" 111 | ], 112 | "val": [ 113 | "TRAIN037", 114 | "TRAIN038", 115 | "TRAIN042", 116 | "TRAIN050", 117 | "TRAIN057", 118 | "TRAIN062", 119 | "TRAIN064" 120 | ] 121 | }, 122 | { 123 | "train": [ 124 | "TRAIN025", 125 | "TRAIN026", 126 | "TRAIN027", 127 | "TRAIN028", 128 | "TRAIN029", 129 | "TRAIN030", 130 | "TRAIN031", 131 | "TRAIN033", 132 | "TRAIN034", 133 | "TRAIN036", 134 | "TRAIN037", 135 | "TRAIN038", 136 | "TRAIN041", 137 | "TRAIN042", 138 | "TRAIN046", 139 | "TRAIN050", 140 | "TRAIN052", 141 | "TRAIN053", 142 | "TRAIN054", 143 | "TRAIN057", 144 | "TRAIN059", 145 | "TRAIN060", 146 | "TRAIN062", 147 | "TRAIN064", 148 | "TRAIN065", 149 | "TRAIN067", 150 | "TRAIN070" 151 | ], 152 | "val": [ 153 | "TRAIN044", 154 | "TRAIN047", 155 | "TRAIN048", 156 | "TRAIN055", 157 | "TRAIN061", 158 | "TRAIN066", 159 | "TRAIN069" 160 | ] 161 | }, 162 | { 163 | "train": [ 164 | "TRAIN025", 165 | "TRAIN028", 166 | "TRAIN031", 167 | "TRAIN033", 168 | "TRAIN034", 169 | "TRAIN037", 170 | "TRAIN038", 171 | "TRAIN041", 172 | "TRAIN042", 173 | "TRAIN044", 174 | "TRAIN046", 175 | "TRAIN047", 176 | "TRAIN048", 177 | "TRAIN050", 178 | "TRAIN052", 179 | "TRAIN053", 180 | "TRAIN054", 181 | "TRAIN055", 182 | "TRAIN057", 183 | "TRAIN059", 184 | "TRAIN060", 185 | "TRAIN061", 186 | "TRAIN062", 187 | "TRAIN064", 188 | "TRAIN066", 189 | "TRAIN067", 190 | "TRAIN069", 191 | "TRAIN070" 192 | ], 193 | "val": [ 194 | "TRAIN026", 195 | "TRAIN027", 196 | "TRAIN029", 197 | "TRAIN030", 198 | "TRAIN036", 199 | "TRAIN065" 200 | ] 201 | } 202 | ] 203 | -------------------------------------------------------------------------------- /dataset_splits/Dataset560_OCTA500_pathology_split/domain_mapping_00.json: -------------------------------------------------------------------------------- 1 | { 2 | "10065": "OTHERS", 3 | "10226": "OTHERS", 4 | "10150": "NORMAL", 5 | "10145": "NORMAL", 6 | "10248": "AMD", 7 | "10180": "OTHERS", 8 | "10064": "OTHERS", 9 | "10225": "OTHERS", 10 | "10008": "CNV", 11 | "10011": "OTHERS", 12 | "10212": "OTHERS", 13 | "10070": "CSC", 14 | "10235": "DR", 15 | "10162": "OTHERS", 16 | "10144": "OTHERS", 17 | "10219": "OTHERS", 18 | "10241": "AMD", 19 | "10002": "OTHERS", 20 | "10097": "OTHERS", 21 | "10289": "RVO", 22 | "10142": "OTHERS", 23 | "10297": "DR", 24 | "10090": "OTHERS", 25 | "10239": "CSC", 26 | "10027": "RVO", 27 | "10269": "OTHERS", 28 | "10181": "DR", 29 | "10109": "OTHERS", 30 | "10063": "DR", 31 | "10135": "DR", 32 | "10223": "OTHERS", 33 | "10098": "CSC", 34 | "10216": "DR", 35 | "10093": "OTHERS", 36 | "10292": "OTHERS", 37 | "10132": "OTHERS", 38 | "10156": "NORMAL", 39 | "10165": "OTHERS", 40 | "10116": "OTHERS", 41 | "10028": "OTHERS", 42 | "10129": "OTHERS", 43 | "10249": "NORMAL", 44 | "10024": "DR", 45 | "10005": "DR", 46 | "10074": "RVO", 47 | "10277": "OTHERS", 48 | "10147": "CSC", 49 | "10151": "DR", 50 | "10158": "RVO", 51 | "10054": "DR", 52 | "10007": "NORMAL", 53 | "10189": "DR", 54 | "10190": "OTHERS", 55 | "10266": "OTHERS", 56 | "10202": "OTHERS", 57 | "10133": "OTHERS", 58 | "10119": "OTHERS", 59 | "10174": "NORMAL", 60 | "10081": "OTHERS", 61 | "10214": "NORMAL", 62 | "10293": "CSC", 63 | "10230": "OTHERS", 64 | "10032": "DR", 65 | "10242": "NORMAL", 66 | "10175": "OTHERS", 67 | "10294": "RVO", 68 | "10192": "OTHERS", 69 | "10006": "OTHERS", 70 | "10288": "AMD", 71 | "10059": "OTHERS", 72 | "10264": "CSC", 73 | "10203": "DR", 74 | "10157": "DR", 75 | "10131": "OTHERS", 76 | "10069": "OTHERS", 77 | "10258": "OTHERS", 78 | "10136": "OTHERS", 79 | "10184": "DR", 80 | "10155": "DR", 81 | "10208": "OTHERS", 82 | "10300": "DR", 83 | "10218": "OTHERS", 84 | "10099": "OTHERS", 85 | "10167": "NORMAL", 86 | "10094": "CSC", 87 | "10204": "OTHERS", 88 | "10075": "CSC", 89 | "10057": "OTHERS", 90 | "10013": "NORMAL", 91 | "10272": "NORMAL", 92 | "10056": "DR", 93 | "10209": "OTHERS", 94 | "10023": "OTHERS", 95 | "10103": "OTHERS", 96 | "10247": "OTHERS", 97 | "10073": "OTHERS", 98 | "10275": "OTHERS", 99 | "10274": "CNV", 100 | "10003": "OTHERS", 101 | "10077": "DR", 102 | "10051": "OTHERS", 103 | "10280": "CSC", 104 | "10260": "CSC", 105 | "10186": "OTHERS", 106 | "10195": "OTHERS", 107 | "10089": "OTHERS", 108 | "10221": "DR", 109 | "10298": "OTHERS", 110 | "10263": "DR", 111 | "10296": "OTHERS", 112 | "10062": "OTHERS", 113 | "10199": "OTHERS", 114 | "10236": "OTHERS", 115 | "10206": "CSC", 116 | "10066": "DR", 117 | "10205": "RVO", 118 | "10282": "NORMAL", 119 | "10092": "RVO", 120 | "10281": "DR", 121 | "10177": "OTHERS", 122 | "10233": "OTHERS", 123 | "10193": "OTHERS", 124 | "10210": "OTHERS", 125 | "10091": "DR", 126 | "10211": "OTHERS", 127 | "10169": "RVO", 128 | "10068": "DR", 129 | "10245": "AMD", 130 | "10111": "DR", 131 | "10041": "AMD", 132 | "10262": "DR", 133 | "10250": "OTHERS", 134 | "10217": "NORMAL", 135 | "10130": "CSC", 136 | "10238": "OTHERS", 137 | "10016": "NORMAL", 138 | "10139": "NORMAL", 139 | "10198": "OTHERS", 140 | "10030": "OTHERS", 141 | "10148": "RVO", 142 | "10244": "AMD", 143 | "10038": "NORMAL", 144 | "10286": "DR", 145 | "10265": "OTHERS", 146 | "10237": "OTHERS", 147 | "10113": "OTHERS", 148 | "10259": "OTHERS", 149 | "10050": "DR", 150 | "10022": "OTHERS", 151 | "10040": "DR", 152 | "10196": "DR", 153 | "10106": "OTHERS", 154 | "10046": "CSC", 155 | "10253": "CSC", 156 | "10020": "OTHERS", 157 | "10019": "OTHERS", 158 | "10240": "OTHERS", 159 | "10083": "NORMAL", 160 | "10012": "OTHERS", 161 | "10187": "OTHERS", 162 | "10114": "AMD", 163 | "10261": "DR", 164 | "10015": "OTHERS", 165 | "10299": "RVO", 166 | "10105": "OTHERS", 167 | "10295": "OTHERS", 168 | "10039": "NORMAL", 169 | "10164": "OTHERS", 170 | "10268": "DR", 171 | "10086": "DR", 172 | "10256": "OTHERS", 173 | "10072": "OTHERS", 174 | "10125": "OTHERS", 175 | "10067": "OTHERS", 176 | "10031": "OTHERS", 177 | "10152": "DR", 178 | "10128": "OTHERS", 179 | "10138": "AMD", 180 | "10134": "OTHERS", 181 | "10166": "OTHERS", 182 | "10088": "OTHERS", 183 | "10033": "OTHERS", 184 | "10182": "OTHERS", 185 | "10365": "DR", 186 | "10491": "NORMAL", 187 | "10383": "NORMAL", 188 | "10454": "DR", 189 | "10353": "DR", 190 | "10485": "NORMAL", 191 | "10345": "DR", 192 | "10311": "DR", 193 | "10479": "DR", 194 | "10391": "DR", 195 | "10388": "NORMAL", 196 | "10380": "DR", 197 | "10446": "NORMAL", 198 | "10436": "DR", 199 | "10374": "NORMAL", 200 | "10373": "DR", 201 | "10421": "DR", 202 | "10306": "DR", 203 | "10499": "DR", 204 | "10348": "NORMAL", 205 | "10349": "DR", 206 | "10317": "DR", 207 | "10439": "CNV", 208 | "10339": "NORMAL", 209 | "10315": "DR", 210 | "10451": "NORMAL", 211 | "10426": "NORMAL", 212 | "10366": "NORMAL", 213 | "10304": "DR", 214 | "10307": "NORMAL", 215 | "10389": "NORMAL", 216 | "10408": "NORMAL", 217 | "10490": "NORMAL", 218 | "10464": "DR", 219 | "10319": "NORMAL", 220 | "10488": "NORMAL", 221 | "10456": "AMD", 222 | "10316": "NORMAL", 223 | "10346": "NORMAL", 224 | "10447": "DR", 225 | "10335": "DR", 226 | "10432": "DR", 227 | "10414": "DR", 228 | "10406": "NORMAL", 229 | "10430": "NORMAL", 230 | "10379": "DR", 231 | "10423": "NORMAL", 232 | "10442": "NORMAL", 233 | "10497": "NORMAL", 234 | "10433": "DR", 235 | "10375": "NORMAL", 236 | "10312": "NORMAL", 237 | "10324": "DR", 238 | "10364": "DR", 239 | "10489": "DR", 240 | "10337": "NORMAL", 241 | "10387": "NORMAL", 242 | "10334": "NORMAL", 243 | "10318": "NORMAL", 244 | "10384": "DR", 245 | "10355": "NORMAL", 246 | "10434": "NORMAL", 247 | "10467": "DR" 248 | } 249 | -------------------------------------------------------------------------------- /dataset_splits/readme.md: -------------------------------------------------------------------------------- 1 | # Guide to split files with case IDs 2 | 3 | To guarantee reproducibility, this folder contains the split files used for all experiments in this paper. For each dataset, there are two files: 4 | 5 | - `splits_final.json`: This file defines the training-validation splits (folds). It is produced during preprocessing by nnU-Net. 6 | - `domain_mapping_00.json`: This file lists all test cases and their "domain". It is produced during dataset conversion [(scripts here)](../src/segmentation_failures/data/dataset_conversion/). 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "segmentation_failures" 7 | version = "2024.0.1" 8 | authors = [ 9 | {name = "Maximilian Zenk", email = "m.zenk@dkfz-heidelberg.de"}, 10 | ] 11 | description = "Code for the medical image segmentation failure detection benchmark" 12 | readme = "README.md" 13 | license = {file = "LICENSE"} 14 | requires-python = ">=3.11" 15 | 16 | dependencies = [ 17 | "torch==2.0.1", 18 | "torchvision==0.15.2", 19 | "pytorch-lightning==2.0.1.post0", 20 | "numpy==1.24.2", 21 | "torchio==0.18.91", 22 | "hydra-core==1.3.2", 23 | "timm==0.6.13", 24 | "pyradiomics==3.0.1", 25 | "scikit-learn==1.2.2", 26 | "SimpleITK==2.2.1", 27 | "nnunetv2==2.2.1", 28 | "monai[nibabel]==1.3.0", 29 | "scikit-image", 30 | "loguru", 31 | "python-dotenv", 32 | "pandas", 33 | "rich", 34 | "tqdm", 35 | "dynamic-network-architectures", 36 | "tensorboard", 37 | "natsort", 38 | ] 39 | 40 | [project.optional-dependencies] 41 | dev = [ 42 | "pytest", 43 | "pytest-cov", 44 | "pre-commit", 45 | "matplotlib", 46 | "seaborn>=0.13.2", 47 | "flake8", 48 | "flake8-bugbear", 49 | "black", 50 | "pytest", 51 | "ipython", 52 | "ipykernel", 53 | ] 54 | launcher = [ 55 | "parallel-ssh" 56 | ] 57 | 58 | [tool.licensecheck] 59 | using = "PEP631" 60 | -------------------------------------------------------------------------------- /src/segmentation_failures/callbacks/batch_logging.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import monai 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | from loguru import logger 8 | from pytorch_lightning import Callback 9 | from torchvision.transforms import ToPILImage 10 | 11 | from segmentation_failures.utils.visualization import make_image_mask_grid 12 | 13 | 14 | class BatchVisualization(Callback): 15 | def __init__( 16 | self, 17 | num_classes: int, 18 | log_dir=None, 19 | every_n_steps=250, 20 | max_num_images=16, 21 | ): 22 | self.num_classes = num_classes 23 | self.log_dir = log_dir 24 | self.every_n_batches = every_n_steps 25 | self.max_num_images = max_num_images 26 | 27 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None: 28 | # Log a subset of the training images and their labels as images to tensorboard 29 | # if batch_idx > 0: 30 | # # Log only once per epoch 31 | if trainer.global_step % self.every_n_batches == 0: 32 | self.log_batch(trainer, batch, outputs) 33 | 34 | def on_validation_batch_end( 35 | self, 36 | trainer, 37 | pl_module, 38 | outputs, 39 | batch, 40 | batch_idx: int, 41 | dataloader_idx: int = 0, 42 | ) -> None: 43 | if not trainer.sanity_checking and batch_idx == 0: 44 | self.log_batch(trainer, batch, outputs, mode="val") 45 | 46 | def log_batch(self, trainer, batch, outputs, mode="train"): 47 | # optional: log to tensorboard 48 | for tb_logger in trainer.loggers: 49 | if isinstance(tb_logger, pl.loggers.tensorboard.TensorBoardLogger): 50 | break 51 | self._log_batch(batch, outputs, mode, step=trainer.global_step, tb_logger=tb_logger) 52 | 53 | def _log_batch(self, batch, outputs=None, mode="train", step=0, tb_logger=None): 54 | images = batch.get("data") 55 | labels = batch.get("target") 56 | reverse_spatial = "properties" in batch # nnunet dataloader 57 | if images is None and labels is None: 58 | logger.warning("No images or labels found in batch. Skipping logging.") 59 | return 60 | elif images is None or not isinstance(images, torch.Tensor): 61 | images = torch.zeros_like(labels) 62 | logger.warning("No images found in batch. Logging only labels.") 63 | elif labels is None or not isinstance(images, torch.Tensor): 64 | labels = torch.zeros_like(images) 65 | logger.warning("No labels found in batch. Logging only images.") 66 | if isinstance(labels, list): 67 | labels = labels[0] # deep supervision 68 | images = images.detach().cpu() 69 | labels = labels.detach().cpu() 70 | num_spatial = len(images.shape[2:]) 71 | if isinstance(images, monai.data.MetaTensor): 72 | images = images.as_tensor() 73 | if isinstance(labels, monai.data.MetaTensor): 74 | labels = labels.as_tensor() 75 | exclusive_labels = labels.shape[1] == 1 76 | if exclusive_labels: 77 | # exclusive labels 78 | # FIXME the maximum in here is due to bad configuration files: see num_fg_classes vs num_classes for brats or kits 79 | labels = F.one_hot( 80 | labels.squeeze(1).to(torch.long), 81 | num_classes=max(self.num_classes, labels.max() + 1), 82 | ) 83 | labels = labels.permute(0, -1, *range(1, num_spatial + 1)) 84 | slice_indices = None 85 | if num_spatial == 3: 86 | # select slice with largest foreground fraction; BCHWD -> B 87 | sum_dims = (1, 3, 4) if reverse_spatial else (1, 2, 3) 88 | slice_indices = torch.argmax( 89 | torch.sum(labels[:, int(exclusive_labels) :], dim=sum_dims), dim=1 90 | ).tolist() 91 | rgb_image = make_image_mask_grid( 92 | image_batch=images, 93 | mask_list=[labels], 94 | max_images=self.max_num_images, 95 | slice_idx=slice_indices, 96 | slice_dim=0 if reverse_spatial else 2, 97 | ) 98 | if self.log_dir is not None: 99 | fpath = Path(self.log_dir) / f"{mode}_batch_{step}.png" 100 | # save tensor as png 101 | img = ToPILImage()(rgb_image) 102 | img.save(fpath) 103 | if tb_logger is not None: 104 | tb_logger.experiment.add_image(f"{mode}_batch", rgb_image, step, dataformats="CHW") 105 | -------------------------------------------------------------------------------- /src/segmentation_failures/callbacks/network_graph_tb.py: -------------------------------------------------------------------------------- 1 | import monai 2 | import pytorch_lightning as pl 3 | 4 | 5 | class NetworkGraphViz(pl.Callback): 6 | def on_train_batch_start( 7 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch, batch_idx: int 8 | ) -> None: 9 | if trainer.global_step == 0: 10 | input_array = batch["data"] 11 | if isinstance(input_array, monai.data.MetaTensor): 12 | input_array = input_array.as_tensor() 13 | self.logger.log_graph(pl_module, input_array=input_array) 14 | -------------------------------------------------------------------------------- /src/segmentation_failures/callbacks/quality_estimator_monitor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from pytorch_lightning.callbacks import Callback 6 | 7 | 8 | class TrainingTargetMonitor(Callback): 9 | def __init__( 10 | self, 11 | output_dir: str, 12 | save_every_n_epochs: int = 1, 13 | ): 14 | """This callback should save the predicted and target quality values after each training epoch""" 15 | self.output_path = Path(output_dir) 16 | self.output_path.mkdir() 17 | self.save_every_n_epochs = save_every_n_epochs 18 | self._quality_true_buffer = [] # this is cleared after each epoch 19 | self._quality_pred_buffer = [] # this is cleared after each epoch 20 | 21 | def on_train_batch_end( 22 | self, trainer, pl_module, outputs, batch, batch_idx: int, unused: int = 0 23 | ) -> None: 24 | if trainer.current_epoch % self.save_every_n_epochs != 0: 25 | return 26 | self._quality_pred_buffer.append(outputs["quality_pred"]) # tensor of shape BM 27 | self._quality_true_buffer.append(outputs["quality_true"]) # tensor of shape BM 28 | 29 | def on_train_epoch_end(self, trainer, pl_module) -> None: 30 | if trainer.current_epoch % self.save_every_n_epochs != 0: 31 | return 32 | # log true quality values (just for sanity checking) 33 | all_quality_true = torch.cat(self._quality_true_buffer, dim=0).cpu().numpy() 34 | all_quality_pred = torch.cat(self._quality_pred_buffer, dim=0).cpu().numpy() 35 | if hasattr(trainer.datamodule, "metric_target_names"): 36 | quality_names = trainer.datamodule.metric_target_names 37 | np.savez( 38 | self.output_path / f"quality_targets_epoch={pl_module.current_epoch}.npz", 39 | quality_true=all_quality_true, 40 | quality_pred=all_quality_pred, 41 | names=quality_names, 42 | ) 43 | self._quality_true_buffer = [] 44 | self._quality_pred_buffer = [] 45 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/analysis/fail_thresholds/fail_thresholds_general.yaml: -------------------------------------------------------------------------------- 1 | # These are very arbitrary 2 | # I need to override these values for each dataset and maybe even specialize for different classes 3 | dice: 0.75 4 | hausdorff95: 30 5 | generalized_dice: 0.75 6 | surface_dice: 0.75 7 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/analysis/failure_detection.yaml: -------------------------------------------------------------------------------- 1 | fail_thresholds: null 2 | save_curves: true 3 | fd_metrics: 4 | - aurc 5 | - e-aurc 6 | - norm-aurc 7 | - opt-aurc 8 | - rand-aurc 9 | - spearman 10 | - pearson 11 | ood_metrics: 12 | - ood_auc 13 | id_domain: ${dataset.id_domain} 14 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/analysis/failure_detection_with_thresholds.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - fail_thresholds: fail_thresholds_general 3 | 4 | save_curves: true 5 | fd_metrics: 6 | - aurc 7 | - e-aurc 8 | - norm-aurc 9 | - opt-aurc 10 | - rand-aurc 11 | - failauc 12 | - failap_suc 13 | - failap_err 14 | - fpr@95tpr 15 | ood_metrics: 16 | - ood_auc 17 | id_domain: ${dataset.id_domain} 18 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/dynamic_resencunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dynamic_unet 3 | 4 | hparams: 5 | res_block: true 6 | blocks_per_stage: 3 7 | checkpoint: null 8 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/dynamic_resencunet_deepsup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dynamic_resencunet 3 | 4 | hparams: 5 | deep_supervision: True 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/dynamic_resencunet_dropout.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dynamic_resencunet 3 | 4 | hparams: 5 | dropout: 0.5 6 | num_dropout_units: 5 7 | append_dropout: True 8 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/dynamic_unet.yaml: -------------------------------------------------------------------------------- 1 | hparams: 2 | _target_: segmentation_failures.networks.dynunet.get_network 3 | spatial_dims: ${dataset.img_dim} 4 | in_channels: ${dataset.img_channels} 5 | out_channels: ${dataset.num_classes} 6 | patch_size: ${datamodule.patch_size} 7 | spacings: ${datamodule.spacing} 8 | dropout: 0.0 9 | num_dropout_units: 0 10 | deep_supervision: False 11 | blocks_per_stage: 1 12 | checkpoint: null 13 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/dynamic_unet_deepsup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dynamic_unet 3 | 4 | hparams: 5 | deep_supervision: True 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/dynamic_unet_dropout.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dynamic_unet 3 | 4 | hparams: 5 | dropout: 0.5 6 | num_dropout_units: 5 7 | append_dropout: True 8 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/dynamic_wideunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dynamic_unet 3 | 4 | hparams: 5 | filter_scaling: 2 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/dynamic_wideunet_dropout.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dynamic_wideunet 3 | 4 | hparams: 5 | dropout: 0.5 6 | num_dropout_units: 5 7 | append_dropout: True 8 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/monai_unet.yaml: -------------------------------------------------------------------------------- 1 | hparams: 2 | _target_: monai.networks.nets.UNet 3 | in_channels: ${dataset.img_channels} 4 | out_channels: ${dataset.num_classes} 5 | spatial_dims: ${dataset.img_dim} 6 | dropout: 0 7 | channels: [16, 32, 64, 128, 256] 8 | strides: [2, 2, 2, 2] 9 | num_res_units: 2 10 | norm: instance 11 | checkpoint: null 12 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/backbone/monai_unet_dropout.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - monai_unet 3 | 4 | hparams: 5 | dropout: 0.3 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | # Note if you want to adapt this from the command line: 2 | # - to add a callback use +checkpoint/train=quality_training_monitor 3 | # - to redefine the whole list use checkpoint/train='[quality_training_monitor,batch_logging]' 4 | defaults: 5 | - train_seg: 6 | - model_checkpoint 7 | # - batch_logging # slow 8 | - prediction_saver 9 | # prediction saver only logs in case of trainer.validate call 10 | - train: 11 | - model_checkpoint 12 | # - batch_logging # slow 13 | - validate: 14 | # - batch_logging # slow 15 | - prediction_saver 16 | - test: 17 | - results_saver 18 | - prediction_saver 19 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | train: null 2 | validate: null 3 | test: null 4 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/test/confidence_saver.yaml: -------------------------------------------------------------------------------- 1 | confidence_saver: 2 | _target_: segmentation_failures.callbacks.confidence_map_writer.PixelConfidenceWriter 3 | output_dir: ${paths.pixel_confid_dir} 4 | num_export_workers: 3 5 | confid_name: null 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/test/dummy.yaml: -------------------------------------------------------------------------------- 1 | dummy: 2 | do_nothing: true 3 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/test/ensemble_prediction_saver.yaml: -------------------------------------------------------------------------------- 1 | prediction_distr_saver: 2 | _target_: segmentation_failures.callbacks.prediction_writer.MultiPredictionWriter 3 | pred_key: prediction_distr 4 | output_dir: ${paths.prediction_samples_dir} 5 | save_probabilities: false 6 | num_export_workers: 2 7 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/test/prediction_saver.yaml: -------------------------------------------------------------------------------- 1 | prediction_saver: 2 | _target_: segmentation_failures.callbacks.prediction_writer.PredictionWriter 3 | output_dir: ${paths.predictions_dir} 4 | save_probabilities: false 5 | num_export_workers: 3 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/test/results_saver.yaml: -------------------------------------------------------------------------------- 1 | results_saver: 2 | _target_: segmentation_failures.callbacks.results_writer.ExperimentDataWriter 3 | output_dir: ${paths.results_dir} 4 | prediction_dir: ${paths.predictions_dir} 5 | num_classes: ${dataset.num_classes} 6 | region_based_eval: ${dataset.overlapping_classes} 7 | num_processes: 6 8 | previous_stage_results_path: null # set only for two-stage models 9 | metric_list: 10 | - dice 11 | - generalized_dice 12 | - surface_dice 13 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train/batch_logging.yaml: -------------------------------------------------------------------------------- 1 | batch_logging: 2 | _target_: segmentation_failures.callbacks.batch_logging.BatchVisualization 3 | num_classes: ${dataset.num_classes} 4 | max_num_images: 8 5 | every_n_steps: 250 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: val_loss_epoch # name of the logged metric which determines when model is improving 4 | mode: min # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | verbose: False 8 | dirpath: ${paths.checkpoint_dir} 9 | filename: "epoch_{epoch:03d}" 10 | auto_insert_metric_name: False 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train/prediction_saver.yaml: -------------------------------------------------------------------------------- 1 | prediction_saver: 2 | _target_: segmentation_failures.callbacks.prediction_writer.PredictionWriter 3 | output_dir: ${paths.predictions_dir} 4 | save_probabilities: false 5 | num_export_workers: 1 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train/quality_training_monitor.yaml: -------------------------------------------------------------------------------- 1 | quality_monitor: 2 | _target_: segmentation_failures.callbacks.quality_estimator_monitor.TrainingTargetMonitor 3 | output_dir: ${paths.extras_dir} 4 | save_every_n_epochs: 10 5 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train/save_val_predictions.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | prediction_saver: 3 | _target_: segmentation_failures.callbacks.prediction_writer.PredictionWriterWithBalancing 4 | output_dir: ${paths.predictions_dir} 5 | num_fg_classes: ${dataset.num_fg_classes} 6 | num_export_workers: 1 7 | num_bins: 20 8 | max_num_per_bin: 2 9 | randomize_bins: true 10 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train_seg/batch_logging.yaml: -------------------------------------------------------------------------------- 1 | batch_logging: 2 | _target_: segmentation_failures.callbacks.batch_logging.BatchVisualization 3 | num_classes: ${dataset.num_classes} 4 | max_num_images: 8 5 | every_n_steps: 250 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train_seg/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: val_loss_epoch # name of the logged metric which determines when model is improving 4 | mode: min # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | verbose: False 8 | dirpath: ${paths.checkpoint_dir} 9 | filename: "epoch_{epoch:03d}" 10 | auto_insert_metric_name: False 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train_seg/prediction_saver.yaml: -------------------------------------------------------------------------------- 1 | prediction_saver: 2 | _target_: segmentation_failures.callbacks.prediction_writer.PredictionWriter 3 | output_dir: ${paths.predictions_dir} 4 | save_probabilities: false 5 | num_export_workers: 1 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train_seg/quality_training_monitor.yaml: -------------------------------------------------------------------------------- 1 | quality_monitor: 2 | _target_: segmentation_failures.callbacks.quality_estimator_monitor.TrainingTargetMonitor 3 | output_dir: ${paths.extras_dir} 4 | save_every_n_epochs: 10 5 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/train_seg/save_val_predictions.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | prediction_saver: 3 | _target_: segmentation_failures.callbacks.prediction_writer.PredictionWriterWithBalancing 4 | output_dir: ${paths.predictions_dir} 5 | num_fg_classes: ${dataset.num_fg_classes} 6 | num_export_workers: 1 7 | num_bins: 20 8 | max_num_per_bin: 2 9 | randomize_bins: true 10 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/validate/batch_logging.yaml: -------------------------------------------------------------------------------- 1 | batch_logging: 2 | _target_: segmentation_failures.callbacks.batch_logging.BatchVisualization 3 | num_classes: ${dataset.num_classes} 4 | max_num_images: 8 5 | every_n_steps: 250 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/validate/confidence_saver.yaml: -------------------------------------------------------------------------------- 1 | confidence_saver: 2 | _target_: segmentation_failures.callbacks.confidence_map_writer.PixelConfidenceWriter 3 | output_dir: ${paths.pixel_confid_dir} 4 | num_export_workers: 3 5 | confid_name: null 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/callbacks/validate/prediction_saver.yaml: -------------------------------------------------------------------------------- 1 | prediction_saver: 2 | _target_: segmentation_failures.callbacks.prediction_writer.PredictionWriter 3 | output_dir: ${paths.predictions_dir} 4 | save_probabilities: false 5 | num_export_workers: 3 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ # see https://hydra.cc/docs/tutorials/basic/your_first_app/defaults/#composition-order-of-primary-config 6 | - hydra: local 7 | - paths: default 8 | - logger: default 9 | - trainer: single_gpu 10 | - callbacks: default 11 | - analysis: failure_detection 12 | - dataset: null 13 | - datamodule: null 14 | - backbone: null 15 | - segmentation: null 16 | - csf_pixel: null 17 | - csf_image: null 18 | - csf_aggregation: null 19 | # debugging config (enable through command line, e.g. `python train.py debug=default) 20 | - debug: null 21 | # optional local config for machine/user specific settings 22 | # it's optional since it doesn't need to exist and is excluded from version control 23 | - optional local: default 24 | 25 | # I currently use loguru 26 | loguru: 27 | level: INFO 28 | file: main.log # relative to hydra:run.dir 29 | 30 | # seed for random number generators in pytorch, numpy and python.random 31 | seed: 32586152 32 | 33 | test: 34 | last_ckpt: true 35 | # this applies only to the image CSF; for segmentation, we always use the last checkpoint 36 | 37 | resume_from_checkpoint: 38 | path: null 39 | load_expt_config: false 40 | 41 | # Sometimes this causes errors on my workstation. See this discussion: 42 | # https://github.com/pytorch/pytorch/issues/973#issuecomment-459398189 43 | # Unfortunately, I wasn't able to find the cause of the issue, so 44 | # the workaround is to set this to `file_system` instead. 45 | mp_sharing_strategy: file_descriptor 46 | 47 | expt_group: default 48 | expt_name: ??? 49 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/all_simple.yaml: -------------------------------------------------------------------------------- 1 | # TODO I don't know how to configure trainer=cpu here... Could do it in the experiments.__init__.py but here would be better. 2 | defaults: 3 | - simple_aggs@hparams.aggregation_methods: 4 | - mean 5 | - only_non_boundary 6 | - foreground 7 | - patch_based 8 | - pairwise_gen_dice 9 | - pairwise_mean_dice 10 | 11 | hparams: 12 | _target_: segmentation_failures.models.confidence_aggregation.SimpleAggModule 13 | dataset_id: ${dataset.dataset_id} 14 | trainable: false 15 | checkpoint: null 16 | twostage: true # implement later 17 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/heuristic.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /trainer: cpu 4 | 5 | csf_aggregation: 6 | hparams: 7 | _target_: segmentation_failures.models.confidence_aggregation.HeuristicAggregationModule 8 | regression_model: regression_forest 9 | dataset_id: ${dataset.dataset_id} 10 | confid_name: ??? 11 | target_metrics: null # inferred automatically 12 | trainable: true 13 | checkpoint: null 14 | twostage: true 15 | 16 | trainer: 17 | max_epochs: 1 18 | check_val_every_n_epoch: 1 19 | 20 | callbacks: 21 | train: 22 | model_checkpoint: 23 | save_top_k: 0 24 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/radiomics.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /trainer: cpu 5 | 6 | csf_aggregation: 7 | hparams: 8 | _target_: segmentation_failures.models.confidence_aggregation.RadiomicsAggregationModule 9 | regression_model: regression_forest 10 | dataset_id: ${dataset.dataset_id} 11 | target_metrics: null # inferred automatically 12 | image_dim: ${dataset.img_dim} 13 | confid_name: ??? 14 | confid_threshold: null # if not set/null, this will be computed on the validation set as in Jungo et al. 15 | trainable: true 16 | checkpoint: null 17 | twostage: true 18 | 19 | trainer: 20 | max_epochs: 1 21 | check_val_every_n_epoch: 1 22 | 23 | callbacks: 24 | train: 25 | model_checkpoint: 26 | save_top_k: 0 27 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/simple_aggs/distance_weighting.yaml: -------------------------------------------------------------------------------- 1 | distance_weighted: 2 | _target_: segmentation_failures.models.confidence_aggregation.get_aggregator 3 | name: distance_weighted 4 | agg_fn: mean 5 | saturate: 4.0 6 | spacing: ${datamodule.spacing} 7 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/simple_aggs/foreground.yaml: -------------------------------------------------------------------------------- 1 | foreground: 2 | _target_: segmentation_failures.models.confidence_aggregation.get_aggregator 3 | name: foreground_weighted 4 | boundary_width: 4 5 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/simple_aggs/mean.yaml: -------------------------------------------------------------------------------- 1 | mean: 2 | _target_: segmentation_failures.models.confidence_aggregation.get_aggregator 3 | name: simple 4 | agg_fn: mean 5 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/simple_aggs/only_non_boundary.yaml: -------------------------------------------------------------------------------- 1 | only_non_boundary: 2 | _target_: segmentation_failures.models.confidence_aggregation.get_aggregator 3 | name: boundary_weighted 4 | invert: true 5 | boundary_width: 4 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/simple_aggs/pairwise_gen_dice.yaml: -------------------------------------------------------------------------------- 1 | pairwise_gen_dice: 2 | _target_: segmentation_failures.models.confidence_aggregation.get_aggregator 3 | name: pairwise_dice 4 | include_zero_label: false 5 | gen_dice_weight: square 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/simple_aggs/pairwise_mean_dice.yaml: -------------------------------------------------------------------------------- 1 | pairwise_mean_dice: 2 | _target_: segmentation_failures.models.confidence_aggregation.get_aggregator 3 | name: pairwise_dice 4 | include_zero_label: false 5 | use_mean_dice: true 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_aggregation/simple_aggs/patch_based.yaml: -------------------------------------------------------------------------------- 1 | patch_min: 2 | _target_: segmentation_failures.models.confidence_aggregation.get_aggregator 3 | name: patch_min 4 | patch_size: null # will result in [10]**D patch 5 | mean: true 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_image/mahalanobis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | csf_image: 4 | hparams: 5 | _target_: segmentation_failures.models.image_confidence.SingleGaussianOODDetector 6 | feature_path: ??? 7 | sw_patch_size: ${datamodule.patch_size} # can be null, in which case no sliding window will be used 8 | sw_batch_size: ${datamodule.batch_size} 9 | sw_overlap: 0.5 10 | sw_training: False 11 | max_feature_size: 10000 12 | store_precision: True 13 | assume_centered: False 14 | trainable: true 15 | checkpoint: null 16 | needs_pretrained_segmentation: true 17 | twostage: false 18 | 19 | trainer: 20 | max_epochs: 1 21 | precision: 32 22 | check_val_every_n_epoch: 1 23 | 24 | callbacks: 25 | train: 26 | model_checkpoint: 27 | save_top_k: 0 28 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_image/mahalanobis_gonzalez.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - mahalanobis 4 | 5 | csf_image: 6 | hparams: 7 | sw_training: True 8 | 9 | trainer: 10 | max_epochs: 1 11 | precision: 32 12 | check_val_every_n_epoch: 1 13 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_image/quality_regression.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | csf_image: 4 | hparams: 5 | _target_: segmentation_failures.models.image_confidence.regression_network.QualityRegressionNet 6 | output_names: null # inferred automatically 7 | img_channels: ${dataset.img_channels} 8 | img_dim: ${dataset.img_dim} 9 | num_classes: ${dataset.num_fg_classes} 10 | confid_name: null 11 | loss: l2 # used in Robinson et al 12 | lr: 2e-4 13 | weight_decay: 1e-4 14 | cosine_annealing: true 15 | voxel_spacing: null # set dynamically in training script 16 | img_size: null # set dynamically in training script 17 | blocks_per_stage: null # set dynamically in training script 18 | trainable: true 19 | checkpoint: null 20 | needs_pretrained_segmentation: false 21 | twostage: true 22 | 23 | trainer: 24 | max_epochs: 1000 25 | check_val_every_n_epoch: 5 26 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_image/vae_image_and_mask.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | csf_image: 4 | hparams: 5 | _target_: segmentation_failures.models.image_confidence.vae_estimator.SimpleVAEmodule 6 | img_dim: ${dataset.img_dim} 7 | img_channels: ${dataset.img_channels} 8 | seg_channels: ${dataset.num_fg_classes} 9 | img_size: null # currently set as a python constant in the datamodule 10 | lr: 1e-4 11 | z_dim: 256 12 | model_h_size: [32, 64, 128, 256, 512] 13 | liu_architecture: false 14 | to_1x1: True 15 | beta: 0.001 16 | normalization_op: none 17 | recon_loss_img: l1 18 | recon_loss_seg: bce 19 | log_n_samples: 0 20 | log_train_recons: false 21 | log_val_recons: true 22 | trainable: true 23 | needs_pretrained_segmentation: false 24 | checkpoint: null 25 | twostage: true 26 | 27 | trainer: 28 | max_epochs: 1000 29 | check_val_every_n_epoch: 1 30 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_image/vae_image_only.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - vae_image_and_mask 4 | 5 | csf_image: 6 | hparams: 7 | img_channels: ${dataset.img_channels} 8 | seg_channels: 0 9 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_image/vae_iterative_surrogate.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - vae_image_and_mask 4 | 5 | csf_image: 6 | hparams: 7 | _target_: segmentation_failures.models.image_confidence.vae_estimator.IterativeSurrogateVAEmodule 8 | surrogate_lr: 1e-3 9 | quality_metric: generalized_dice 10 | convergence_thresh: 1e-2 11 | 12 | trainer: 13 | # As I do some optimization at test-time, I can't use inference_mode (only no_grad, which I disable locally) 14 | # I experienced issues with the requires_grad attribute not being inherited when using inference_mode 15 | inference_mode: false 16 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_image/vae_mask_only.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - vae_image_and_mask 4 | 5 | csf_image: 6 | hparams: 7 | img_channels: 0 8 | seg_channels: ${dataset.num_fg_classes} 9 | 10 | datamodule: 11 | load_images: False 12 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_pixel/baseline.yaml: -------------------------------------------------------------------------------- 1 | hparams: 2 | _target_: segmentation_failures.models.pixel_confidence.PosthocMultiConfidenceSegmenter 3 | csf_names: 4 | - maxsoftmax 5 | - predictive_entropy 6 | num_mcd_samples: 0 7 | overlapping_classes: ${dataset.overlapping_classes} 8 | everything_on_gpu: false 9 | trainable: false 10 | checkpoint: null 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_pixel/deep_ensemble.yaml: -------------------------------------------------------------------------------- 1 | hparams: 2 | _target_: segmentation_failures.models.pixel_confidence.DeepEnsembleMultiConfidenceSegmenter 3 | csf_names: 4 | # - maxsoftmax 5 | - predictive_entropy 6 | - mutual_information 7 | overlapping_classes: ${dataset.overlapping_classes} 8 | everything_on_gpu: false 9 | num_models: 5 10 | trainable: false 11 | checkpoint: null 12 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/csf_pixel/mcdropout.yaml: -------------------------------------------------------------------------------- 1 | hparams: 2 | _target_: segmentation_failures.models.pixel_confidence.PosthocMultiConfidenceSegmenter 3 | csf_names: 4 | # - maxsoftmax 5 | - predictive_entropy 6 | - mutual_information 7 | num_mcd_samples: 10 8 | overlapping_classes: ${dataset.overlapping_classes} 9 | everything_on_gpu: false 10 | trainable: false 11 | checkpoint: null 12 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/acdc_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 3 5 | patch_size: [20, 256, 224] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/brats19_lhgg_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | patch_size: [128, 128, 128] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/covid_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | patch_size: [28, 256, 256] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/dummy.yaml: -------------------------------------------------------------------------------- 1 | # How to use this: just use the datamodule you want to simulate as default here and add the arguments indicated below 2 | 3 | defaults: 4 | - kits23_nnunet 5 | 6 | # IMPORTANT: Add these arguments below; they determine the actual behavior of the dataloader 7 | hparams: 8 | _target_: segmentation_failures.data.datamodules.dummy_modules.DummyNNunetDataModule 9 | dummy_num_samples: 1 10 | dummy_num_channels: 1 11 | dummy_img_size: [625, 625, 625] 12 | dummy_batch_size: 1 13 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/heuristic_radiomics.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - simple_agg 3 | 4 | metric_targets: 5 | - generalized_dice 6 | - dice 7 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/kits23_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | patch_size: [128, 128, 128] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/mnms_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 4 5 | patch_size: [10, 320, 320] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/mvseg23_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | patch_size: [112, 128, 128] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/nnunet.yaml: -------------------------------------------------------------------------------- 1 | _target_: segmentation_failures.data.datamodules.nnunet_module.NNunetDataModule 2 | dataset_id: ${dataset.dataset_id} 3 | fold: ??? 4 | device: ${trainer.accelerator} 5 | test_data_root: ${paths.data_root_dir} 6 | batch_size: ??? 7 | patch_size: ??? # ZYX 8 | spacing: null # obtained automatically from nnunet trainer 9 | nnunet_config: 3d_fullres 10 | nnunet_plans_id: nnUNetPlans 11 | deep_supervision: ${backbone.hparams.deep_supervision} 12 | num_workers: null # (training) automatically set by nnunet 13 | num_workers_preproc: 3 14 | domain_mapping: 0 15 | preproc_only: false 16 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/octa500_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | # patch_size: [112, 112, 112] 6 | patch_size: [32, 256, 256] 7 | nnunet_config: 3d_fullres_custompatch_largerspacing 8 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/prostate_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | patch_size: [20, 320, 256] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/quality_regression.yaml: -------------------------------------------------------------------------------- 1 | _target_: segmentation_failures.data.datamodules.quality_regression.QualityRegressionDataModule 2 | metric_targets: 3 | - generalized_dice 4 | - dice 5 | prediction_dir: ${paths.predictions_dir} 6 | dataset_id: ${dataset.dataset_id} 7 | fold: ??? 8 | test_data_root: ${paths.data_root_dir} 9 | confid_name: null 10 | confid_dir: ${paths.pixel_confid_dir} 11 | batch_size: 2 12 | num_workers: null 13 | pin_memory: true 14 | domain_mapping: 0 15 | preproc_only: false 16 | cache_num: 1.0 17 | use_metatensor: true 18 | randomize_prediction: 0.33 19 | include_background: false 20 | expt_group: ${expt_group} 21 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/retina_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | nnunet_config: 2d 5 | batch_size: 12 6 | patch_size: [512, 512] 7 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/retouch_cirrus_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | # patch_size: [48, 160, 160] 6 | patch_size: [16, 384, 256] 7 | nnunet_config: 3d_fullres_custompatch 8 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/retouch_spectralis_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | patch_size: [32, 256, 112] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/retouch_topcon_nnunet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nnunet 3 | 4 | batch_size: 2 5 | patch_size: [32, 192, 128] 6 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/simple_agg.yaml: -------------------------------------------------------------------------------- 1 | _target_: segmentation_failures.data.datamodules.simple_agg.SimpleAggDataModule 2 | dataset_id: ${dataset.dataset_id} 3 | fold: null 4 | prediction_dir: ${paths.predictions_dir} 5 | prediction_samples_dir: ${paths.prediction_samples_dir} 6 | confid_dir: ${paths.pixel_confid_dir} 7 | confid_name: null 8 | num_workers: 3 9 | pin_memory: true 10 | metric_targets: null 11 | expt_group: ${expt_group} 12 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/simple_fets22_corrupted.yaml: -------------------------------------------------------------------------------- 1 | _target_: segmentation_failures.data.datamodules.SimpleBraTS 2 | train_data_dir: ${oc.env:nnUNet_raw}/Dataset500_simple_fets_corruptions 3 | test_data_dir: ${paths.data_root_dir}/Dataset500_simple_fets_corruptions 4 | dataset_id: ${dataset.dataset_id} 5 | patch_size: null # not used 6 | spacing: null # not used 7 | batch_size: 32 8 | batch_size_inference: 32 9 | fold: 0 10 | num_workers: 4 11 | pin_memory: True 12 | cache_num: 1.0 13 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/datamodule/vae.yaml: -------------------------------------------------------------------------------- 1 | _target_: segmentation_failures.data.datamodules.vae.VAEdataModule 2 | dataset_id: ${dataset.dataset_id} 3 | fold: ??? 4 | test_data_root: ${paths.data_root_dir} 5 | prediction_dir: ${paths.predictions_dir} 6 | batch_size: 6 7 | num_workers: null 8 | pin_memory: true 9 | fixed_steps_per_epoch: 250 10 | domain_mapping: 0 11 | preprocess_only: false 12 | cache_num: 1.0 13 | use_metatensor: true 14 | clip_values: 2.0 15 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/abstract.yaml: -------------------------------------------------------------------------------- 1 | dataset_id: ??? 2 | img_dim: ??? 3 | img_channels: ??? 4 | num_classes: ??? 5 | num_fg_classes: ??? # TODO This is a workaround. Some other configs need this value 6 | overlapping_classes: ??? 7 | id_domain: ??? 8 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/acdc.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "510" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 4 8 | num_fg_classes: 3 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: ID 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/brats19_lhgg.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "503" 5 | img_dim: 3 6 | img_channels: 4 7 | num_classes: 3 8 | # TODO num_classes=3 is necessary because it's used to set the number of network outputs 9 | # This is not nice, is there another way? Need to check also the other uses of this variable... 10 | num_fg_classes: 3 # TODO This is a workaround. Some other configs need this value 11 | overlapping_classes: true 12 | id_domain: ["HGG", "LGG"] 13 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/covid_gonzalez.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "520" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 2 8 | num_fg_classes: 1 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: ID 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/kits23.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "515" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 3 8 | # TODO num_classes=3 is necessary because it's used to set the number of network outputs 9 | # This is not nice, is there another way? Need to check also the other uses of this variable... 10 | num_fg_classes: 3 # TODO This is a workaround. Some other configs need this value 11 | overlapping_classes: true 12 | id_domain: ID 13 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/mnms.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "511" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 4 8 | num_fg_classes: 3 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: B 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/mvseg23.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "514" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 3 8 | num_fg_classes: 2 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: ID 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/octa500.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "560" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 6 8 | num_fg_classes: 5 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: ["NORMAL", "AMD", "CNV"] 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/prostate_gonzalez.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "521" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 2 8 | num_fg_classes: 1 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: ID 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/retina.yaml: -------------------------------------------------------------------------------- 1 | dataset_id: 531 2 | img_dim: 2 3 | img_channels: 3 4 | num_classes: 2 5 | # TODO num_classes=2 is necessary because it's used to set the number of network outputs 6 | # This is not nice, is there another way? Need to check also the other uses of this variable... 7 | num_fg_classes: 2 # TODO This is a workaround. Some other configs need this value 8 | overlapping_classes: true 9 | id_domain: ID 10 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/retouch_cirrus.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "540" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 4 8 | num_fg_classes: 3 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: ["Spectralis", "Topcon"] 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/retouch_spectralis.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "541" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 4 8 | num_fg_classes: 3 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: ["Cirrus", "Topcon"] 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/retouch_topcon.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "542" 5 | img_dim: 3 6 | img_channels: 1 7 | num_classes: 4 8 | num_fg_classes: 3 # TODO This is a workaround. Some other configs need this value 9 | overlapping_classes: false 10 | id_domain: ["Spectralis", "Cirrus"] 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/dataset/simple_fets22_corrupted.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - abstract 3 | 4 | dataset_id: "500" 5 | img_dim: 2 6 | img_channels: 4 7 | num_classes: 2 8 | num_fg_classes: 1 9 | overlapping_classes: false 10 | id_domain: noshift 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | defaults: 7 | - override /hydra: debug.yaml 8 | 9 | trainer: 10 | max_epochs: 1 11 | precision: 32 12 | accelerator: cpu # debuggers don't like gpus 13 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 14 | # track_grad_norm: 2 # track gradient norm with loggers 15 | deterministic: true 16 | check_val_every_n_epoch: 1 17 | 18 | datamodule: 19 | num_workers: 0 # debuggers don't like multiprocessing 20 | # pin_memory: False # disable gpu memory pin 21 | 22 | # sets level of all command line loggers to 'DEBUG' 23 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 24 | hydra: 25 | verbose: True 26 | 27 | # use this to set level of only chosen command line loggers to 'DEBUG': 28 | # verbose: [src.train, src.utils] 29 | 30 | # config is already printed by hydra when `hydra/verbose: True` 31 | print_config: False 32 | 33 | loguru: 34 | level: DEBUG 35 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/debug/limit_batches.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - default.yaml 4 | 5 | trainer: 6 | max_epochs: 2 7 | limit_train_batches: 3 8 | limit_val_batches: 5 9 | limit_test_batches: 5 10 | # datamodule: 11 | # hparams: 12 | # cache_num: 0 13 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/debug/lrfind.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run learning rate finder 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | auto_lr_find: true 10 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/debug/step.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/hydra/cluster.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - local 3 | 4 | run: 5 | dir: ${paths.log_dir}/${expt_group}/Dataset${dataset.dataset_id}/runs/${expt_name}/${hydra.job.name}/${oc.env:LSB_JOBID} 6 | sweep: 7 | dir: ${paths.log_dir}/${expt_group}/Dataset${dataset.dataset_id}/multiruns/${expt_name}/${hydra.job.name}/${oc.env:LSB_JOBID} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/hydra/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - local 3 | 4 | run: 5 | dir: ${paths.log_dir}/${expt_group}/Dataset${dataset.dataset_id}/debug_runs/${expt_name}/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: ${paths.log_dir}/${expt_group}/Dataset${dataset.dataset_id}/debug_multiruns/${expt_name}/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/hydra/local.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | # # enable color logging 3 | # - override hydra_logging: colorlog 4 | # - override job_logging: colorlog 5 | - override job_logging: disabled 6 | 7 | run: 8 | dir: ${paths.log_dir}/${expt_group}/Dataset${dataset.dataset_id}/runs/${expt_name}/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 9 | sweep: 10 | dir: ${paths.log_dir}/${expt_group}/Dataset${dataset.dataset_id}/multiruns/${expt_name}/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | subdir: ${hydra.job.num} 12 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: ${paths.output_dir} 6 | name: "csv_logs" 7 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/logger/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - csv 3 | - tensorboard 4 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: ${paths.output_dir}/tensorboard 6 | name: null 7 | version: ${expt_name} 8 | log_graph: false 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to directory with testing data for all datasets 2 | data_root_dir: ${oc.env:TESTDATA_ROOT_DIR} 3 | 4 | # path to logging directory 5 | log_dir: ${oc.env:EXPERIMENT_ROOT_DIR} 6 | 7 | # path to output directory, created dynamically by hydra 8 | # path generation pattern is specified in `conf/hydra/default.yaml` 9 | # use it to store all files generated during the run, like ckpts and metrics 10 | output_dir: ${hydra:runtime.output_dir} 11 | # Below are subdirs of output_dir!! 12 | results_dir: ${paths.output_dir}/results 13 | analysis_dir: ${paths.output_dir}/analysis 14 | pixel_confid_dir: ${paths.output_dir}/confidence_maps 15 | predictions_dir: ${paths.output_dir}/predictions 16 | prediction_samples_dir: ${paths.output_dir}/prediction_samples 17 | checkpoint_dir: ${paths.output_dir}/checkpoints 18 | extras_dir: ${paths.output_dir}/extras 19 | 20 | # path to working directory 21 | work_dir: ${hydra:runtime.cwd} 22 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/segmentation/baseline.yaml: -------------------------------------------------------------------------------- 1 | hparams: 2 | _target_: segmentation_failures.models.segmentation.UNet_segmenter 3 | num_classes: ${dataset.num_classes} 4 | lr: 1e-3 5 | weight_decay: 1e-5 6 | checkpoint: null 7 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/segmentation/dynunet.yaml: -------------------------------------------------------------------------------- 1 | hparams: 2 | _target_: segmentation_failures.models.segmentation.DynUnetModule 3 | num_classes: ${dataset.num_classes} 4 | patch_size: ${datamodule.patch_size} 5 | batch_dice: false 6 | sw_batch_size: ${datamodule.batch_size} 7 | lr: 1e-2 8 | weight_decay: 3e-5 9 | overlapping_classes: ${dataset.overlapping_classes} 10 | checkpoint: null 11 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | accelerator: cpu 4 | devices: 1 5 | precision: 32 6 | check_val_every_n_epoch: 1 7 | -------------------------------------------------------------------------------- /src/segmentation_failures/conf/trainer/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | accelerator: gpu 4 | devices: 1 5 | 6 | min_epochs: 1 7 | max_epochs: 10 8 | precision: 16-mixed 9 | 10 | # number of validation steps to execute at the beginning of the training 11 | # num_sanity_val_steps: 0 12 | log_every_n_steps: 20 13 | check_val_every_n_epoch: 1 14 | # # for full reproducibility, set this to true; default is false 15 | # deterministic: false 16 | -------------------------------------------------------------------------------- /src/segmentation_failures/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/segmentation_failures_benchmark/a1af98be0f93c2bdc30ffe5bb6eda531c485d87d/src/segmentation_failures/data/__init__.py -------------------------------------------------------------------------------- /src/segmentation_failures/data/corruptions/image_corruptions_tio.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torchio as tio 4 | 5 | 6 | class TransformMagnitude(Enum): 7 | LOW = 1 8 | MEDIUM = 2 9 | HIGH = 3 10 | 11 | 12 | class TransformRegistry: 13 | transforms = {} 14 | 15 | @classmethod 16 | def register(cls, name: str): 17 | if name in cls.transforms: 18 | print("WARNING: Transform %s already registered. Will replace it", name) 19 | 20 | def inner(func): 21 | cls.transforms[name] = func 22 | return func 23 | 24 | return inner 25 | 26 | @classmethod 27 | def get_transform(cls, name: str, magnitude: str, **kwargs): 28 | magnitude = TransformMagnitude[magnitude.upper()] 29 | # kwargs can overwrite the magnitude settings 30 | if name not in cls.transforms: 31 | raise ValueError("Transform %s not registered" % name) 32 | transform, trf_settings = cls.transforms[name](magnitude) 33 | trf_settings.update(kwargs) 34 | return transform(**trf_settings), trf_settings 35 | 36 | @classmethod 37 | def list_transforms(cls): 38 | return list(cls.transforms) 39 | 40 | 41 | @TransformRegistry.register("biasfield") 42 | def biasfield(magnitude: TransformMagnitude): 43 | default_settings = { 44 | "order": 3, 45 | } 46 | magnitude_overwrites = { 47 | TransformMagnitude.LOW: {"coefficients": 0.2}, 48 | TransformMagnitude.MEDIUM: {"coefficients": 0.6}, 49 | TransformMagnitude.HIGH: {"order": 5, "coefficients": 0.8}, 50 | } 51 | default_settings.update(magnitude_overwrites[magnitude]) 52 | return tio.RandomBiasField, default_settings 53 | 54 | 55 | @TransformRegistry.register("ghosting") 56 | def ghosting(magnitude: TransformMagnitude): 57 | default_settings = { 58 | "num_ghosts": (4, 8), 59 | } 60 | magnitude_overwrites = { 61 | TransformMagnitude.LOW: {"intensity": (0.2, 0.4)}, 62 | TransformMagnitude.MEDIUM: {"intensity": (0.5, 0.7)}, 63 | TransformMagnitude.HIGH: {"intensity": (0.8, 1.0)}, 64 | } 65 | default_settings.update(magnitude_overwrites[magnitude]) 66 | return tio.RandomGhosting, default_settings 67 | 68 | 69 | @TransformRegistry.register("spike") 70 | def spike(magnitude: TransformMagnitude): 71 | default_settings = { 72 | "num_spikes": 1, 73 | } 74 | magnitude_overwrites = { 75 | TransformMagnitude.LOW: {"intensity": (0.1, 0.2)}, 76 | TransformMagnitude.MEDIUM: {"intensity": (0.3, 0.5)}, 77 | TransformMagnitude.HIGH: {"intensity": (0.7, 0.9)}, 78 | } 79 | # negative intensities aren't used here because tio would then sample from [-I, I] and 80 | # I want the artefacts to be always visible 81 | default_settings.update(magnitude_overwrites[magnitude]) 82 | return tio.RandomSpike, default_settings 83 | 84 | 85 | @TransformRegistry.register("affine") 86 | def affine(magnitude: TransformMagnitude): 87 | default_settings = { 88 | "translation": 0, 89 | } 90 | magnitude_overwrites = { 91 | TransformMagnitude.LOW: {"degrees": 5, "scales": (0.9, 1.4)}, 92 | TransformMagnitude.MEDIUM: {"degrees": (5, 15), "scales": (0.7, 1.8)}, 93 | TransformMagnitude.HIGH: {"degrees": (15, 30), "scales": (0.6, 2.0)}, 94 | } 95 | default_settings.update(magnitude_overwrites[magnitude]) 96 | return tio.RandomAffine, default_settings 97 | 98 | 99 | # TODO for 2D this transform is not working as expected. 100 | # The results look just like an affine transform 101 | # @TransformRegistry.register("motion") 102 | # def motion(magnitude: TransformMagnitude): 103 | # default_settings = { 104 | # } 105 | # magnitude_overwrites = { 106 | # TransformMagnitude.LOW: {}, 107 | # TransformMagnitude.MEDIUM: {}, 108 | # TransformMagnitude.HIGH: {}, 109 | # } 110 | # default_settings.update(magnitude_overwrites[magnitude]) 111 | # return tio.RandomMotion, default_settings 112 | 113 | 114 | # TODO This transform depends on the image size and spacing, so it's hard to set general values here 115 | # @TransformRegistry.register("elastic") 116 | # def elastic(magnitude: TransformMagnitude): 117 | # default_settings = { 118 | # "num_control_points": 7, 119 | # } 120 | # magnitude_overwrites = { 121 | # TransformMagnitude.LOW: {"max_displacement": 2}, 122 | # TransformMagnitude.MEDIUM: {"max_displacement": 3.5}, 123 | # TransformMagnitude.HIGH: {"max_displacement": 4.5}, 124 | # } 125 | # default_settings.update(magnitude_overwrites[magnitude]) 126 | # return tio.RandomElasticDeformation, default_settings 127 | -------------------------------------------------------------------------------- /src/segmentation_failures/data/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .monai_modules import SimpleBraTS 2 | from .nnunet_module import NNunetDataModule 3 | -------------------------------------------------------------------------------- /src/segmentation_failures/data/datamodules/additional_readers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Sequence 4 | from typing import Any 5 | 6 | import numpy as np 7 | from monai.config import PathLike 8 | from monai.data.image_reader import ImageReader 9 | from monai.data.utils import is_supported_format 10 | from monai.utils import ensure_tuple 11 | from tifffile import TiffFile 12 | 13 | 14 | class TiffReader(ImageReader): 15 | def __init__(self, rgb: bool, **kwargs): 16 | self.rgb = rgb # assume 2d rgb image 17 | self.kwargs = kwargs 18 | 19 | def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: 20 | """ 21 | Verify whether the specified `filename` is supported by the current reader. 22 | This method should return True if the reader is able to read the format suggested by the 23 | `filename`. 24 | 25 | Args: 26 | filename: file name or a list of file names to read. 27 | if a list of files, verify all the suffixes. 28 | 29 | """ 30 | return is_supported_format(filename, ["tif", "tiff"]) 31 | 32 | def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: 33 | """ 34 | Read image data from specified file or files. 35 | Note that it returns a data object or a sequence of data objects. 36 | 37 | Args: 38 | data: file name or a list of file names to read. 39 | kwargs: additional args for actual `read` API of 3rd party libs. 40 | 41 | """ 42 | img_: list[TiffFile] = [] 43 | 44 | filenames: Sequence[PathLike] = ensure_tuple(data) 45 | kwargs_ = self.kwargs.copy() 46 | kwargs_.update(kwargs) 47 | for name in filenames: 48 | img_.append(TiffFile(name, **kwargs_)) # type: ignore 49 | return img_ if len(filenames) > 1 else img_[0] 50 | 51 | def get_data(self, img) -> tuple[np.ndarray, dict]: 52 | """ 53 | Extract data array and metadata from loaded image and return them. 54 | This function must return two objects, the first is a numpy array of image data, 55 | the second is a dictionary of metadata. 56 | 57 | Args: 58 | img: an image object loaded from an image file or a list of image objects. 59 | 60 | """ 61 | img_array: list[np.ndarray] = [] 62 | meta_dict = {"original_channel_dim": float("nan")} 63 | 64 | for curr_img in ensure_tuple(img): 65 | curr_arr = curr_img.asarray() 66 | assert len(curr_arr.shape) == 2 67 | img_array.append(curr_arr) 68 | curr_img.close() 69 | # not sure if necessary, but TifFile is used as a context manager in the examples 70 | if len(img_array) == 1: 71 | img_array = img_array[0] 72 | else: 73 | img_array = np.stack(img_array, axis=0) 74 | meta_dict["original_channel_dim"] = 0 75 | return img_array, meta_dict 76 | -------------------------------------------------------------------------------- /src/segmentation_failures/data/datamodules/dummy_modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset, default_collate 4 | 5 | from segmentation_failures.data.datamodules.nnunet_module import NNunetDataModule 6 | 7 | 8 | class DummyDataset(Dataset): 9 | def __init__(self, data, target_data, spacing=(1, 1, 1)): 10 | self.data = data 11 | self.target_data = target_data 12 | self.spacing = spacing 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, idx): 18 | img_size = self.data[idx].shape[1:] 19 | return { 20 | "data": self.data[idx], 21 | "target": self.target_data[idx], 22 | "keys": f"sample_{idx}", 23 | "properties": [ 24 | { 25 | "spacing": np.array(self.spacing), 26 | "shape_after_cropping_and_before_resampling": img_size, 27 | "shape_before_cropping": img_size, 28 | "bbox_used_for_cropping": [[i, img_size[i]] for i in range(len(img_size))], 29 | } 30 | ], 31 | } 32 | 33 | 34 | class DummyNNunetDataModule(NNunetDataModule): 35 | # NOTE: the batch collation by pytorch converts properties to tensors, which causes problems with some callbacks; 36 | # currently need to disable them when using this module 37 | def __init__( 38 | self, 39 | dummy_num_samples: int, 40 | dummy_num_channels: int, 41 | dummy_img_size: list[int], 42 | dummy_batch_size: int, 43 | **kwargs, 44 | ): 45 | super().__init__(**kwargs) 46 | self.dummy_batch_size = dummy_batch_size 47 | xs = torch.randn(dummy_num_samples, dummy_num_channels, *dummy_img_size) 48 | if self.nnunet_trainer.label_manager.has_regions: 49 | num_classes = len(self.nnunet_trainer.label_manager.foreground_regions) 50 | ys = torch.randint(0, 2, size=(dummy_num_samples, num_classes, *dummy_img_size)) 51 | else: 52 | num_classes = len(self.nnunet_trainer.label_manager.foreground_labels) 53 | ys = torch.randint(0, num_classes, size=(dummy_num_samples, 1, *dummy_img_size)) 54 | self.dummy_dataset = DummyDataset(xs, ys, spacing=self.preprocess_info["spacing"]) 55 | 56 | def prepare_data(self): 57 | pass 58 | 59 | def setup(self, stage=None): 60 | pass 61 | 62 | def train_dataloader(self): 63 | return DataLoader(self.dummy_dataset, batch_size=self.dummy_batch_size) 64 | 65 | def val_dataloader(self): 66 | return DataLoader(self.dummy_dataset, batch_size=self.dummy_batch_size) 67 | 68 | def test_dataloader(self): 69 | return DataLoader(self.dummy_dataset, batch_size=self.dummy_batch_size) 70 | -------------------------------------------------------------------------------- /src/segmentation_failures/data/dataset_conversion/nnunet_fets22.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | 6 | import pandas as pd 7 | import SimpleITK as sitk 8 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 9 | from sklearn.model_selection import train_test_split 10 | from tqdm import tqdm 11 | 12 | from segmentation_failures.utils.io import save_json 13 | 14 | TASK_NAME = "Dataset502_FeTS22" 15 | 16 | 17 | def convert_labels_to_nnunet( 18 | orig_label_file: Path, target_label_file: Path, mapping: dict[int, int] 19 | ): 20 | # convert labels from brats format to nnunet format 21 | # nnUNet has a different label convention than BraTS; convert back here 22 | seg_sitk = sitk.ReadImage(str(orig_label_file)) 23 | seg = sitk.GetArrayFromImage(seg_sitk) 24 | new_seg = seg.copy() 25 | for orig_label, new_label in mapping.items(): 26 | new_seg[seg == orig_label] = new_label 27 | new_seg_sitk = sitk.GetImageFromArray(new_seg) 28 | new_seg_sitk.CopyInformation(seg_sitk) 29 | sitk.WriteImage(new_seg_sitk, str(target_label_file)) 30 | 31 | 32 | def default_split(split_csv: str, seed: int = 0, test_size_inst1=0.2): 33 | split_df = pd.read_csv(split_csv) 34 | all_cases = split_df.Subject_ID.tolist() 35 | inst1_cases = split_df.loc[split_df["Partition_ID"] == 1, "Subject_ID"].tolist() 36 | train_cases, _ = train_test_split(inst1_cases, test_size=test_size_inst1, random_state=seed) 37 | test_cases = list(set(all_cases).difference(train_cases)) 38 | domain_dict = split_df.set_index("Subject_ID")["Partition_ID"].to_dict() 39 | return train_cases, test_cases, domain_dict 40 | 41 | 42 | def copy_case( 43 | case_dir: Path, 44 | image_target_dir: Path, 45 | label_target_dir: Path, 46 | modalities: dict[int, str], 47 | seg_suffix: str = "seg", 48 | ): 49 | assert case_dir.is_dir() 50 | case_id = case_dir.name 51 | for modality_id, modality_name in modalities.items(): 52 | image_file = case_dir / f"{case_id}_{modality_name}.nii.gz" 53 | shutil.copy(image_file, image_target_dir / f"{case_id}_{modality_id:04d}.nii.gz") 54 | label_file = case_dir / f"{case_id}_{seg_suffix}.nii.gz" 55 | shutil.copy(label_file, label_target_dir / f"{case_id}.nii.gz") 56 | 57 | 58 | def main(): 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("raw_data_dir", type=str, help="Path to raw data directory") 61 | args = parser.parse_args() 62 | seed = 42 63 | MODALITIES = { 64 | 0: "t1", 65 | 1: "t1ce", 66 | 2: "t2", 67 | 3: "flair", 68 | } 69 | perc_id_test_cases = 0.15 70 | source_dir = Path(args.raw_data_dir) 71 | target_root_dir = Path(os.environ["nnUNet_raw"]) / TASK_NAME 72 | target_root_dir.mkdir() 73 | 74 | images_train_dir = target_root_dir / "imagesTr" 75 | images_test_dir = target_root_dir / "imagesTs" 76 | labels_train_dir = target_root_dir / "labelsTr" 77 | labels_test_dir = target_root_dir / "labelsTs" 78 | 79 | # split cases into train/test 80 | # default split: Use the partition1.csv from the data directory. 81 | # - site 1 (largest) is split into training/testing 82 | # - other sites are 100% testing. 83 | train_cases, test_cases, domain_mapping = default_split( 84 | source_dir / "partitioning_1.csv", 85 | seed=seed, 86 | test_size_inst1=perc_id_test_cases, 87 | ) 88 | # copy all images and labels to correct locations 89 | images_train_dir.mkdir() 90 | labels_train_dir.mkdir() 91 | images_test_dir.mkdir() 92 | labels_test_dir.mkdir() 93 | print("Copying cases...") 94 | for case_id in tqdm(train_cases): 95 | copy_case(source_dir / case_id, images_train_dir, labels_train_dir, MODALITIES) 96 | for case_id in tqdm(test_cases): 97 | copy_case(source_dir / case_id, images_test_dir, labels_test_dir, MODALITIES) 98 | # save only domains of test cases 99 | domain_mapping = {k: v for k, v in domain_mapping.items() if k in test_cases} 100 | save_json(domain_mapping, target_root_dir / "domain_mapping_00.json") 101 | 102 | # map labels to nnunet format 103 | print("Converting labels to nnUNet format...") 104 | all_label_files = list(labels_train_dir.glob("*.nii.gz")) + list( 105 | labels_test_dir.glob("*.nii.gz") 106 | ) 107 | for label_file in tqdm(all_label_files): 108 | convert_labels_to_nnunet( 109 | label_file, 110 | label_file, 111 | mapping={ 112 | 1: 2, # necrosis 113 | 2: 1, # edema 114 | 4: 3, # enhancing 115 | }, 116 | ) 117 | generate_dataset_json( 118 | output_folder=str(target_root_dir), 119 | channel_names=MODALITIES, 120 | labels={ 121 | "background": 0, 122 | "whole_tumor": [1, 2, 3], 123 | "tumor_core": [2, 3], 124 | "enhancing_tumor": 3, 125 | }, 126 | regions_class_order=[1, 2, 3], 127 | num_training_cases=len(train_cases), 128 | file_ending=".nii.gz", 129 | dataset_name=TASK_NAME, 130 | dim=3, 131 | ) 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | # no need for special train-val splits here 137 | -------------------------------------------------------------------------------- /src/segmentation_failures/data/dataset_conversion/nnunet_kits23.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standard nnunet conversion: 3 | - Copy all images to the training directory 4 | - Convert naming 5 | - splits are generated automatically by nnunet 6 | """ 7 | 8 | import argparse 9 | import os 10 | import random 11 | import shutil 12 | from pathlib import Path 13 | 14 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 15 | from tqdm import tqdm 16 | 17 | from segmentation_failures.utils.io import load_json, save_json 18 | 19 | TASK_NAME = "Dataset515_KiTS23" 20 | 21 | 22 | def copy_case( 23 | case_dir: Path, 24 | image_target_dir: Path, 25 | label_target_dir: Path, 26 | ): 27 | assert case_dir.is_dir() 28 | case_id = case_dir.name 29 | shutil.copy(case_dir / "imaging.nii.gz", image_target_dir / f"{case_id}_0000.nii.gz") 30 | shutil.copy(case_dir / "segmentation.nii.gz", label_target_dir / f"{case_id}.nii.gz") 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("raw_data_dir", type=str, help="Path to raw data directory") 36 | parser.add_argument( 37 | "--use_default_splits", action="store_true", help="Use default train/test split." 38 | ) 39 | args = parser.parse_args() 40 | source_dir = Path(args.raw_data_dir) 41 | target_root_dir = Path(os.environ["nnUNet_raw"]) / TASK_NAME 42 | target_root_dir.mkdir() 43 | default_split_path = None 44 | if args.use_default_splits: 45 | default_split_path = ( 46 | Path(__file__).resolve().parents[4] 47 | / "dataset_splits" 48 | / TASK_NAME 49 | / "splits_final.json" 50 | ) 51 | images_train_dir = target_root_dir / "imagesTr" 52 | images_test_dir = target_root_dir / "imagesTs" 53 | labels_train_dir = target_root_dir / "labelsTr" 54 | labels_test_dir = target_root_dir / "labelsTs" 55 | 56 | # copy all images and labels to correct locations 57 | images_train_dir.mkdir() 58 | labels_train_dir.mkdir() 59 | images_test_dir.mkdir() 60 | labels_test_dir.mkdir() 61 | print("Copying cases...") 62 | # random split into train and test 63 | case_ids = [x.name for x in source_dir.iterdir() if x.is_dir()] 64 | if default_split_path is not None: 65 | all_splits = load_json(default_split_path) 66 | train_cases = all_splits[0]["train"] + all_splits[0]["val"] 67 | test_cases = [x for x in case_ids if x not in train_cases] 68 | else: 69 | random.seed(420000) 70 | test_cases = random.sample(case_ids, k=int(len(case_ids) * 0.25)) 71 | train_cases = list(set(case_ids) - set(test_cases)) 72 | for case in tqdm(train_cases): 73 | case_dir = source_dir / case 74 | if case_dir.is_dir(): 75 | copy_case(case_dir, images_train_dir, labels_train_dir) 76 | for case in tqdm(test_cases): 77 | case_dir = source_dir / case 78 | if case_dir.is_dir(): 79 | copy_case(case_dir, images_test_dir, labels_test_dir) 80 | # save only domains of test cases 81 | domain_mapping = {p: "ID" for p in test_cases} 82 | save_json(domain_mapping, target_root_dir / "domain_mapping_00.json") 83 | 84 | generate_dataset_json( 85 | output_folder=str(target_root_dir), 86 | channel_names={0: "CT"}, 87 | labels={"background": 0, "kidney_and_masses": (1, 2, 3), "masses": (2, 3), "tumor": 2}, 88 | num_training_cases=len(train_cases), 89 | file_ending=".nii.gz", 90 | regions_class_order=(1, 3, 2), 91 | # order has to be this because tumor is the last region (== label 2) 92 | dataset_name=TASK_NAME, 93 | overwrite_image_reader_writer="NibabelIOWithReorient", 94 | description="KiTS2023", 95 | dim=3, 96 | ) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | # no need for special train-val splits here 102 | -------------------------------------------------------------------------------- /src/segmentation_failures/data/dataset_conversion/nnunet_mvseg23.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standard nnunet conversion: 3 | - Copy all images to the training directory 4 | - Convert naming 5 | - splits are generated automatically by nnunet 6 | """ 7 | 8 | import argparse 9 | import os 10 | import random 11 | import shutil 12 | from pathlib import Path 13 | 14 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 15 | from tqdm import tqdm 16 | 17 | from segmentation_failures.utils.io import load_json, save_json 18 | 19 | TASK_NAME = "Dataset514_MVSeg23" 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("raw_data_dir", type=str, help="Path to raw data directory") 25 | parser.add_argument( 26 | "--use_default_splits", action="store_true", help="Use default train/test split." 27 | ) 28 | args = parser.parse_args() 29 | source_dir = Path(args.raw_data_dir) 30 | target_root_dir = Path(os.environ["nnUNet_raw"]) / TASK_NAME 31 | target_root_dir.mkdir() 32 | default_split_path = None 33 | if args.use_default_splits: 34 | default_split_path = ( 35 | Path(__file__).resolve().parents[4] 36 | / "dataset_splits" 37 | / TASK_NAME 38 | / "splits_final.json" 39 | ) 40 | images_train_dir = target_root_dir / "imagesTr" 41 | images_test_dir = target_root_dir / "imagesTs" 42 | labels_train_dir = target_root_dir / "labelsTr" 43 | labels_test_dir = target_root_dir / "labelsTs" 44 | 45 | # copy all images and labels to correct locations 46 | images_train_dir.mkdir() 47 | labels_train_dir.mkdir() 48 | images_test_dir.mkdir() 49 | labels_test_dir.mkdir() 50 | # split into train and test 51 | case_dict = {} 52 | for lab_file in (source_dir / "val").glob("*-label.nii.gz"): 53 | case_id = lab_file.name.removesuffix("-label.nii.gz") 54 | img_file = source_dir / "val" / f"{case_id}-US.nii.gz" 55 | assert case_id not in case_dict 56 | case_dict[case_id] = {"img": img_file, "lab": lab_file} 57 | for lab_file in (source_dir / "train").glob("*-label.nii.gz"): 58 | case_id = lab_file.name.removesuffix("-label.nii.gz") 59 | img_file = source_dir / "train" / f"{case_id}-US.nii.gz" 60 | assert case_id not in case_dict 61 | case_dict[case_id] = {"img": img_file, "lab": lab_file} 62 | if default_split_path is not None: 63 | all_splits = load_json(default_split_path) 64 | train_cases = all_splits[0]["train"] + all_splits[0]["val"] 65 | test_cases = [x for x in case_dict if x not in train_cases] 66 | else: 67 | random.seed(420000) 68 | test_cases = random.sample(list(case_dict.keys()), k=55) 69 | train_cases = list(set(case_dict.keys()) - set(test_cases)) 70 | # copying 71 | print(f"Copying {len(train_cases)} training cases and {len(test_cases)} test cases...") 72 | for case_id in tqdm(train_cases): 73 | shutil.copy(case_dict[case_id]["img"], images_train_dir / f"{case_id}_0000.nii.gz") 74 | shutil.copy(case_dict[case_id]["lab"], labels_train_dir / f"{case_id}.nii.gz") 75 | for case_id in tqdm(test_cases): 76 | shutil.copy(case_dict[case_id]["img"], images_test_dir / f"{case_id}_0000.nii.gz") 77 | shutil.copy(case_dict[case_id]["lab"], labels_test_dir / f"{case_id}.nii.gz") 78 | # save only domains of test cases 79 | domain_mapping = {p: "ID" for p in test_cases} 80 | save_json(domain_mapping, target_root_dir / "domain_mapping_00.json") 81 | 82 | generate_dataset_json( 83 | output_folder=str(target_root_dir), 84 | channel_names={0: "3D-US"}, 85 | labels={ 86 | "background": 0, 87 | "Posterior leaflet": 1, 88 | "Anterior leaflet": 2, 89 | }, 90 | num_training_cases=len(train_cases), 91 | file_ending=".nii.gz", 92 | # order has to be this because tumor is the last region (== label 2) 93 | dataset_name=TASK_NAME, 94 | description="MVSeg2023", 95 | dim=3, 96 | ) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | # no need for special train-val splits here 102 | -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .experiment_data import ExperimentData 2 | from .failure_detection.fd_analysis import evaluate_failures 3 | from .ood_detection.ood_analysis import evaluate_ood 4 | -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/failure_detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/segmentation_failures_benchmark/a1af98be0f93c2bdc30ffe5bb6eda531c485d87d/src/segmentation_failures/evaluation/failure_detection/__init__.py -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/ood_detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/segmentation_failures_benchmark/a1af98be0f93c2bdc30ffe5bb6eda531c485d87d/src/segmentation_failures/evaluation/ood_detection/__init__.py -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/ood_detection/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/IML-DKFZ/fd-shifts/tree/main 2 | from __future__ import annotations 3 | 4 | from dataclasses import dataclass 5 | from functools import cached_property 6 | from typing import Any, Callable, TypeVar, cast 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | from sklearn import metrics as skm 11 | from typing_extensions import ParamSpec 12 | 13 | _metric_funcs = {} 14 | 15 | T = TypeVar("T") 16 | P = ParamSpec("P") 17 | 18 | 19 | def may_raise_sklearn_exception(func: Callable[P, T]) -> Callable[P, T]: 20 | def _inner_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 21 | try: 22 | return func(*args, **kwargs) 23 | except ValueError: 24 | return cast(T, np.nan) 25 | 26 | return _inner_wrapper 27 | 28 | 29 | @dataclass 30 | class StatsCache: 31 | """Cache for stats computed by scikit used by multiple metrics. 32 | 33 | Attributes: 34 | confids (array_like): Confidence values associated with the predictions 35 | risks (array_like): Risk values associated with the predictions 36 | """ 37 | 38 | scores: npt.NDArray[Any] 39 | ood_labels: npt.NDArray[Any] 40 | 41 | @cached_property 42 | def roc_curve_stats(self) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 43 | fpr, tpr, _ = skm.roc_curve(self.ood_labels, self.scores, pos_label=1) 44 | return fpr, tpr 45 | 46 | # maybe add PR curve stats here 47 | 48 | 49 | def register_metric_func(name: str) -> Callable: 50 | def _inner_wrapper(func: Callable) -> Callable: 51 | _metric_funcs[name] = func 52 | return func 53 | 54 | return _inner_wrapper 55 | 56 | 57 | def get_metric_function(metric_name: str) -> Callable[[StatsCache], float]: 58 | return _metric_funcs[metric_name] 59 | 60 | 61 | @register_metric_func("ood_auc") 62 | @may_raise_sklearn_exception 63 | def failauc(stats_cache: StatsCache) -> float: 64 | fpr, tpr = stats_cache.roc_curve_stats 65 | return skm.auc(fpr, tpr) 66 | 67 | 68 | @register_metric_func("ood_fpr@95tpr") 69 | @may_raise_sklearn_exception 70 | def fpr_at_95_tpr(stats_cache: StatsCache) -> float: 71 | fpr, tpr = stats_cache.roc_curve_stats 72 | return np.min(fpr[np.argwhere(tpr >= 0.9495)]) 73 | 74 | 75 | @register_metric_func("ood_detection_error@95tpr") 76 | @may_raise_sklearn_exception 77 | def deterror_at_95_tpr(stats_cache: StatsCache) -> float: 78 | fpr, tpr = stats_cache.roc_curve_stats 79 | tpr_mask = np.argwhere(tpr >= 0.9495) 80 | fpr95 = np.min(fpr[tpr_mask]) 81 | tpr95 = np.min(tpr[tpr_mask]) 82 | return 0.5 * (1 - tpr95 + fpr95) 83 | -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/ood_detection/ood_analysis.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from loguru import logger 7 | from omegaconf import OmegaConf 8 | 9 | import segmentation_failures.evaluation.ood_detection.metrics as ood_metrics 10 | from segmentation_failures.evaluation.experiment_data import ExperimentData 11 | 12 | 13 | def compute_ood_scores( 14 | confid_arr: np.ndarray, 15 | ood_labels: np.ndarray, 16 | query_metrics: List[str], 17 | ) -> Dict[str, float]: 18 | """Based on confidence values and OOD labels, compute some OOD detection metrics. 19 | 20 | Args: 21 | confid_arr (np.ndarray): 1D-Array with confidences. 22 | ood_labels (np.ndarray): 1D-Array with OOD-labels (binary). 23 | query_metrics (List[str]): List of OOD metrics to compute for the data. 24 | 25 | Returns: 26 | Dict[str, float]: Dictionary of scalar OOD metrics. 27 | """ 28 | # arrays should be 1D and have the same length 29 | assert len(confid_arr.shape) == len(ood_labels.shape) == 1 30 | assert len(confid_arr) == len(ood_labels) 31 | ood_scores: Dict[str, float] = {} 32 | if np.any(np.isnan(confid_arr)): 33 | logger.warning("NaN values in confidence scores. Inserting NaN in metrics.") 34 | for score in query_metrics: 35 | ood_scores[score] = np.nan 36 | return ood_scores 37 | stats = ood_metrics.StatsCache( 38 | scores=-confid_arr, # higher confidence -> lower ood score 39 | ood_labels=ood_labels, 40 | ) 41 | # scores 42 | for score in query_metrics: 43 | score_fn = ood_metrics.get_metric_function(score) 44 | ood_scores[score] = score_fn(stats) 45 | # curves: maybe later 46 | return ood_scores 47 | 48 | 49 | def evaluate_ood(expt_data: ExperimentData, output_dir: Path, config: OmegaConf): 50 | # this should compute different FD-metrics and save them as a dataframe to the output_dir 51 | # I just compute one risk for every segmentation metric present in the dataframe. 52 | id_domains = config.id_domain 53 | if id_domains is None: 54 | logger.warning("No ID domain specified. Skipping OOD analysis.") 55 | return 56 | if isinstance(id_domains, str): 57 | id_domains = [id_domains] 58 | domains = np.unique(expt_data.domain_names).tolist() 59 | if set(domains) == set(id_domains): 60 | logger.warning("All domains are ID domains. Skipping OOD analysis.") 61 | return 62 | domains.append("all_ood_") # also evaluate on all ood domains together 63 | output_dir.mkdir(exist_ok=True) 64 | output_file = output_dir / "ood_metrics.csv" 65 | if output_file.exists(): 66 | # get an alternative file name 67 | i = 1 68 | while output_file.exists(): 69 | output_file = output_dir / f"ood_metrics_{i}.csv" 70 | i += 1 71 | logger.warning( 72 | f"Output file {output_dir / 'ood_metrics.csv'} already exists. Saving to {output_file} instead." 73 | ) 74 | 75 | # Also compute OOD scores 76 | analysis_results = [] 77 | ood_metrics = config.ood_metrics 78 | if len(ood_metrics) == 0: 79 | return 80 | if not set(id_domains).issubset(domains): 81 | logger.warning( 82 | f"ID domain(s) {id_domains} not found in experiment data. Maybe it is misconfigured?" 83 | ) 84 | for curr_domain in domains: 85 | if curr_domain in id_domains: 86 | continue 87 | for confid_idx, confid_name in enumerate(expt_data.confid_scores_names): 88 | id_mask = np.isin(np.array(expt_data.domain_names), id_domains) 89 | if curr_domain == "all_ood_": 90 | ood_mask = np.logical_not(id_mask) 91 | else: 92 | ood_mask = np.array(expt_data.domain_names) == curr_domain 93 | testset_mask = np.logical_or(ood_mask, id_mask) 94 | subset_confid = expt_data.confid_scores[testset_mask, confid_idx] 95 | subset_labels = ood_mask[testset_mask].astype(int) 96 | 97 | scores = compute_ood_scores( 98 | confid_arr=subset_confid, 99 | ood_labels=subset_labels, 100 | query_metrics=ood_metrics, 101 | ) 102 | result_row = { 103 | "confid_name": confid_name, 104 | "domain": curr_domain, 105 | "n_cases_id": np.sum(id_mask), 106 | "n_cases_ood": np.sum(ood_mask), 107 | } 108 | result_row.update(scores) 109 | analysis_results.append(result_row) 110 | 111 | pd.DataFrame(analysis_results).to_csv(output_file) 112 | -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/segmentation_failures_benchmark/a1af98be0f93c2bdc30ffe5bb6eda531c485d87d/src/segmentation_failures/evaluation/segmentation/__init__.py -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/segmentation/custom_metrics/hausdorff.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | from monai.metrics import HausdorffDistanceMetric as MonaiHDMetric 6 | from monai.metrics.utils import ( 7 | get_mask_edges, 8 | get_surface_distance, 9 | ignore_background, 10 | is_binary_tensor, 11 | ) 12 | 13 | 14 | class HausdorffDistanceMetric(MonaiHDMetric): 15 | """ 16 | The only difference to MONAI's implementation is that I set fixed values for the empty GT/pred cases. 17 | """ 18 | 19 | def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore 20 | """ 21 | Args: 22 | y_pred: input data to compute, typical segmentation model output. 23 | It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values 24 | should be binarized. 25 | y: ground truth to compute the distance. It must be one-hot format and first dim is batch. 26 | The values should be binarized. 27 | 28 | Raises: 29 | ValueError: when `y` is not a binarized tensor. 30 | ValueError: when `y_pred` has less than three dimensions. 31 | """ 32 | is_binary_tensor(y_pred, "y_pred") 33 | is_binary_tensor(y, "y") 34 | 35 | dims = y_pred.ndimension() 36 | if dims < 3: 37 | raise ValueError("y_pred should have at least three dimensions.") 38 | # compute (BxC) for each channel for each batch 39 | return compute_hausdorff_distance( 40 | y_pred=y_pred, 41 | y=y, 42 | include_background=self.include_background, 43 | distance_metric=self.distance_metric, 44 | percentile=self.percentile, 45 | directed=self.directed, 46 | ) 47 | 48 | 49 | def compute_hausdorff_distance( 50 | y_pred: Union[np.ndarray, torch.Tensor], 51 | y: Union[np.ndarray, torch.Tensor], 52 | include_background: bool = False, 53 | distance_metric: str = "euclidean", 54 | percentile: Optional[float] = None, 55 | directed: bool = False, 56 | ): 57 | """ 58 | Compute the Hausdorff distance. 59 | 60 | Args: 61 | y_pred: input data to compute, typical segmentation model output. 62 | It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values 63 | should be binarized. 64 | y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. 65 | The values should be binarized. 66 | include_background: whether to skip distance computation on the first channel of 67 | the predicted output. Defaults to ``False``. 68 | distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] 69 | the metric used to compute surface distance. Defaults to ``"euclidean"``. 70 | percentile: an optional float number between 0 and 100. If specified, the corresponding 71 | percentile of the Hausdorff Distance rather than the maximum result will be achieved. 72 | Defaults to ``None``. 73 | directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. 74 | """ 75 | 76 | if not include_background: 77 | y_pred, y = ignore_background(y_pred=y_pred, y=y) 78 | if isinstance(y, torch.Tensor): 79 | y = y.float() 80 | if isinstance(y_pred, torch.Tensor): 81 | y_pred = y_pred.float() 82 | 83 | if y.shape != y_pred.shape: 84 | raise ValueError( 85 | f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}." 86 | ) 87 | 88 | batch_size, n_class = y_pred.shape[:2] 89 | hd = np.empty((batch_size, n_class)) 90 | # This is arbitrary: I use half the diagonal as the worst HD value. 91 | WORST_VAL = 0.5 * np.linalg.norm(y_pred.shape[2:]) 92 | for b, c in np.ndindex(batch_size, n_class): 93 | (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) 94 | if not np.any(edges_gt) and not np.any(edges_pred): 95 | # empty gt and empty pred => perfect score 0 96 | hd[b, c] = 0 97 | elif not np.any(edges_gt) or not np.any(edges_pred): 98 | # one of them is zero => worst score. 99 | hd[b, c] = WORST_VAL 100 | else: 101 | distance_1 = compute_percent_hausdorff_distance( 102 | edges_pred, edges_gt, distance_metric, percentile 103 | ) 104 | if directed: 105 | hd[b, c] = distance_1 106 | else: 107 | distance_2 = compute_percent_hausdorff_distance( 108 | edges_gt, edges_pred, distance_metric, percentile 109 | ) 110 | hd[b, c] = max(distance_1, distance_2) 111 | return torch.from_numpy(hd) 112 | 113 | 114 | def compute_percent_hausdorff_distance( 115 | edges_pred: np.ndarray, 116 | edges_gt: np.ndarray, 117 | distance_metric: str = "euclidean", 118 | percentile: Optional[float] = None, 119 | ): 120 | """ 121 | This function is used to compute the directed Hausdorff distance. 122 | """ 123 | 124 | surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) 125 | 126 | # for both pred and gt do not have foreground 127 | if surface_distance.shape == (0,): 128 | return np.nan 129 | 130 | if not percentile: 131 | return surface_distance.max() 132 | 133 | if 0 <= percentile <= 100: 134 | return np.percentile(surface_distance, percentile) 135 | raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.") 136 | -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/segmentation/custom_metrics/surface_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from monai.metrics import HausdorffDistanceMetric, SurfaceDiceMetric 4 | 5 | 6 | class SurfaceDiceEmptyHandlingMetric(SurfaceDiceMetric): 7 | # I think MONAI's class_thresholds argument documentation is wrong; it has units of mm (same as spacing) 8 | def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor: 9 | # tensor shapes BCHW[D] 10 | result_sd = super()._compute_tensor(y_pred, y, **kwargs) 11 | # handle empty GT case 12 | empty_gt = y.sum(dim=list(range(2, y_pred.ndim))) == 0 13 | empty_pred = y_pred.sum(dim=list(range(2, y_pred.ndim))) == 0 14 | if not self.include_background: 15 | # remove the first class 16 | empty_gt = empty_gt[:, 1:] 17 | empty_pred = empty_pred[:, 1:] 18 | empty_both = torch.logical_and(empty_gt, empty_pred) 19 | empty_gt_only = torch.logical_and(empty_gt, torch.logical_not(empty_pred)) 20 | result_sd[empty_both] = 1.0 21 | result_sd[empty_gt_only] = 0.0 22 | return result_sd 23 | 24 | 25 | # same for hausdorff distance 26 | class HausdorffDistanceEmptyHandlingMetric(HausdorffDistanceMetric): 27 | def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor: 28 | # tensor shapes BCHW[D] 29 | result_hd = super()._compute_tensor(y_pred, y, **kwargs) 30 | # handle empty GT case 31 | empty_gt = y.sum(dim=list(range(2, y_pred.ndim))) == 0 32 | empty_pred = y_pred.sum(dim=list(range(2, y_pred.ndim))) == 0 33 | if not self.include_background: 34 | # remove the first class 35 | empty_gt = empty_gt[:, 1:] 36 | empty_pred = empty_pred[:, 1:] 37 | empty_both = torch.logical_and(empty_gt, empty_pred) 38 | empty_gt_only = torch.logical_and(empty_gt, torch.logical_not(empty_pred)) 39 | result_hd[empty_both] = 0.0 40 | result_hd[empty_gt_only] = 0.5 * np.linalg.norm(y_pred.shape[2:]) 41 | # This is arbitrary: I use half the diagonal as the worst HD value. 42 | return result_hd 43 | -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/segmentation/distance_thresholds.py: -------------------------------------------------------------------------------- 1 | # I simply chose them as roughly the smallest median spacing 2 | DISTANCE_THRESHOLDS = { 3 | # dataset_id: [thresholds] (one for each class) 4 | "500": [1.0], 5 | "503": [1.0, 1.0, 1.0], 6 | "510": [1.5, 1.5, 1.5], 7 | "511": [1.25, 1.25, 1.25], 8 | "514": [1.0, 1.0], 9 | "515": [1.0330772532390826, 1.1328796488598762, 1.1498198361434828], 10 | # from https://github.com/neheller/kits23/blob/main/kits23/configuration/labels.py 11 | "520": [1.0], 12 | "521": [1.0], 13 | "531": [2.0, 2.0], 14 | "540": [0.05, 0.01, 0.02], 15 | "541": [0.05, 0.01, 0.02], 16 | "542": [0.05, 0.01, 0.02], 17 | "560": [0.02, 0.02, 0.02, 0.02, 0.02], 18 | } 19 | 20 | 21 | def get_distance_thresholds(dataset_id: int | str): 22 | if isinstance(dataset_id, int): 23 | dataset_id = f"{dataset_id:03d}" 24 | return DISTANCE_THRESHOLDS[dataset_id] 25 | -------------------------------------------------------------------------------- /src/segmentation_failures/evaluation/segmentation/segmentation_metrics.py: -------------------------------------------------------------------------------- 1 | # from abc import ABC 2 | from dataclasses import dataclass 3 | from typing import Callable, Dict, Tuple 4 | 5 | import monai.metrics as mn_metrics 6 | import numpy as np 7 | 8 | from segmentation_failures.evaluation.segmentation.custom_metrics.surface_distance import ( 9 | HausdorffDistanceEmptyHandlingMetric, 10 | SurfaceDiceEmptyHandlingMetric, 11 | ) 12 | 13 | 14 | @dataclass 15 | class MetricsInfo: 16 | higher_better: bool = True 17 | min_value: float = 0.0 18 | max_value: float = 0.0 19 | classwise: bool = True 20 | 21 | 22 | _metric_factories = {} 23 | 24 | 25 | def register_metric(name: str) -> Callable: 26 | def _inner_wrapper(metric_factory: Callable) -> Callable: 27 | _metric_factories[name] = metric_factory 28 | return metric_factory 29 | 30 | return _inner_wrapper 31 | 32 | 33 | def get_metrics( 34 | metric_list: list | str | None = None, **metric_kwargs 35 | ) -> Dict[str, mn_metrics.CumulativeIterationMetric]: 36 | return get_metrics_and_info(metric_list, **metric_kwargs)[0] 37 | 38 | 39 | def get_metrics_info( 40 | metric_list: list | str | None = None, **metric_kwargs 41 | ) -> Dict[str, MetricsInfo]: 42 | return get_metrics_and_info(metric_list, **metric_kwargs)[1] 43 | 44 | 45 | def get_metrics_and_info( 46 | metric_list: list | str | None = None, **metric_kwargs 47 | ) -> Tuple[dict[str, mn_metrics.CumulativeIterationMetric], dict[str, MetricsInfo]]: 48 | if isinstance(metric_list, str): 49 | metric_list = [metric_list] 50 | if metric_list is None: 51 | # default: get all 52 | metric_list = list(_metric_factories.keys()) 53 | metrics = {} 54 | infos = {} 55 | for k in metric_list: 56 | metrics[k], infos[k] = _metric_factories[k](**metric_kwargs) 57 | return metrics, infos 58 | 59 | 60 | @register_metric("dice") 61 | def dice_score(include_background=False): 62 | # ignore_empty=False makes sure that cases with empty GT and pred receive a score of 1 63 | metric = mn_metrics.DiceMetric( 64 | include_background=include_background, reduction="none", ignore_empty=False 65 | ) 66 | info = MetricsInfo(True, 0, 1, True) 67 | return metric, info 68 | 69 | 70 | @register_metric("hausdorff95") 71 | def hd95_score(include_background=False): 72 | metric = HausdorffDistanceEmptyHandlingMetric( 73 | include_background=include_background, percentile=95, reduction="none" 74 | ) 75 | info = MetricsInfo(False, 0, np.inf, True) 76 | return metric, info 77 | 78 | 79 | @register_metric("generalized_dice") 80 | def gen_dice_score(include_background=False, weight_type="square"): 81 | metric = mn_metrics.GeneralizedDiceScore( 82 | include_background=include_background, 83 | reduction="none", 84 | weight_type=weight_type, 85 | ) 86 | info = MetricsInfo(True, 0, 1, False) 87 | return metric, info 88 | 89 | 90 | @register_metric("surface_dice") 91 | def surface_dice_score(include_background=False, class_thresholds=1.0): 92 | if not isinstance(class_thresholds, (list, tuple)): 93 | class_thresholds = [class_thresholds] 94 | metric = SurfaceDiceEmptyHandlingMetric( 95 | class_thresholds=class_thresholds, 96 | include_background=include_background, 97 | reduction="none", 98 | use_subvoxels=True, 99 | ) 100 | info = MetricsInfo(True, 0, 1, True) 101 | return metric, info 102 | -------------------------------------------------------------------------------- /src/segmentation_failures/experiments/cluster.py: -------------------------------------------------------------------------------- 1 | # import json 2 | import re 3 | import shlex 4 | import subprocess 5 | from pathlib import Path 6 | 7 | from pssh.clients import SSHClient 8 | from pssh.exceptions import Timeout 9 | from rich import print 10 | from rich.syntax import Syntax 11 | 12 | from segmentation_failures.experiments.experiment import Experiment 13 | 14 | BASH_BSUB_COMMAND = r""" 15 | bsub -gpu num=1:j_exclusive=yes:gmem={gmem}\ 16 | -L /bin/bash \ 17 | -R "select[hname!='e230-dgx2-1']" \ 18 | -q gpu \ 19 | -u 'm.zenk@dkfz-heidelberg.de' \ 20 | -N \ 21 | -J "{name}" \ 22 | -o $HOME/job_outputs/%J.out \ 23 | -g /m167k/default_limited \ 24 | bash -li -c 'set -o pipefail; echo $LSB_JOBID && source $HOME/.bashrc_segfail_new && {command}' 25 | """ 26 | # I use a .bashrc file here, which is easier on the cluster. dotenv won't override this 27 | 28 | BASH_BSUB_COMMAND_CPU = r""" 29 | bsub -q long \ 30 | -n 8 \ 31 | -R "rusage[mem=100G]" \ 32 | -L /bin/bash \ 33 | -u 'm.zenk@dkfz-heidelberg.de' \ 34 | -N \ 35 | -J "{name}" \ 36 | -o $HOME/job_outputs/%J.out \ 37 | bash -li -c 'set -o pipefail; echo $LSB_JOBID && source $HOME/.bashrc_segfail_new && {command}' 38 | """ 39 | 40 | BASH_BASE_COMMAND = r""" 41 | HYDRA_FULL_ERROR=1 python $HOME/rsynced_code/segfail_project_new/src/segmentation_failures/scripts/{task}.py {overwrites} 42 | """ 43 | # task can be: train_seg, train_image_csf, test_fd 44 | 45 | # Run before executing anything on the cluster 46 | # later, it may be better to switch to a git repository and also log the commit hash 47 | RSYNC_CODE_COMMAND = r""" 48 | rsync -rtvu --delete --stats -f'- __pycache__/' -f'+ src/***' -f'+ pyproject.toml' -f'- *' {source_dir}/ m167k@odcf-worker02.dkfz.de:/home/m167k/rsynced_code/segfail_project_new 49 | """ 50 | 51 | 52 | def submit( 53 | _experiments: list[Experiment], 54 | task: str, 55 | dry_run: bool, 56 | user_overwrites: dict, 57 | cpu=False, 58 | gmem="10.7G", 59 | ): 60 | if len(_experiments) == 0: 61 | print("Nothing to run") 62 | return 63 | 64 | rsync_cmd = RSYNC_CODE_COMMAND.format(source_dir=Path(__file__).parents[3].absolute()) 65 | print( 66 | Syntax( 67 | rsync_cmd.strip(), 68 | "bash", 69 | word_wrap=True, 70 | background_color="default", 71 | ) 72 | ) 73 | if dry_run: 74 | rsync_cmd = rsync_cmd.replace("rsync", "rsync -n", 1) 75 | subprocess.run(shlex.split(rsync_cmd), check=True) 76 | 77 | client = SSHClient("odcf-worker02.dkfz.de") 78 | for experiment in _experiments: 79 | # Compile overwrites. Precedence: cmdline > experiment > global 80 | final_overwrites = experiment.overwrites() 81 | final_overwrites["hydra"] = "cluster" # affects expt folder naming 82 | final_overwrites.update(user_overwrites) 83 | overwrites = " ".join([f"'{k}={v}'" for k, v in final_overwrites.items()]) 84 | cmd = BASH_BASE_COMMAND.format( 85 | task=task, 86 | overwrites=overwrites, 87 | ).strip() 88 | 89 | print( 90 | Syntax( 91 | re.sub(r"([^,]) ", "\\1 \\\n\t", cmd), 92 | "bash", 93 | word_wrap=True, 94 | background_color="default", 95 | ) 96 | ) 97 | 98 | if cpu: 99 | cmd = BASH_BSUB_COMMAND_CPU.format( 100 | name=f"{experiment.task} {experiment.dataset}_{experiment.name}", 101 | command=cmd, 102 | ).strip() 103 | else: 104 | cmd = BASH_BSUB_COMMAND.format( 105 | name=f"{experiment.task} {experiment.dataset}_{experiment.name}", 106 | command=cmd, 107 | gmem=gmem, 108 | ).strip() 109 | 110 | print( 111 | Syntax( 112 | cmd, 113 | "bash", 114 | word_wrap=True, 115 | background_color="default", 116 | ) 117 | ) 118 | 119 | if dry_run: 120 | return 121 | try: 122 | with client.open_shell(read_timeout=1) as shell: 123 | shell.run(cmd) 124 | 125 | try: 126 | for line in shell.stdout: 127 | print(line) 128 | except Timeout: 129 | pass 130 | 131 | try: 132 | for line in shell.stderr: 133 | print(line) 134 | except Timeout: 135 | pass 136 | except subprocess.CalledProcessError: 137 | continue 138 | try: 139 | for line in shell.stdout: 140 | print(line) 141 | except Timeout: 142 | pass 143 | 144 | try: 145 | for line in shell.stderr: 146 | print(line) 147 | except Timeout: 148 | pass 149 | -------------------------------------------------------------------------------- /src/segmentation_failures/experiments/experiments_revision_ds_size.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | maybe_cluster="--cluster" 4 | 5 | # ================================================================================================= 6 | # INVESTIGATING dataset size 7 | group=revision_datasize_2408 8 | dataset=mnms 9 | 10 | # # SEGMENTATION 11 | # base_epochs=2500 12 | # for fold in 5 10 15 20; do 13 | # for seed in 0; do 14 | # # MnMs 15 | # if (( fold >= 20 )); then 16 | # check_every=80 17 | # epochs=$(( base_epochs * 16 )) 18 | # elif (( fold >= 15 )); then 19 | # check_every=40 20 | # epochs=$(( base_epochs * 8 )) 21 | # elif (( fold >= 10 )); then 22 | # check_every=20 23 | # epochs=$(( base_epochs * 4 )) 24 | # elif (( fold >= 5 )); then 25 | # check_every=10 26 | # epochs=$(( base_epochs * 2 )) 27 | # else 28 | # epochs=$base_epochs 29 | # check_every=5 30 | # fi 31 | # echo "epochs: $epochs" 32 | # python launcher.py --task train_seg --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --group $group $maybe_cluster \ 33 | # --overwrites trainer.check_val_every_n_epoch=$check_every trainer.max_epochs=$epochs 34 | # done 35 | # done 36 | 37 | 38 | # # ================================================================================================= 39 | # # CROSS-VALIDATION PIXEL CSF (prepare regression methods) 40 | 41 | # seed=0 42 | # for fold in {5..24}; do 43 | # python launcher.py --task validate_pixel_csf --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel baseline \ 44 | # --group $group $maybe_cluster 45 | # done 46 | 47 | # # THEN 48 | # for fold in 5 10 15 20; do 49 | # python prepare_auxdata.py --expt_group $group --start_fold $fold $maybe_cluster 50 | # done 51 | 52 | 53 | 54 | # # ================================================================================================= 55 | # # STAGE 2 TRAINING 56 | 57 | # seed=0 58 | # base_epochs=1000 59 | # for fold in 10; do 60 | # # python launcher.py --task train_image_csf --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_image mahalanobis_gonzalez \ 61 | # # --group $group $maybe_cluster 62 | # if (( fold >= 20 )); then 63 | # check_every=80 64 | # epochs=$(( base_epochs * 16 )) 65 | # elif (( fold >= 15 )); then 66 | # check_every=40 67 | # epochs=$(( base_epochs * 8 )) 68 | # elif (( fold >= 10 )); then 69 | # check_every=20 70 | # epochs=$(( base_epochs * 4 )) 71 | # elif (( fold >= 5 )); then 72 | # check_every=10 73 | # epochs=$(( base_epochs * 2 )) 74 | # else 75 | # epochs=$base_epochs 76 | # check_every=5 77 | # fi 78 | # echo "epochs: $epochs" 79 | # python launcher.py --task train_image_csf --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel baseline --csf_image quality_regression \ 80 | # --group $group $maybe_cluster --overwrites trainer.check_val_every_n_epoch=$check_every trainer.max_epochs=$epochs 81 | # # python launcher.py --task train_image_csf --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel baseline --csf_aggregation predictive_entropy+heuristic \ 82 | # # --group $group $maybe_cluster --cpu 83 | # done 84 | 85 | 86 | # # ================================================================================================= 87 | # # INFERENCE PIXEL CSF 88 | # seed=0 89 | # for fold in 5 10 15 20; do 90 | # python launcher.py --task test_pixel_csf --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel baseline \ 91 | # --group $group $maybe_cluster 92 | # # python launcher.py --task test_pixel_csf --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel mcdropout \ 93 | # # --group $group $maybe_cluster 94 | # # python launcher.py --task test_pixel_csf --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel deep_ensemble \ 95 | # # --group $group $maybe_cluster 96 | # done 97 | 98 | 99 | # # ================================================================================================= 100 | # # FAILURE DETECTION TESTING 101 | # seed=0 102 | # for fold in 5 10 15 20; do 103 | # python launcher.py --task test_fd --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel baseline --csf_aggregation all_simple \ 104 | # --group $group $maybe_cluster --cpu 105 | # python launcher.py --task test_fd --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel mcdropout --csf_aggregation all_simple \ 106 | # --group $group $maybe_cluster --cpu 107 | # python launcher.py --task test_fd --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel deep_ensemble --csf_aggregation all_simple \ 108 | # --group $group $maybe_cluster --cpu 109 | # python launcher.py --task test_fd --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_image mahalanobis_gonzalez \ 110 | # --group $group $maybe_cluster 111 | # python launcher.py --task test_fd --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel deep_ensemble --csf_image quality_regression \ 112 | # --group $group $maybe_cluster 113 | # python launcher.py --task test_fd --dataset $dataset --fold $fold --seed $seed --backbone dynamic_unet_dropout --csf_pixel baseline --csf_aggregation predictive_entropy+heuristic \ 114 | # --group $group $maybe_cluster --cpu 115 | # done 116 | -------------------------------------------------------------------------------- /src/segmentation_failures/experiments/nnunet_cluster.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from pssh.clients import SSHClient 4 | from pssh.exceptions import Timeout 5 | from rich import print 6 | from rich.syntax import Syntax 7 | 8 | BASH_BSUB_COMMAND = r""" 9 | bsub -gpu num=1:j_exclusive=yes:gmem={gmem}\ 10 | -L /bin/bash \ 11 | -q gpu \ 12 | -u 'm.zenk@dkfz-heidelberg.de' \ 13 | -B -N \ 14 | -R "select[hname!='e230-dgx2-2']" \ 15 | -J "{name}" \ 16 | -o $HOME/job_outputs/%J.out \ 17 | bash -li -c 'set -o pipefail; echo $LSB_JOBID && source $HOME/.bashrc_nnunetV2 && {command}' 18 | """ 19 | # I use a .bashrc file here, which is easier on the cluster. dotenv won't override this 20 | 21 | 22 | def get_gmem(): 23 | return "10.7G" 24 | 25 | 26 | def submit(nnunet_cmd: str, dry_run: bool): 27 | client = SSHClient("odcf-worker01.inet.dkfz-heidelberg.de") 28 | # Compile overwrites. Precedence: cmdline > experiment > global 29 | 30 | print( 31 | Syntax( 32 | re.sub(r"([^,]) ", "\\1 \\\n\t", nnunet_cmd), 33 | "bash", 34 | word_wrap=True, 35 | background_color="default", 36 | ) 37 | ) 38 | 39 | cmd = BASH_BSUB_COMMAND.format( 40 | name=nnunet_cmd, 41 | command=nnunet_cmd, 42 | gmem=get_gmem(), 43 | ).strip() 44 | 45 | print( 46 | Syntax( 47 | cmd, 48 | "bash", 49 | word_wrap=True, 50 | background_color="default", 51 | ) 52 | ) 53 | 54 | if dry_run: 55 | return 56 | with client.open_shell(read_timeout=1) as shell: 57 | shell.run(cmd) 58 | 59 | try: 60 | for line in shell.stdout: 61 | print(line) 62 | except Timeout: 63 | pass 64 | 65 | try: 66 | for line in shell.stderr: 67 | print(line) 68 | except Timeout: 69 | pass 70 | 71 | 72 | if __name__ == "__main__": 73 | import argparse 74 | 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument( 77 | "nnunet_cmd", help="The nnUNet command to run. You need to pass it in quotes." 78 | ) 79 | parser.add_argument("--dry-run", action="store_true") 80 | args = parser.parse_args() 81 | 82 | submit(args.nnunet_cmd, args.dry_run) 83 | -------------------------------------------------------------------------------- /src/segmentation_failures/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/segmentation_failures_benchmark/a1af98be0f93c2bdc30ffe5bb6eda531c485d87d/src/segmentation_failures/models/__init__.py -------------------------------------------------------------------------------- /src/segmentation_failures/models/confidence_aggregation/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | BackgroundAggregator, 3 | BoundaryAggregator, 4 | EuclideanDistanceMapAggregator, 5 | ForegroundAggregator, 6 | ForegroundSizeAggregator, 7 | ForegroundSizeWeightedAggregator, 8 | SimpleAggregator, 9 | get_aggregator, 10 | ) 11 | from .heuristic import HeuristicAggregationModule 12 | from .radiomics import RadiomicsAggregationModule 13 | from .simple_agg import SimpleAggModule 14 | -------------------------------------------------------------------------------- /src/segmentation_failures/models/confidence_aggregation/simple_agg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from loguru import logger 3 | from pytorch_lightning import LightningModule 4 | 5 | from segmentation_failures.models.confidence_aggregation.base import ( 6 | AbstractAggregator, 7 | AbstractEnsembleAggregator, 8 | ) 9 | from segmentation_failures.utils.data import load_dataset_json 10 | from segmentation_failures.utils.label_handling import convert_to_onehot_batch 11 | 12 | 13 | class SimpleAggModule(LightningModule): 14 | def __init__( 15 | self, 16 | aggregation_methods: dict[str, AbstractAggregator], 17 | dataset_id: int, 18 | ) -> None: 19 | super().__init__() 20 | if not hasattr(aggregation_methods, "items"): 21 | raise ValueError("Aggregation methods must be a dict.") 22 | self.aggregators = { 23 | k: am for k, am in aggregation_methods.items() if isinstance(am, AbstractAggregator) 24 | } 25 | self.ensemble_aggregators = { 26 | k: am 27 | for k, am in aggregation_methods.items() 28 | if isinstance(am, AbstractEnsembleAggregator) 29 | } 30 | dataset_json = load_dataset_json(dataset_id) 31 | self.regions_or_labels = [v for _, v in dataset_json["labels"].items()] 32 | 33 | def test_step(self, batch, batch_idx): 34 | # assume batch keys ["confid", "pred", "pred_samples", "confid_names"] 35 | # predictions are label maps here, not logits! 36 | aggregated_confid = {} 37 | assert batch["pred"].shape[1] == 1, batch["pred"].shape 38 | prediction = batch["pred"].squeeze(1) # labels, shape BHW[D] 39 | prediction_distr = batch.get("pred_samples", None) # shape BKHW[D], K = samples 40 | for confid_idx, pxl_confid in enumerate(batch["confid_names"]): 41 | if isinstance(pxl_confid, (list, tuple)): 42 | # artefact of default collation 43 | assert len(pxl_confid) == 1 44 | pxl_confid = pxl_confid[0] 45 | # confid shape BCHW[D] 46 | confid_map = batch["confid"][:, confid_idx] 47 | # NOTE prediction should be identical in each iteration 48 | for agg_name, agg_fn in self.aggregators.items(): 49 | confid_name = f"{pxl_confid}_{agg_name}" 50 | aggregated_confid[confid_name] = agg_fn(prediction, confid_map).cpu() 51 | if len(self.ensemble_aggregators) > 0 and prediction_distr is None: 52 | logger.warning( 53 | f"The aggregations {list(self.ensemble_aggregators)} were configured, but no ensemble predictions are available." 54 | ) 55 | elif len(self.ensemble_aggregators) > 0 and prediction_distr.shape[1] > 1: 56 | # these aggregators need a one/multi-hot label map 57 | n = prediction_distr.shape[1] 58 | batch_size = prediction_distr.shape[0] 59 | # Note: shape KBCHW[D] 60 | onehot_lab = torch.zeros( 61 | (n, batch_size, len(self.regions_or_labels)) + prediction.shape[1:] 62 | ) 63 | for b in range(batch_size): 64 | onehot_lab[:, b] = convert_to_onehot_batch( 65 | prediction_distr[b].unsqueeze(1), self.regions_or_labels 66 | ) 67 | onehot_lab = onehot_lab.to(torch.bool) 68 | # can also be multihot 69 | for agg_name, agg_fn in self.ensemble_aggregators.items(): 70 | aggregated_confid[agg_name] = agg_fn(onehot_lab).cpu() 71 | out_dict = { 72 | "prediction": prediction, 73 | "confidence": aggregated_confid, 74 | "prediction_distr": prediction_distr, # this might cause memory problems 75 | } 76 | return out_dict 77 | 78 | def training_step(self, *args, **kwargs): 79 | raise NotImplementedError("This confidence score is not trained.") 80 | -------------------------------------------------------------------------------- /src/segmentation_failures/models/image_confidence/__init__.py: -------------------------------------------------------------------------------- 1 | from .mahalanobis import SingleGaussianOODDetector 2 | from .regression_network import QualityRegressionNet 3 | from .vae_estimator import SimpleVAEmodule 4 | -------------------------------------------------------------------------------- /src/segmentation_failures/models/pixel_confidence/__init__.py: -------------------------------------------------------------------------------- 1 | from .ensemble import DeepEnsembleMultiConfidenceSegmenter 2 | from .posthoc import PosthocMultiConfidenceSegmenter 3 | from .scores import get_pixel_csf 4 | -------------------------------------------------------------------------------- /src/segmentation_failures/models/pixel_confidence/ensemble.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from loguru import logger 6 | 7 | from segmentation_failures.models.pixel_confidence.posthoc import ( 8 | compute_confidence_map, 9 | compute_mean_prediction, 10 | ) 11 | from segmentation_failures.models.pixel_confidence.scores import get_pixel_csf 12 | 13 | 14 | class DeepEnsembleMultiConfidenceSegmenter(pl.LightningModule): 15 | def __init__( 16 | self, 17 | segmentation_net: List[torch.nn.Module], 18 | csf_names: str | list[str], 19 | overlapping_classes: bool = False, 20 | everything_on_gpu=False, 21 | confidence_precision="32", 22 | num_models: int = 5, 23 | ) -> None: 24 | """NOTE: segmentation net is not a single network but a list! 25 | 26 | It's just named this way for compatibility with the testing pipeline conventions.""" 27 | super().__init__() 28 | if isinstance(csf_names, str): 29 | csf_names = [csf_names] 30 | self.csf_dict = {} 31 | for csf in csf_names: 32 | self.csf_dict[csf] = get_pixel_csf(csf) 33 | self.model_list = torch.nn.ModuleList(segmentation_net) 34 | if num_models <= 1: 35 | raise ValueError("Number of models in ensemble must be greater than 1") 36 | if len(segmentation_net) < num_models: 37 | raise ValueError("Number of models in ensemble is less than specified") 38 | if len(segmentation_net) > num_models: 39 | logger.info( 40 | f"Number of models in ensemble is greater than specified ({len(segmentation_net)}). " 41 | f"Using only first {num_models}." 42 | ) 43 | self.model_list = self.model_list[:num_models] 44 | self.overlapping_classes = overlapping_classes 45 | self.everything_on_gpu = everything_on_gpu 46 | if confidence_precision == "32": 47 | self.confid_dtype = torch.float32 48 | elif confidence_precision == "64": 49 | self.confid_dtype = torch.float64 50 | else: 51 | raise ValueError(f"Unknown precision {confidence_precision}") 52 | 53 | def forward(self, x: torch.Tensor, query_confids=None): 54 | if query_confids is None: 55 | query_confids = self.csf_dict.keys() 56 | if isinstance(query_confids, str): 57 | query_confids = [query_confids] 58 | logger.debug( 59 | f"Starting inference of segmentation model ({len(self.model_list)} in ensemble)" 60 | ) 61 | device = self.device if self.everything_on_gpu else torch.device("cpu") 62 | logits_distr = self.segmentation_inference(x).to(device) 63 | with torch.autocast(device_type=device.type, enabled=False): 64 | logits_distr = logits_distr.to(dtype=self.confid_dtype) 65 | # prediction shape KBCHW[D], K=#models in ensemble 66 | logits = compute_mean_prediction(logits_distr, self.overlapping_classes, mc_dim=0) 67 | for curr_name in query_confids: 68 | logger.debug(f"Computing confidence map {curr_name}") 69 | csf_fn = self.csf_dict[curr_name] 70 | confid = compute_confidence_map( 71 | logits_distr, csf_fn, self.overlapping_classes, mc_dim=0 72 | ) 73 | # generator is used for memory saving, but it probably breaks backpropagation (which I don't need) 74 | yield { 75 | "csf": curr_name, 76 | "logits": logits, 77 | "confid": confid, 78 | "logits_distr": logits_distr, 79 | } 80 | 81 | def segmentation_inference(self, x): 82 | # assume x shape BCHW[D] 83 | device = self.device if self.everything_on_gpu else "cpu" 84 | logits_distr = torch.zeros( 85 | ( 86 | len(self.model_list), 87 | x.shape[0], 88 | self.model_list[0].hparams.num_classes, 89 | *x.shape[2:], 90 | ), 91 | device=device, 92 | ) 93 | for i, model in enumerate(self.model_list): 94 | logits_distr[i] = model(x).to(device=device) 95 | return logits_distr 96 | 97 | def test_step(self, batch, batch_idx): 98 | logger.debug(f"Current case ID: {batch['keys']}") 99 | confid_dict = {} 100 | confid_map_generator = self(batch["data"]) 101 | for output in confid_map_generator: 102 | # prediction doesn't change between iterations 103 | prediction, confid = output["logits"], output["confid"] 104 | prediction_distr = output["logits_distr"] # this might cause memory problems 105 | confid_dict[output["csf"]] = confid 106 | return { 107 | "prediction": prediction, 108 | "confidence_pixel": confid_dict, 109 | "prediction_distr": prediction_distr, 110 | } 111 | 112 | def validation_step(self, batch, batch_idx): 113 | return self.test_step(batch, batch_idx) 114 | 115 | def training_step(self, *args, **kwargs): 116 | raise NotImplementedError("This module is not trained.") 117 | -------------------------------------------------------------------------------- /src/segmentation_failures/models/pixel_confidence/scores.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | _scores_factories = {} 6 | 7 | 8 | def register_pixel_csf(name): 9 | def decorator_(cls): 10 | _scores_factories[name] = cls 11 | return cls 12 | 13 | return decorator_ 14 | 15 | 16 | def get_pixel_csf(name, **kwargs): 17 | return _scores_factories[name](**kwargs) 18 | 19 | 20 | class PixelConfidenceScore(ABC): 21 | LABEL_DIM = 1 22 | 23 | @abstractmethod 24 | def __call__( 25 | self, prediction: torch.Tensor, mc_samples_dim: int | None = None 26 | ) -> torch.Tensor: 27 | """ 28 | Arguments: 29 | prediction: expected shape is [num_samples, num_classes, *spatial_dims] 30 | mc_samples_dim: Dimension where the MC samples are located. For example, based on the expected `prediction` shape, this 31 | results in a shape of [num_samples, num_mc_samples, num_classes, *spatial_dims] for `mc_samples_dim=1`. 32 | returns: 33 | confidences: shape [num_samples, *pixel_shape] 34 | """ 35 | 36 | 37 | @register_pixel_csf("maxsoftmax") 38 | class MaximumSoftmaxScore(PixelConfidenceScore): 39 | def __call__(self, softmax: torch.Tensor, mc_samples_dim: int | None = None) -> torch.Tensor: 40 | label_dim = self.LABEL_DIM 41 | if mc_samples_dim is not None: 42 | label_dim = self.LABEL_DIM + (mc_samples_dim <= self.LABEL_DIM) 43 | if mc_samples_dim is not None: 44 | # average MC-samples first 45 | mean_softmax = softmax.mean(dim=mc_samples_dim, keepdim=True) 46 | return torch.amax(mean_softmax, dim=label_dim).squeeze(dim=mc_samples_dim) 47 | return torch.amax(softmax, label_dim) 48 | 49 | 50 | @register_pixel_csf("predictive_entropy") 51 | class PredictiveEntropyScore(PixelConfidenceScore): 52 | # this computes a shifted and negative entropy to make it a confidence score 53 | def __call__(self, softmax: torch.Tensor, mc_samples_dim: int | None = None) -> torch.Tensor: 54 | label_dim = self.LABEL_DIM 55 | if mc_samples_dim is None: 56 | entropy = compute_entropy(softmax) 57 | else: 58 | label_dim = self.LABEL_DIM + (mc_samples_dim <= self.LABEL_DIM) 59 | entropy = predictive_entropy(softmax, dim=mc_samples_dim) 60 | num_classes = softmax.shape[label_dim] 61 | max_entropy = torch.log(num_classes * torch.ones_like(entropy)) 62 | return max_entropy - entropy # higher score -> higher confidence 63 | 64 | 65 | @register_pixel_csf("expected_entropy") 66 | class ExpectedEntropyScore(PixelConfidenceScore): 67 | # this computes a shifted and negative entropy to make it a confidence score 68 | def __call__(self, softmax: torch.Tensor, mc_samples_dim: int) -> torch.Tensor: 69 | # for mc_samples_dim == 0, expected prediction shape is [n_mc_samples, n_batch, n_classes, *spatial_dims] 70 | # returns confidence of shape [n_batch, *spatial_dims] 71 | label_dim = self.LABEL_DIM + (mc_samples_dim <= self.LABEL_DIM) 72 | if mc_samples_dim is None: 73 | raise ValueError("Expected entropy is only defined over multiple samples") 74 | entropy = expected_entropy(softmax, dim=mc_samples_dim) 75 | num_classes = softmax.shape[label_dim] 76 | max_entropy = torch.log(num_classes * torch.ones_like(entropy)) 77 | return max_entropy - entropy # higher score -> higher confidence 78 | 79 | 80 | @register_pixel_csf("mutual_information") 81 | class MutualInformationScore(PixelConfidenceScore): 82 | def __call__(self, softmax: torch.Tensor, mc_samples_dim: int) -> torch.Tensor: 83 | if mc_samples_dim is None: 84 | raise ValueError("Expected entropy is only defined over multiple samples") 85 | mutual_info = predictive_entropy(softmax, dim=mc_samples_dim) - expected_entropy( 86 | softmax, dim=mc_samples_dim 87 | ) 88 | return -mutual_info 89 | 90 | 91 | def predictive_entropy(prob_distr: torch.Tensor, dim: int) -> torch.Tensor: 92 | return compute_entropy(prob_distr.mean(dim=dim)) 93 | 94 | 95 | def expected_entropy(prob_distr: torch.Tensor, dim: int) -> torch.Tensor: 96 | return torch.concatenate( 97 | [compute_entropy(p).unsqueeze(dim) for p in list(torch.unbind(prob_distr, dim=dim))] 98 | ).mean(dim=dim) 99 | 100 | 101 | def compute_entropy(prob: torch.Tensor): 102 | # assumed shape: NxD -> entropy along axis 1 103 | logzero_fix = torch.zeros_like(prob) 104 | logzero_fix[prob == 0] = torch.finfo(prob.dtype).eps 105 | return torch.sum(prob * (-torch.log(prob + logzero_fix)), dim=1) 106 | -------------------------------------------------------------------------------- /src/segmentation_failures/models/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynunet_module import DynUnetModule 2 | from .monai_segmenter import UNet_segmenter 3 | -------------------------------------------------------------------------------- /src/segmentation_failures/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/segmentation_failures_benchmark/a1af98be0f93c2bdc30ffe5bb6eda531c485d87d/src/segmentation_failures/networks/__init__.py -------------------------------------------------------------------------------- /src/segmentation_failures/networks/nnunet/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from nnunetv2.utilities.get_network_from_plans import get_network_from_plans 4 | from nnunetv2.utilities.plans_handling.plans_handler import ( 5 | ConfigurationManager, 6 | PlansManager, 7 | ) 8 | 9 | from segmentation_failures.utils.data import get_dataset_dir 10 | 11 | 12 | def build_network( 13 | dataset_id, 14 | plans_name, 15 | config_name, 16 | input_channels, 17 | num_outputs, 18 | allow_init=True, 19 | deep_supervision=False, 20 | ): 21 | preproc_data_base = get_dataset_dir(dataset_id, os.environ["nnUNet_preprocessed"]) 22 | plans_manager = PlansManager(preproc_data_base / f"{plans_name}.json") 23 | configuration_manager: ConfigurationManager = plans_manager.get_configuration(config_name) 24 | return get_network_from_plans( 25 | arch_class_name=configuration_manager.network_arch_class_name, 26 | arch_kwargs=configuration_manager.network_arch_init_kwargs, 27 | arch_kwargs_req_import=configuration_manager.network_arch_init_kwargs_req_import, 28 | input_channels=input_channels, 29 | output_channels=num_outputs, 30 | allow_init=allow_init, 31 | deep_supervision=deep_supervision, 32 | ) 33 | -------------------------------------------------------------------------------- /src/segmentation_failures/networks/vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .vae import VAE, VAE3d 2 | -------------------------------------------------------------------------------- /src/segmentation_failures/networks/vae/vae.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is adapted from David Zimmerer's VAE implementation. Credits to him. 3 | """ 4 | 5 | import torch 6 | import torch.distributions as dist 7 | 8 | from segmentation_failures.networks.vae.encoder_decoder import ( 9 | Encoder, 10 | EncoderLiu, 11 | Generator, 12 | GeneratorLiu, 13 | ) 14 | 15 | 16 | class VAE(torch.nn.Module): 17 | def __init__( 18 | self, 19 | input_size, 20 | z_dim=512, 21 | h_sizes=(16, 32, 64, 256, 1024), 22 | kernel_size=3, 23 | to_1x1=True, 24 | conv_op=torch.nn.Conv2d, 25 | upsample_op=torch.nn.ConvTranspose2d, 26 | normalization_op=None, 27 | activation_op=torch.nn.LeakyReLU, 28 | conv_params=None, 29 | activation_params=None, 30 | block_op=None, 31 | block_params=None, 32 | output_channels=None, 33 | additional_input_slices=None, 34 | symmetric_decoder=True, 35 | liu_architecture=False, 36 | *args, 37 | **kwargs, 38 | ): 39 | 40 | super(VAE, self).__init__() 41 | 42 | input_size_enc = list(input_size) 43 | input_size_dec = list(input_size) 44 | if output_channels is not None: 45 | input_size_dec[0] = output_channels 46 | if additional_input_slices is not None: 47 | input_size_enc[0] += additional_input_slices * 2 48 | 49 | min_spatial_size = 4 50 | strides = [] 51 | curr_sizes = input_size[1:] 52 | for layer in h_sizes: 53 | curr_stride = [1 if 0.5 * sz < min_spatial_size else 2 for sz in curr_sizes] 54 | if any([sz % st != 0 for sz, st in zip(curr_sizes, curr_stride)]): 55 | raise ValueError( 56 | f"Spatial size {curr_sizes} for layer {layer} not divisible by stride {curr_stride}" 57 | ) 58 | curr_sizes = [sz // st for sz, st in zip(curr_sizes, curr_stride)] 59 | strides.append(curr_stride) 60 | strides = strides[::-1] # the strided convolutions should rather come at the end 61 | if liu_architecture: 62 | enc_cls = EncoderLiu 63 | dec_cls = GeneratorLiu 64 | else: 65 | enc_cls = Encoder 66 | dec_cls = Generator 67 | self.enc = enc_cls( 68 | image_size=input_size_enc, 69 | h_size=h_sizes, 70 | z_dim=z_dim * 2, 71 | kernel_size=kernel_size, 72 | strides=strides, 73 | normalization_op=normalization_op, 74 | to_1x1=to_1x1, 75 | conv_op=conv_op, 76 | conv_params=conv_params, 77 | activation_op=activation_op, 78 | activation_params=activation_params, 79 | block_op=block_op, 80 | block_params=block_params, 81 | ) 82 | self.dec = dec_cls( 83 | image_size=input_size_dec, 84 | h_size=h_sizes[::-1] if symmetric_decoder else h_sizes, 85 | # David uses an asymmetric decoder 86 | # the model has more parameters with symmetric decoder and also I observed checkerboard patterns 87 | z_dim=z_dim, 88 | kernel_size=kernel_size, 89 | strides=strides[::-1] if symmetric_decoder else strides, 90 | normalization_op=normalization_op, 91 | to_1x1=to_1x1, 92 | upsample_op=upsample_op, 93 | conv_params=conv_params, 94 | activation_op=activation_op, 95 | activation_params=activation_params, 96 | block_op=block_op, 97 | block_params=block_params, 98 | ) 99 | self.hidden_size = self.enc.output_size 100 | 101 | def forward(self, inpt, sample=True, ret_y=False, **kwargs): 102 | y1 = self.enc(inpt, **kwargs) 103 | 104 | mu, log_std = torch.chunk(y1, 2, dim=1) 105 | 106 | std = torch.exp(log_std) 107 | z_dist = dist.Normal(mu, std) 108 | if sample: 109 | z_sample = z_dist.rsample() 110 | else: 111 | z_sample = mu 112 | 113 | x_rec = self.dec(z_sample) 114 | 115 | if ret_y: 116 | return x_rec, y1 117 | else: 118 | return x_rec, z_dist 119 | 120 | def generate_samples(self, num_samples, device): 121 | latent_size = (int(0.5 * self.hidden_size[0]), *self.hidden_size[1:]) 122 | z = torch.randn(num_samples, *latent_size, device=device) 123 | return self.dec(z) 124 | 125 | def encode(self, inpt, **kwargs): 126 | enc = self.enc(inpt, **kwargs) 127 | mu, log_std = torch.chunk(enc, 2, dim=1) 128 | return mu, log_std 129 | 130 | def decode(self, inpt, **kwargs): 131 | x_rec = self.dec(inpt, **kwargs) 132 | return x_rec 133 | 134 | 135 | class VAE3d(VAE): 136 | def __init__( 137 | self, conv_op=torch.nn.Conv3d, upsample_op=torch.nn.ConvTranspose3d, *args, **kwargs 138 | ): 139 | super().__init__(conv_op=conv_op, upsample_op=upsample_op, *args, **kwargs) 140 | -------------------------------------------------------------------------------- /src/segmentation_failures/scripts/check_dataset_splits.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | 6 | import dotenv 7 | 8 | from segmentation_failures.utils.data import get_dataset_dir 9 | from segmentation_failures.utils.io import load_json 10 | 11 | # load environment variables from `.env` file if it exists 12 | dotenv.load_dotenv(Path(__file__).resolve().parents[1] / ".env", override=False, verbose=True) 13 | 14 | 15 | def check_trainval_split(dataset_name: str, ref_split_path: Path, dry_run=False): 16 | # check if the splits are the same 17 | split_path = Path(os.environ["nnUNet_preprocessed"]) / dataset_name / "splits_final.json" 18 | if dataset_name.startswith("Dataset500"): 19 | split_path = Path(os.environ["nnUNet_raw"]) / dataset_name / "splits_final.json" 20 | if not split_path.exists(): 21 | print("No split found. Copying reference split...") 22 | if not dry_run: 23 | shutil.copy(ref_split_path, split_path) 24 | return 25 | ref_split = load_json(split_path) 26 | curr_split = load_json(ref_split_path) 27 | # compare the two splits. Each should be a list of dictionaries with keys "train" and "val" 28 | for fold_idx in range(len(ref_split)): 29 | ref_fold = ref_split[fold_idx] 30 | curr_fold = curr_split[fold_idx] 31 | assert ref_fold.keys() == curr_fold.keys() 32 | for key in ref_fold: 33 | # should be lists 34 | assert isinstance(ref_fold[key], list) 35 | assert isinstance(curr_fold[key], list) 36 | if set(ref_fold[key]) != set(curr_fold[key]): 37 | raise ValueError(f"TRAIN/VAL: Splits are different for fold {fold_idx}.") 38 | print("TRAIN/VAL: Splits are OK.") 39 | 40 | 41 | def check_test_cases(dataset_name: str, ref_split_path: Path, dry_run=False): 42 | test_data_dir = Path(os.environ["TESTDATA_ROOT_DIR"]) / dataset_name 43 | # check if the domain mapping is the same 44 | domain_mapping_path = test_data_dir / "domain_mapping_00.json" 45 | if not domain_mapping_path.exists(): 46 | print("No domain mapping found. Copying reference domain mapping...") 47 | if not dry_run: 48 | shutil.copy(ref_split_path, domain_mapping_path) 49 | return 50 | ref_domain_mapping = load_json(ref_split_path) 51 | curr_domain_mapping = load_json(domain_mapping_path) 52 | if ref_domain_mapping.keys() != curr_domain_mapping.keys(): 53 | raise ValueError("TEST: Test cases are different.") 54 | else: 55 | print("TEST: Split is OK.") 56 | wrong_entries = [] 57 | for key in ref_domain_mapping: 58 | if ref_domain_mapping[key] != curr_domain_mapping[key]: 59 | wrong_entries.append(key) 60 | if len(wrong_entries) > 0: 61 | raise ValueError(f"TEST: Domain mapping is different for {wrong_entries}.") 62 | print("TEST: Domain mapping is OK.") 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--dataset_ids", nargs="+", help="Dataset IDs to check.", default=None) 68 | parser.add_argument( 69 | "--dry_run", action="store_true", help="Do not copy train/val split files." 70 | ) 71 | args = parser.parse_args() 72 | split_dir = Path(__file__).resolve().parents[3] / "dataset_splits" 73 | if args.dataset_ids is not None: 74 | datasets_ids = args.dataset_ids 75 | else: 76 | datasets_ids = [x.name for x in split_dir.iterdir() if x.is_dir()] 77 | for ds in datasets_ids: 78 | dataset_dir = get_dataset_dir(ds, split_dir) 79 | print("=" * 10) 80 | print(dataset_dir.name) 81 | dataset_name = dataset_dir.name 82 | check_trainval_split(dataset_name, dataset_dir / "splits_final.json") 83 | check_test_cases(dataset_name, dataset_dir / "domain_mapping_00.json") 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /src/segmentation_failures/scripts/test_pixel_csf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Similar to validate_pixel_csf.py, but execute the testing loop. 3 | """ 4 | 5 | import sys 6 | from pathlib import Path 7 | 8 | import dotenv 9 | import hydra 10 | import pytorch_lightning as pl 11 | import torch 12 | import yaml 13 | from loguru import logger 14 | from omegaconf import DictConfig, OmegaConf 15 | 16 | from segmentation_failures.scripts.train_image_csf import setup_segmentation_model 17 | from segmentation_failures.scripts.validate_pixel_csf import setup_model 18 | 19 | # load environment variables from `.env` file if it exists 20 | dotenv.load_dotenv(Path(__file__).absolute().parents[1] / ".env", override=False, verbose=True) 21 | 22 | 23 | @hydra.main(config_path="../conf", config_name="config", version_base="1.2") 24 | def main(config: DictConfig): 25 | torch.multiprocessing.set_sharing_strategy(config.mp_sharing_strategy) 26 | logger.remove() # Remove default 'stderr' handler 27 | logger.add(sys.stderr, level=config.loguru.level) 28 | logger.add(Path(config.paths.output_dir) / config.loguru.file, level=config.loguru.level) 29 | 30 | if config.get("seed"): 31 | pl.seed_everything(config.seed, workers=True) 32 | 33 | if config.get("image_csf") is not None or config.get("csf_aggregation") is not None: 34 | raise ValueError("This script is only for methods without image-csf or csf-aggregation") 35 | 36 | logger.info(f"Experiment directory: {config.paths.output_dir}") 37 | # ------------ 38 | # data 39 | # ------------ 40 | logger.info(f"Instantiating datamodule <{config.datamodule['_target_']}>") 41 | data_module: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule) 42 | data_module.prepare_data() 43 | if hasattr(data_module, "preprocess_info"): 44 | # workaround. I dislike this solution 45 | config.datamodule.spacing = data_module.preprocess_info["spacing"] 46 | # ------------ 47 | # model 48 | # ------------ 49 | logger.info("Instantiating model") 50 | seg_model = setup_segmentation_model(config, load_best_ckpt=False) 51 | model = setup_model(config, seg_model) 52 | 53 | # ------------ 54 | # testing 55 | # ------------ 56 | # Init callbacks 57 | callbacks = [] 58 | if "callbacks" in config: 59 | for _, cb_conf in config.callbacks.test.items(): 60 | if "_target_" in cb_conf: 61 | logger.info(f"Instantiating callback <{cb_conf['_target_']}>") 62 | callbacks.append(hydra.utils.instantiate(cb_conf)) 63 | 64 | # Init lightning loggers 65 | expt_logger = [] 66 | if "logger" in config: 67 | for _, lg_conf in config.logger.items(): 68 | if "_target_" in lg_conf: 69 | logger.info(f"Instantiating logger <{lg_conf['_target_']}>") 70 | expt_logger.append(hydra.utils.instantiate(lg_conf)) 71 | 72 | logger.info(f"Instantiating trainer <{config.trainer['_target_']}>") 73 | trainer: pl.Trainer = hydra.utils.instantiate( 74 | config.trainer, 75 | _convert_="partial", 76 | callbacks=callbacks, 77 | logger=expt_logger, 78 | ) 79 | 80 | logger.info("Starting validation...") 81 | trainer.test(model, datamodule=data_module) 82 | 83 | # Save configuration diff at the end to capture any runtime changes 84 | final_config_yaml = yaml.dump(OmegaConf.to_container(config), sort_keys=False) 85 | hydra_config_path = Path(config.paths.output_dir) / ".hydra/config.yaml" 86 | hydra_config_path.rename(hydra_config_path.parent / "initial_config.yaml") 87 | with open(hydra_config_path, "w") as file: 88 | file.write(final_config_yaml) 89 | with open(Path(config.paths.output_dir) / "COMPLETED", "w") as file: 90 | file.write("") 91 | logger.info("Finished successfully.") 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /src/segmentation_failures/scripts/validate_pixel_csf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Similar to evaluate.py, but I don't use pytorch-lightning here. This script is the one I currently use most. 3 | 4 | Inputs: test data, model checkpoint, confidence scoring function, flag which of prediction/evaluation to do 5 | Outputs: Saves segmentation mask + per-sample confidences 6 | 7 | - Loop over test set and predict each sample 8 | - For methods that compute confidences from softmax output: just implement it as a simple function/class? 9 | - For methods that compute confidences based on intermediate network activations: 10 | not sure what's the best way to extract feature maps 11 | - For methods that include confidence scoring components in training: need to modify lightning module. 12 | Maybe the predict_step method can be used to separate the inference behavior? 13 | 14 | """ 15 | 16 | import sys 17 | from pathlib import Path 18 | 19 | import dotenv 20 | import hydra 21 | import pytorch_lightning as pl 22 | import torch 23 | import yaml 24 | from loguru import logger 25 | from omegaconf import DictConfig, OmegaConf 26 | 27 | from segmentation_failures.scripts.train_image_csf import setup_segmentation_model 28 | 29 | # load environment variables from `.env` file if it exists 30 | dotenv.load_dotenv(Path(__file__).absolute().parents[1] / ".env", override=False, verbose=True) 31 | 32 | 33 | def setup_model(cfg: DictConfig, seg_model) -> pl.LightningModule: 34 | if cfg.get("csf_pixel") is None: 35 | logger.info("No pixel confidence configured. Continuing with segmentation model") 36 | return seg_model 37 | # initialize pixel csf using the segmentation network 38 | pixel_csf = hydra.utils.instantiate(cfg.csf_pixel.hparams, segmentation_net=seg_model) 39 | if cfg.csf_pixel.checkpoint is not None: 40 | # here I need to extract the network from the lightning checkpoint. 41 | raise NotImplementedError("So far I don't have any methods with trained pixel csf.") 42 | return pixel_csf 43 | 44 | 45 | @hydra.main(config_path="../conf", config_name="config", version_base="1.2") 46 | def main(config: DictConfig): 47 | torch.multiprocessing.set_sharing_strategy(config.mp_sharing_strategy) 48 | logger.remove() # Remove default 'stderr' handler 49 | logger.add(sys.stderr, level=config.loguru.level) 50 | logger.add(Path(config.paths.output_dir) / config.loguru.file, level=config.loguru.level) 51 | 52 | if config.get("seed"): 53 | pl.seed_everything(config.seed, workers=True) 54 | 55 | if config.get("image_csf") is not None or config.get("csf_aggregation") is not None: 56 | raise ValueError("This script is only for methods without image-csf or csf-aggregation") 57 | 58 | logger.info(f"Experiment directory: {config.paths.output_dir}") 59 | # ------------ 60 | # data 61 | # ------------ 62 | logger.info(f"Instantiating datamodule <{config.datamodule['_target_']}>") 63 | if ( 64 | config.datamodule["_target_"] 65 | == "segmentation_failures.data.datamodules.nnunet_module.NNunetDataModule" 66 | ): 67 | # inference-style validation 68 | config.datamodule.preproc_only = True 69 | config.datamodule.batch_size = 1 70 | data_module: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule) 71 | data_module.prepare_data() 72 | if hasattr(data_module, "preprocess_info"): 73 | # workaround. I dislike this solution 74 | config.datamodule.spacing = data_module.preprocess_info["spacing"] 75 | # ------------ 76 | # model 77 | # ------------ 78 | logger.info("Instantiating model") 79 | seg_model = setup_segmentation_model(config, load_best_ckpt=False) 80 | model = setup_model(config, seg_model) 81 | 82 | # ------------ 83 | # validation 84 | # ------------ 85 | # Init callbacks 86 | callbacks = [] 87 | if "callbacks" in config: 88 | for _, cb_conf in config.callbacks.validate.items(): 89 | if "_target_" in cb_conf: 90 | logger.info(f"Instantiating callback <{cb_conf['_target_']}>") 91 | callbacks.append(hydra.utils.instantiate(cb_conf)) 92 | 93 | # Init lightning loggers 94 | expt_logger = [] 95 | if "logger" in config: 96 | for _, lg_conf in config.logger.items(): 97 | if "_target_" in lg_conf: 98 | logger.info(f"Instantiating logger <{lg_conf['_target_']}>") 99 | expt_logger.append(hydra.utils.instantiate(lg_conf)) 100 | 101 | logger.info(f"Instantiating trainer <{config.trainer['_target_']}>") 102 | trainer: pl.Trainer = hydra.utils.instantiate( 103 | config.trainer, 104 | _convert_="partial", 105 | callbacks=callbacks, 106 | logger=expt_logger, 107 | ) 108 | 109 | logger.info("Starting validation...") 110 | trainer.validate(model, datamodule=data_module) 111 | 112 | # Save configuration diff at the end to capture any runtime changes 113 | final_config_yaml = yaml.dump(OmegaConf.to_container(config), sort_keys=False) 114 | hydra_config_path = Path(config.paths.output_dir) / ".hydra/config.yaml" 115 | hydra_config_path.rename(hydra_config_path.parent / "initial_config.yaml") 116 | with open(hydra_config_path, "w") as file: 117 | file.write(final_config_yaml) 118 | with open(Path(config.paths.output_dir) / "COMPLETED", "w") as file: 119 | file.write("") 120 | logger.info("Finished successfully.") 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/__init__.py: -------------------------------------------------------------------------------- 1 | GLOBAL_SEEDS = { 2 | 0: 988515, 3 | 1: 719121, 4 | 2: 278828, 5 | 3: 860931, 6 | 4: 173255, 7 | # 5: 655061, 8 | # 6: 700111, 9 | # 7: 754949, 10 | # 8: 166207, 11 | # 9: 308839, 12 | # 10: 420430, 13 | # 11: 659182, 14 | # 12: 126543, 15 | # 13: 713150, 16 | # 14: 672869, 17 | } 18 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | from segmentation_failures.utils.io import load_expt_config 5 | 6 | 7 | def get_experiments_for_seed_fold( 8 | search_dir: str | Path, seed: int | list[int], fold: int | list[int] 9 | ) -> List[Path]: 10 | if isinstance(seed, int): 11 | seed = [seed] 12 | if isinstance(fold, int): 13 | fold = [fold] 14 | # Search in search_dir for the experiment with the given seed and fold. 15 | search_dir = Path(search_dir) 16 | if not search_dir.exists(): 17 | raise FileNotFoundError( 18 | f"Could not find directory {search_dir} for automatic segmentation checkpoint selection." 19 | ) 20 | matches = [] 21 | for rundir in search_dir.iterdir(): 22 | if not rundir.is_dir(): 23 | continue 24 | # need to load the config to get the seed and fold 25 | seg_config = load_expt_config(rundir) 26 | if "fold" in seg_config.datamodule: 27 | curr_fold = seg_config.datamodule.fold 28 | else: 29 | # fold location in config changed at some point; 30 | # this is for compatibility with older experiments 31 | curr_fold = seg_config.datamodule.hparams.fold 32 | if seg_config.seed in seed and curr_fold in fold: 33 | matches.append(rundir) 34 | if len(matches) == 0: 35 | raise FileNotFoundError( 36 | f"Could not find any experiment with seed {seed} and fold {fold} in {search_dir}." 37 | ) 38 | return matches 39 | 40 | 41 | def get_checkpoint_from_experiment(expt_dir: str, last_ckpt: bool) -> Path: 42 | checkpoint_dir = expt_dir / "checkpoints" 43 | if not checkpoint_dir.exists(): 44 | raise FileNotFoundError(f"Could not find directory {expt_dir}/checkpoints.") 45 | checkpoints_files = [x for x in Path(checkpoint_dir).iterdir() if x.suffix == ".ckpt"] 46 | not_last_ckpts = [] 47 | last_ckpt_file = None 48 | for ckpt in checkpoints_files: 49 | if ckpt.stem == "last": 50 | last_ckpt_file = ckpt 51 | if last_ckpt: 52 | return ckpt 53 | if ckpt.stem != "last": 54 | if not ckpt.name.startswith("epoch"): 55 | # Just a check 56 | raise ValueError("Expected checkpoints to start with 'epoch'") 57 | not_last_ckpts.append(ckpt) 58 | if len(not_last_ckpts) == 0: 59 | return last_ckpt_file # better than nothing 60 | return sorted(not_last_ckpts)[-1] 61 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/config_handling.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from omegaconf import DictConfig, ListConfig 4 | 5 | 6 | def flatten(cfg: Any, resolve: bool = False) -> list[tuple[str, Any]]: 7 | ret = [] 8 | 9 | def handle_dict(key: Any, value: Any, resolve: bool) -> list[tuple[str, Any]]: 10 | return [(f"{key}.{k1}", v1) for k1, v1 in flatten(value, resolve=resolve)] 11 | 12 | def handle_list(key: Any, value: Any, resolve: bool) -> list[tuple[str, Any]]: 13 | return [(f"{key}.{idx}", v1) for idx, v1 in flatten(value, resolve=resolve)] 14 | 15 | if isinstance(cfg, DictConfig): 16 | for k, v in cfg.items_ex(resolve=resolve): 17 | if isinstance(v, DictConfig): 18 | ret.extend(handle_dict(k, v, resolve=resolve)) 19 | elif isinstance(v, ListConfig): 20 | ret.extend(handle_list(k, v, resolve=resolve)) 21 | else: 22 | ret.append((str(k), v)) 23 | elif isinstance(cfg, ListConfig): 24 | for idx, v in enumerate(cfg._iter_ex(resolve=resolve)): 25 | if isinstance(v, DictConfig): 26 | ret.extend(handle_dict(idx, v, resolve=resolve)) 27 | elif isinstance(v, ListConfig): 28 | ret.extend(handle_list(idx, v, resolve=resolve)) 29 | else: 30 | ret.append((str(idx), v)) 31 | else: 32 | assert False 33 | 34 | return ret 35 | 36 | 37 | def compare_configs(cfg1, cfg2, resolve=False): 38 | flat1 = dict(flatten(cfg1, resolve=resolve)) 39 | flat2 = dict(flatten(cfg2, resolve=resolve)) 40 | keys_only_in1 = set(flat1.keys()).difference(set(flat2.keys())) 41 | keys_only_in2 = set(flat2.keys()).difference(set(flat1.keys())) 42 | if len(keys_only_in1) > 0: 43 | print(f"Keys only in 1: {keys_only_in1}") 44 | if len(keys_only_in2) > 0: 45 | print(f"Keys only in 2: {keys_only_in2}") 46 | for k in flat1.keys(): 47 | if flat1[k] != flat2[k]: 48 | print(f"DIFF {k}: {flat1[k]} != {flat2[k]}") 49 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/feature_extraction.py: -------------------------------------------------------------------------------- 1 | # credit: https://github.com/MECLabTUDA/Lifelong-nnUNet/blob/dev-ood_detection/nnunet_ext/calibration/mahalanobis/ActivationSeeker.py 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class ActivationSeeker: 7 | def __init__(self, to_cpu=False): 8 | self.activation = {} 9 | self.handles = [] 10 | self.to_cpu = to_cpu 11 | 12 | def get_activation(self, name): 13 | def hook(model, input, output): 14 | activation = output.detach() 15 | if self.to_cpu: 16 | activation = activation.cpu() 17 | self.activation[name] = activation 18 | 19 | return hook 20 | 21 | def attach_hooks(self, model, hook_name_paths_dict): 22 | for param_name, param_path in hook_name_paths_dict.items(): 23 | child, child_path = get_module_recursive(model, param_path) 24 | handle = child.register_forward_hook(self.get_activation(param_name)) 25 | self.handles.append(handle) 26 | 27 | def get_data_activations(self, model=None, inputs=None): 28 | if model is not None and inputs is not None: 29 | model(inputs) 30 | activation_dict = dict(self.activation) 31 | return activation_dict 32 | 33 | def get_dl_activations(self, agent, dl): 34 | ix = 0 35 | dl_activations = [] 36 | for data in dl: 37 | dl_activations.append(self.get_data_activations(agent, data)) 38 | ix += 1 39 | if ix == 2: 40 | break 41 | return dl_activations 42 | 43 | def remove_handles(self): 44 | while len(self.handles) > 0: 45 | handle = self.handles.pop() 46 | handle.remove() 47 | 48 | 49 | def get_module_recursive(parent_module, child_path): 50 | r"""Extracts a specific module from a model and the module's name. Also 51 | returns the path that is a module (as lower-level paths may be passed) 52 | Args: 53 | parent_module (torch.nn.Module): a PyTorch model or parent module 54 | child_path (str): the name of a module as extracted from the parent 55 | when named_parameters() is called recursively, i.e. the module is not a 56 | direct child, and therefore not an attribute, of the parent module 57 | """ 58 | module = parent_module 59 | new_child_path = [] 60 | for module_name in child_path.split("."): 61 | child_module = getattr(module, module_name) 62 | if isinstance(child_module, nn.Module): 63 | module = child_module 64 | new_child_path.append(module_name) 65 | else: 66 | break 67 | return module, ".".join(new_child_path) 68 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | from omegaconf import OmegaConf 7 | 8 | 9 | def save_json(content, target_name): 10 | target_name = str(target_name) 11 | if not str(target_name).endswith(".json"): 12 | target_name += ".json" 13 | with open(target_name, "w", encoding="utf-8") as f: 14 | json.dump(content, f, indent=2) 15 | 16 | 17 | def load_json(target_name): 18 | target_name = str(target_name) 19 | if not str(target_name).endswith(".json"): 20 | target_name += ".json" 21 | with open(target_name, encoding="utf-8") as f: 22 | content = json.load(f) 23 | return content 24 | 25 | 26 | def load_expt_config(expt_dir: Union[str, Path], resolve=False): 27 | if isinstance(expt_dir, str): 28 | expt_dir = Path(expt_dir) 29 | if os.environ.get("EXPERIMENT_ROOT_DIR", None) is None: 30 | os.environ["EXPERIMENT_ROOT_DIR"] = "dummy_expt_root_dir" 31 | if os.environ.get("TESTDATA_ROOT_DIR", None) is None: 32 | os.environ["TESTDATA_ROOT_DIR"] = "dummy_ds_root_dir" 33 | 34 | if resolve: 35 | hydra_config = OmegaConf.load(expt_dir / ".hydra" / "hydra.yaml").hydra 36 | # We need the hydra resolver because the run directory is interpolated using hydra... Not sure if there's a better way 37 | OmegaConf.register_new_resolver( 38 | "hydra", 39 | lambda path: OmegaConf.select(hydra_config, path), 40 | replace=False, # for safety 41 | ) 42 | config = OmegaConf.load(expt_dir / ".hydra" / "config.yaml") 43 | # I had problems with resolution later, so I just do it here. 44 | # Not sure how it works under the hood. 45 | if resolve: 46 | OmegaConf.resolve(config) 47 | OmegaConf.clear_resolver("hydra") 48 | return config 49 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/label_handling.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import monai.transforms as trf 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def convert_to_onehot( 9 | segmentation: np.ndarray | torch.Tensor, all_labels_or_regions: Union[int, tuple[int, ...]] 10 | ) -> np.ndarray | torch.Tensor: 11 | if isinstance(segmentation, np.ndarray): 12 | return convert_to_onehot_np(segmentation, all_labels_or_regions) 13 | else: 14 | return convert_to_onehot_torch(segmentation, all_labels_or_regions) 15 | 16 | 17 | def convert_to_onehot_np( 18 | segmentation: np.ndarray, all_labels_or_regions: Union[int, tuple[int, ...]] 19 | ) -> np.ndarray: 20 | # assume shape HW[D] 21 | if isinstance(all_labels_or_regions, int): 22 | all_labels_or_regions = list(range(all_labels_or_regions)) 23 | result = np.zeros((len(all_labels_or_regions), *segmentation.shape), dtype=np.uint8) 24 | for i, l in enumerate(all_labels_or_regions): 25 | if np.isscalar(l): 26 | result[i] = segmentation == l 27 | else: 28 | result[i] = np.isin(segmentation, l) 29 | return result 30 | 31 | 32 | def convert_to_onehot_torch( 33 | segmentation: torch.Tensor, all_labels_or_regions: Union[int, tuple[int, ...]] 34 | ) -> torch.Tensor: 35 | return convert_to_onehot_batch(segmentation.unsqueeze(0), all_labels_or_regions).squeeze(0) 36 | 37 | 38 | def convert_to_onehot_batch( 39 | segmentation: torch.Tensor, all_labels_or_regions: Union[int, tuple[int, ...]] 40 | ) -> torch.Tensor: 41 | # assume shape B1HW[D] 42 | batch_size = segmentation.shape[0] 43 | assert segmentation.shape[1] == 1 44 | if isinstance(all_labels_or_regions, int): 45 | all_labels_or_regions = list(range(all_labels_or_regions)) 46 | result = torch.zeros( 47 | (batch_size, len(all_labels_or_regions), *segmentation.shape[2:]), 48 | dtype=segmentation.dtype, 49 | device=segmentation.device, 50 | ) 51 | for i, l in enumerate(all_labels_or_regions): 52 | if np.isscalar(l): 53 | result[:, i] = segmentation[:, 0] == l 54 | else: 55 | result[:, i] = torch.isin( 56 | segmentation[:, 0], torch.tensor(l, device=segmentation.device) 57 | ) 58 | return result 59 | 60 | 61 | def convert_nnunet_regions_to_labels(region_map, region_class_order: list): 62 | # assume shape B,C,*spatial 63 | if isinstance(region_map, np.ndarray): 64 | assert region_map.dtype == bool 65 | label_map = np.zeros((region_map.shape[0], *region_map.shape[2:]), dtype=np.uint16) 66 | else: 67 | assert region_map.dtype == torch.bool 68 | # no uint16 in torch 69 | label_map = torch.zeros( 70 | (region_map.shape[0], *region_map.shape[2:]), 71 | dtype=torch.int16, 72 | device=region_map.device, 73 | ) 74 | for i, c in enumerate(region_class_order): 75 | label_map[region_map[:, i]] = c 76 | return label_map 77 | 78 | 79 | def discretize_softmax(probs: np.ndarray, overlapping_classes: bool) -> np.ndarray: 80 | """ 81 | This is a helper function to convert logits to a discrete segmentation. 82 | 83 | Args: 84 | logits: (C, H, W[, D]) array of logits 85 | overlapping_classes: whether the logits are for overlapping classes or not 86 | """ 87 | if overlapping_classes: 88 | return (probs > 0.5).astype(np.uint8) 89 | else: 90 | return np.argmax(probs, axis=0).astype(np.uint8) 91 | 92 | 93 | class ConvertSegToRegions(trf.MapTransform): 94 | def __init__(self, keys, class_or_regions_defs, include_background=True): 95 | super().__init__(keys) 96 | self.class_or_regions_defs = class_or_regions_defs 97 | if not include_background: 98 | self.class_or_regions_defs = { 99 | k: v 100 | for k, v in class_or_regions_defs.items() 101 | if k.lower() not in ["background", "bg"] 102 | } 103 | 104 | def __call__(self, data): 105 | for key in self.keys: 106 | # assume shape 1HW[D] 107 | assert data[key].shape[0] == 1 108 | data[key] = convert_to_onehot(data[key], list(self.class_or_regions_defs.values())) 109 | return data 110 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from loguru import logger 3 | 4 | 5 | def disable_dropout(nn_module): 6 | found_dropout = False 7 | for layer in nn_module.named_modules(): 8 | if isinstance(layer[1], nn.modules.dropout.Dropout): 9 | layer[1].eval() 10 | found_dropout = True 11 | if not found_dropout: 12 | logger.warning("No dropout layers found in model. Cannot disable dropout.") 13 | 14 | 15 | def enable_dropout(nn_module): 16 | found_dropout = False 17 | for layer in nn_module.named_modules(): 18 | if isinstance(layer[1], nn.modules.dropout.Dropout): 19 | layer[1].train() 20 | found_dropout = True 21 | if not found_dropout: 22 | logger.warning("No dropout layers found in model. Cannot enable dropout.") 23 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/view_images_napari.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trying out napari with the FeTS data. I like it :) 3 | """ 4 | 5 | # from argparse import ArgumentParser 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | 9 | import napari 10 | import numpy as np 11 | import SimpleITK as sitk 12 | from loguru import logger 13 | 14 | from segmentation_failures.utils.io import load_json 15 | 16 | 17 | def main(data_root: Path, train=True, shift=None): 18 | # read info file 19 | dataset_dict = load_json(data_root / "dataset.json") 20 | if train: 21 | image_path = data_root / "imagesTr" 22 | label_path = data_root / "labelsTr" 23 | else: 24 | if shift: 25 | shift = "_" + shift.lower() 26 | else: 27 | shift = "" 28 | image_path = data_root / f"imagesTs{shift}" 29 | label_path = data_root / f"labelsTs{shift}" 30 | 31 | color_list = ["blue", "green", "red", "cyan", "yellow"] 32 | label_colors = {int(k): color_list[i] for i, k in enumerate(dataset_dict["labels"])} 33 | label_colors.update({0: None}) 34 | # 0: None, 35 | # 1: "royalblue", 36 | # 2: "orange", 37 | # } 38 | for case in label_path.iterdir(): 39 | case_name = case.name.split(".")[0] 40 | logger.info(case_name) 41 | # load data (flair) 42 | img_npy = [] 43 | for mod in dataset_dict["modalities"]: 44 | file_path = str(image_path / f"{case_name}_{int(mod):04d}.nii.gz") 45 | logger.debug(f"Loading image {file_path}") 46 | img = sitk.ReadImage(file_path) 47 | img_npy.append(sitk.GetArrayFromImage(img)) 48 | img_npy = np.stack(img_npy) 49 | 50 | img_viewer = napari.view_image( 51 | img_npy, 52 | rgb=False, 53 | channel_axis=0, 54 | name=list(dataset_dict["modalities"].values()), 55 | visible=[i == 0 for i in range(len(img_npy))], 56 | colormap="gray", 57 | ) 58 | 59 | seg = sitk.ReadImage(str(case)) 60 | seg_npy = sitk.GetArrayFromImage(seg).astype(int) 61 | img_viewer.add_labels(seg_npy, name="segmentation", opacity=0.5, color=label_colors) 62 | try: 63 | napari.run() 64 | except KeyboardInterrupt: 65 | break 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = ArgumentParser() 70 | parser.add_argument("data_root", type=str) 71 | parser.add_argument("--testset", action="store_true") 72 | parser.add_argument("--testshift", type=str, default=None, required=False) 73 | args = parser.parse_args() 74 | data_root = Path( 75 | # "/Users/e290-mb003-wl/tmp_sshfs/Datasets/segmentation_failures/Task000D2_Example" 76 | "/Users/e290-mb003-wl/datasets/Task001_simple_brats" 77 | ) 78 | # shift = "biasfield" 79 | shift = None 80 | main(data_root, train=not args.testset, shift=args.testshift) 81 | -------------------------------------------------------------------------------- /src/segmentation_failures/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import torch 3 | import torchvision 4 | 5 | 6 | def make_image_mask_grid( 7 | image_batch: torch.Tensor, 8 | mask_list: torch.Tensor | list[torch.Tensor], 9 | max_images=-1, 10 | alpha=0.5, 11 | slice_idx: list[int] | None = None, 12 | slice_dim: int = 2, 13 | ): 14 | """Produce image grid from images and predictions. 15 | 16 | Args: 17 | image_batch (torch.Tensor): shape [batch, modality, *spatial_dims] 18 | pred_batch_list (torch.Tensor): list of tensors with shape [batch, class, *spatial_dims], one hot encoded! 19 | max_images (int, optional): limit number of images in output. Defaults to 5. 20 | alpha (float, optional): alpha value for mask overlay. Defaults to 0.5. 21 | slice_idx (list[int], optional): slice index for 3D images. Defaults to None. 22 | slice_dim (int, optional): dimension to slice for 3D images (only counting spatial dims). Defaults to 2. 23 | 24 | Returns: 25 | torch.Tensor: RGB image grid 26 | """ 27 | if max_images == -1: 28 | max_images = image_batch.shape[0] 29 | if not isinstance(mask_list, (list, tuple)): 30 | mask_list = [mask_list] 31 | all_data = [image_batch] + mask_list 32 | for i, batch_data in enumerate(all_data): 33 | batch_data = batch_data.detach().cpu() 34 | batch_data = batch_data[:max_images] 35 | if batch_data.ndim == 5: 36 | # for 3D just take one slice 37 | if slice_idx is None: 38 | slice_idx = [batch_data.shape[-1] // 2] * batch_data.shape[0] 39 | # if the slice_idx is a list, select slice slice_idx[i] for batch i 40 | assert isinstance(slice_idx, (list, tuple)) 41 | batch_data = torch.stack( 42 | [ 43 | batch.select(dim=1 + slice_dim, index=slice_idx[i]) 44 | for i, batch in enumerate(batch_data) 45 | ] 46 | ) 47 | if i > 0: 48 | batch_data = batch_data.to(dtype=torch.bool) 49 | all_data[i] = batch_data 50 | # NOTE I could remove the background class by doing pred_batch = pred_batch[1:] 51 | # but then I would need to define colors (the torchvision function color palette starts with black) 52 | image_batch, mask_list = all_data[0], all_data[1:] 53 | # pick only first modality in image# roi is binary mask 54 | image_batch = image_batch[:, 0] 55 | # normalize image and convert to RGB 56 | tmp_min = image_batch.flatten(start_dim=1).min() 57 | tmp_max = image_batch.flatten(start_dim=1).max() 58 | image_batch = (image_batch - tmp_min) / (tmp_max - tmp_min) 59 | image_batch = torch.stack([image_batch, image_batch, image_batch], dim=1) 60 | if image_batch.is_floating_point(): 61 | image_batch = (image_batch * 255).to(dtype=torch.uint8) 62 | grid_list = [] 63 | 64 | for idx, img in enumerate(image_batch): 65 | grid_list.append(img) 66 | for mask in mask_list: 67 | colors = sns.color_palette(n_colors=mask.shape[1]).as_hex() 68 | colors.insert(0, "#000000") 69 | grid_list.append( 70 | # torchvision.utils.draw_segmentation_masks(img, mask[idx], alpha=alpha) 71 | torchvision.utils.draw_segmentation_masks( 72 | img, mask[idx], alpha=alpha, colors=colors 73 | ) 74 | ) 75 | return torchvision.utils.make_grid(grid_list, nrow=len(mask_list) + 1, normalize=False) 76 | -------------------------------------------------------------------------------- /tests/callbacks/test_results_writer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import nibabel as nib 4 | import numpy as np 5 | import torch 6 | from nnunetv2.imageio.nibabel_reader_writer import NibabelIO 7 | 8 | from segmentation_failures.evaluation.segmentation.compute_seg_metrics import ( 9 | compute_metrics_for_file, 10 | ) 11 | from segmentation_failures.evaluation.segmentation.segmentation_metrics import ( 12 | get_metrics, 13 | ) 14 | 15 | 16 | def test_compute_metrics_for_file(tmp_path: Path): 17 | NUM_LABELS = 3 18 | BATCH_SIZE = 1 19 | region_based = False 20 | metric_multi_fn = get_metrics(["dice"]) 21 | # metric_single_fn = get_metrics(["generalized_dice"]) 22 | pred_dir = tmp_path / "preds" 23 | lab_dir = tmp_path / "labels" 24 | pred_dir.mkdir() 25 | lab_dir.mkdir() 26 | if region_based: 27 | all_labels = [0] 28 | for i in range(1, NUM_LABELS): 29 | all_labels.insert(1, tuple(range(i, NUM_LABELS))) 30 | else: 31 | all_labels = list(range(NUM_LABELS)) 32 | # save dummy label and prediction for evaluation 33 | preds = torch.randn(BATCH_SIZE, NUM_LABELS, 3, 3, 1) # logits 34 | targets = torch.argmax(preds, dim=1).to(torch.uint8) 35 | outputs = { 36 | "confidence": torch.zeros(len(preds)), 37 | "prediction": preds, 38 | } 39 | batch = { 40 | "target": targets, 41 | "keys": [f"case_{idx:03d}" for idx in range(BATCH_SIZE)], 42 | "properties": [{"spacing": [1, 1, 1]}] * BATCH_SIZE, 43 | } 44 | if region_based: 45 | raise NotImplementedError 46 | else: 47 | class_prediction = torch.argmax(outputs["prediction"], dim=1).numpy().astype(np.uint8) 48 | reader_fn = NibabelIO().read_seg 49 | for i, case_id in enumerate(batch["keys"]): 50 | # I just save predictions here like this because the callback needs too much nnunet stuff. 51 | label_file = lab_dir / f"{case_id}.nii.gz" 52 | pred_file = pred_dir / f"{case_id}.nii.gz" 53 | nib.save(nib.Nifti1Image(class_prediction[i], np.eye(4)), pred_file) 54 | nib.save( 55 | nib.Nifti1Image(batch["target"][i].numpy(), np.eye(4)), 56 | label_file, 57 | ) 58 | metrics, metrics_multi = compute_metrics_for_file( 59 | metric_multi_fn, label_file, pred_file, all_labels, seg_reader_fn=reader_fn 60 | ) 61 | assert all([np.size(x) == 1 for x in metrics.values()]) 62 | if region_based: 63 | assert all([np.size(x) == len(all_labels) for x in metrics_multi.values()]) 64 | else: 65 | # no background class 66 | assert all([np.size(x) == (len(all_labels) - 1) for x in metrics_multi.values()]) 67 | -------------------------------------------------------------------------------- /tests/evaluation/test_ood_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from segmentation_failures.evaluation.ood_detection.metrics import ( 4 | StatsCache, 5 | get_metric_function, 6 | ) 7 | 8 | 9 | def test_auroc(): 10 | # test this with classification 11 | np.random.seed(42) 12 | pseudo_ood_labels = np.random.randint(0, 2, size=100000) 13 | rand_confids = np.random.rand(100000) 14 | perfect_confids = 1 - pseudo_ood_labels 15 | auroc = get_metric_function("ood_auc") 16 | result1 = auroc(StatsCache(scores=-rand_confids, ood_labels=pseudo_ood_labels)) 17 | result2 = auroc(StatsCache(scores=-perfect_confids, ood_labels=pseudo_ood_labels)) 18 | # Since the true_eaurc is only an approximation, the results are not exactly equal 19 | assert np.isclose(result1, 0.5, atol=1e-2) 20 | assert result2 == 1.0 21 | -------------------------------------------------------------------------------- /tests/evaluation/test_seg_metrics.py: -------------------------------------------------------------------------------- 1 | # test cases taken from deepmind's repo 2 | import torch 3 | 4 | from segmentation_failures.evaluation.segmentation.segmentation_metrics import ( 5 | get_metrics, 6 | ) 7 | 8 | 9 | def test_distance_metrics_cube(): 10 | sd_metric = get_metrics("surface_dice", class_thresholds=[1], include_background=True)[ 11 | "surface_dice" 12 | ] 13 | hd95_metric = get_metrics("hausdorff95", include_background=True)["hausdorff95"] 14 | mask_gt = torch.zeros(100, 100, 100, dtype=float) 15 | mask_pred = torch.zeros(100, 100, 100, dtype=float) 16 | mask_gt[0:50, :, :] = 1 17 | mask_pred[0:51, :, :] = 1 18 | mask_gt = mask_gt.reshape((1, 1, *mask_gt.shape)) 19 | mask_pred = mask_pred.reshape((1, 1, *mask_pred.shape)) 20 | surf_dice = sd_metric(mask_pred, mask_gt, spacing=(2, 1, 1)) 21 | hd95 = hd95_metric(mask_pred, mask_gt, spacing=(2, 1, 1)) 22 | assert round(surf_dice.item(), 3) == 0.836 23 | assert hd95 == 2.0 24 | 25 | 26 | def test_distance_metrics_two_points(): 27 | sd_metric = get_metrics("surface_dice", class_thresholds=[1], include_background=True)[ 28 | "surface_dice" 29 | ] 30 | hd95_metric = get_metrics("hausdorff95", include_background=True)["hausdorff95"] 31 | mask_gt = torch.zeros(100, 100, 100, dtype=float) 32 | mask_pred = torch.zeros(100, 100, 100, dtype=float) 33 | mask_gt[50, 60, 70] = 1 34 | mask_pred[50, 60, 72] = 1 35 | mask_gt = mask_gt.reshape((1, 1, *mask_gt.shape)) 36 | mask_pred = mask_pred.reshape((1, 1, *mask_pred.shape)) 37 | surf_dice = sd_metric(mask_pred, mask_gt, spacing=(3, 2, 1)) 38 | hd95 = hd95_metric(mask_pred, mask_gt, spacing=(3, 2, 1)) 39 | assert surf_dice.item() == 0.5 40 | assert hd95.item() == 2 41 | 42 | 43 | def test_distance_metrics_empty_gt(): 44 | sd_metric = get_metrics("surface_dice", class_thresholds=[1], include_background=True)[ 45 | "surface_dice" 46 | ] 47 | hd95_metric = get_metrics("hausdorff95", include_background=True)["hausdorff95"] 48 | mask_gt = torch.zeros(100, 100, 100, dtype=float) 49 | mask_pred = torch.zeros(100, 100, 100, dtype=float) 50 | # mask_gt[50, 60, 70] = 1 # same happens for empty pred 51 | mask_pred[50, 60, 72] = 1 52 | mask_gt = mask_gt.reshape((1, 1, *mask_gt.shape)) 53 | mask_pred = mask_pred.reshape((1, 1, *mask_pred.shape)) 54 | surf_dice = sd_metric(mask_pred, mask_gt, spacing=(3, 2, 1)) 55 | hd95 = hd95_metric(mask_pred, mask_gt, spacing=(3, 2, 1)) 56 | assert surf_dice.item() == 0.0 57 | assert not torch.isnan(hd95) 58 | 59 | 60 | def test_distance_metrics_empty_both(): 61 | sd_metric = get_metrics("surface_dice", class_thresholds=[1], include_background=True)[ 62 | "surface_dice" 63 | ] 64 | hd95_metric = get_metrics("hausdorff95", include_background=True)["hausdorff95"] 65 | mask_gt = torch.zeros(100, 100, 100, dtype=float) 66 | mask_pred = torch.zeros(100, 100, 100, dtype=float) 67 | mask_gt = mask_gt.reshape((1, 1, *mask_gt.shape)) 68 | mask_pred = mask_pred.reshape((1, 1, *mask_pred.shape)) 69 | surf_dice = sd_metric(mask_pred, mask_gt, spacing=(3, 2, 1)) 70 | hd95 = hd95_metric(mask_pred, mask_gt, spacing=(3, 2, 1)) 71 | assert surf_dice.item() == 1.0 72 | assert hd95.item() == 0.0 73 | -------------------------------------------------------------------------------- /tests/models/confidence_scoring/test_aggregation.py: -------------------------------------------------------------------------------- 1 | """How to test? 2 | - Test that rejector model works as expected: Generate dummy data (arbitrary regression task), run fit and predict with pipeline 3 | - Test extract_features 4 | """ 5 | 6 | import time 7 | 8 | import pytest 9 | import torch 10 | 11 | from segmentation_failures.models.confidence_aggregation import ( 12 | ForegroundAggregator, 13 | ForegroundSizeAggregator, 14 | HeuristicAggregationModule, 15 | RadiomicsAggregationModule, 16 | ) 17 | from segmentation_failures.models.confidence_aggregation.base import ( 18 | PairwiseDiceAggregator, 19 | ) 20 | 21 | 22 | # TODO this fails because one environment variable isn't set 23 | def test_extract_features(): 24 | dummy_module = HeuristicAggregationModule( 25 | regression_model="regression_forest", 26 | dataset_id=500, 27 | confid_name="dummy_confid", 28 | target_metrics=["generalized_dice"], 29 | heuristic_list=[ForegroundAggregator(), ForegroundSizeAggregator()], 30 | ) 31 | # region based is a bit hard to simulate here 32 | dummy_prediction = torch.tensor( 33 | [ 34 | [0, 0, 2, 2], 35 | [0, 1, 2, 2], 36 | [0, 1, 1, 0], 37 | [0, 0, 0, 0], 38 | ], 39 | ) 40 | dummy_prediction = dummy_prediction.reshape(1, 1, *dummy_prediction.shape) 41 | dummy_confid = torch.rand_like(dummy_prediction[:, 0], dtype=float) 42 | dummy_image = torch.rand_like(dummy_prediction, dtype=float) 43 | features = dummy_module.extract_features(dummy_image, dummy_prediction, dummy_confid) 44 | assert features.shape == (len(dummy_prediction), len(dummy_module.aggregator_list)) 45 | 46 | 47 | # TODO outdated test 48 | # radiomics requires a trainer mock, which I don't want to implement. 49 | @pytest.mark.parametrize( 50 | "method", 51 | [ 52 | "heuristic", 53 | ], 54 | ) 55 | @pytest.mark.parametrize("img_dim", [2, 3]) 56 | def test_multiclass_aggregation(method: str, img_dim: int): 57 | NUM_BATCH = 2 58 | NUM_CLASSES = 4 59 | IMG_SIZE = 20 60 | IMG_SHAPE = [IMG_SIZE] * img_dim 61 | 62 | class SimulateModel: 63 | def eval(self): 64 | pass 65 | 66 | def requires_grad_(self, val): 67 | pass 68 | 69 | def __call__(self, x, confid_name): 70 | dummy_prediction = torch.randn(size=(NUM_BATCH, NUM_CLASSES, *IMG_SHAPE)) 71 | dummy_confid = torch.rand(NUM_BATCH, *IMG_SHAPE) 72 | yield {"logits": dummy_prediction, "confid": dummy_confid} 73 | 74 | def forward(self, batch): 75 | return self(batch) 76 | 77 | if method == "heuristic": 78 | dummy_module = HeuristicAggregationModule( 79 | SimulateModel(), 80 | num_classes=NUM_CLASSES, 81 | target_metric="generalized_dice", 82 | confid_name="dummy_confid", 83 | ) 84 | elif method == "radiomics": 85 | dummy_module = RadiomicsAggregationModule( 86 | image_dim=img_dim, 87 | pixel_csf=SimulateModel(), 88 | num_classes=NUM_CLASSES, 89 | target_metric="generalized_dice", 90 | confid_threshold=0.7, 91 | confid_name="dummy_confid", 92 | ) 93 | else: 94 | raise ValueError 95 | outputs = [] 96 | for i in range(3): 97 | batch = { 98 | "data": torch.rand(NUM_BATCH, 1, *IMG_SHAPE), # 1 is modality 99 | "target": torch.randint(NUM_CLASSES, size=(NUM_BATCH, 1, *IMG_SHAPE)), 100 | } 101 | outputs.append(dummy_module.training_step(batch, i)) 102 | assert outputs[-1]["quality_true"].shape == (NUM_BATCH,) 103 | dummy_module.on_train_epoch_end() 104 | 105 | 106 | def test_pairwise_dice_agg(num_batch=4, img_size=(5, 5), region_based=True): 107 | NUM_CLASSES = 2 108 | NUM_SAMPLES = 4 109 | consensus_pred = torch.zeros(NUM_SAMPLES, num_batch, NUM_CLASSES, *img_size) 110 | start_x = 1 111 | start_y = 1 112 | size = 2 113 | consensus_pred[:, :, 1, start_x : start_x + size, start_y : start_y + size] = 1 114 | if region_based: 115 | consensus_pred[:, :, 0, 0:size, -size:] = 1 116 | consensus_pred = consensus_pred.to(dtype=torch.bool) 117 | score = PairwiseDiceAggregator(include_zero_label=region_based) 118 | start = time.time() 119 | result = score.aggregate(consensus_pred) 120 | end = time.time() 121 | print(f"Time taken for pairwise dice: {end - start} seconds") 122 | assert torch.allclose(result, torch.ones_like(result)) 123 | 124 | disjoint_pred = torch.zeros(NUM_SAMPLES, num_batch, NUM_CLASSES, *img_size) 125 | for i in range(len(disjoint_pred)): 126 | disjoint_pred[i, :, 1, i] = 1 127 | disjoint_pred[i, :, 0] = 1 - disjoint_pred[i, :, 1] 128 | if region_based: 129 | disjoint_pred[i, :, 0] = disjoint_pred[i, :, 1] 130 | 131 | disjoint_pred = disjoint_pred.to(dtype=torch.bool) 132 | result = score.aggregate(disjoint_pred) 133 | assert torch.allclose(result, torch.zeros_like(result)) 134 | -------------------------------------------------------------------------------- /tests/models/confidence_scoring/test_mahalanobis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from segmentation_failures.models.image_confidence import SingleGaussianOODDetector 5 | 6 | 7 | def test_fit_gaussian(): 8 | toy_features = torch.randn(100, 2) + 42 9 | dummy_module = torch.nn.Module() 10 | dummy_module.add_module( 11 | "model", torch.nn.Sequential(torch.nn.Linear(3, 3), torch.nn.Linear(3, 2)) 12 | ) 13 | # Prepare mock objects 14 | my_ood_detector = SingleGaussianOODDetector( 15 | segmentation_net=dummy_module, 16 | feature_path="0", 17 | ) 18 | my_ood_detector.training_epoch_end(outputs=[{"features": toy_features}]) 19 | 20 | assert np.all( 21 | my_ood_detector.gaussian_estimator.location_ == toy_features.numpy().mean(axis=0) 22 | ) 23 | 24 | 25 | # def test_save_load_model(): 26 | # feature_path = "dummy" 27 | # toy_data = torch.randn(100, 2) + 42 28 | # expected_location = toy_data.numpy().mean(axis=0) 29 | 30 | # # Prepare mock objects 31 | # my_ood_detector = SingleGaussianOODDetector(feature_path=feature_path) 32 | # my_ood_detector.training_epoch_end(outputs=[{"features": toy_data}]) 33 | 34 | # assert np.all(my_ood_detector.gaussian_estimator.location_ == expected_location) 35 | 36 | # # # TODO idk how to save/load manually 37 | # # SingleGaussianOODDetector.save -> need lightning trainer for this... 38 | # # SingleGaussianOODDetector.load_from_checkpoint() 39 | --------------------------------------------------------------------------------