├── LICENSE ├── README.md ├── assets ├── README.md ├── U-Mamba-network.png ├── create_visualization_video.py ├── original_img.png ├── seg_mask_overlay.png └── visual_seg.mp4 ├── data ├── README.md └── nnUNet_raw │ ├── Dataset701_AbdomenCT │ └── dataset.json │ ├── Dataset702_AbdomenMR │ └── dataset.json │ ├── Dataset703_NeurIPSCell │ └── dataset.json │ └── Dataset704_Endovis17 │ └── dataset.json ├── evaluation ├── SurfaceDice.py ├── __pycache__ │ └── SurfaceDice.cpython-310.pyc ├── abdomen_DSC_Eval.py ├── abdomen_NSD_Eval.py ├── compute_cell_metric.py ├── endoscopy_DSC_Eval.py └── endoscopy_NSD_Eval.py └── umamba ├── .gitignore ├── LICENSE ├── nnunetv2 ├── __init__.py ├── batch_running │ ├── __init__.py │ ├── benchmarking │ │ ├── __init__.py │ │ ├── generate_benchmarking_commands.py │ │ └── summarize_benchmark_results.py │ ├── collect_results_custom_Decathlon.py │ ├── collect_results_custom_Decathlon_2d.py │ ├── generate_lsf_runs_customDecathlon.py │ └── release_trainings │ │ ├── __init__.py │ │ └── nnunetv2_v1 │ │ ├── __init__.py │ │ ├── collect_results.py │ │ └── generate_lsf_commands.py ├── configuration.py ├── dataset_conversion │ ├── Dataset027_ACDC.py │ ├── Dataset073_Fluo_C3DH_A549_SIM.py │ ├── Dataset114_MNMs.py │ ├── Dataset115_EMIDEC.py │ ├── Dataset120_RoadSegmentation.py │ ├── Dataset137_BraTS21.py │ ├── Dataset218_Amos2022_task1.py │ ├── Dataset219_Amos2022_task2.py │ ├── Dataset220_KiTS2023.py │ ├── Dataset221_AutoPETII_2023.py │ ├── Dataset988_dummyDataset4.py │ ├── __init__.py │ ├── convert_MSD_dataset.py │ ├── convert_raw_dataset_from_old_nnunet_format.py │ ├── datasets_for_integration_tests │ │ ├── Dataset996_IntegrationTest_Hippocampus_regions_ignore.py │ │ ├── Dataset997_IntegrationTest_Hippocampus_regions.py │ │ ├── Dataset998_IntegrationTest_Hippocampus_ignore.py │ │ ├── Dataset999_IntegrationTest_Hippocampus.py │ │ └── __init__.py │ └── generate_dataset_json.py ├── ensembling │ ├── __init__.py │ └── ensemble.py ├── evaluation │ ├── __init__.py │ ├── accumulate_cv_results.py │ ├── evaluate_predictions.py │ └── find_best_configuration.py ├── experiment_planning │ ├── __init__.py │ ├── dataset_fingerprint │ │ ├── __init__.py │ │ └── fingerprint_extractor.py │ ├── experiment_planners │ │ ├── __init__.py │ │ ├── default_experiment_planner.py │ │ ├── network_topology.py │ │ ├── readme.md │ │ └── resencUNet_planner.py │ ├── plan_and_preprocess_api.py │ ├── plan_and_preprocess_entrypoints.py │ ├── plans_for_pretraining │ │ ├── __init__.py │ │ └── move_plans_between_datasets.py │ └── verify_dataset_integrity.py ├── imageio │ ├── __init__.py │ ├── base_reader_writer.py │ ├── natural_image_reader_writer.py │ ├── nibabel_reader_writer.py │ ├── reader_writer_registry.py │ ├── readme.md │ ├── simpleitk_reader_writer.py │ └── tif_reader_writer.py ├── inference │ ├── __init__.py │ ├── data_iterators.py │ ├── examples.py │ ├── export_prediction.py │ ├── predict_from_raw_data.py │ ├── readme.md │ └── sliding_window_prediction.py ├── model_sharing │ ├── __init__.py │ ├── entry_points.py │ ├── model_download.py │ ├── model_export.py │ └── model_import.py ├── nets │ ├── UMambaBot_2d.py │ ├── UMambaBot_3d.py │ ├── UMambaEnc_2d.py │ └── UMambaEnc_3d.py ├── paths.py ├── postprocessing │ ├── __init__.py │ └── remove_connected_components.py ├── preprocessing │ ├── __init__.py │ ├── cropping │ │ ├── __init__.py │ │ └── cropping.py │ ├── normalization │ │ ├── __init__.py │ │ ├── default_normalization_schemes.py │ │ ├── map_channel_name_to_normalization.py │ │ └── readme.md │ ├── preprocessors │ │ ├── __init__.py │ │ └── default_preprocessor.py │ └── resampling │ │ ├── __init__.py │ │ ├── default_resampling.py │ │ └── utils.py ├── run │ ├── __init__.py │ ├── load_pretrained_weights.py │ └── run_training.py ├── tests │ ├── __init__.py │ └── integration_tests │ │ ├── __init__.py │ │ ├── add_lowres_and_cascade.py │ │ ├── cleanup_integration_test.py │ │ ├── lsf_commands.sh │ │ ├── prepare_integration_tests.sh │ │ ├── readme.md │ │ ├── run_integration_test.sh │ │ ├── run_integration_test_bestconfig_inference.py │ │ └── run_integration_test_trainingOnly_DDP.sh ├── training │ ├── __init__.py │ ├── data_augmentation │ │ ├── __init__.py │ │ ├── compute_initial_patch_size.py │ │ └── custom_transforms │ │ │ ├── __init__.py │ │ │ ├── cascade_transforms.py │ │ │ ├── deep_supervision_donwsampling.py │ │ │ ├── limited_length_multithreaded_augmenter.py │ │ │ ├── manipulating_data_dict.py │ │ │ ├── masking.py │ │ │ ├── region_based_training.py │ │ │ └── transforms_for_dummy_2d.py │ ├── dataloading │ │ ├── __init__.py │ │ ├── base_data_loader.py │ │ ├── data_loader_2d.py │ │ ├── data_loader_3d.py │ │ ├── nnunet_dataset.py │ │ └── utils.py │ ├── logging │ │ ├── __init__.py │ │ └── nnunet_logger.py │ ├── loss │ │ ├── __init__.py │ │ ├── compound_losses.py │ │ ├── deep_supervision.py │ │ ├── dice.py │ │ └── robust_ce_loss.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ └── polylr.py │ └── nnUNetTrainer │ │ ├── __init__.py │ │ ├── nnUNetTrainer.py │ │ ├── nnUNetTrainerSegResNet.py │ │ ├── nnUNetTrainerSegResNet_2xFeat.py │ │ ├── nnUNetTrainerSegResNet_2xFeat_2xDepth.py │ │ ├── nnUNetTrainerSegResNet_2xFeat_4xDepth.py │ │ ├── nnUNetTrainerSwinUNETR.py │ │ ├── nnUNetTrainerSwinUNETR_Tiny.py │ │ ├── nnUNetTrainerUMambaBot.py │ │ ├── nnUNetTrainerUMambaEnc.py │ │ ├── nnUNetTrainerUMambaEncNoAMP.py │ │ ├── nnUNetTrainerUNETR.py │ │ └── variants │ │ ├── __init__.py │ │ ├── benchmarking │ │ ├── __init__.py │ │ ├── nnUNetTrainerBenchmark_5epochs.py │ │ └── nnUNetTrainerBenchmark_5epochs_noDataLoading.py │ │ ├── data_augmentation │ │ ├── __init__.py │ │ ├── nnUNetTrainerDA5.py │ │ ├── nnUNetTrainerDAOrd0.py │ │ ├── nnUNetTrainerNoDA.py │ │ └── nnUNetTrainerNoMirroring.py │ │ ├── loss │ │ ├── __init__.py │ │ ├── nnUNetTrainerCELoss.py │ │ ├── nnUNetTrainerDiceLoss.py │ │ └── nnUNetTrainerTopkLoss.py │ │ ├── lr_schedule │ │ ├── __init__.py │ │ └── nnUNetTrainerCosAnneal.py │ │ ├── network_architecture │ │ ├── __init__.py │ │ ├── nnUNetTrainerBN.py │ │ └── nnUNetTrainerNoDeepSupervision.py │ │ ├── optimizer │ │ ├── __init__.py │ │ ├── nnUNetTrainerAdam.py │ │ └── nnUNetTrainerAdan.py │ │ ├── sampling │ │ ├── __init__.py │ │ └── nnUNetTrainer_probabilisticOversampling.py │ │ └── training_length │ │ ├── __init__.py │ │ ├── nnUNetTrainer_Xepochs.py │ │ └── nnUNetTrainer_Xepochs_NoMirroring.py └── utilities │ ├── __init__.py │ ├── collate_outputs.py │ ├── dataset_name_id_conversion.py │ ├── ddp_allgather.py │ ├── default_n_proc_DA.py │ ├── file_path_utilities.py │ ├── find_class_by_name.py │ ├── get_network_from_plans.py │ ├── helpers.py │ ├── json_export.py │ ├── label_handling │ ├── __init__.py │ └── label_handling.py │ ├── network_initialization.py │ ├── overlay_plots.py │ ├── plans_handling │ ├── __init__.py │ └── plans_handler.py │ └── utils.py └── setup.py /assets/README.md: -------------------------------------------------------------------------------- 1 | - network architecture: `U-Mamba-network.png` 2 | - segmentation demo: `original_img.png` `seg_mask_overlay.png` 3 | - create_visualization_video.py: script to generate the video visualization demo 4 | -------------------------------------------------------------------------------- /assets/U-Mamba-network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/assets/U-Mamba-network.png -------------------------------------------------------------------------------- /assets/create_visualization_video.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jan 2 20:27:52 2024 5 | 6 | This script is used to generate the visualization video for the segmentation results. 7 | 8 | author (@GitHub ID): Zhihe Wang (@h1shen), Feifei Li (@ff98li), Jun Ma (@JunMa11) 9 | """ 10 | 11 | import cv2 12 | import os 13 | join = os.path.join 14 | listdir = os.listdir 15 | makedirs = os.makedirs 16 | remove = os.remove 17 | isfile = os.path.isfile 18 | isdir = os.path.isdir 19 | basename = os.path.basename 20 | from tqdm import trange 21 | 22 | def slide_image( 23 | img1, 24 | img2, 25 | now_step, 26 | slide_step, 27 | target_size, 28 | line_thick = 5 29 | ): 30 | slide_lenth = target_size[0] 31 | slide_unit = slide_lenth / slide_step 32 | 33 | start = int(slide_unit*now_step + 0.5) 34 | slide_img = img1.copy() 35 | slide_img[:, start:] = img2[:, start:] 36 | cv2.rectangle(slide_img, (start+line_thick, 0), (start, target_size[1]), (255, 255, 255), -1) 37 | 38 | return slide_img 39 | 40 | def generate_video( 41 | fg_img_path, 42 | bg_img_path, 43 | slide_step, 44 | save_name = None, 45 | line_thick = 5, 46 | save_video_dir = './', 47 | ): 48 | if save_name is None: 49 | save_vid_path = join(save_video_dir, basename(fg_img_path).replace('.png', '.mp4')) 50 | else: 51 | save_vid_path = join(save_video_dir, save_name) 52 | if isfile(save_vid_path): 53 | print(f"Video {save_vid_path} already exists. Skipping...") 54 | return 55 | 56 | bg_img = cv2.imread(bg_img_path, cv2.IMREAD_UNCHANGED) 57 | fg_img = cv2.imread(fg_img_path, cv2.IMREAD_UNCHANGED) 58 | 59 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 60 | output_fps = 60 61 | output_width = fg_img.shape[1] 62 | output_height = fg_img.shape[0] 63 | target_size = (output_width, output_height) 64 | vizwriter = cv2.VideoWriter( 65 | save_vid_path, 66 | fourcc, 67 | output_fps, 68 | (output_width, output_height), 69 | True 70 | ) 71 | total_frames = slide_step 72 | 73 | for item in trange(0, slide_step): 74 | frame = slide_image(fg_img, bg_img, item, slide_step, target_size, line_thick) 75 | vizwriter.write(cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)) 76 | if vizwriter: 77 | vizwriter.release() 78 | 79 | 80 | if __name__ == '__main__': 81 | bg_img_path = 'original_img.png' 82 | fg_img_path = 'seg_mask_overlay.png' 83 | slide_step = 200 # the number of frames in the video 84 | save_name = 'visual_seg.mp4' 85 | generate_video(fg_img_path, bg_img_path, slide_step, save_name) -------------------------------------------------------------------------------- /assets/original_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/assets/original_img.png -------------------------------------------------------------------------------- /assets/seg_mask_overlay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/assets/seg_mask_overlay.png -------------------------------------------------------------------------------- /assets/visual_seg.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/assets/visual_seg.mp4 -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Download dataset [here](https://drive.google.com/drive/folders/1DmyIye4Gc9wwaA7MVKFVi-bWD2qQb-qN?usp=sharing) 2 | 3 | Please organize the dataset as follows: 4 | 5 | ``` 6 | data/ 7 | ├── nnUNet_raw/ 8 | │ ├── Dataset701_AbdomenCT/ 9 | │ │ ├── imagesTr 10 | │ │ │ ├── FLARE22_Tr_0001_0000.nii.gz 11 | │ │ │ ├── FLARE22_Tr_0002_0000.nii.gz 12 | │ │ │ ├── ... 13 | │ │ ├── labelsTr 14 | │ │ │ ├── FLARE22_Tr_0001.nii.gz 15 | │ │ │ ├── FLARE22_Tr_0002.nii.gz 16 | │ │ │ ├── ... 17 | │ │ ├── dataset.json 18 | │ ├── Dataset702_AbdomenMR/ 19 | │ │ ├── imagesTr 20 | │ │ │ ├── amos_0507_0000.nii.gz 21 | │ │ │ ├── amos_0508_0000.nii.gz 22 | │ │ │ ├── ... 23 | │ │ ├── labelsTr 24 | │ │ │ ├── amos_0507.nii.gz 25 | │ │ │ ├── amos_0508.nii.gz 26 | │ │ │ ├── ... 27 | │ │ ├── dataset.json 28 | │ ├── ... 29 | ``` 30 | -------------------------------------------------------------------------------- /data/nnUNet_raw/Dataset701_AbdomenCT/dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "channel_names": { 3 | "0": "CT" 4 | }, 5 | "labels": { 6 | "background": 0, 7 | "liver": 1, 8 | "right kidney": 2, 9 | "spleen": 3, 10 | "pancreas": 4, 11 | "aorta": 5, 12 | "inferior vena cava": 6, 13 | "right adrenal gland": 7, 14 | "left adrenal gland": 8, 15 | "gallbladder": 9, 16 | "esophagus": 10, 17 | "stomach": 11, 18 | "duodenum": 12, 19 | "left kidney": 13 20 | }, 21 | "numTraining": 50, 22 | "file_ending": ".nii.gz", 23 | "name": "Dataset701_AbdomenCT", 24 | "description": "This dataset was from MICCAI FLARE 2022 Challenge. The training set contained 50 CT scans that were from the MSD Pancreas dataset and the annotations were from AbdomenCT-1K. Another 50 validation cases were from TCIA and the annotations were provided by the challenge organizers." 25 | } -------------------------------------------------------------------------------- /data/nnUNet_raw/Dataset702_AbdomenMR/dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "channel_names": { 3 | "0": "MR" 4 | }, 5 | "labels": { 6 | "background": 0, 7 | "liver": 1, 8 | "right kidney": 2, 9 | "spleen": 3, 10 | "pancreas": 4, 11 | "aorta": 5, 12 | "inferior vena cava": 6, 13 | "right adrenal gland": 7, 14 | "left adrenal gland": 8, 15 | "gallbladder": 9, 16 | "esophagus": 10, 17 | "stomach": 11, 18 | "duodenum": 12, 19 | "left kidney": 13 20 | }, 21 | "numTraining": 50, 22 | "file_ending": ".nii.gz", 23 | "name": "Dataset702_AbdomenMR", 24 | "description": "This dataset was from MICCAI AMOS 2022 Challenge. The original dataset contained 60 annotation cases. We annotated another 50 MRI scans as the testing set. The annotations were generated by radiologists with the assistance of MedSAM and ITK-SNAP." 25 | } -------------------------------------------------------------------------------- /data/nnUNet_raw/Dataset703_NeurIPSCell/dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "channel_names": { 3 | "0": "R", 4 | "1": "G", 5 | "2": "B" 6 | }, 7 | "labels": { 8 | "background": 0, 9 | "interior": 1, 10 | "boundary": 2 11 | }, 12 | "numTraining": 1000, 13 | "file_ending": ".png", 14 | "name": "Dataset703_NeurIPSCell", 15 | "description": "This dataset was from the NeurIPS 2022 Cell Segmentation Challenge https://neurips22-cellseg.grand-challenge.org/. Please note that this is an instance segmentation task." 16 | } -------------------------------------------------------------------------------- /data/nnUNet_raw/Dataset704_Endovis17/dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "channel_names": { 3 | "0": "R", 4 | "1": "G", 5 | "2": "B" 6 | }, 7 | "labels": { 8 | "background": 0, 9 | "Bipolar Forceps": 1, 10 | "Prograsp Forceps": 2, 11 | "Large Needle Driver": 3, 12 | "Vessel Sealer": 4, 13 | "Grasping Retractor": 5, 14 | "Monopolar Curved Scissors": 6, 15 | "Ultrasound Probe": 7 16 | }, 17 | "numTraining": 1800, 18 | "file_ending": ".png", 19 | "name": "Dataset704_Endovis17", 20 | "description": "This dataset was from the MICCAI 2017 Robotic Instrument Segmentation Challenge https://endovissub2017-roboticinstrumentsegmentation.grand-challenge.org/Home/" 21 | } 22 | -------------------------------------------------------------------------------- /evaluation/__pycache__/SurfaceDice.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/evaluation/__pycache__/SurfaceDice.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/abdomen_DSC_Eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Apr 15 12:59:48 2022 4 | 5 | @author: 12593 6 | """ 7 | 8 | import numpy as np 9 | import nibabel as nb 10 | import os 11 | from collections import OrderedDict 12 | import pandas as pd 13 | from SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance, compute_dice_coefficient 14 | join = os.path.join 15 | basename = os.path.basename 16 | from tqdm import tqdm 17 | 18 | import argparse 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | '--gt_path', 22 | type=str, 23 | default='' 24 | ) 25 | parser.add_argument( 26 | '--seg_path', 27 | type=str, 28 | default='' 29 | ) 30 | parser.add_argument( 31 | '--save_path', 32 | type=str, 33 | default='' 34 | ) 35 | 36 | args = parser.parse_args() 37 | 38 | gt_path = args.gt_path 39 | seg_path = args.seg_path 40 | save_path = args.save_path 41 | 42 | filenames = os.listdir(seg_path) 43 | filenames = [x for x in filenames if x.endswith('.nii.gz')] 44 | filenames = [x for x in filenames if os.path.exists(join(seg_path, x))] 45 | filenames.sort() 46 | 47 | seg_metrics = OrderedDict() 48 | seg_metrics['Name'] = list() 49 | label_tolerance = OrderedDict({'Liver': 5, 'RK':3, 'Spleen':3, 'Pancreas':5, 50 | 'Aorta': 2, 'IVC':2, 'RAG':2, 'LAG':2, 'Gallbladder': 2, 51 | 'Esophagus':3, 'Stomach': 5, 'Duodenum': 7, 'LK':3}) 52 | for organ in label_tolerance.keys(): 53 | seg_metrics['{}_DSC'.format(organ)] = list() 54 | # for organ in label_tolerance.keys(): 55 | # seg_metrics['{}_NSD'.format(organ)] = list() 56 | 57 | def find_lower_upper_zbound(organ_mask): 58 | """ 59 | Parameters 60 | ---------- 61 | seg : TYPE 62 | DESCRIPTION. 63 | 64 | Returns 65 | ------- 66 | z_lower: lower bound in z axis: int 67 | z_upper: upper bound in z axis: int 68 | 69 | """ 70 | organ_mask = np.uint8(organ_mask) 71 | assert np.max(organ_mask) ==1, print('mask label error!') 72 | z_index = np.where(organ_mask>0)[2] 73 | z_lower = np.min(z_index) 74 | z_upper = np.max(z_index) 75 | 76 | return z_lower, z_upper 77 | 78 | 79 | 80 | for name in tqdm(filenames): 81 | seg_metrics['Name'].append(name) 82 | # load grond truth and segmentation 83 | gt_nii = nb.load(join(gt_path, name)) 84 | case_spacing = gt_nii.header.get_zooms() 85 | gt_data = np.uint8(gt_nii.get_fdata()) 86 | seg_data = np.uint8(nb.load(join(seg_path, name)).get_fdata()) 87 | 88 | for i, organ in enumerate(label_tolerance.keys(),1): 89 | if np.sum(gt_data==i)==0 and np.sum(seg_data==i)==0: 90 | DSC_i = 1 91 | NSD_i = 1 92 | elif np.sum(gt_data==i)==0 and np.sum(seg_data==i)>0: 93 | DSC_i = 0 94 | NSD_i = 0 95 | else: 96 | if i==5 or i==6 or i==10: # for Aorta, IVC, and Esophagus, only evaluate the labelled slices in ground truth 97 | z_lower, z_upper = find_lower_upper_zbound(gt_data==i) 98 | organ_i_gt, organ_i_seg = gt_data[:,:,z_lower:z_upper]==i, seg_data[:,:,z_lower:z_upper]==i 99 | else: 100 | organ_i_gt, organ_i_seg = gt_data==i, seg_data==i 101 | 102 | DSC_i = compute_dice_coefficient(organ_i_gt, organ_i_seg) 103 | # surface_distances = compute_surface_distances(organ_i_gt, organ_i_seg, case_spacing) 104 | # NSD_i = compute_surface_dice_at_tolerance(surface_distances, label_tolerance[organ]) 105 | seg_metrics['{}_DSC'.format(organ)].append(round(DSC_i, 4)) 106 | 107 | dataframe = pd.DataFrame(seg_metrics) 108 | #dataframe.to_csv(seg_path + '_DSC.csv', index=False) 109 | dataframe.to_csv(save_path, index=False) 110 | 111 | case_avg_DSC = dataframe.mean(axis=0, numeric_only=True) 112 | print(20 * '>') 113 | print(f'Average DSC for {basename(seg_path)}: {case_avg_DSC.mean()}') 114 | print(20 * '<') -------------------------------------------------------------------------------- /evaluation/abdomen_NSD_Eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Apr 15 12:59:48 2022 4 | 5 | @author: 12593 6 | """ 7 | import sys 8 | import numpy as np 9 | import nibabel as nb 10 | import os 11 | from collections import OrderedDict 12 | import pandas as pd 13 | from SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance, compute_dice_coefficient 14 | join = os.path.join 15 | basename = os.path.basename 16 | from tqdm import tqdm 17 | 18 | import argparse 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | '--gt_path', 22 | type=str, 23 | default='' 24 | ) 25 | parser.add_argument( 26 | '--seg_path', 27 | type=str, 28 | default='' 29 | ) 30 | parser.add_argument( 31 | '--save_path', 32 | type=str, 33 | default='' 34 | ) 35 | 36 | args = parser.parse_args() 37 | 38 | gt_path = args.gt_path 39 | seg_path = args.seg_path 40 | save_path = args.save_path 41 | 42 | filenames = os.listdir(seg_path) 43 | filenames = [x for x in filenames if x.endswith('.nii.gz')] 44 | filenames = [x for x in filenames if os.path.exists(join(seg_path, x))] 45 | filenames.sort() 46 | 47 | seg_metrics = OrderedDict() 48 | seg_metrics['Name'] = list() 49 | label_tolerance = OrderedDict({'Liver': 5, 'RK':3, 'Spleen':3, 'Pancreas':5, 50 | 'Aorta': 2, 'IVC':2, 'RAG':2, 'LAG':2, 'Gallbladder': 2, 51 | 'Esophagus':3, 'Stomach': 5, 'Duodenum': 7, 'LK':3}) 52 | 53 | for organ in label_tolerance.keys(): 54 | seg_metrics['{}_NSD'.format(organ)] = list() 55 | 56 | def find_lower_upper_zbound(organ_mask): 57 | """ 58 | Parameters 59 | ---------- 60 | seg : TYPE 61 | DESCRIPTION. 62 | 63 | Returns 64 | ------- 65 | z_lower: lower bound in z axis: int 66 | z_upper: upper bound in z axis: int 67 | 68 | """ 69 | organ_mask = np.uint8(organ_mask) 70 | assert np.max(organ_mask) ==1, print('mask label error!') 71 | z_index = np.where(organ_mask>0)[2] 72 | z_lower = np.min(z_index) 73 | z_upper = np.max(z_index) 74 | 75 | return z_lower, z_upper 76 | 77 | 78 | 79 | for name in tqdm(filenames): 80 | seg_metrics['Name'].append(name) 81 | # load grond truth and segmentation 82 | gt_nii = nb.load(join(gt_path, name)) 83 | case_spacing = gt_nii.header.get_zooms() 84 | gt_data = np.uint8(gt_nii.get_fdata()) 85 | seg_data = np.uint8(nb.load(join(seg_path, name)).get_fdata()) 86 | 87 | for i, organ in enumerate(label_tolerance.keys(),1): 88 | if np.sum(gt_data==i)==0 and np.sum(seg_data==i)==0: 89 | DSC_i = 1 90 | NSD_i = 1 91 | elif np.sum(gt_data==i)==0 and np.sum(seg_data==i)>0: 92 | DSC_i = 0 93 | NSD_i = 0 94 | else: 95 | if i==5 or i==6 or i==10: # for Aorta, IVC, and Esophagus, only evaluate the labelled slices in ground truth 96 | z_lower, z_upper = find_lower_upper_zbound(gt_data==i) 97 | organ_i_gt, organ_i_seg = gt_data[:,:,z_lower:z_upper]==i, seg_data[:,:,z_lower:z_upper]==i 98 | else: 99 | organ_i_gt, organ_i_seg = gt_data==i, seg_data==i 100 | 101 | #DSC_i = compute_dice_coefficient(organ_i_gt, organ_i_seg) 102 | surface_distances = compute_surface_distances(organ_i_gt, organ_i_seg, case_spacing) 103 | NSD_i = compute_surface_dice_at_tolerance(surface_distances, label_tolerance[organ]) 104 | #seg_metrics['{}_DSC'.format(organ)].append(round(DSC_i, 4)) 105 | seg_metrics['{}_NSD'.format(organ)].append(round(NSD_i, 4)) 106 | 107 | dataframe = pd.DataFrame(seg_metrics) 108 | #dataframe.to_csv(seg_path + '_DSC.csv', index=False) 109 | dataframe.to_csv(save_path, index=False) 110 | 111 | case_avg_NSD = dataframe.mean(axis=0, numeric_only=True) 112 | print(20 * '>') 113 | print(f'Average NSD for {basename(seg_path)}: {case_avg_NSD.mean()}') 114 | print(20 * '<') -------------------------------------------------------------------------------- /evaluation/endoscopy_DSC_Eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Apr 15 12:59:48 2022 4 | 5 | @author: 12593 6 | """ 7 | 8 | import numpy as np 9 | #import nibabel as nb 10 | import cv2 11 | import os 12 | from collections import OrderedDict 13 | import pandas as pd 14 | from SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance, compute_dice_coefficient 15 | join = os.path.join 16 | basename = os.path.basename 17 | from tqdm import tqdm 18 | 19 | import argparse 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | '--gt_path', 23 | type=str, 24 | default='' 25 | ) 26 | parser.add_argument( 27 | '--seg_path', 28 | type=str, 29 | default='' 30 | ) 31 | parser.add_argument( 32 | '--save_path', 33 | type=str, 34 | default='' 35 | ) 36 | 37 | args = parser.parse_args() 38 | 39 | gt_path = args.gt_path 40 | seg_path = args.seg_path 41 | save_path = args.save_path 42 | 43 | filenames = os.listdir(seg_path) 44 | filenames = [x for x in filenames if x.endswith('.png')] 45 | filenames = [x for x in filenames if os.path.exists(join(seg_path, x))] 46 | filenames.sort() 47 | 48 | seg_metrics = OrderedDict( 49 | Name = list(), 50 | DSC = list(), 51 | ) 52 | 53 | for name in tqdm(filenames): 54 | seg_metrics['Name'].append(name) 55 | 56 | gt_mask = cv2.imread(join(gt_path, name), cv2.IMREAD_UNCHANGED) 57 | seg_mask = cv2.imread(join(seg_path, name), cv2.IMREAD_UNCHANGED) 58 | case_spacing = [1,1,1] 59 | gt_data = np.uint8(gt_mask) 60 | seg_data = np.uint8(seg_mask) 61 | 62 | gt_labels = np.unique(gt_data)[1:] 63 | seg_labels = np.unique(seg_data)[1:] 64 | labels = np.union1d(gt_labels, seg_labels) 65 | 66 | assert len(labels) > 0, 'Ground truth mask max: {}'.format(gt_data.max()) 67 | 68 | #DSC_arr = np.zeros(len(labels)) 69 | DSC_arr = [] 70 | for i in labels: 71 | if np.sum(gt_data==i)==0 and np.sum(seg_data==i)==0: 72 | DSC_i = 1 73 | NSD_i = 1 74 | elif np.sum(gt_data==i)==0 and np.sum(seg_data==i)>0: 75 | DSC_i = 0 76 | NSD_i = 0 77 | else: 78 | tool_i_gt, tool_i_seg = gt_data==i, seg_data==i 79 | DSC_i = compute_dice_coefficient(tool_i_gt, tool_i_seg) 80 | 81 | DSC_arr.append(DSC_i) 82 | 83 | DSC = np.mean(DSC_arr) 84 | seg_metrics['DSC'].append(round(DSC, 4)) 85 | 86 | dataframe = pd.DataFrame(seg_metrics) 87 | #dataframe.to_csv(seg_path + '_DSC.csv', index=False) 88 | dataframe.to_csv(save_path, index=False) 89 | 90 | case_avg_DSC = dataframe.mean(axis=0, numeric_only=True) 91 | print(20 * '>') 92 | print(f'Average DSC for {basename(seg_path)}: {case_avg_DSC.mean()}') 93 | print(20 * '<') -------------------------------------------------------------------------------- /evaluation/endoscopy_NSD_Eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Apr 15 12:59:48 2022 4 | 5 | @author: 12593 6 | """ 7 | 8 | import numpy as np 9 | #import nibabel as nb 10 | import cv2 11 | import os 12 | from collections import OrderedDict 13 | import pandas as pd 14 | from SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance, compute_dice_coefficient 15 | join = os.path.join 16 | basename = os.path.basename 17 | from tqdm import tqdm 18 | 19 | import argparse 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | '--gt_path', 23 | type=str, 24 | default='' 25 | ) 26 | parser.add_argument( 27 | '--seg_path', 28 | type=str, 29 | default='' 30 | ) 31 | parser.add_argument( 32 | '--save_path', 33 | type=str, 34 | default='' 35 | ) 36 | 37 | args = parser.parse_args() 38 | 39 | gt_path = args.gt_path 40 | seg_path = args.seg_path 41 | save_path = args.save_path 42 | 43 | filenames = os.listdir(seg_path) 44 | filenames = [x for x in filenames if x.endswith('.png')] 45 | filenames = [x for x in filenames if os.path.exists(join(seg_path, x))] 46 | filenames.sort() 47 | 48 | seg_metrics = OrderedDict( 49 | Name = list(), 50 | NSD = list(), 51 | gt_labels = list(), 52 | seg_labels = list(), 53 | union = list(), 54 | ) 55 | 56 | for name in tqdm(filenames): 57 | seg_metrics['Name'].append(name) 58 | 59 | gt_mask = cv2.imread(join(gt_path, name), cv2.IMREAD_UNCHANGED) 60 | seg_mask = cv2.imread(join(seg_path, name), cv2.IMREAD_UNCHANGED) 61 | case_spacing = [1,1,1] 62 | gt_data = np.uint8(gt_mask) 63 | seg_data = np.uint8(seg_mask) 64 | 65 | gt_labels = np.unique(gt_data)[1:] 66 | seg_metrics['gt_labels'].append(gt_labels.tolist()) 67 | seg_labels = np.unique(seg_data)[1:] 68 | seg_metrics['seg_labels'].append(seg_labels.tolist()) 69 | labels = np.union1d(gt_labels, seg_labels) 70 | seg_metrics['union'].append(labels.tolist()) 71 | 72 | assert len(labels) > 0, 'Ground truth mask max: {}'.format(gt_data.max()) 73 | 74 | #DSC_arr = np.zeros(len(labels)) 75 | #DSC_arr = [] 76 | NSD_arr = [] 77 | for i in labels: 78 | if np.sum(gt_data==i)==0 and np.sum(seg_data==i)==0: 79 | NSD_i = 1 80 | elif np.sum(gt_data==i)==0 and np.sum(seg_data==i)>0: 81 | NSD_i = 0 82 | else: 83 | tool_i_gt, tool_i_seg = gt_data==i, seg_data==i 84 | surface_distances = compute_surface_distances(tool_i_gt[..., None], tool_i_seg[..., None], case_spacing) 85 | NSD_i = compute_surface_dice_at_tolerance(surface_distances, 2) 86 | 87 | NSD_arr.append(NSD_i) 88 | 89 | #DSC = np.mean(DSC_arr) 90 | #seg_metrics['DSC'].append(round(DSC, 4)) 91 | NSD = np.mean(NSD_arr) 92 | seg_metrics['NSD'].append(round(NSD, 4)) 93 | 94 | dataframe = pd.DataFrame(seg_metrics) 95 | #dataframe.to_csv(seg_path + '_DSC.csv', index=False) 96 | dataframe.to_csv(save_path, index=False) 97 | 98 | case_avg_DSC = dataframe.mean(axis=0, numeric_only=True) 99 | print(20 * '>') 100 | print(f'Average NSD for {basename(seg_path)}: {case_avg_DSC.mean()}') 101 | print(20 * '<') -------------------------------------------------------------------------------- /umamba/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | *.memmap 92 | *.png 93 | *.zip 94 | *.npz 95 | *.npy 96 | *.jpg 97 | *.jpeg 98 | .idea 99 | *.txt 100 | .idea/* 101 | *.png 102 | *.nii.gz 103 | *.nii 104 | *.tif 105 | *.bmp 106 | *.pkl 107 | *.xml 108 | *.pkl 109 | *.pdf 110 | *.png 111 | *.jpg 112 | *.jpeg 113 | 114 | *.model 115 | 116 | -------------------------------------------------------------------------------- /umamba/nnunetv2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/batch_running/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/benchmarking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/batch_running/benchmarking/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | """ 3 | This code probably only works within the DKFZ infrastructure (using LSF). You will need to adapt it to your scheduler! 4 | """ 5 | gpu_models = [#'NVIDIAA100_PCIE_40GB', 'NVIDIAGeForceRTX2080Ti', 'NVIDIATITANRTX', 'TeslaV100_SXM2_32GB', 6 | 'NVIDIAA100_SXM4_40GB']#, 'TeslaV100_PCIE_32GB'] 7 | datasets = [2, 3, 4, 5] 8 | trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading'] 9 | plans = ['nnUNetPlans'] 10 | configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x'] 11 | num_gpus = 1 12 | 13 | benchmark_configurations = {d: configs for d in datasets} 14 | 15 | exclude_hosts = "-R \"select[hname!='e230-dgxa100-1']'\"" 16 | resources = "-R \"tensorcore\"" 17 | queue = "-q gpu" 18 | preamble = "-L /bin/bash \"source ~/load_env_torch210.sh && " 19 | train_command = 'nnUNet_compile=False nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_benchmark nnUNetv2_train' 20 | 21 | folds = (0, ) 22 | 23 | use_these_modules = { 24 | tr: plans for tr in trainers 25 | } 26 | 27 | additional_arguments = f' -num_gpus {num_gpus}' # '' 28 | 29 | output_file = "/home/isensee/deleteme.txt" 30 | with open(output_file, 'w') as f: 31 | for g in gpu_models: 32 | gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmodel={g}" 33 | for tr in use_these_modules.keys(): 34 | for p in use_these_modules[tr]: 35 | for dataset in benchmark_configurations.keys(): 36 | for config in benchmark_configurations[dataset]: 37 | for fl in folds: 38 | command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' 39 | if additional_arguments is not None and len(additional_arguments) > 0: 40 | command += f' {additional_arguments}' 41 | f.write(f'{command}\"\n') -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile 2 | from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name 3 | from nnunetv2.paths import nnUNet_results 4 | from nnunetv2.utilities.file_path_utilities import get_output_folder 5 | 6 | if __name__ == '__main__': 7 | trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading'] 8 | datasets = [2, 3, 4, 5] 9 | plans = ['nnUNetPlans'] 10 | configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x'] 11 | output_file = join(nnUNet_results, 'benchmark_results.csv') 12 | 13 | torch_version = '2.1.0.dev20230330'#"2.0.0"#"2.1.0.dev20230328" #"1.11.0a0+gitbc2c6ed" # 14 | cudnn_version = 8700 # 8302 # 15 | num_gpus = 1 16 | 17 | unique_gpus = set() 18 | 19 | # collect results in the most janky way possible. Amazing coding skills! 20 | all_results = {} 21 | for tr in trainers: 22 | all_results[tr] = {} 23 | for p in plans: 24 | all_results[tr][p] = {} 25 | for c in configs: 26 | all_results[tr][p][c] = {} 27 | for d in datasets: 28 | dataset_name = maybe_convert_to_dataset_name(d) 29 | output_folder = get_output_folder(dataset_name, tr, p, c, fold=0) 30 | expected_benchmark_file = join(output_folder, 'benchmark_result.json') 31 | all_results[tr][p][c][d] = {} 32 | if isfile(expected_benchmark_file): 33 | # filter results for what we want 34 | results = [i for i in load_json(expected_benchmark_file).values() 35 | if i['num_gpus'] == num_gpus and i['cudnn_version'] == cudnn_version and 36 | i['torch_version'] == torch_version] 37 | for r in results: 38 | all_results[tr][p][c][d][r['gpu_name']] = r 39 | unique_gpus.add(r['gpu_name']) 40 | 41 | # haha. Fuck this. Collect GPUs in the code above. 42 | # unique_gpus = np.unique([i["gpu_name"] for tr in trainers for p in plans for c in configs for d in datasets for i in all_results[tr][p][c][d]]) 43 | 44 | unique_gpus = list(unique_gpus) 45 | unique_gpus.sort() 46 | 47 | with open(output_file, 'w') as f: 48 | f.write('Dataset,Trainer,Plans,Config') 49 | for g in unique_gpus: 50 | f.write(f",{g}") 51 | f.write("\n") 52 | for d in datasets: 53 | for tr in trainers: 54 | for p in plans: 55 | for c in configs: 56 | gpu_results = [] 57 | for g in unique_gpus: 58 | if g in all_results[tr][p][c][d].keys(): 59 | gpu_results.append(round(all_results[tr][p][c][d][g]["fastest_epoch"], ndigits=2)) 60 | else: 61 | gpu_results.append("MISSING") 62 | # skip if all are missing 63 | if all([i == 'MISSING' for i in gpu_results]): 64 | continue 65 | f.write(f"{d},{tr},{p},{c}") 66 | for g in gpu_results: 67 | f.write(f",{g}") 68 | f.write("\n") 69 | f.write("\n") 70 | 71 | -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.utilities.file_and_folder_operations import * 2 | 3 | from nnunetv2.batch_running.collect_results_custom_Decathlon import collect_results, summarize 4 | from nnunetv2.paths import nnUNet_results 5 | 6 | if __name__ == '__main__': 7 | use_these_trainers = { 8 | 'nnUNetTrainer': ('nnUNetPlans', ), 9 | } 10 | all_results_file = join(nnUNet_results, 'hrnet_results.csv') 11 | datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82] 12 | collect_results(use_these_trainers, datasets, all_results_file) 13 | 14 | folds = (0, ) 15 | configs = ('2d', ) 16 | output_file = join(nnUNet_results, 'hrnet_results_summary_fold0.csv') 17 | summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) 18 | 19 | -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | 4 | 5 | def merge(dict1, dict2): 6 | keys = np.unique(list(dict1.keys()) + list(dict2.keys())) 7 | keys = np.unique(keys) 8 | res = {} 9 | for k in keys: 10 | all_configs = [] 11 | if dict1.get(k) is not None: 12 | all_configs += list(dict1[k]) 13 | if dict2.get(k) is not None: 14 | all_configs += list(dict2[k]) 15 | if len(all_configs) > 0: 16 | res[k] = tuple(np.unique(all_configs)) 17 | return res 18 | 19 | 20 | if __name__ == "__main__": 21 | # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of 22 | # datasets for evaluation and future development 23 | configurations_all = { 24 | 2: ("3d_fullres", "2d"), 25 | 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 26 | 4: ("2d", "3d_fullres"), 27 | 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 28 | 20: ("2d", "3d_fullres"), 29 | 24: ("2d", "3d_fullres"), 30 | 27: ("2d", "3d_fullres"), 31 | 38: ("2d", "3d_fullres"), 32 | 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 33 | 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 34 | 82: ("2d", "3d_fullres"), 35 | # 83: ("2d", "3d_fullres"), 36 | } 37 | 38 | configurations_3d_fr_only = { 39 | i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i] 40 | } 41 | 42 | configurations_3d_c_only = { 43 | i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i] 44 | } 45 | 46 | configurations_3d_lr_only = { 47 | i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i] 48 | } 49 | 50 | configurations_2d_only = { 51 | i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i] 52 | } 53 | 54 | num_gpus = 1 55 | exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\" -R \"select[hname!='e230-dgx1-1']\" -R \"select[hname!='e230-dgxa100-1']\" -R \"select[hname!='e230-dgxa100-2']\" -R \"select[hname!='e230-dgxa100-3']\" -R \"select[hname!='e230-dgxa100-4']\"" 56 | resources = "-R \"tensorcore\"" 57 | gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=33G" 58 | queue = "-q gpu-lowprio" 59 | preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " 60 | train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_release nnUNetv2_train' 61 | 62 | folds = (0, ) 63 | # use_this = configurations_2d_only 64 | use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) 65 | # use_this = merge(use_this, configurations_3d_c_only) 66 | 67 | use_these_modules = { 68 | 'nnUNetTrainer': ('nnUNetPlans',), 69 | 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), 70 | # 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), 71 | } 72 | 73 | additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' 74 | 75 | output_file = "/home/isensee/deleteme.txt" 76 | with open(output_file, 'w') as f: 77 | for tr in use_these_modules.keys(): 78 | for p in use_these_modules[tr]: 79 | for dataset in use_this.keys(): 80 | for config in use_this[dataset]: 81 | for fl in folds: 82 | command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' 83 | if additional_arguments is not None and len(additional_arguments) > 0: 84 | command += f' {additional_arguments}' 85 | f.write(f'{command}\"\n') 86 | 87 | -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/release_trainings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/batch_running/release_trainings/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | 4 | 5 | def merge(dict1, dict2): 6 | keys = np.unique(list(dict1.keys()) + list(dict2.keys())) 7 | keys = np.unique(keys) 8 | res = {} 9 | for k in keys: 10 | all_configs = [] 11 | if dict1.get(k) is not None: 12 | all_configs += list(dict1[k]) 13 | if dict2.get(k) is not None: 14 | all_configs += list(dict2[k]) 15 | if len(all_configs) > 0: 16 | res[k] = tuple(np.unique(all_configs)) 17 | return res 18 | 19 | 20 | if __name__ == "__main__": 21 | # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of 22 | # datasets for evaluation and future development 23 | configurations_all = { 24 | # 1: ("3d_fullres", "2d"), 25 | 2: ("3d_fullres", "2d"), 26 | # 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 27 | # 4: ("2d", "3d_fullres"), 28 | 5: ("2d", "3d_fullres"), 29 | # 6: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 30 | # 7: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 31 | # 8: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 32 | # 9: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 33 | # 10: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 34 | # 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 35 | 20: ("2d", "3d_fullres"), 36 | 24: ("2d", "3d_fullres"), 37 | 27: ("2d", "3d_fullres"), 38 | 35: ("2d", "3d_fullres"), 39 | 38: ("2d", "3d_fullres"), 40 | # 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 41 | # 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 42 | # 82: ("2d", "3d_fullres"), 43 | # 83: ("2d", "3d_fullres"), 44 | } 45 | 46 | configurations_3d_fr_only = { 47 | i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i] 48 | } 49 | 50 | configurations_3d_c_only = { 51 | i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i] 52 | } 53 | 54 | configurations_3d_lr_only = { 55 | i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i] 56 | } 57 | 58 | configurations_2d_only = { 59 | i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i] 60 | } 61 | 62 | num_gpus = 1 63 | exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"" 64 | resources = "-R \"tensorcore\"" 65 | gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=1G" 66 | queue = "-q gpu-lowprio" 67 | preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " 68 | train_command = 'nnUNet_keep_files_open=True nnUNet_results=/dkfz/cluster/gpu/data/OE0441/isensee/nnUNet_results_remake_release_normfix nnUNetv2_train' 69 | 70 | folds = (0, 1, 2, 3, 4) 71 | # use_this = configurations_2d_only 72 | # use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) 73 | # use_this = merge(use_this, configurations_3d_c_only) 74 | use_this = configurations_all 75 | 76 | use_these_modules = { 77 | 'nnUNetTrainer': ('nnUNetPlans',), 78 | } 79 | 80 | additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' 81 | 82 | output_file = "/home/isensee/deleteme.txt" 83 | with open(output_file, 'w') as f: 84 | for tr in use_these_modules.keys(): 85 | for p in use_these_modules[tr]: 86 | for dataset in use_this.keys(): 87 | for config in use_this[dataset]: 88 | for fl in folds: 89 | command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' 90 | if additional_arguments is not None and len(additional_arguments) > 0: 91 | command += f' {additional_arguments}' 92 | f.write(f'{command}\"\n') 93 | 94 | -------------------------------------------------------------------------------- /umamba/nnunetv2/configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA 4 | 5 | default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc']) 6 | 7 | ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low 8 | # resolution axis must be 3x as large as the next largest spacing) 9 | 10 | default_n_proc_DA = get_allowed_n_proc_DA() 11 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/Dataset027_ACDC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 6 | from nnunetv2.paths import nnUNet_raw 7 | 8 | 9 | def make_out_dirs(dataset_id: int, task_name="ACDC"): 10 | dataset_name = f"Dataset{dataset_id:03d}_{task_name}" 11 | 12 | out_dir = Path(nnUNet_raw.replace('"', "")) / dataset_name 13 | out_train_dir = out_dir / "imagesTr" 14 | out_labels_dir = out_dir / "labelsTr" 15 | out_test_dir = out_dir / "imagesTs" 16 | 17 | os.makedirs(out_dir, exist_ok=True) 18 | os.makedirs(out_train_dir, exist_ok=True) 19 | os.makedirs(out_labels_dir, exist_ok=True) 20 | os.makedirs(out_test_dir, exist_ok=True) 21 | 22 | return out_dir, out_train_dir, out_labels_dir, out_test_dir 23 | 24 | 25 | def copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path): 26 | """Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases.""" 27 | patients_train = sorted([f for f in (src_data_folder / "training").iterdir() if f.is_dir()]) 28 | patients_test = sorted([f for f in (src_data_folder / "testing").iterdir() if f.is_dir()]) 29 | 30 | num_training_cases = 0 31 | # Copy training files and corresponding labels. 32 | for patient_dir in patients_train: 33 | for file in patient_dir.iterdir(): 34 | if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name: 35 | # The stem is 'patient.nii', and the suffix is '.gz'. 36 | # We split the stem and append _0000 to the patient part. 37 | shutil.copy(file, train_dir / f"{file.stem.split('.')[0]}_0000.nii.gz") 38 | num_training_cases += 1 39 | elif file.suffix == ".gz" and "_gt" in file.name: 40 | shutil.copy(file, labels_dir / file.name.replace("_gt", "")) 41 | 42 | # Copy test files. 43 | for patient_dir in patients_test: 44 | for file in patient_dir.iterdir(): 45 | if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name: 46 | shutil.copy(file, test_dir / f"{file.stem.split('.')[0]}_0000.nii.gz") 47 | 48 | return num_training_cases 49 | 50 | 51 | def convert_acdc(src_data_folder: str, dataset_id=27): 52 | out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id) 53 | num_training_cases = copy_files(Path(src_data_folder), train_dir, labels_dir, test_dir) 54 | 55 | generate_dataset_json( 56 | str(out_dir), 57 | channel_names={ 58 | 0: "cineMRI", 59 | }, 60 | labels={ 61 | "background": 0, 62 | "RV": 1, 63 | "MLV": 2, 64 | "LVC": 3, 65 | }, 66 | file_ending=".nii.gz", 67 | num_training_cases=num_training_cases, 68 | ) 69 | 70 | 71 | if __name__ == "__main__": 72 | import argparse 73 | 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument( 76 | "-i", 77 | "--input_folder", 78 | type=str, 79 | help="The downloaded ACDC dataset dir. Should contain extracted 'training' and 'testing' folders.", 80 | ) 81 | parser.add_argument( 82 | "-d", "--dataset_id", required=False, type=int, default=27, help="nnU-Net Dataset ID, default: 27" 83 | ) 84 | args = parser.parse_args() 85 | print("Converting...") 86 | convert_acdc(args.input_folder, args.dataset_id) 87 | print("Done!") 88 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/Dataset115_EMIDEC.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs 5 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 6 | 7 | 8 | def copy_files(src_data_dir: Path, src_test_dir: Path, train_dir: Path, labels_dir: Path, test_dir: Path): 9 | """Copy files from the EMIDEC dataset to the nnUNet dataset folder. Returns the number of training cases.""" 10 | patients_train = sorted([f for f in src_data_dir.iterdir() if f.is_dir()]) 11 | patients_test = sorted([f for f in src_test_dir.iterdir() if f.is_dir()]) 12 | 13 | # Copy training files and corresponding labels. 14 | for patient in patients_train: 15 | train_file = patient / "Images" / f"{patient.name}.nii.gz" 16 | label_file = patient / "Contours" / f"{patient.name}.nii.gz" 17 | shutil.copy(train_file, train_dir / f"{train_file.stem.split('.')[0]}_0000.nii.gz") 18 | shutil.copy(label_file, labels_dir) 19 | 20 | # Copy test files. 21 | for patient in patients_test: 22 | test_file = patient / "Images" / f"{patient.name}.nii.gz" 23 | shutil.copy(test_file, test_dir / f"{test_file.stem.split('.')[0]}_0000.nii.gz") 24 | 25 | return len(patients_train) 26 | 27 | 28 | def convert_emidec(src_data_dir: str, src_test_dir: str, dataset_id=27): 29 | out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id, task_name="EMIDEC") 30 | num_training_cases = copy_files(Path(src_data_dir), Path(src_test_dir), train_dir, labels_dir, test_dir) 31 | 32 | generate_dataset_json( 33 | str(out_dir), 34 | channel_names={ 35 | 0: "cineMRI", 36 | }, 37 | labels={ 38 | "background": 0, 39 | "cavity": 1, 40 | "normal_myocardium": 2, 41 | "myocardial_infarction": 3, 42 | "no_reflow": 4, 43 | }, 44 | file_ending=".nii.gz", 45 | num_training_cases=num_training_cases, 46 | ) 47 | 48 | 49 | if __name__ == "__main__": 50 | import argparse 51 | 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("-i", "--input_dir", type=str, help="The EMIDEC dataset directory.") 54 | parser.add_argument("-t", "--test_dir", type=str, help="The EMIDEC test set directory.") 55 | parser.add_argument( 56 | "-d", "--dataset_id", required=False, type=int, default=115, help="nnU-Net Dataset ID, default: 115" 57 | ) 58 | args = parser.parse_args() 59 | print("Converting...") 60 | convert_emidec(args.input_dir, args.test_dir, args.dataset_id) 61 | print("Done!") 62 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import shutil 3 | from multiprocessing import Pool 4 | 5 | from batchgenerators.utilities.file_and_folder_operations import * 6 | 7 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 8 | from nnunetv2.paths import nnUNet_raw 9 | from skimage import io 10 | from acvl_utils.morphology.morphology_helper import generic_filter_components 11 | from scipy.ndimage import binary_fill_holes 12 | 13 | 14 | def load_and_covnert_case(input_image: str, input_seg: str, output_image: str, output_seg: str, 15 | min_component_size: int = 50): 16 | seg = io.imread(input_seg) 17 | seg[seg == 255] = 1 18 | image = io.imread(input_image) 19 | image = image.sum(2) 20 | mask = image == (3 * 255) 21 | # the dataset has large white areas in which road segmentations can exist but no image information is available. 22 | # Remove the road label in these areas 23 | mask = generic_filter_components(mask, filter_fn=lambda ids, sizes: [i for j, i in enumerate(ids) if 24 | sizes[j] > min_component_size]) 25 | mask = binary_fill_holes(mask) 26 | seg[mask] = 0 27 | io.imsave(output_seg, seg, check_contrast=False) 28 | shutil.copy(input_image, output_image) 29 | 30 | 31 | if __name__ == "__main__": 32 | # extracted archive from https://www.kaggle.com/datasets/insaff/massachusetts-roads-dataset?resource=download 33 | source = '/media/fabian/data/raw_datasets/Massachussetts_road_seg/road_segmentation_ideal' 34 | 35 | dataset_name = 'Dataset120_RoadSegmentation' 36 | 37 | imagestr = join(nnUNet_raw, dataset_name, 'imagesTr') 38 | imagests = join(nnUNet_raw, dataset_name, 'imagesTs') 39 | labelstr = join(nnUNet_raw, dataset_name, 'labelsTr') 40 | labelsts = join(nnUNet_raw, dataset_name, 'labelsTs') 41 | maybe_mkdir_p(imagestr) 42 | maybe_mkdir_p(imagests) 43 | maybe_mkdir_p(labelstr) 44 | maybe_mkdir_p(labelsts) 45 | 46 | train_source = join(source, 'training') 47 | test_source = join(source, 'testing') 48 | 49 | with multiprocessing.get_context("spawn").Pool(8) as p: 50 | 51 | # not all training images have a segmentation 52 | valid_ids = subfiles(join(train_source, 'output'), join=False, suffix='png') 53 | num_train = len(valid_ids) 54 | r = [] 55 | for v in valid_ids: 56 | r.append( 57 | p.starmap_async( 58 | load_and_covnert_case, 59 | (( 60 | join(train_source, 'input', v), 61 | join(train_source, 'output', v), 62 | join(imagestr, v[:-4] + '_0000.png'), 63 | join(labelstr, v), 64 | 50 65 | ),) 66 | ) 67 | ) 68 | 69 | # test set 70 | valid_ids = subfiles(join(test_source, 'output'), join=False, suffix='png') 71 | for v in valid_ids: 72 | r.append( 73 | p.starmap_async( 74 | load_and_covnert_case, 75 | (( 76 | join(test_source, 'input', v), 77 | join(test_source, 'output', v), 78 | join(imagests, v[:-4] + '_0000.png'), 79 | join(labelsts, v), 80 | 50 81 | ),) 82 | ) 83 | ) 84 | _ = [i.get() for i in r] 85 | 86 | generate_dataset_json(join(nnUNet_raw, dataset_name), {0: 'R', 1: 'G', 2: 'B'}, {'background': 0, 'road': 1}, 87 | num_train, '.png', dataset_name=dataset_name) 88 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.utilities.file_and_folder_operations import * 2 | import shutil 3 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 4 | from nnunetv2.paths import nnUNet_raw 5 | 6 | 7 | def convert_amos_task1(amos_base_dir: str, nnunet_dataset_id: int = 218): 8 | """ 9 | AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into 10 | the train set. Having a 5-fold cross-validation is superior to a single train:val split 11 | """ 12 | task_name = "AMOS2022_postChallenge_task1" 13 | 14 | foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) 15 | 16 | # setting up nnU-Net folders 17 | out_base = join(nnUNet_raw, foldername) 18 | imagestr = join(out_base, "imagesTr") 19 | imagests = join(out_base, "imagesTs") 20 | labelstr = join(out_base, "labelsTr") 21 | maybe_mkdir_p(imagestr) 22 | maybe_mkdir_p(imagests) 23 | maybe_mkdir_p(labelstr) 24 | 25 | dataset_json_source = load_json(join(amos_base_dir, 'dataset.json')) 26 | 27 | training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']] 28 | tr_ctr = 0 29 | for tr in training_identifiers: 30 | if int(tr.split("_")[-1]) <= 410: # these are the CT images 31 | tr_ctr += 1 32 | shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) 33 | shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz')) 34 | 35 | test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']] 36 | for ts in test_identifiers: 37 | if int(ts.split("_")[-1]) <= 500: # these are the CT images 38 | shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz')) 39 | 40 | val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']] 41 | for vl in val_identifiers: 42 | if int(vl.split("_")[-1]) <= 409: # these are the CT images 43 | tr_ctr += 1 44 | shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz')) 45 | shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz')) 46 | 47 | generate_dataset_json(out_base, {0: "CT"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()}, 48 | num_training_cases=tr_ctr, file_ending='.nii.gz', 49 | dataset_name=task_name, reference='https://amos22.grand-challenge.org/', 50 | release='https://zenodo.org/record/7262581', 51 | overwrite_image_reader_writer='NibabelIOWithReorient', 52 | description="This is the dataset as released AFTER the challenge event. It has the " 53 | "validation set gt in it! We just use the validation images as additional " 54 | "training cases because AMOS doesn't specify how they should be used. nnU-Net's" 55 | " 5-fold CV is better than some random train:val split.") 56 | 57 | 58 | if __name__ == '__main__': 59 | import argparse 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('input_folder', type=str, 62 | help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. " 63 | "Use this link: https://zenodo.org/record/7262581." 64 | "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!") 65 | parser.add_argument('-d', required=False, type=int, default=218, help='nnU-Net Dataset ID, default: 218') 66 | args = parser.parse_args() 67 | amos_base = args.input_folder 68 | convert_amos_task1(amos_base, args.d) 69 | 70 | 71 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.utilities.file_and_folder_operations import * 2 | import shutil 3 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 4 | from nnunetv2.paths import nnUNet_raw 5 | 6 | 7 | def convert_amos_task2(amos_base_dir: str, nnunet_dataset_id: int = 219): 8 | """ 9 | AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into 10 | the train set. Having a 5-fold cross-validation is superior to a single train:val split 11 | """ 12 | task_name = "AMOS2022_postChallenge_task2" 13 | 14 | foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) 15 | 16 | # setting up nnU-Net folders 17 | out_base = join(nnUNet_raw, foldername) 18 | imagestr = join(out_base, "imagesTr") 19 | imagests = join(out_base, "imagesTs") 20 | labelstr = join(out_base, "labelsTr") 21 | maybe_mkdir_p(imagestr) 22 | maybe_mkdir_p(imagests) 23 | maybe_mkdir_p(labelstr) 24 | 25 | dataset_json_source = load_json(join(amos_base_dir, 'dataset.json')) 26 | 27 | training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']] 28 | for tr in training_identifiers: 29 | shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) 30 | shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz')) 31 | 32 | test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']] 33 | for ts in test_identifiers: 34 | shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz')) 35 | 36 | val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']] 37 | for vl in val_identifiers: 38 | shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz')) 39 | shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz')) 40 | 41 | generate_dataset_json(out_base, {0: "either_CT_or_MR"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()}, 42 | num_training_cases=len(training_identifiers) + len(val_identifiers), file_ending='.nii.gz', 43 | dataset_name=task_name, reference='https://amos22.grand-challenge.org/', 44 | release='https://zenodo.org/record/7262581', 45 | overwrite_image_reader_writer='NibabelIOWithReorient', 46 | description="This is the dataset as released AFTER the challenge event. It has the " 47 | "validation set gt in it! We just use the validation images as additional " 48 | "training cases because AMOS doesn't specify how they should be used. nnU-Net's" 49 | " 5-fold CV is better than some random train:val split.") 50 | 51 | 52 | if __name__ == '__main__': 53 | import argparse 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('input_folder', type=str, 56 | help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. " 57 | "Use this link: https://zenodo.org/record/7262581." 58 | "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!") 59 | parser.add_argument('-d', required=False, type=int, default=219, help='nnU-Net Dataset ID, default: 219') 60 | args = parser.parse_args() 61 | amos_base = args.input_folder 62 | convert_amos_task2(amos_base, args.d) 63 | 64 | # /home/isensee/Downloads/amos22/amos22/ 65 | 66 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.utilities.file_and_folder_operations import * 2 | import shutil 3 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 4 | from nnunetv2.paths import nnUNet_raw 5 | 6 | 7 | def convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 220): 8 | task_name = "KiTS2023" 9 | 10 | foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) 11 | 12 | # setting up nnU-Net folders 13 | out_base = join(nnUNet_raw, foldername) 14 | imagestr = join(out_base, "imagesTr") 15 | labelstr = join(out_base, "labelsTr") 16 | maybe_mkdir_p(imagestr) 17 | maybe_mkdir_p(labelstr) 18 | 19 | cases = subdirs(kits_base_dir, prefix='case_', join=False) 20 | for tr in cases: 21 | shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) 22 | shutil.copy(join(kits_base_dir, tr, 'segmentation.nii.gz'), join(labelstr, f'{tr}.nii.gz')) 23 | 24 | generate_dataset_json(out_base, {0: "CT"}, 25 | labels={ 26 | "background": 0, 27 | "kidney": (1, 2, 3), 28 | "masses": (2, 3), 29 | "tumor": 2 30 | }, 31 | regions_class_order=(1, 3, 2), 32 | num_training_cases=len(cases), file_ending='.nii.gz', 33 | dataset_name=task_name, reference='none', 34 | release='prerelease', 35 | overwrite_image_reader_writer='NibabelIOWithReorient', 36 | description="KiTS2023") 37 | 38 | 39 | if __name__ == '__main__': 40 | import argparse 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('input_folder', type=str, 43 | help="The downloaded and extracted KiTS2023 dataset (must have case_XXXXX subfolders)") 44 | parser.add_argument('-d', required=False, type=int, default=220, help='nnU-Net Dataset ID, default: 220') 45 | args = parser.parse_args() 46 | amos_base = args.input_folder 47 | convert_kits2023(amos_base, args.d) 48 | 49 | # /media/isensee/raw_data/raw_datasets/kits23/dataset 50 | 51 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.utilities.file_and_folder_operations import * 2 | import shutil 3 | from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json 4 | from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed 5 | 6 | 7 | def convert_autopet(autopet_base_dir:str = '/media/isensee/My Book1/AutoPET/nifti/FDG-PET-CT-Lesions', 8 | nnunet_dataset_id: int = 221): 9 | task_name = "AutoPETII_2023" 10 | 11 | foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) 12 | 13 | # setting up nnU-Net folders 14 | out_base = join(nnUNet_raw, foldername) 15 | imagestr = join(out_base, "imagesTr") 16 | labelstr = join(out_base, "labelsTr") 17 | maybe_mkdir_p(imagestr) 18 | maybe_mkdir_p(labelstr) 19 | 20 | patients = subdirs(autopet_base_dir, prefix='PETCT', join=False) 21 | n = 0 22 | identifiers = [] 23 | for pat in patients: 24 | patient_acquisitions = subdirs(join(autopet_base_dir, pat), join=False) 25 | for pa in patient_acquisitions: 26 | n += 1 27 | identifier = f"{pat}_{pa}" 28 | identifiers.append(identifier) 29 | if not isfile(join(imagestr, f'{identifier}_0000.nii.gz')): 30 | shutil.copy(join(autopet_base_dir, pat, pa, 'CTres.nii.gz'), join(imagestr, f'{identifier}_0000.nii.gz')) 31 | if not isfile(join(imagestr, f'{identifier}_0001.nii.gz')): 32 | shutil.copy(join(autopet_base_dir, pat, pa, 'SUV.nii.gz'), join(imagestr, f'{identifier}_0001.nii.gz')) 33 | if not isfile(join(imagestr, f'{identifier}.nii.gz')): 34 | shutil.copy(join(autopet_base_dir, pat, pa, 'SEG.nii.gz'), join(labelstr, f'{identifier}.nii.gz')) 35 | 36 | generate_dataset_json(out_base, {0: "CT", 1:"CT"}, 37 | labels={ 38 | "background": 0, 39 | "tumor": 1 40 | }, 41 | num_training_cases=n, file_ending='.nii.gz', 42 | dataset_name=task_name, reference='https://autopet-ii.grand-challenge.org/', 43 | release='release', 44 | # overwrite_image_reader_writer='NibabelIOWithReorient', 45 | description=task_name) 46 | 47 | # manual split 48 | splits = [] 49 | for fold in range(5): 50 | val_patients = patients[fold :: 5] 51 | splits.append( 52 | { 53 | 'train': [i for i in identifiers if not any([i.startswith(v) for v in val_patients])], 54 | 'val': [i for i in identifiers if any([i.startswith(v) for v in val_patients])], 55 | } 56 | ) 57 | pp_out_dir = join(nnUNet_preprocessed, foldername) 58 | maybe_mkdir_p(pp_out_dir) 59 | save_json(splits, join(pp_out_dir, 'splits_final.json'), sort_keys=False) 60 | 61 | 62 | if __name__ == '__main__': 63 | import argparse 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('input_folder', type=str, 66 | help="The downloaded and extracted autopet dataset (must have PETCT_XXX subfolders)") 67 | parser.add_argument('-d', required=False, type=int, default=221, help='nnU-Net Dataset ID, default: 221') 68 | args = parser.parse_args() 69 | amos_base = args.input_folder 70 | convert_autopet(amos_base, args.d) 71 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from batchgenerators.utilities.file_and_folder_operations import * 4 | 5 | from nnunetv2.paths import nnUNet_raw 6 | from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets 7 | 8 | if __name__ == '__main__': 9 | # creates a dummy dataset where there are no files in imagestr and labelstr 10 | source_dataset = 'Dataset004_Hippocampus' 11 | 12 | target_dataset = 'Dataset987_dummyDataset4' 13 | target_dataset_dir = join(nnUNet_raw, target_dataset) 14 | maybe_mkdir_p(target_dataset_dir) 15 | 16 | dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset)) 17 | 18 | # the returned dataset will have absolute paths. We should use relative paths so that you can freely copy 19 | # datasets around between systems. As long as the source dataset is there it will continue working even if 20 | # nnUNet_raw is in different locations 21 | 22 | # paths must be relative to target_dataset_dir!!! 23 | for k in dataset.keys(): 24 | dataset[k]['label'] = os.path.relpath(dataset[k]['label'], target_dataset_dir) 25 | dataset[k]['images'] = [os.path.relpath(i, target_dataset_dir) for i in dataset[k]['images']] 26 | 27 | # load old dataset.json 28 | dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json')) 29 | dataset_json['dataset'] = dataset 30 | 31 | # save 32 | save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False) 33 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/dataset_conversion/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from copy import deepcopy 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, isdir, load_json, save_json 5 | from nnunetv2.paths import nnUNet_raw 6 | 7 | 8 | def convert(source_folder, target_dataset_name): 9 | """ 10 | remember that old tasks were called TaskXXX_YYY and new ones are called DatasetXXX_YYY 11 | source_folder 12 | """ 13 | if isdir(join(nnUNet_raw, target_dataset_name)): 14 | raise RuntimeError(f'Target dataset name {target_dataset_name} already exists. Aborting... ' 15 | f'(we might break something). If you are sure you want to proceed, please manually ' 16 | f'delete {join(nnUNet_raw, target_dataset_name)}') 17 | maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) 18 | shutil.copytree(join(source_folder, 'imagesTr'), join(nnUNet_raw, target_dataset_name, 'imagesTr')) 19 | shutil.copytree(join(source_folder, 'labelsTr'), join(nnUNet_raw, target_dataset_name, 'labelsTr')) 20 | if isdir(join(source_folder, 'imagesTs')): 21 | shutil.copytree(join(source_folder, 'imagesTs'), join(nnUNet_raw, target_dataset_name, 'imagesTs')) 22 | if isdir(join(source_folder, 'labelsTs')): 23 | shutil.copytree(join(source_folder, 'labelsTs'), join(nnUNet_raw, target_dataset_name, 'labelsTs')) 24 | if isdir(join(source_folder, 'imagesVal')): 25 | shutil.copytree(join(source_folder, 'imagesVal'), join(nnUNet_raw, target_dataset_name, 'imagesVal')) 26 | if isdir(join(source_folder, 'labelsVal')): 27 | shutil.copytree(join(source_folder, 'labelsVal'), join(nnUNet_raw, target_dataset_name, 'labelsVal')) 28 | shutil.copy(join(source_folder, 'dataset.json'), join(nnUNet_raw, target_dataset_name)) 29 | 30 | dataset_json = load_json(join(nnUNet_raw, target_dataset_name, 'dataset.json')) 31 | del dataset_json['tensorImageSize'] 32 | del dataset_json['numTest'] 33 | del dataset_json['training'] 34 | del dataset_json['test'] 35 | dataset_json['channel_names'] = deepcopy(dataset_json['modality']) 36 | del dataset_json['modality'] 37 | 38 | dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()} 39 | dataset_json['file_ending'] = ".nii.gz" 40 | save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False) 41 | 42 | 43 | def convert_entry_point(): 44 | import argparse 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("input_folder", type=str, 47 | help='Raw old nnUNet dataset. This must be the folder with imagesTr,labelsTr etc subfolders! ' 48 | 'Please provide the PATH to the old Task, not just the task name. nnU-Net V2 does not ' 49 | 'know where v1 tasks are.') 50 | parser.add_argument("output_dataset_name", type=str, 51 | help='New dataset NAME (not path!). Must follow the DatasetXXX_NAME convention!') 52 | args = parser.parse_args() 53 | convert(args.input_folder, args.output_dataset_name) 54 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import shutil 3 | 4 | import numpy as np 5 | from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json, nifti_files 6 | 7 | from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name 8 | from nnunetv2.paths import nnUNet_raw 9 | from nnunetv2.utilities.label_handling.label_handling import LabelManager 10 | from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager 11 | 12 | 13 | def sparsify_segmentation(seg: np.ndarray, label_manager: LabelManager, percent_of_slices: float) -> np.ndarray: 14 | assert label_manager.has_ignore_label, "This preprocessor only works with datasets that have an ignore label!" 15 | seg_new = np.ones_like(seg) * label_manager.ignore_label 16 | x, y, z = seg.shape 17 | # x 18 | num_slices = max(1, round(x * percent_of_slices)) 19 | selected_slices = np.random.choice(x, num_slices, replace=False) 20 | seg_new[selected_slices] = seg[selected_slices] 21 | # y 22 | num_slices = max(1, round(y * percent_of_slices)) 23 | selected_slices = np.random.choice(y, num_slices, replace=False) 24 | seg_new[:, selected_slices] = seg[:, selected_slices] 25 | # z 26 | num_slices = max(1, round(z * percent_of_slices)) 27 | selected_slices = np.random.choice(z, num_slices, replace=False) 28 | seg_new[:, :, selected_slices] = seg[:, :, selected_slices] 29 | return seg_new 30 | 31 | 32 | if __name__ == '__main__': 33 | dataset_name = 'IntegrationTest_Hippocampus_regions_ignore' 34 | dataset_id = 996 35 | dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" 36 | 37 | try: 38 | existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) 39 | if existing_dataset_name != dataset_name: 40 | raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " 41 | f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " 42 | f"nnUNet_results!") 43 | except RuntimeError: 44 | pass 45 | 46 | if isdir(join(nnUNet_raw, dataset_name)): 47 | shutil.rmtree(join(nnUNet_raw, dataset_name)) 48 | 49 | source_dataset = maybe_convert_to_dataset_name(4) 50 | shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) 51 | 52 | # additionally optimize entire hippocampus region, remove Posterior 53 | dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) 54 | dj['labels'] = { 55 | 'background': 0, 56 | 'hippocampus': (1, 2), 57 | 'anterior': 1, 58 | 'ignore': 3 59 | } 60 | dj['regions_class_order'] = (2, 1) 61 | save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) 62 | 63 | # now add ignore label to segmentation images 64 | np.random.seed(1234) 65 | lm = LabelManager(label_dict=dj['labels'], regions_class_order=dj.get('regions_class_order')) 66 | 67 | segs = nifti_files(join(nnUNet_raw, dataset_name, 'labelsTr')) 68 | for s in segs: 69 | seg_itk = sitk.ReadImage(s) 70 | seg_npy = sitk.GetArrayFromImage(seg_itk) 71 | seg_npy = sparsify_segmentation(seg_npy, lm, 0.1 / 3) 72 | seg_itk_new = sitk.GetImageFromArray(seg_npy) 73 | seg_itk_new.CopyInformation(seg_itk) 74 | sitk.WriteImage(seg_itk_new, s) 75 | 76 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json 4 | 5 | from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name 6 | from nnunetv2.paths import nnUNet_raw 7 | 8 | if __name__ == '__main__': 9 | dataset_name = 'IntegrationTest_Hippocampus_regions' 10 | dataset_id = 997 11 | dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" 12 | 13 | try: 14 | existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) 15 | if existing_dataset_name != dataset_name: 16 | raise FileExistsError( 17 | f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " 18 | f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " 19 | f"nnUNet_results!") 20 | except RuntimeError: 21 | pass 22 | 23 | if isdir(join(nnUNet_raw, dataset_name)): 24 | shutil.rmtree(join(nnUNet_raw, dataset_name)) 25 | 26 | source_dataset = maybe_convert_to_dataset_name(4) 27 | shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) 28 | 29 | # additionally optimize entire hippocampus region, remove Posterior 30 | dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) 31 | dj['labels'] = { 32 | 'background': 0, 33 | 'hippocampus': (1, 2), 34 | 'anterior': 1 35 | } 36 | dj['regions_class_order'] = (2, 1) 37 | save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) 38 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json 4 | 5 | from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name 6 | from nnunetv2.paths import nnUNet_raw 7 | 8 | 9 | if __name__ == '__main__': 10 | dataset_name = 'IntegrationTest_Hippocampus_ignore' 11 | dataset_id = 998 12 | dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" 13 | 14 | try: 15 | existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) 16 | if existing_dataset_name != dataset_name: 17 | raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " 18 | f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " 19 | f"nnUNet_results!") 20 | except RuntimeError: 21 | pass 22 | 23 | if isdir(join(nnUNet_raw, dataset_name)): 24 | shutil.rmtree(join(nnUNet_raw, dataset_name)) 25 | 26 | source_dataset = maybe_convert_to_dataset_name(4) 27 | shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) 28 | 29 | # set class 2 to ignore label 30 | dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) 31 | dj['labels']['ignore'] = 2 32 | del dj['labels']['Posterior'] 33 | save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) 34 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | from batchgenerators.utilities.file_and_folder_operations import isdir, join 4 | 5 | from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name 6 | from nnunetv2.paths import nnUNet_raw 7 | 8 | 9 | if __name__ == '__main__': 10 | dataset_name = 'IntegrationTest_Hippocampus' 11 | dataset_id = 999 12 | dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" 13 | 14 | try: 15 | existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) 16 | if existing_dataset_name != dataset_name: 17 | raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " 18 | f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " 19 | f"nnUNet_results!") 20 | except RuntimeError: 21 | pass 22 | 23 | if isdir(join(nnUNet_raw, dataset_name)): 24 | shutil.rmtree(join(nnUNet_raw, dataset_name)) 25 | 26 | source_dataset = maybe_convert_to_dataset_name(4) 27 | shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) 28 | -------------------------------------------------------------------------------- /umamba/nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/ensembling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/ensembling/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/evaluation/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/evaluation/accumulate_cv_results.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from typing import Union, List, Tuple 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile 5 | 6 | from nnunetv2.configuration import default_num_processes 7 | from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder 8 | from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed 9 | from nnunetv2.utilities.plans_handling.plans_handler import PlansManager 10 | 11 | 12 | def accumulate_cv_results(trained_model_folder, 13 | merged_output_folder: str, 14 | folds: Union[List[int], Tuple[int, ...]], 15 | num_processes: int = default_num_processes, 16 | overwrite: bool = True): 17 | """ 18 | There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to 19 | collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files! 20 | """ 21 | 22 | if overwrite and isdir(merged_output_folder): 23 | shutil.rmtree(merged_output_folder) 24 | maybe_mkdir_p(merged_output_folder) 25 | 26 | dataset_json = load_json(join(trained_model_folder, 'dataset.json')) 27 | plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) 28 | rw = plans_manager.image_reader_writer_class() 29 | shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json')) 30 | shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json')) 31 | 32 | did_we_copy_something = False 33 | for f in folds: 34 | expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation') 35 | if not isdir(expected_validation_folder): 36 | raise RuntimeError(f"fold {f} of model {trained_model_folder} is missing. Please train it!") 37 | predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False) 38 | for pf in predicted_files: 39 | if overwrite and isfile(join(merged_output_folder, pf)): 40 | raise RuntimeError(f'More than one of your folds has a prediction for case {pf}') 41 | if overwrite or not isfile(join(merged_output_folder, pf)): 42 | shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf)) 43 | did_we_copy_something = True 44 | 45 | if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')): 46 | label_manager = plans_manager.get_label_manager(dataset_json) 47 | gt_folder = join(nnUNet_raw, plans_manager.dataset_name, 'labelsTr') 48 | if not isdir(gt_folder): 49 | gt_folder = join(nnUNet_preprocessed, plans_manager.dataset_name, 'gt_segmentations') 50 | compute_metrics_on_folder(gt_folder, 51 | merged_output_folder, 52 | join(merged_output_folder, 'summary.json'), 53 | rw, 54 | dataset_json['file_ending'], 55 | label_manager.foreground_regions if label_manager.has_regions else 56 | label_manager.foreground_labels, 57 | label_manager.ignore_label, 58 | num_processes) 59 | -------------------------------------------------------------------------------- /umamba/nnunetv2/experiment_planning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/experiment_planning/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/experiment_planning/dataset_fingerprint/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/experiment_planning/dataset_fingerprint/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/experiment_planning/experiment_planners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/experiment_planning/experiment_planners/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/experiment_planning/experiment_planners/network_topology.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | 4 | 5 | def get_shape_must_be_divisible_by(net_numpool_per_axis): 6 | return 2 ** np.array(net_numpool_per_axis) 7 | 8 | 9 | def pad_shape(shape, must_be_divisible_by): 10 | """ 11 | pads shape so that it is divisible by must_be_divisible_by 12 | :param shape: 13 | :param must_be_divisible_by: 14 | :return: 15 | """ 16 | if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)): 17 | must_be_divisible_by = [must_be_divisible_by] * len(shape) 18 | else: 19 | assert len(must_be_divisible_by) == len(shape) 20 | 21 | new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))] 22 | 23 | for i in range(len(shape)): 24 | if shape[i] % must_be_divisible_by[i] == 0: 25 | new_shp[i] -= must_be_divisible_by[i] 26 | new_shp = np.array(new_shp).astype(int) 27 | return new_shp 28 | 29 | 30 | def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool): 31 | """ 32 | this is the same as get_pool_and_conv_props_v2 from old nnunet 33 | 34 | :param spacing: 35 | :param patch_size: 36 | :param min_feature_map_size: min edge length of feature maps in bottleneck 37 | :param max_numpool: 38 | :return: 39 | """ 40 | # todo review this code 41 | dim = len(spacing) 42 | 43 | current_spacing = deepcopy(list(spacing)) 44 | current_size = deepcopy(list(patch_size)) 45 | 46 | pool_op_kernel_sizes = [[1] * len(spacing)] 47 | conv_kernel_sizes = [] 48 | 49 | num_pool_per_axis = [0] * dim 50 | kernel_size = [1] * dim 51 | 52 | while True: 53 | # exclude axes that we cannot pool further because of min_feature_map_size constraint 54 | valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size] 55 | if len(valid_axes_for_pool) < 1: 56 | break 57 | 58 | spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool] 59 | 60 | # find axis that are within factor of 2 within smallest spacing 61 | min_spacing_of_valid = min(spacings_of_axes) 62 | valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2] 63 | 64 | # max_numpool constraint 65 | valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] 66 | 67 | if len(valid_axes_for_pool) == 1: 68 | if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size: 69 | pass 70 | else: 71 | break 72 | if len(valid_axes_for_pool) < 1: 73 | break 74 | 75 | # now we need to find kernel sizes 76 | # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within 77 | # factor 2 of min_spacing. Once they are 3 they remain 3 78 | for d in range(dim): 79 | if kernel_size[d] == 3: 80 | continue 81 | else: 82 | if current_spacing[d] / min(current_spacing) < 2: 83 | kernel_size[d] = 3 84 | 85 | other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] 86 | 87 | pool_kernel_sizes = [0] * dim 88 | for v in valid_axes_for_pool: 89 | pool_kernel_sizes[v] = 2 90 | num_pool_per_axis[v] += 1 91 | current_spacing[v] *= 2 92 | current_size[v] = np.ceil(current_size[v] / 2) 93 | for nv in other_axes: 94 | pool_kernel_sizes[nv] = 1 95 | 96 | pool_op_kernel_sizes.append(pool_kernel_sizes) 97 | conv_kernel_sizes.append(deepcopy(kernel_size)) 98 | #print(conv_kernel_sizes) 99 | 100 | must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) 101 | patch_size = pad_shape(patch_size, must_be_divisible_by) 102 | 103 | # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here 104 | conv_kernel_sizes.append([3]*dim) 105 | return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by 106 | -------------------------------------------------------------------------------- /umamba/nnunetv2/experiment_planning/experiment_planners/readme.md: -------------------------------------------------------------------------------- 1 | What do experiment planners need to do (these are notes for myself while rewriting nnU-Net, they are provided as is 2 | without further explanations. These notes also include new features): 3 | - (done) preprocessor name should be configurable via cli 4 | - (done) gpu memory target should be configurable via cli 5 | - (done) plans name should be configurable via cli 6 | - (done) data name should be specified in plans (plans specify the data they want to use, this will allow us to manually 7 | edit plans files without having to copy the data folders) 8 | - plans must contain: 9 | - (done) transpose forward/backward 10 | - (done) preprocessor name (can differ for each config) 11 | - (done) spacing 12 | - (done) normalization scheme 13 | - (done) target spacing 14 | - (done) conv and pool op kernel sizes 15 | - (done) base num features for architecture 16 | - (done) data identifier 17 | - num conv per stage? 18 | - (done) use mask for norm 19 | - [NO. Handled by LabelManager & dataset.json] num segmentation outputs 20 | - [NO. Handled by LabelManager & dataset.json] ignore class 21 | - [NO. Handled by LabelManager & dataset.json] list of regions or classes 22 | - [NO. Handled by LabelManager & dataset.json] regions class order, if applicable 23 | - (done) resampling function to be used 24 | - (done) the image reader writer class that should be used 25 | 26 | 27 | dataset.json 28 | mandatory: 29 | - numTraining 30 | - labels (value 'ignore' has special meaning. Cannot have more than one ignore_label) 31 | - modalities 32 | - file_ending 33 | 34 | optional 35 | - overwrite_image_reader_writer (if absent, auto) 36 | - regions 37 | - region_class_order 38 | - -------------------------------------------------------------------------------- /umamba/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Tuple 2 | 3 | from torch import nn 4 | 5 | from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner 6 | from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet 7 | 8 | 9 | class ResEncUNetPlanner(ExperimentPlanner): 10 | def __init__(self, dataset_name_or_id: Union[str, int], 11 | gpu_memory_target_in_gb: float = 8, 12 | preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans', 13 | overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, 14 | suppress_transpose: bool = False): 15 | super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, 16 | overwrite_target_spacing, suppress_transpose) 17 | 18 | self.UNet_base_num_features = 32 19 | self.UNet_class = ResidualEncoderUNet 20 | # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as 21 | # much as possible 22 | self.UNet_reference_val_3d = 680000000 23 | self.UNet_reference_val_2d = 135000000 24 | self.UNet_reference_com_nfeatures = 32 25 | self.UNet_reference_val_corresp_GB = 8 26 | self.UNet_reference_val_corresp_bs_2d = 12 27 | self.UNet_reference_val_corresp_bs_3d = 2 28 | self.UNet_featuremap_min_edge_length = 4 29 | self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) 30 | self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) 31 | self.UNet_min_batch_size = 2 32 | self.UNet_max_features_2d = 512 33 | self.UNet_max_features_3d = 320 34 | 35 | 36 | if __name__ == '__main__': 37 | # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively 38 | net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), 39 | conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), 40 | n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3, 41 | n_conv_per_stage_decoder=(1, 1, 1, 1, 1), 42 | conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None, 43 | nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) 44 | print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned 45 | # from this one to match the regular nnunetplans more closely 46 | 47 | net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512), 48 | conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2), 49 | n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3, 50 | n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), 51 | conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, 52 | nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) 53 | print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 54 | 55 | -------------------------------------------------------------------------------- /umamba/nnunetv2/experiment_planning/plans_for_pretraining/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/experiment_planning/plans_for_pretraining/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/imageio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/imageio/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/imageio/natural_image_reader_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center 2 | # (DKFZ), Heidelberg, Germany 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Tuple, Union, List 17 | import numpy as np 18 | from nnunetv2.imageio.base_reader_writer import BaseReaderWriter 19 | from skimage import io 20 | 21 | 22 | class NaturalImage2DIO(BaseReaderWriter): 23 | """ 24 | ONLY SUPPORTS 2D IMAGES!!! 25 | """ 26 | 27 | # there are surely more we could add here. Everything that can be read by skimage.io should be supported 28 | supported_file_endings = [ 29 | '.png', 30 | # '.jpg', 31 | # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps! 32 | '.bmp', 33 | '.tif' 34 | ] 35 | 36 | def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: 37 | images = [] 38 | for f in image_fnames: 39 | npy_img = io.imread(f) 40 | if npy_img.ndim == 3: 41 | # rgb image, last dimension should be the color channel and the size of that channel should be 3 42 | # (or 4 if we have alpha) 43 | assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \ 44 | "dimension must have shape 3 or 4 " \ 45 | f"(RGB or RGBA). Image shape here is {npy_img.shape}" 46 | # move RGB(A) to front, add additional dim so that we have shape (1, c, X, Y), where c is either 3 or 4 47 | images.append(npy_img.transpose((2, 0, 1))[:, None]) 48 | elif npy_img.ndim == 2: 49 | # grayscale image 50 | images.append(npy_img[None, None]) 51 | 52 | if not self._check_all_same([i.shape for i in images]): 53 | print('ERROR! Not all input images have the same shape!') 54 | print('Shapes:') 55 | print([i.shape for i in images]) 56 | print('Image files:') 57 | print(image_fnames) 58 | raise RuntimeError() 59 | return np.vstack(images).astype(np.float32), {'spacing': (999, 1, 1)} 60 | 61 | def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: 62 | return self.read_images((seg_fname, )) 63 | 64 | def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: 65 | io.imsave(output_fname, seg[0].astype(np.uint8), check_contrast=False) 66 | 67 | 68 | if __name__ == '__main__': 69 | images = ('/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/imagesTr/img-11_0000.png',) 70 | segmentation = '/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/labelsTr/img-11.png' 71 | imgio = NaturalImage2DIO() 72 | img, props = imgio.read_images(images) 73 | seg, segprops = imgio.read_seg(segmentation) -------------------------------------------------------------------------------- /umamba/nnunetv2/imageio/reader_writer_registry.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from typing import Type 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import join 5 | 6 | import nnunetv2 7 | from nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO 8 | from nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient 9 | from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO 10 | from nnunetv2.imageio.tif_reader_writer import Tiff3DIO 11 | from nnunetv2.imageio.base_reader_writer import BaseReaderWriter 12 | from nnunetv2.utilities.find_class_by_name import recursive_find_python_class 13 | 14 | LIST_OF_IO_CLASSES = [ 15 | NaturalImage2DIO, 16 | SimpleITKIO, 17 | Tiff3DIO, 18 | NibabelIO, 19 | NibabelIOWithReorient 20 | ] 21 | 22 | 23 | def determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None, 24 | allow_nonmatching_filename: bool = False, verbose: bool = True 25 | ) -> Type[BaseReaderWriter]: 26 | if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \ 27 | dataset_json_content['overwrite_image_reader_writer'] != 'None': 28 | ioclass_name = dataset_json_content['overwrite_image_reader_writer'] 29 | # trying to find that class in the nnunetv2.imageio module 30 | try: 31 | ret = recursive_find_reader_writer_by_name(ioclass_name) 32 | if verbose: print(f'Using {ret} reader/writer') 33 | return ret 34 | except RuntimeError: 35 | if verbose: print(f'Warning: Unable to find ioclass specified in dataset.json: {ioclass_name}') 36 | if verbose: print('Trying to automatically determine desired class') 37 | return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file, 38 | allow_nonmatching_filename, verbose) 39 | 40 | 41 | def determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False, 42 | verbose: bool = True): 43 | for rw in LIST_OF_IO_CLASSES: 44 | if file_ending.lower() in rw.supported_file_endings: 45 | if example_file is not None: 46 | # if an example file is provided, try if we can actually read it. If not move on to the next reader 47 | try: 48 | tmp = rw() 49 | _ = tmp.read_images((example_file,)) 50 | if verbose: print(f'Using {rw} as reader/writer') 51 | return rw 52 | except: 53 | if verbose: print(f'Failed to open file {example_file} with reader {rw}:') 54 | traceback.print_exc() 55 | pass 56 | else: 57 | if verbose: print(f'Using {rw} as reader/writer') 58 | return rw 59 | else: 60 | if allow_nonmatching_filename and example_file is not None: 61 | try: 62 | tmp = rw() 63 | _ = tmp.read_images((example_file,)) 64 | if verbose: print(f'Using {rw} as reader/writer') 65 | return rw 66 | except: 67 | if verbose: print(f'Failed to open file {example_file} with reader {rw}:') 68 | if verbose: traceback.print_exc() 69 | pass 70 | raise RuntimeError(f"Unable to determine a reader for file ending {file_ending} and file {example_file} (file None means no file provided).") 71 | 72 | 73 | def recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]: 74 | ret = recursive_find_python_class(join(nnunetv2.__path__[0], "imageio"), rw_class_name, 'nnunetv2.imageio') 75 | if ret is None: 76 | raise RuntimeError("Unable to find reader writer class '%s'. Please make sure this class is located in the " 77 | "nnunetv2.imageio module." % rw_class_name) 78 | else: 79 | return ret 80 | -------------------------------------------------------------------------------- /umamba/nnunetv2/imageio/readme.md: -------------------------------------------------------------------------------- 1 | - Derive your adapter from `BaseReaderWriter`. 2 | - Reimplement all abstractmethods. 3 | - make sure to support 2d and 3d input images (or raise some error). 4 | - place it in this folder or nnU-Net won't find it! 5 | - add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py` 6 | 7 | Bam, you're done! -------------------------------------------------------------------------------- /umamba/nnunetv2/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/inference/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/inference/sliding_window_prediction.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import numpy as np 4 | import torch 5 | from typing import Union, Tuple, List 6 | from acvl_utils.cropping_and_padding.padding import pad_nd_image 7 | from scipy.ndimage import gaussian_filter 8 | 9 | 10 | @lru_cache(maxsize=2) 11 | def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8, 12 | value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \ 13 | -> torch.Tensor: 14 | tmp = np.zeros(tile_size) 15 | center_coords = [i // 2 for i in tile_size] 16 | sigmas = [i * sigma_scale for i in tile_size] 17 | tmp[tuple(center_coords)] = 1 18 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) 19 | 20 | gaussian_importance_map = torch.from_numpy(gaussian_importance_map) 21 | 22 | gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor 23 | gaussian_importance_map = gaussian_importance_map.type(dtype).to(device) 24 | 25 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans! 26 | gaussian_importance_map[gaussian_importance_map == 0] = torch.min( 27 | gaussian_importance_map[gaussian_importance_map != 0]) 28 | 29 | return gaussian_importance_map 30 | 31 | 32 | def compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \ 33 | List[List[int]]: 34 | assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size" 35 | assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1' 36 | 37 | # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of 38 | # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46 39 | target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size] 40 | 41 | num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)] 42 | 43 | steps = [] 44 | for dim in range(len(tile_size)): 45 | # the highest step value for this dimension is 46 | max_step_value = image_size[dim] - tile_size[dim] 47 | if num_steps[dim] > 1: 48 | actual_step_size = max_step_value / (num_steps[dim] - 1) 49 | else: 50 | actual_step_size = 99999999999 # does not matter because there is only one step at 0 51 | 52 | steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])] 53 | 54 | steps.append(steps_here) 55 | 56 | return steps 57 | 58 | 59 | if __name__ == '__main__': 60 | a = torch.rand((4, 2, 32, 23)) 61 | a_npy = a.numpy() 62 | 63 | a_padded = pad_nd_image(a, new_shape=(48, 27)) 64 | a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27)) 65 | assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))]) 66 | assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))]) 67 | assert np.all(a_padded.numpy() == a_npy_padded) 68 | -------------------------------------------------------------------------------- /umamba/nnunetv2/model_sharing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/model_sharing/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/model_sharing/entry_points.py: -------------------------------------------------------------------------------- 1 | from nnunetv2.model_sharing.model_download import download_and_install_from_url 2 | from nnunetv2.model_sharing.model_export import export_pretrained_model 3 | from nnunetv2.model_sharing.model_import import install_model_from_zip_file 4 | 5 | 6 | def print_license_warning(): 7 | print('') 8 | print('######################################################') 9 | print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!') 10 | print('######################################################') 11 | print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some " 12 | "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use " 13 | "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!") 14 | print('######################################################') 15 | print('') 16 | 17 | 18 | def download_by_url(): 19 | import argparse 20 | parser = argparse.ArgumentParser( 21 | description="Use this to download pretrained models. This script is intended to download models via url only. " 22 | "CAREFUL: This script will overwrite " 23 | "existing models (if they share the same trainer class and plans as " 24 | "the pretrained model.") 25 | parser.add_argument("url", type=str, help='URL of the pretrained model') 26 | args = parser.parse_args() 27 | url = args.url 28 | download_and_install_from_url(url) 29 | 30 | 31 | def install_from_zip_entry_point(): 32 | import argparse 33 | parser = argparse.ArgumentParser( 34 | description="Use this to install a zip file containing a pretrained model.") 35 | parser.add_argument("zip", type=str, help='zip file') 36 | args = parser.parse_args() 37 | zip = args.zip 38 | install_model_from_zip_file(zip) 39 | 40 | 41 | def export_pretrained_model_entry(): 42 | import argparse 43 | parser = argparse.ArgumentParser( 44 | description="Use this to export a trained model as a zip file.") 45 | parser.add_argument('-d', type=str, required=True, help='Dataset name or id') 46 | parser.add_argument('-o', type=str, required=True, help='Output file name') 47 | parser.add_argument('-c', nargs='+', type=str, required=False, 48 | default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'), 49 | help="List of configuration names") 50 | parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class') 51 | parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier') 52 | parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids') 53 | parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ), 54 | help='Lis tof checkpoint names to export. Default: checkpoint_final.pth') 55 | parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations') 56 | parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well') 57 | args = parser.parse_args() 58 | 59 | export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr, 60 | plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk, 61 | export_crossval_predictions=args.exp_cv_preds) 62 | -------------------------------------------------------------------------------- /umamba/nnunetv2/model_sharing/model_download.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import requests 4 | from batchgenerators.utilities.file_and_folder_operations import * 5 | from time import time 6 | from nnunetv2.model_sharing.model_import import install_model_from_zip_file 7 | from nnunetv2.paths import nnUNet_results 8 | from tqdm import tqdm 9 | 10 | 11 | def download_and_install_from_url(url): 12 | assert nnUNet_results is not None, "Cannot install model because network_training_output_dir is not " \ 13 | "set (RESULTS_FOLDER missing as environment variable, see " \ 14 | "Installation instructions)" 15 | print('Downloading pretrained model from url:', url) 16 | import http.client 17 | http.client.HTTPConnection._http_vsn = 10 18 | http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0' 19 | 20 | import os 21 | home = os.path.expanduser('~') 22 | random_number = int(time() * 1e7) 23 | tempfile = join(home, f'.nnunetdownload_{str(random_number)}') 24 | 25 | try: 26 | download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16) 27 | print("Download finished. Extracting...") 28 | install_model_from_zip_file(tempfile) 29 | print("Done") 30 | except Exception as e: 31 | raise e 32 | finally: 33 | if isfile(tempfile): 34 | os.remove(tempfile) 35 | 36 | 37 | def download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str: 38 | # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests 39 | # NOTE the stream=True parameter below 40 | with requests.get(url, stream=True, timeout=100) as r: 41 | r.raise_for_status() 42 | with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f: 43 | for chunk in r.iter_content(chunk_size=chunk_size): 44 | f.write(chunk) 45 | return local_filename 46 | 47 | 48 | -------------------------------------------------------------------------------- /umamba/nnunetv2/model_sharing/model_import.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | from nnunetv2.paths import nnUNet_results 4 | 5 | 6 | def install_model_from_zip_file(zip_file: str): 7 | with zipfile.ZipFile(zip_file, 'r') as zip_ref: 8 | zip_ref.extractall(nnUNet_results) -------------------------------------------------------------------------------- /umamba/nnunetv2/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | join = os.path.join 17 | """ 18 | Please make sure your data is organized as follows: 19 | 20 | data/ 21 | ├── nnUNet_raw/ 22 | │ ├── Dataset701_AbdomenCT/ 23 | │ │ ├── imagesTr 24 | │ │ │ ├── FLARE22_Tr_0001_0000.nii.gz 25 | │ │ │ ├── FLARE22_Tr_0002_0000.nii.gz 26 | │ │ │ ├── ... 27 | │ │ ├── labelsTr 28 | │ │ │ ├── FLARE22_Tr_0001.nii.gz 29 | │ │ │ ├── FLARE22_Tr_0002.nii.gz 30 | │ │ │ ├── ... 31 | │ │ ├── dataset.json 32 | │ ├── Dataset702_AbdomenMR/ 33 | │ │ ├── imagesTr 34 | │ │ │ ├── amos_0507_0000.nii.gz 35 | │ │ │ ├── amos_0508_0000.nii.gz 36 | │ │ │ ├── ... 37 | │ │ ├── labelsTr 38 | │ │ │ ├── amos_0507.nii.gz 39 | │ │ │ ├── amos_0508.nii.gz 40 | │ │ │ ├── ... 41 | │ │ ├── dataset.json 42 | │ ├── ... 43 | """ 44 | base = join(os.sep.join(__file__.split(os.sep)[:-3]), 'data') 45 | # or you can set your own path, e.g., base = '/home/user_name/Documents/U-Mamba/data' 46 | nnUNet_raw = join(base, 'nnUNet_raw') # os.environ.get('nnUNet_raw') 47 | nnUNet_preprocessed = join(base, 'nnUNet_preprocessed') # os.environ.get('nnUNet_preprocessed') 48 | nnUNet_results = join(base, 'nnUNet_results') # os.environ.get('nnUNet_results') 49 | 50 | if nnUNet_raw is None: 51 | print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files " 52 | "are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like " 53 | "this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set " 54 | "this up properly.") 55 | 56 | if nnUNet_preprocessed is None: 57 | print("nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing " 58 | "or training. If this is not intended, please read documentation/setting_up_paths.md for information on how " 59 | "to set this up.") 60 | 61 | if nnUNet_results is None: 62 | print("nnUNet_results is not defined and nnU-Net cannot be used for training or " 63 | "inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information " 64 | "on how to set this up.") 65 | -------------------------------------------------------------------------------- /umamba/nnunetv2/postprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/postprocessing/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/preprocessing/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/cropping/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/preprocessing/cropping/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/cropping/cropping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # Hello! crop_to_nonzero is the function you are looking for. Ignore the rest. 5 | from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, crop_to_bbox, bounding_box_to_slice 6 | 7 | 8 | def create_nonzero_mask(data): 9 | """ 10 | 11 | :param data: 12 | :return: the mask is True where the data is nonzero 13 | """ 14 | from scipy.ndimage import binary_fill_holes 15 | assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)" 16 | nonzero_mask = np.zeros(data.shape[1:], dtype=bool) 17 | for c in range(data.shape[0]): 18 | this_mask = data[c] != 0 19 | nonzero_mask = nonzero_mask | this_mask 20 | nonzero_mask = binary_fill_holes(nonzero_mask) 21 | return nonzero_mask 22 | 23 | 24 | def crop_to_nonzero(data, seg=None, nonzero_label=-1): 25 | """ 26 | 27 | :param data: 28 | :param seg: 29 | :param nonzero_label: this will be written into the segmentation map 30 | :return: 31 | """ 32 | nonzero_mask = create_nonzero_mask(data) 33 | bbox = get_bbox_from_mask(nonzero_mask) 34 | 35 | slicer = bounding_box_to_slice(bbox) 36 | data = data[tuple([slice(None), *slicer])] 37 | 38 | if seg is not None: 39 | seg = seg[tuple([slice(None), *slicer])] 40 | 41 | nonzero_mask = nonzero_mask[slicer][None] 42 | if seg is not None: 43 | seg[(seg == 0) & (~nonzero_mask)] = nonzero_label 44 | else: 45 | nonzero_mask = nonzero_mask.astype(np.int8) 46 | nonzero_mask[nonzero_mask == 0] = nonzero_label 47 | nonzero_mask[nonzero_mask > 0] = 0 48 | seg = nonzero_mask 49 | return data, seg, bbox 50 | 51 | 52 | -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/normalization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/preprocessing/normalization/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from nnunetv2.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \ 4 | ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization 5 | 6 | channel_name_to_normalization_mapping = { 7 | 'CT': CTNormalization, 8 | 'noNorm': NoNormalization, 9 | 'zscore': ZScoreNormalization, 10 | 'rescale_to_0_1': RescaleTo01Normalization, 11 | 'rgb_to_0_1': RGBTo01Normalization 12 | } 13 | 14 | 15 | def get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]: 16 | """ 17 | If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is 18 | not found, use the default (ZScoreNormalization) 19 | """ 20 | norm_scheme = channel_name_to_normalization_mapping.get(channel_name) 21 | if norm_scheme is None: 22 | norm_scheme = ZScoreNormalization 23 | # print('Using %s for image normalization' % norm_scheme.__name__) 24 | return norm_scheme 25 | -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/normalization/readme.md: -------------------------------------------------------------------------------- 1 | The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different 2 | then you can just 3 | - create a new subclass of ImageNormalization 4 | - map your custom channel identifier to that subclass in channel_name_to_normalization_mapping 5 | - run plan and preprocess again with your custom normlaization scheme -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/preprocessors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/preprocessing/preprocessors/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/resampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/preprocessing/resampling/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/preprocessing/resampling/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import nnunetv2 4 | from batchgenerators.utilities.file_and_folder_operations import join 5 | from nnunetv2.utilities.find_class_by_name import recursive_find_python_class 6 | 7 | 8 | def recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable: 9 | ret = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "resampling"), resampling_fn, 10 | 'nnunetv2.preprocessing.resampling') 11 | if ret is None: 12 | raise RuntimeError("Unable to find resampling function named '%s'. Please make sure this fn is located in the " 13 | "nnunetv2.preprocessing.resampling module." % resampling_fn) 14 | else: 15 | return ret 16 | -------------------------------------------------------------------------------- /umamba/nnunetv2/run/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/run/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/run/load_pretrained_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._dynamo import OptimizedModule 3 | from torch.nn.parallel import DistributedDataParallel as DDP 4 | 5 | 6 | def load_pretrained_weights(network, fname, verbose=False): 7 | """ 8 | Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the 9 | shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps) 10 | identified by keys ending with '.seg_layers') are not transferred! 11 | 12 | If the pretrained weights were obtained with a training outside nnU-Net and DDP or torch.optimize was used, 13 | you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds 14 | '_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as 15 | nnUNetTrainer.save_checkpoint takes care of that! 16 | 17 | """ 18 | saved_model = torch.load(fname) 19 | pretrained_dict = saved_model['network_weights'] 20 | 21 | skip_strings_in_pretrained = [ 22 | '.seg_layers.', 23 | ] 24 | 25 | if isinstance(network, DDP): 26 | mod = network.module 27 | else: 28 | mod = network 29 | if isinstance(mod, OptimizedModule): 30 | mod = mod._orig_mod 31 | 32 | model_dict = mod.state_dict() 33 | # verify that all but the segmentation layers have the same shape 34 | for key, _ in model_dict.items(): 35 | if all([i not in key for i in skip_strings_in_pretrained]): 36 | assert key in pretrained_dict, \ 37 | f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \ 38 | f"compatible with your network." 39 | assert model_dict[key].shape == pretrained_dict[key].shape, \ 40 | f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ 41 | f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \ 42 | f"does not seem to be compatible with your network." 43 | 44 | # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained 45 | # encoders. Not supported by this function though (see assertions above) 46 | 47 | # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do' 48 | # pretrained_dict = {'module.' + k if is_ddp else k: v 49 | # for k, v in pretrained_dict.items() 50 | # if (('module.' + k if is_ddp else k) in model_dict) and 51 | # all([i not in k for i in skip_strings_in_pretrained])} 52 | 53 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 54 | if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])} 55 | 56 | model_dict.update(pretrained_dict) 57 | 58 | print("################### Loading pretrained weights from file ", fname, '###################') 59 | if verbose: 60 | print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") 61 | for key, value in pretrained_dict.items(): 62 | print(key, 'shape', value.shape) 63 | print("################### Done ###################") 64 | mod.load_state_dict(model_dict) 65 | 66 | 67 | -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/tests/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.utilities.file_and_folder_operations import * 2 | 3 | from nnunetv2.paths import nnUNet_preprocessed 4 | from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name 5 | 6 | if __name__ == '__main__': 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('-d', nargs='+', type=int, help='List of dataset ids') 11 | args = parser.parse_args() 12 | 13 | for d in args.d: 14 | dataset_name = maybe_convert_to_dataset_name(d) 15 | plans = load_json(join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json')) 16 | plans['configurations']['3d_lowres'] = { 17 | "data_identifier": "nnUNetPlans_3d_lowres", # do not be a dumbo and forget this. I was a dumbo. And I paid dearly with ~10 min debugging time 18 | 'inherits_from': '3d_fullres', 19 | "patch_size": [20, 28, 20], 20 | "median_image_size_in_voxels": [18.0, 25.0, 18.0], 21 | "spacing": [2.0, 2.0, 2.0], 22 | "n_conv_per_stage_encoder": [2, 2, 2], 23 | "n_conv_per_stage_decoder": [2, 2], 24 | "num_pool_per_axis": [2, 2, 2], 25 | "pool_op_kernel_sizes": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], 26 | "conv_kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], 27 | "next_stage": "3d_cascade_fullres" 28 | } 29 | plans['configurations']['3d_cascade_fullres'] = { 30 | 'inherits_from': '3d_fullres', 31 | "previous_stage": "3d_lowres" 32 | } 33 | save_json(plans, join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'), sort_keys=False) -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/integration_tests/cleanup_integration_test.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | from batchgenerators.utilities.file_and_folder_operations import isdir, join 4 | 5 | from nnunetv2.paths import nnUNet_raw, nnUNet_results, nnUNet_preprocessed 6 | 7 | if __name__ == '__main__': 8 | # deletes everything! 9 | dataset_names = [ 10 | 'Dataset996_IntegrationTest_Hippocampus_regions_ignore', 11 | 'Dataset997_IntegrationTest_Hippocampus_regions', 12 | 'Dataset998_IntegrationTest_Hippocampus_ignore', 13 | 'Dataset999_IntegrationTest_Hippocampus', 14 | ] 15 | for fld in [nnUNet_raw, nnUNet_preprocessed, nnUNet_results]: 16 | for d in dataset_names: 17 | if isdir(join(fld, d)): 18 | shutil.rmtree(join(fld, d)) 19 | 20 | -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/integration_tests/lsf_commands.sh: -------------------------------------------------------------------------------- 1 | bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 996" 2 | bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 997" 3 | bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 998" 4 | bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 999" 5 | 6 | 7 | bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 996" 8 | bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 997" 9 | bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 998" 10 | bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 999" 11 | -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/integration_tests/prepare_integration_tests.sh: -------------------------------------------------------------------------------- 1 | # assumes you are in the nnunet repo! 2 | 3 | # prepare raw datasets 4 | python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py 5 | python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py 6 | python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py 7 | python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py 8 | 9 | # now run experiment planning without preprocessing 10 | nnUNetv2_plan_and_preprocess -d 996 997 998 999 --no_pp 11 | 12 | # now add 3d lowres and cascade 13 | python nnunetv2/tests/integration_tests/add_lowres_and_cascade.py -d 996 997 998 999 14 | 15 | # now preprocess everything 16 | nnUNetv2_preprocess -d 996 997 998 999 -c 2d 3d_lowres 3d_fullres -np 8 8 8 # no need to preprocess cascade as its the same data as 3d_fullres 17 | 18 | # done -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/integration_tests/readme.md: -------------------------------------------------------------------------------- 1 | # Preface 2 | 3 | I am just a mortal with many tasks and limited time. Aint nobody got time for unittests. 4 | 5 | HOWEVER, at least some integration tests should be performed testing nnU-Net from start to finish. 6 | 7 | # Introduction - What the heck is happening? 8 | This test covers all possible labeling scenarios (standard labels, regions, ignore labels and regions with 9 | ignore labels). It runs the entire nnU-Net pipeline from start to finish: 10 | 11 | - fingerprint extraction 12 | - experiment planning 13 | - preprocessing 14 | - train all 4 configurations (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres) as 5-fold CV 15 | - automatically find the best model or ensemble 16 | - determine the postprocessing used for this 17 | - predict some test set 18 | - apply postprocessing to the test set 19 | 20 | To speed things up, we do the following: 21 | - pick Dataset004_Hippocampus because it is quadratisch praktisch gut. MNIST of medical image segmentation 22 | - by default this dataset does not have 3d_lowres or cascade. We just manually add them (cool new feature, eh?). See `add_lowres_and_cascade.py` to learn more! 23 | - we use nnUNetTrainer_5epochs for a short training 24 | 25 | # How to run it? 26 | 27 | Set your pwd to be the nnunet repo folder (the one where the `nnunetv2` folder and the `setup.py` are located!) 28 | 29 | Now generate the 4 dummy datasets (ids 996, 997, 998, 999) from dataset 4. This will crash if you don't have Dataset004! 30 | ```commandline 31 | bash nnunetv2/tests/integration_tests/prepare_integration_tests.sh 32 | ``` 33 | 34 | Now you can run the integration test for each of the datasets: 35 | ```commandline 36 | bash nnunetv2/tests/integration_tests/run_integration_test.sh DATSET_ID 37 | ``` 38 | use DATSET_ID 996, 997, 998 and 999. You can run these independently on different GPUs/systems to speed things up. 39 | This will take i dunno like 10-30 Minutes!? 40 | 41 | Also run 42 | ```commandline 43 | bash nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh DATSET_ID 44 | ``` 45 | to verify DDP is working (needs 2 GPUs!) 46 | 47 | # How to check if the test was successful? 48 | If I was not as lazy as I am I would have programmed some automatism that checks if Dice scores etc are in an acceptable range. 49 | So you need to do the following: 50 | 1) check that none of your runs crashed (duh) 51 | 2) for each run, navigate to `nnUNet_results/DATASET_NAME` and take a look at the `inference_information.json` file. 52 | Does it make sense? If so: NICE! 53 | 54 | Once the integration test is completed you can delete all the temporary files associated with it by running: 55 | 56 | ```commandline 57 | python nnunetv2/tests/integration_tests/cleanup_integration_test.py 58 | ``` -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/integration_tests/run_integration_test.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_5epochs --npz 4 | nnUNetv2_train $1 3d_fullres 1 -tr nnUNetTrainer_5epochs --npz 5 | nnUNetv2_train $1 3d_fullres 2 -tr nnUNetTrainer_5epochs --npz 6 | nnUNetv2_train $1 3d_fullres 3 -tr nnUNetTrainer_5epochs --npz 7 | nnUNetv2_train $1 3d_fullres 4 -tr nnUNetTrainer_5epochs --npz 8 | 9 | nnUNetv2_train $1 2d 0 -tr nnUNetTrainer_5epochs --npz 10 | nnUNetv2_train $1 2d 1 -tr nnUNetTrainer_5epochs --npz 11 | nnUNetv2_train $1 2d 2 -tr nnUNetTrainer_5epochs --npz 12 | nnUNetv2_train $1 2d 3 -tr nnUNetTrainer_5epochs --npz 13 | nnUNetv2_train $1 2d 4 -tr nnUNetTrainer_5epochs --npz 14 | 15 | nnUNetv2_train $1 3d_lowres 0 -tr nnUNetTrainer_5epochs --npz 16 | nnUNetv2_train $1 3d_lowres 1 -tr nnUNetTrainer_5epochs --npz 17 | nnUNetv2_train $1 3d_lowres 2 -tr nnUNetTrainer_5epochs --npz 18 | nnUNetv2_train $1 3d_lowres 3 -tr nnUNetTrainer_5epochs --npz 19 | nnUNetv2_train $1 3d_lowres 4 -tr nnUNetTrainer_5epochs --npz 20 | 21 | nnUNetv2_train $1 3d_cascade_fullres 0 -tr nnUNetTrainer_5epochs --npz 22 | nnUNetv2_train $1 3d_cascade_fullres 1 -tr nnUNetTrainer_5epochs --npz 23 | nnUNetv2_train $1 3d_cascade_fullres 2 -tr nnUNetTrainer_5epochs --npz 24 | nnUNetv2_train $1 3d_cascade_fullres 3 -tr nnUNetTrainer_5epochs --npz 25 | nnUNetv2_train $1 3d_cascade_fullres 4 -tr nnUNetTrainer_5epochs --npz 26 | 27 | python nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py -d $1 -------------------------------------------------------------------------------- /umamba/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh: -------------------------------------------------------------------------------- 1 | nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_10epochs -num_gpus 2 2 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/data_augmentation/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/compute_initial_patch_size.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): 5 | if isinstance(rot_x, (tuple, list)): 6 | rot_x = max(np.abs(rot_x)) 7 | if isinstance(rot_y, (tuple, list)): 8 | rot_y = max(np.abs(rot_y)) 9 | if isinstance(rot_z, (tuple, list)): 10 | rot_z = max(np.abs(rot_z)) 11 | rot_x = min(90 / 360 * 2. * np.pi, rot_x) 12 | rot_y = min(90 / 360 * 2. * np.pi, rot_y) 13 | rot_z = min(90 / 360 * 2. * np.pi, rot_z) 14 | from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d 15 | coords = np.array(final_patch_size) 16 | final_shape = np.copy(coords) 17 | if len(coords) == 3: 18 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) 19 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) 20 | final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) 21 | elif len(coords) == 2: 22 | final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) 23 | final_shape /= min(scale_range) 24 | return final_shape.astype(int) 25 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/custom_transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/data_augmentation/custom_transforms/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, List 2 | 3 | from batchgenerators.augmentations.utils import resize_segmentation 4 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 5 | import numpy as np 6 | 7 | 8 | class DownsampleSegForDSTransform2(AbstractTransform): 9 | ''' 10 | data_dict['output_key'] will be a list of segmentations scaled according to ds_scales 11 | ''' 12 | def __init__(self, ds_scales: Union[List, Tuple], 13 | order: int = 0, input_key: str = "seg", 14 | output_key: str = "seg", axes: Tuple[int] = None): 15 | """ 16 | Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision 17 | output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape. 18 | ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling 19 | for each axis independently 20 | """ 21 | self.axes = axes 22 | self.output_key = output_key 23 | self.input_key = input_key 24 | self.order = order 25 | self.ds_scales = ds_scales 26 | 27 | def __call__(self, **data_dict): 28 | if self.axes is None: 29 | axes = list(range(2, data_dict[self.input_key].ndim)) 30 | else: 31 | axes = self.axes 32 | 33 | output = [] 34 | for s in self.ds_scales: 35 | if not isinstance(s, (tuple, list)): 36 | s = [s] * len(axes) 37 | else: 38 | assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \ 39 | f'for each axis) then the number of entried in that tuple (here ' \ 40 | f'{len(s)}) must be the same as the number of axes (here {len(axes)}).' 41 | 42 | if all([i == 1 for i in s]): 43 | output.append(data_dict[self.input_key]) 44 | else: 45 | new_shape = np.array(data_dict[self.input_key].shape).astype(float) 46 | for i, a in enumerate(axes): 47 | new_shape[a] *= s[i] 48 | new_shape = np.round(new_shape).astype(int) 49 | out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype) 50 | for b in range(data_dict[self.input_key].shape[0]): 51 | for c in range(data_dict[self.input_key].shape[1]): 52 | out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order) 53 | output.append(out_seg) 54 | data_dict[self.output_key] = output 55 | return data_dict 56 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter 2 | 3 | 4 | class LimitedLenWrapper(NonDetMultiThreadedAugmenter): 5 | def __init__(self, my_imaginary_length, *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | self.len = my_imaginary_length 8 | 9 | def __len__(self): 10 | return self.len 11 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 2 | 3 | 4 | class RemoveKeyTransform(AbstractTransform): 5 | def __init__(self, key_to_remove: str): 6 | self.key_to_remove = key_to_remove 7 | 8 | def __call__(self, **data_dict): 9 | _ = data_dict.pop(self.key_to_remove, None) 10 | return data_dict 11 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/custom_transforms/masking.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 4 | 5 | 6 | class MaskTransform(AbstractTransform): 7 | def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0, 8 | data_key: str = "data", seg_key: str = "seg"): 9 | """ 10 | Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!! 11 | """ 12 | self.apply_to_channels = apply_to_channels 13 | self.seg_key = seg_key 14 | self.data_key = data_key 15 | self.set_outside_to = set_outside_to 16 | self.mask_idx_in_seg = mask_idx_in_seg 17 | 18 | def __call__(self, **data_dict): 19 | mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0 20 | for c in self.apply_to_channels: 21 | data_dict[self.data_key][:, c][mask] = self.set_outside_to 22 | return data_dict 23 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 4 | import numpy as np 5 | 6 | 7 | class ConvertSegmentationToRegionsTransform(AbstractTransform): 8 | def __init__(self, regions: Union[List, Tuple], 9 | seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0): 10 | """ 11 | regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region, 12 | example: 13 | regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2 14 | :param regions: 15 | :param seg_key: 16 | :param output_key: 17 | """ 18 | self.seg_channel = seg_channel 19 | self.output_key = output_key 20 | self.seg_key = seg_key 21 | self.regions = regions 22 | 23 | def __call__(self, **data_dict): 24 | seg = data_dict.get(self.seg_key) 25 | num_regions = len(self.regions) 26 | if seg is not None: 27 | seg_shp = seg.shape 28 | output_shape = list(seg_shp) 29 | output_shape[1] = num_regions 30 | region_output = np.zeros(output_shape, dtype=seg.dtype) 31 | for b in range(seg_shp[0]): 32 | for region_id, region_source_labels in enumerate(self.regions): 33 | if not isinstance(region_source_labels, (list, tuple)): 34 | region_source_labels = (region_source_labels, ) 35 | for label_value in region_source_labels: 36 | region_output[b, region_id][seg[b, self.seg_channel] == label_value] = 1 37 | data_dict[self.output_key] = region_output 38 | return data_dict 39 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, List 2 | 3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 4 | 5 | 6 | class Convert3DTo2DTransform(AbstractTransform): 7 | def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): 8 | """ 9 | Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel 10 | """ 11 | self.apply_to_keys = apply_to_keys 12 | 13 | def __call__(self, **data_dict): 14 | for k in self.apply_to_keys: 15 | shp = data_dict[k].shape 16 | assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.' 17 | data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) 18 | shape_key = f'orig_shape_{k}' 19 | assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \ 20 | f'It does that using the {shape_key} key. That key is ' \ 21 | f'already taken. Bummer.' 22 | data_dict[shape_key] = shp 23 | return data_dict 24 | 25 | 26 | class Convert2DTo3DTransform(AbstractTransform): 27 | def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): 28 | """ 29 | Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z) 30 | """ 31 | self.apply_to_keys = apply_to_keys 32 | 33 | def __call__(self, **data_dict): 34 | for k in self.apply_to_keys: 35 | shape_key = f'orig_shape_{k}' 36 | assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \ 37 | f'Convert2DTo3DTransform only works in tandem with ' \ 38 | f'Convert3DTo2DTransform and you probably forgot to add ' \ 39 | f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \ 40 | f'is where the missing key is generated)' 41 | original_shape = data_dict[shape_key] 42 | current_shape = data_dict[k].shape 43 | data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2], 44 | current_shape[-2], current_shape[-1])) 45 | return data_dict 46 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/dataloading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/dataloading/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/dataloading/data_loader_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase 3 | from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset 4 | 5 | 6 | class nnUNetDataLoader3D(nnUNetDataLoaderBase): 7 | def generate_train_batch(self): 8 | selected_keys = self.get_indices() 9 | # preallocate memory for data and seg 10 | data_all = np.zeros(self.data_shape, dtype=np.float32) 11 | seg_all = np.zeros(self.seg_shape, dtype=np.int16) 12 | case_properties = [] 13 | 14 | for j, i in enumerate(selected_keys): 15 | # oversampling foreground will improve stability of model training, especially if many patches are empty 16 | # (Lung for example) 17 | force_fg = self.get_do_oversample(j) 18 | 19 | data, seg, properties = self._data.load_case(i) 20 | case_properties.append(properties) 21 | 22 | # If we are doing the cascade then the segmentation from the previous stage will already have been loaded by 23 | # self._data.load_case(i) (see nnUNetDataset.load_case) 24 | shape = data.shape[1:] 25 | dim = len(shape) 26 | bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations']) 27 | 28 | # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the 29 | # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad. 30 | # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size 31 | # later 32 | valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)] 33 | valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)] 34 | 35 | # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage. 36 | # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to 37 | # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also 38 | # remove label -1 in the data augmentation but this way it is less error prone) 39 | this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) 40 | data = data[this_slice] 41 | 42 | this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) 43 | seg = seg[this_slice] 44 | 45 | padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)] 46 | data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0) 47 | seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1) 48 | 49 | return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys} 50 | 51 | 52 | if __name__ == '__main__': 53 | folder = '/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/3d_fullres' 54 | ds = nnUNetDataset(folder, 0) # this should not load the properties! 55 | dl = nnUNetDataLoader3D(ds, 5, (16, 16, 16), (16, 16, 16), 0.33, None, None) 56 | a = next(dl) 57 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/logging/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/loss/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/loss/deep_supervision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DeepSupervisionWrapper(nn.Module): 6 | def __init__(self, loss, weight_factors=None): 7 | """ 8 | Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of 9 | inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then 10 | applied to each entry like this: 11 | l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ... 12 | If weights are None, all w will be 1. 13 | """ 14 | super(DeepSupervisionWrapper, self).__init__() 15 | assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0" 16 | self.weight_factors = tuple(weight_factors) 17 | self.loss = loss 18 | 19 | def forward(self, *args): 20 | assert all([isinstance(i, (tuple, list)) for i in args]), \ 21 | f"all args must be either tuple or list, got {[type(i) for i in args]}" 22 | # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because 23 | # this code is executed a lot of times! 24 | 25 | if self.weight_factors is None: 26 | weights = (1, ) * len(args[0]) 27 | else: 28 | weights = self.weight_factors 29 | 30 | return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) 31 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/loss/robust_ce_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import numpy as np 4 | 5 | 6 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss): 7 | """ 8 | this is just a compatibility layer because my target tensor is float and has an extra dimension 9 | 10 | input must be logits, not probabilities! 11 | """ 12 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 13 | if target.ndim == input.ndim: 14 | assert target.shape[1] == 1 15 | target = target[:, 0] 16 | return super().forward(input, target.long()) 17 | 18 | 19 | class TopKLoss(RobustCrossEntropyLoss): 20 | """ 21 | input must be logits, not probabilities! 22 | """ 23 | def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0): 24 | self.k = k 25 | super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing) 26 | 27 | def forward(self, inp, target): 28 | target = target[:, 0].long() 29 | res = super(TopKLoss, self).forward(inp, target) 30 | num_voxels = np.prod(res.shape, dtype=np.int64) 31 | res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) 32 | return res.mean() 33 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/lr_scheduler/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/lr_scheduler/polylr.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyLRScheduler(_LRScheduler): 5 | def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None): 6 | self.optimizer = optimizer 7 | self.initial_lr = initial_lr 8 | self.max_steps = max_steps 9 | self.exponent = exponent 10 | self.ctr = 0 11 | super().__init__(optimizer, current_step if current_step is not None else -1, False) 12 | 13 | def step(self, current_step=None): 14 | if current_step is None or current_step == -1: 15 | current_step = self.ctr 16 | self.ctr += 1 17 | 18 | new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent 19 | for param_group in self.optimizer.param_groups: 20 | param_group['lr'] = new_lr 21 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUMambaBot.py: -------------------------------------------------------------------------------- 1 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 2 | from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager 3 | from torch import nn 4 | from nnunetv2.nets.UMambaBot_3d import get_umamba_bot_3d_from_plans 5 | from nnunetv2.nets.UMambaBot_2d import get_umamba_bot_2d_from_plans 6 | 7 | 8 | class nnUNetTrainerUMambaBot(nnUNetTrainer): 9 | @staticmethod 10 | def build_network_architecture(plans_manager: PlansManager, 11 | dataset_json, 12 | configuration_manager: ConfigurationManager, 13 | num_input_channels, 14 | enable_deep_supervision: bool = True) -> nn.Module: 15 | 16 | if len(configuration_manager.patch_size) == 2: 17 | model = get_umamba_bot_2d_from_plans(plans_manager, dataset_json, configuration_manager, 18 | num_input_channels, deep_supervision=enable_deep_supervision) 19 | elif len(configuration_manager.patch_size) == 3: 20 | model = get_umamba_bot_3d_from_plans(plans_manager, dataset_json, configuration_manager, 21 | num_input_channels, deep_supervision=enable_deep_supervision) 22 | else: 23 | raise NotImplementedError("Only 2D and 3D models are supported") 24 | 25 | print("UMambaBot: {}".format(model)) 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/nnUNetTrainerUMambaEnc.py: -------------------------------------------------------------------------------- 1 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 2 | from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager 3 | from torch import nn 4 | from nnunetv2.nets.UMambaEnc_3d import get_umamba_enc_3d_from_plans 5 | from nnunetv2.nets.UMambaEnc_2d import get_umamba_enc_2d_from_plans 6 | 7 | class nnUNetTrainerUMambaEnc(nnUNetTrainer): 8 | @staticmethod 9 | def build_network_architecture(plans_manager: PlansManager, 10 | dataset_json, 11 | configuration_manager: ConfigurationManager, 12 | num_input_channels, 13 | enable_deep_supervision: bool = True) -> nn.Module: 14 | 15 | if len(configuration_manager.patch_size) == 2: 16 | model = get_umamba_enc_2d_from_plans(plans_manager, dataset_json, configuration_manager, 17 | num_input_channels, deep_supervision=enable_deep_supervision) 18 | elif len(configuration_manager.patch_size) == 3: 19 | model = get_umamba_enc_3d_from_plans(plans_manager, dataset_json, configuration_manager, 20 | num_input_channels, deep_supervision=enable_deep_supervision) 21 | else: 22 | raise NotImplementedError("Only 2D and 3D models are supported") 23 | 24 | 25 | print("UMambaEnc: {}".format(model)) 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json 3 | 4 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 5 | from torch import distributed as dist 6 | 7 | 8 | class nnUNetTrainerBenchmark_5epochs(nnUNetTrainer): 9 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 10 | device: torch.device = torch.device('cuda')): 11 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 12 | assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results." 13 | self.disable_checkpointing = True 14 | self.num_epochs = 5 15 | assert torch.cuda.is_available(), "This only works on GPU" 16 | self.crashed_with_runtime_error = False 17 | 18 | def perform_actual_validation(self, save_probabilities: bool = False): 19 | pass 20 | 21 | def save_checkpoint(self, filename: str) -> None: 22 | # do not trust people to remember that self.disable_checkpointing must be True for this trainer 23 | pass 24 | 25 | def run_training(self): 26 | try: 27 | super().run_training() 28 | except RuntimeError: 29 | self.crashed_with_runtime_error = True 30 | 31 | def on_train_end(self): 32 | super().on_train_end() 33 | 34 | if not self.is_ddp or self.local_rank == 0: 35 | torch_version = torch.__version__ 36 | cudnn_version = torch.backends.cudnn.version() 37 | gpu_name = torch.cuda.get_device_name() 38 | if self.crashed_with_runtime_error: 39 | fastest_epoch = 'Not enough VRAM!' 40 | else: 41 | epoch_times = [i - j for i, j in zip(self.logger.my_fantastic_logging['epoch_end_timestamps'], 42 | self.logger.my_fantastic_logging['epoch_start_timestamps'])] 43 | fastest_epoch = min(epoch_times) 44 | 45 | if self.is_ddp: 46 | num_gpus = dist.get_world_size() 47 | else: 48 | num_gpus = 1 49 | 50 | benchmark_result_file = join(self.output_folder, 'benchmark_result.json') 51 | if isfile(benchmark_result_file): 52 | old_results = load_json(benchmark_result_file) 53 | else: 54 | old_results = {} 55 | # generate some unique key 56 | my_key = f"{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__gpus_{num_gpus}" 57 | old_results[my_key] = { 58 | 'torch_version': torch_version, 59 | 'cudnn_version': cudnn_version, 60 | 'gpu_name': gpu_name, 61 | 'fastest_epoch': fastest_epoch, 62 | 'num_gpus': num_gpus, 63 | } 64 | save_json(old_results, 65 | join(self.output_folder, 'benchmark_result.json')) 66 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import ( 4 | nnUNetTrainerBenchmark_5epochs, 5 | ) 6 | from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels 7 | 8 | 9 | class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs): 10 | def __init__( 11 | self, 12 | plans: dict, 13 | configuration: str, 14 | fold: int, 15 | dataset_json: dict, 16 | unpack_dataset: bool = True, 17 | device: torch.device = torch.device("cuda"), 18 | ): 19 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 20 | self._set_batch_size_and_oversample() 21 | num_input_channels = determine_num_input_channels( 22 | self.plans_manager, self.configuration_manager, self.dataset_json 23 | ) 24 | patch_size = self.configuration_manager.patch_size 25 | dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device) 26 | if self.enable_deep_supervision: 27 | dummy_target = [ 28 | torch.round( 29 | torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) 30 | * max(self.label_manager.all_labels) 31 | ) 32 | for k in self._get_deep_supervision_scales() 33 | ] 34 | else: 35 | raise NotImplementedError("This trainer does not support deep supervision") 36 | self.dummy_batch = {"data": dummy_data, "target": dummy_target} 37 | 38 | def get_dataloaders(self): 39 | return None, None 40 | 41 | def run_training(self): 42 | try: 43 | self.on_train_start() 44 | 45 | for epoch in range(self.current_epoch, self.num_epochs): 46 | self.on_epoch_start() 47 | 48 | self.on_train_epoch_start() 49 | train_outputs = [] 50 | for batch_id in range(self.num_iterations_per_epoch): 51 | train_outputs.append(self.train_step(self.dummy_batch)) 52 | self.on_train_epoch_end(train_outputs) 53 | 54 | with torch.no_grad(): 55 | self.on_validation_epoch_start() 56 | val_outputs = [] 57 | for batch_id in range(self.num_val_iterations_per_epoch): 58 | val_outputs.append(self.validation_step(self.dummy_batch)) 59 | self.on_validation_epoch_end(val_outputs) 60 | 61 | self.on_epoch_end() 62 | 63 | self.on_train_end() 64 | except RuntimeError: 65 | self.crashed_with_runtime_error = True 66 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, List 2 | 3 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 4 | 5 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 6 | import numpy as np 7 | 8 | 9 | class nnUNetTrainerNoDA(nnUNetTrainer): 10 | @staticmethod 11 | def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], 12 | rotation_for_DA: dict, 13 | deep_supervision_scales: Union[List, Tuple, None], 14 | mirror_axes: Tuple[int, ...], 15 | do_dummy_2d_data_aug: bool, 16 | order_resampling_data: int = 1, 17 | order_resampling_seg: int = 0, 18 | border_val_seg: int = -1, 19 | use_mask_for_norm: List[bool] = None, 20 | is_cascaded: bool = False, 21 | foreground_labels: Union[Tuple[int, ...], List[int]] = None, 22 | regions: List[Union[List[int], Tuple[int, ...], int]] = None, 23 | ignore_label: int = None) -> AbstractTransform: 24 | return nnUNetTrainer.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels, 25 | regions, ignore_label) 26 | 27 | def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): 28 | return super().get_plain_dataloaders( 29 | initial_patch_size=self.configuration_manager.patch_size, 30 | dim=dim 31 | ) 32 | 33 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 34 | # we need to disable mirroring here so that no mirroring will be applied in inferene! 35 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 36 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 37 | mirror_axes = None 38 | self.inference_allowed_mirroring_axes = None 39 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 40 | 41 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py: -------------------------------------------------------------------------------- 1 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 2 | 3 | 4 | class nnUNetTrainerNoMirroring(nnUNetTrainer): 5 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 6 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 7 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 8 | mirror_axes = None 9 | self.inference_allowed_mirroring_axes = None 10 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 11 | 12 | 13 | class nnUNetTrainer_onlyMirror01(nnUNetTrainer): 14 | """ 15 | Only mirrors along spatial axes 0 and 1 for 3D and 0 for 2D 16 | """ 17 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 18 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 19 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 20 | patch_size = self.configuration_manager.patch_size 21 | dim = len(patch_size) 22 | if dim == 2: 23 | mirror_axes = (0, ) 24 | else: 25 | mirror_axes = (0, 1) 26 | self.inference_allowed_mirroring_axes = mirror_axes 27 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 28 | 29 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper 3 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 4 | from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss 5 | import numpy as np 6 | 7 | 8 | class nnUNetTrainerCELoss(nnUNetTrainer): 9 | def _build_loss(self): 10 | assert not self.label_manager.has_regions, "regions not supported by this trainer" 11 | loss = RobustCrossEntropyLoss( 12 | weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100 13 | ) 14 | 15 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 16 | # this gives higher resolution outputs more weight in the loss 17 | if self.enable_deep_supervision: 18 | deep_supervision_scales = self._get_deep_supervision_scales() 19 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) 20 | weights[-1] = 0 21 | 22 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 23 | weights = weights / weights.sum() 24 | # now wrap the loss 25 | loss = DeepSupervisionWrapper(loss, weights) 26 | return loss 27 | 28 | 29 | class nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss): 30 | def __init__( 31 | self, 32 | plans: dict, 33 | configuration: str, 34 | fold: int, 35 | dataset_json: dict, 36 | unpack_dataset: bool = True, 37 | device: torch.device = torch.device("cuda"), 38 | ): 39 | """used for debugging plans etc""" 40 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 41 | self.num_epochs = 5 42 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss 5 | from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper 6 | from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss 7 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 8 | from nnunetv2.utilities.helpers import softmax_helper_dim1 9 | 10 | 11 | class nnUNetTrainerDiceLoss(nnUNetTrainer): 12 | def _build_loss(self): 13 | loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice, 14 | 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp}, 15 | apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1) 16 | 17 | if self.enable_deep_supervision: 18 | deep_supervision_scales = self._get_deep_supervision_scales() 19 | 20 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 21 | # this gives higher resolution outputs more weight in the loss 22 | weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) 23 | weights[-1] = 0 24 | 25 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 26 | weights = weights / weights.sum() 27 | # now wrap the loss 28 | loss = DeepSupervisionWrapper(loss, weights) 29 | return loss 30 | 31 | 32 | class nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainer): 33 | def _build_loss(self): 34 | # set smooth to 0 35 | if self.label_manager.has_regions: 36 | loss = DC_and_BCE_loss({}, 37 | {'batch_dice': self.configuration_manager.batch_dice, 38 | 'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp}, 39 | use_ignore_label=self.label_manager.ignore_label is not None, 40 | dice_class=MemoryEfficientSoftDiceLoss) 41 | else: 42 | loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, 43 | 'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, 44 | ignore_label=self.label_manager.ignore_label, 45 | dice_class=MemoryEfficientSoftDiceLoss) 46 | 47 | if self.enable_deep_supervision: 48 | deep_supervision_scales = self._get_deep_supervision_scales() 49 | 50 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 51 | # this gives higher resolution outputs more weight in the loss 52 | weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) 53 | weights[-1] = 0 54 | 55 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 56 | weights = weights / weights.sum() 57 | # now wrap the loss 58 | loss = DeepSupervisionWrapper(loss, weights) 59 | return loss 60 | 61 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py: -------------------------------------------------------------------------------- 1 | from nnunetv2.training.loss.compound_losses import DC_and_topk_loss 2 | from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper 3 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 4 | import numpy as np 5 | from nnunetv2.training.loss.robust_ce_loss import TopKLoss 6 | 7 | 8 | class nnUNetTrainerTopk10Loss(nnUNetTrainer): 9 | def _build_loss(self): 10 | assert not self.label_manager.has_regions, "regions not supported by this trainer" 11 | loss = TopKLoss( 12 | ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10 13 | ) 14 | 15 | if self.enable_deep_supervision: 16 | deep_supervision_scales = self._get_deep_supervision_scales() 17 | 18 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 19 | # this gives higher resolution outputs more weight in the loss 20 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) 21 | weights[-1] = 0 22 | 23 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 24 | weights = weights / weights.sum() 25 | # now wrap the loss 26 | loss = DeepSupervisionWrapper(loss, weights) 27 | return loss 28 | 29 | 30 | class nnUNetTrainerTopk10LossLS01(nnUNetTrainer): 31 | def _build_loss(self): 32 | assert not self.label_manager.has_regions, "regions not supported by this trainer" 33 | loss = TopKLoss( 34 | ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, 35 | k=10, 36 | label_smoothing=0.1, 37 | ) 38 | 39 | if self.enable_deep_supervision: 40 | deep_supervision_scales = self._get_deep_supervision_scales() 41 | 42 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 43 | # this gives higher resolution outputs more weight in the loss 44 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) 45 | weights[-1] = 0 46 | 47 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 48 | weights = weights / weights.sum() 49 | # now wrap the loss 50 | loss = DeepSupervisionWrapper(loss, weights) 51 | return loss 52 | 53 | 54 | class nnUNetTrainerDiceTopK10Loss(nnUNetTrainer): 55 | def _build_loss(self): 56 | assert not self.label_manager.has_regions, "regions not supported by this trainer" 57 | loss = DC_and_topk_loss( 58 | {"batch_dice": self.configuration_manager.batch_dice, "smooth": 1e-5, "do_bg": False, "ddp": self.is_ddp}, 59 | {"k": 10, "label_smoothing": 0.0}, 60 | weight_ce=1, 61 | weight_dice=1, 62 | ignore_label=self.label_manager.ignore_label, 63 | ) 64 | if self.enable_deep_supervision: 65 | deep_supervision_scales = self._get_deep_supervision_scales() 66 | 67 | # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases 68 | # this gives higher resolution outputs more weight in the loss 69 | weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) 70 | weights[-1] = 0 71 | 72 | # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 73 | weights = weights / weights.sum() 74 | # now wrap the loss 75 | loss = DeepSupervisionWrapper(loss, weights) 76 | return loss 77 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import CosineAnnealingLR 3 | 4 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 5 | 6 | 7 | class nnUNetTrainerCosAnneal(nnUNetTrainer): 8 | def configure_optimizers(self): 9 | optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 10 | momentum=0.99, nesterov=True) 11 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) 12 | return optimizer, lr_scheduler 13 | 14 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py: -------------------------------------------------------------------------------- 1 | from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet, PlainConvUNet 2 | from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_batchnorm 3 | from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0, InitWeights_He 4 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 5 | from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager 6 | from torch import nn 7 | 8 | 9 | class nnUNetTrainerBN(nnUNetTrainer): 10 | @staticmethod 11 | def build_network_architecture(plans_manager: PlansManager, 12 | dataset_json, 13 | configuration_manager: ConfigurationManager, 14 | num_input_channels, 15 | enable_deep_supervision: bool = True) -> nn.Module: 16 | num_stages = len(configuration_manager.conv_kernel_sizes) 17 | 18 | dim = len(configuration_manager.conv_kernel_sizes[0]) 19 | conv_op = convert_dim_to_conv_op(dim) 20 | 21 | label_manager = plans_manager.get_label_manager(dataset_json) 22 | 23 | segmentation_network_class_name = configuration_manager.UNet_class_name 24 | mapping = { 25 | 'PlainConvUNet': PlainConvUNet, 26 | 'ResidualEncoderUNet': ResidualEncoderUNet 27 | } 28 | kwargs = { 29 | 'PlainConvUNet': { 30 | 'conv_bias': True, 31 | 'norm_op': get_matching_batchnorm(conv_op), 32 | 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, 33 | 'dropout_op': None, 'dropout_op_kwargs': None, 34 | 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, 35 | }, 36 | 'ResidualEncoderUNet': { 37 | 'conv_bias': True, 38 | 'norm_op': get_matching_batchnorm(conv_op), 39 | 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, 40 | 'dropout_op': None, 'dropout_op_kwargs': None, 41 | 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, 42 | } 43 | } 44 | assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ 45 | 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ 46 | 'into either this ' \ 47 | 'function (get_network_from_plans) or ' \ 48 | 'the init of your nnUNetModule to accommodate that.' 49 | network_class = mapping[segmentation_network_class_name] 50 | 51 | conv_or_blocks_per_stage = { 52 | 'n_conv_per_stage' 53 | if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, 54 | 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder 55 | } 56 | # network class name!! 57 | model = network_class( 58 | input_channels=num_input_channels, 59 | n_stages=num_stages, 60 | features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, 61 | configuration_manager.unet_max_num_features) for i in range(num_stages)], 62 | conv_op=conv_op, 63 | kernel_sizes=configuration_manager.conv_kernel_sizes, 64 | strides=configuration_manager.pool_op_kernel_sizes, 65 | num_classes=label_manager.num_segmentation_heads, 66 | deep_supervision=enable_deep_supervision, 67 | **conv_or_blocks_per_stage, 68 | **kwargs[segmentation_network_class_name] 69 | ) 70 | model.apply(InitWeights_He(1e-2)) 71 | if network_class == ResidualEncoderUNet: 72 | model.apply(init_last_bn_before_add_to_0) 73 | return model 74 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py: -------------------------------------------------------------------------------- 1 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 2 | import torch 3 | 4 | 5 | class nnUNetTrainerNoDeepSupervision(nnUNetTrainer): 6 | def __init__( 7 | self, 8 | plans: dict, 9 | configuration: str, 10 | fold: int, 11 | dataset_json: dict, 12 | unpack_dataset: bool = True, 13 | device: torch.device = torch.device("cuda"), 14 | ): 15 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 16 | self.enable_deep_supervision = False 17 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam, AdamW 3 | 4 | from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler 5 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 6 | 7 | 8 | class nnUNetTrainerAdam(nnUNetTrainer): 9 | def configure_optimizers(self): 10 | optimizer = AdamW(self.network.parameters(), 11 | lr=self.initial_lr, 12 | weight_decay=self.weight_decay, 13 | amsgrad=True) 14 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 15 | # momentum=0.99, nesterov=True) 16 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) 17 | return optimizer, lr_scheduler 18 | 19 | 20 | class nnUNetTrainerVanillaAdam(nnUNetTrainer): 21 | def configure_optimizers(self): 22 | optimizer = Adam(self.network.parameters(), 23 | lr=self.initial_lr, 24 | weight_decay=self.weight_decay) 25 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 26 | # momentum=0.99, nesterov=True) 27 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) 28 | return optimizer, lr_scheduler 29 | 30 | 31 | class nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam): 32 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 33 | device: torch.device = torch.device('cuda')): 34 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 35 | self.initial_lr = 1e-3 36 | 37 | 38 | class nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam): 39 | # https://twitter.com/karpathy/status/801621764144971776?lang=en 40 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 41 | device: torch.device = torch.device('cuda')): 42 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 43 | self.initial_lr = 3e-4 44 | 45 | 46 | class nnUNetTrainerAdam1en3(nnUNetTrainerAdam): 47 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 48 | device: torch.device = torch.device('cuda')): 49 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 50 | self.initial_lr = 1e-3 51 | 52 | 53 | class nnUNetTrainerAdam3en4(nnUNetTrainerAdam): 54 | # https://twitter.com/karpathy/status/801621764144971776?lang=en 55 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 56 | device: torch.device = torch.device('cuda')): 57 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 58 | self.initial_lr = 3e-4 59 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler 4 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | try: 7 | from adan_pytorch import Adan 8 | except ImportError: 9 | Adan = None 10 | 11 | 12 | class nnUNetTrainerAdan(nnUNetTrainer): 13 | def configure_optimizers(self): 14 | if Adan is None: 15 | raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') 16 | optimizer = Adan(self.network.parameters(), 17 | lr=self.initial_lr, 18 | # betas=(0.02, 0.08, 0.01), defaults 19 | weight_decay=self.weight_decay) 20 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 21 | # momentum=0.99, nesterov=True) 22 | lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) 23 | return optimizer, lr_scheduler 24 | 25 | 26 | class nnUNetTrainerAdan1en3(nnUNetTrainerAdan): 27 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 28 | device: torch.device = torch.device('cuda')): 29 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 30 | self.initial_lr = 1e-3 31 | 32 | 33 | class nnUNetTrainerAdan3en4(nnUNetTrainerAdan): 34 | # https://twitter.com/karpathy/status/801621764144971776?lang=en 35 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 36 | device: torch.device = torch.device('cuda')): 37 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 38 | self.initial_lr = 3e-4 39 | 40 | 41 | class nnUNetTrainerAdan1en1(nnUNetTrainerAdan): 42 | # this trainer makes no sense -> nan! 43 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 44 | device: torch.device = torch.device('cuda')): 45 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 46 | self.initial_lr = 1e-1 47 | 48 | 49 | class nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan): 50 | # def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 51 | # device: torch.device = torch.device('cuda')): 52 | # super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 53 | # self.num_epochs = 15 54 | 55 | def configure_optimizers(self): 56 | if Adan is None: 57 | raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') 58 | optimizer = Adan(self.network.parameters(), 59 | lr=self.initial_lr, 60 | # betas=(0.02, 0.08, 0.01), defaults 61 | weight_decay=self.weight_decay) 62 | # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, 63 | # momentum=0.99, nesterov=True) 64 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) 65 | return optimizer, lr_scheduler 66 | 67 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 4 | 5 | 6 | class nnUNetTrainer_5epochs(nnUNetTrainer): 7 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 8 | device: torch.device = torch.device('cuda')): 9 | """used for debugging plans etc""" 10 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 11 | self.num_epochs = 5 12 | 13 | 14 | class nnUNetTrainer_1epoch(nnUNetTrainer): 15 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 16 | device: torch.device = torch.device('cuda')): 17 | """used for debugging plans etc""" 18 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 19 | self.num_epochs = 1 20 | 21 | 22 | class nnUNetTrainer_10epochs(nnUNetTrainer): 23 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 24 | device: torch.device = torch.device('cuda')): 25 | """used for debugging plans etc""" 26 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 27 | self.num_epochs = 10 28 | 29 | 30 | class nnUNetTrainer_20epochs(nnUNetTrainer): 31 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 32 | device: torch.device = torch.device('cuda')): 33 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 34 | self.num_epochs = 20 35 | 36 | 37 | class nnUNetTrainer_50epochs(nnUNetTrainer): 38 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 39 | device: torch.device = torch.device('cuda')): 40 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 41 | self.num_epochs = 50 42 | 43 | 44 | class nnUNetTrainer_100epochs(nnUNetTrainer): 45 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 46 | device: torch.device = torch.device('cuda')): 47 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 48 | self.num_epochs = 100 49 | 50 | 51 | class nnUNetTrainer_250epochs(nnUNetTrainer): 52 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 53 | device: torch.device = torch.device('cuda')): 54 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 55 | self.num_epochs = 250 56 | 57 | 58 | class nnUNetTrainer_2000epochs(nnUNetTrainer): 59 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 60 | device: torch.device = torch.device('cuda')): 61 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 62 | self.num_epochs = 2000 63 | 64 | 65 | class nnUNetTrainer_4000epochs(nnUNetTrainer): 66 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 67 | device: torch.device = torch.device('cuda')): 68 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 69 | self.num_epochs = 4000 70 | 71 | 72 | class nnUNetTrainer_8000epochs(nnUNetTrainer): 73 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 74 | device: torch.device = torch.device('cuda')): 75 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 76 | self.num_epochs = 8000 77 | -------------------------------------------------------------------------------- /umamba/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer 4 | 5 | 6 | class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer): 7 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 8 | device: torch.device = torch.device('cuda')): 9 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 10 | self.num_epochs = 250 11 | 12 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 13 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 14 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 15 | mirror_axes = None 16 | self.inference_allowed_mirroring_axes = None 17 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 18 | 19 | 20 | class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer): 21 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 22 | device: torch.device = torch.device('cuda')): 23 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 24 | self.num_epochs = 2000 25 | 26 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 27 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 28 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 29 | mirror_axes = None 30 | self.inference_allowed_mirroring_axes = None 31 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 32 | 33 | 34 | class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer): 35 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 36 | device: torch.device = torch.device('cuda')): 37 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 38 | self.num_epochs = 4000 39 | 40 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 41 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 42 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 43 | mirror_axes = None 44 | self.inference_allowed_mirroring_axes = None 45 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 46 | 47 | 48 | class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer): 49 | def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, 50 | device: torch.device = torch.device('cuda')): 51 | super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) 52 | self.num_epochs = 8000 53 | 54 | def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): 55 | rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ 56 | super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() 57 | mirror_axes = None 58 | self.inference_allowed_mirroring_axes = None 59 | return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes 60 | 61 | -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/utilities/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/collate_outputs.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | 6 | def collate_outputs(outputs: List[dict]): 7 | """ 8 | used to collate default train_step and validation_step outputs. If you want something different then you gotta 9 | extend this 10 | 11 | we expect outputs to be a list of dictionaries where each of the dict has the same set of keys 12 | """ 13 | collated = {} 14 | for k in outputs[0].keys(): 15 | if np.isscalar(outputs[0][k]): 16 | collated[k] = [o[k] for o in outputs] 17 | elif isinstance(outputs[0][k], np.ndarray): 18 | collated[k] = np.vstack([o[k][None] for o in outputs]) 19 | elif isinstance(outputs[0][k], list): 20 | collated[k] = [item for o in outputs for item in o[k]] 21 | else: 22 | raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. ' 23 | f'Modify collate_outputs to add this functionality') 24 | return collated -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/dataset_name_id_conversion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Union 15 | 16 | from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results 17 | from batchgenerators.utilities.file_and_folder_operations import * 18 | import numpy as np 19 | 20 | 21 | def find_candidate_datasets(dataset_id: int): 22 | startswith = "Dataset%03.0d" % dataset_id 23 | if nnUNet_preprocessed is not None and isdir(nnUNet_preprocessed): 24 | candidates_preprocessed = subdirs(nnUNet_preprocessed, prefix=startswith, join=False) 25 | else: 26 | candidates_preprocessed = [] 27 | 28 | if nnUNet_raw is not None and isdir(nnUNet_raw): 29 | candidates_raw = subdirs(nnUNet_raw, prefix=startswith, join=False) 30 | else: 31 | candidates_raw = [] 32 | 33 | candidates_trained_models = [] 34 | if nnUNet_results is not None and isdir(nnUNet_results): 35 | candidates_trained_models += subdirs(nnUNet_results, prefix=startswith, join=False) 36 | 37 | all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models 38 | unique_candidates = np.unique(all_candidates) 39 | return unique_candidates 40 | 41 | 42 | def convert_id_to_dataset_name(dataset_id: int): 43 | unique_candidates = find_candidate_datasets(dataset_id) 44 | if len(unique_candidates) > 1: 45 | raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the " 46 | "following folders:\n%s\n%s\n%s" % (dataset_id, nnUNet_raw, nnUNet_preprocessed, nnUNet_results)) 47 | if len(unique_candidates) == 0: 48 | raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID " 49 | f"exists and that nnU-Net knows where raw and preprocessed data are located " 50 | f"(see Documentation - Installation). Here are your currently defined folders:\n" 51 | f"nnUNet_preprocessed={os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None'}\n" 52 | f"nnUNet_results={os.environ.get('nnUNet_results') if os.environ.get('nnUNet_results') is not None else 'None'}\n" 53 | f"nnUNet_raw={os.environ.get('nnUNet_raw') if os.environ.get('nnUNet_raw') is not None else 'None'}\n" 54 | f"If something is not right, adapt your environment variables.") 55 | return unique_candidates[0] 56 | 57 | 58 | def convert_dataset_name_to_id(dataset_name: str): 59 | assert dataset_name.startswith("Dataset") 60 | dataset_id = int(dataset_name[7:10]) 61 | return dataset_id 62 | 63 | 64 | def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str: 65 | if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"): 66 | return dataset_name_or_id 67 | if isinstance(dataset_name_or_id, str): 68 | try: 69 | dataset_name_or_id = int(dataset_name_or_id) 70 | except ValueError: 71 | raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to " 72 | "convert it to a dataset ID (int). That failed, however. Please give an integer number " 73 | "('1', '2', etc) or a correct dataset name. Your input: %s" % dataset_name_or_id) 74 | return convert_id_to_dataset_name(dataset_name_or_id) 75 | -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/ddp_allgather.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Optional, Tuple 15 | 16 | import torch 17 | from torch import distributed 18 | 19 | 20 | def print_if_rank0(*args): 21 | if distributed.get_rank() == 0: 22 | print(*args) 23 | 24 | 25 | class AllGatherGrad(torch.autograd.Function): 26 | # stolen from pytorch lightning 27 | @staticmethod 28 | def forward( 29 | ctx: Any, 30 | tensor: torch.Tensor, 31 | group: Optional["torch.distributed.ProcessGroup"] = None, 32 | ) -> torch.Tensor: 33 | ctx.group = group 34 | 35 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] 36 | 37 | torch.distributed.all_gather(gathered_tensor, tensor, group=group) 38 | gathered_tensor = torch.stack(gathered_tensor, dim=0) 39 | 40 | return gathered_tensor 41 | 42 | @staticmethod 43 | def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: 44 | grad_output = torch.cat(grad_output) 45 | 46 | torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) 47 | 48 | return grad_output[torch.distributed.get_rank()], None 49 | 50 | -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/default_n_proc_DA.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | 4 | 5 | def get_allowed_n_proc_DA(): 6 | """ 7 | This function is used to set the number of processes used on different Systems. It is specific to our cluster 8 | infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed. 9 | 10 | IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script 11 | (see first line). 12 | 13 | Interpret the output as the number of processes used for data augmentation PER GPU. 14 | 15 | The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our 16 | systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per 17 | GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use. 18 | """ 19 | 20 | if 'nnUNet_n_proc_DA' in os.environ.keys(): 21 | use_this = int(os.environ['nnUNet_n_proc_DA']) 22 | else: 23 | hostname = subprocess.getoutput(['hostname']) 24 | if hostname in ['Fabian', ]: 25 | use_this = 12 26 | elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']: 27 | use_this = 16 28 | elif hostname.startswith('e230-dgx1'): 29 | use_this = 10 30 | elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'): 31 | use_this = 16 32 | elif hostname.startswith('e230-dgx2'): 33 | use_this = 6 34 | elif hostname.startswith('e230-dgxa100-'): 35 | use_this = 28 36 | elif hostname.startswith('lsf22-gpu'): 37 | use_this = 28 38 | elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'): 39 | use_this = 12 40 | else: 41 | use_this = 12 # default value 42 | 43 | use_this = min(use_this, os.cpu_count()) 44 | return use_this 45 | -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/find_class_by_name.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import * 5 | 6 | 7 | def recursive_find_python_class(folder: str, class_name: str, current_module: str): 8 | tr = None 9 | for importer, modname, ispkg in pkgutil.iter_modules([folder]): 10 | # print(modname, ispkg) 11 | if not ispkg: 12 | m = importlib.import_module(current_module + "." + modname) 13 | if hasattr(m, class_name): 14 | tr = getattr(m, class_name) 15 | break 16 | 17 | if tr is None: 18 | for importer, modname, ispkg in pkgutil.iter_modules([folder]): 19 | if ispkg: 20 | next_current_module = current_module + "." + modname 21 | tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module) 22 | if tr is not None: 23 | break 24 | return tr -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/get_network_from_plans.py: -------------------------------------------------------------------------------- 1 | from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet 2 | from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op 3 | from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 4 | from nnunetv2.utilities.network_initialization import InitWeights_He 5 | from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager 6 | from torch import nn 7 | 8 | 9 | def get_network_from_plans(plans_manager: PlansManager, 10 | dataset_json: dict, 11 | configuration_manager: ConfigurationManager, 12 | num_input_channels: int, 13 | deep_supervision: bool = True): 14 | """ 15 | we may have to change this in the future to accommodate other plans -> network mappings 16 | 17 | num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the 18 | trainer rather than inferring it again from the plans here. 19 | """ 20 | num_stages = len(configuration_manager.conv_kernel_sizes) 21 | 22 | dim = len(configuration_manager.conv_kernel_sizes[0]) 23 | conv_op = convert_dim_to_conv_op(dim) 24 | 25 | label_manager = plans_manager.get_label_manager(dataset_json) 26 | 27 | segmentation_network_class_name = configuration_manager.UNet_class_name 28 | mapping = { 29 | 'PlainConvUNet': PlainConvUNet, 30 | 'ResidualEncoderUNet': ResidualEncoderUNet 31 | } 32 | kwargs = { 33 | 'PlainConvUNet': { 34 | 'conv_bias': True, 35 | 'norm_op': get_matching_instancenorm(conv_op), 36 | 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, 37 | 'dropout_op': None, 'dropout_op_kwargs': None, 38 | 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, 39 | }, 40 | 'ResidualEncoderUNet': { 41 | 'conv_bias': True, 42 | 'norm_op': get_matching_instancenorm(conv_op), 43 | 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, 44 | 'dropout_op': None, 'dropout_op_kwargs': None, 45 | 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, 46 | } 47 | } 48 | assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ 49 | 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ 50 | 'into either this ' \ 51 | 'function (get_network_from_plans) or ' \ 52 | 'the init of your nnUNetModule to accommodate that.' 53 | network_class = mapping[segmentation_network_class_name] 54 | 55 | conv_or_blocks_per_stage = { 56 | 'n_conv_per_stage' 57 | if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, 58 | 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder 59 | } 60 | # network class name!! 61 | model = network_class( 62 | input_channels=num_input_channels, 63 | n_stages=num_stages, 64 | features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, 65 | configuration_manager.unet_max_num_features) for i in range(num_stages)], 66 | conv_op=conv_op, 67 | kernel_sizes=configuration_manager.conv_kernel_sizes, 68 | strides=configuration_manager.pool_op_kernel_sizes, 69 | num_classes=label_manager.num_segmentation_heads, 70 | deep_supervision=deep_supervision, 71 | **conv_or_blocks_per_stage, 72 | **kwargs[segmentation_network_class_name] 73 | ) 74 | model.apply(InitWeights_He(1e-2)) 75 | if network_class == ResidualEncoderUNet: 76 | model.apply(init_last_bn_before_add_to_0) 77 | return model 78 | -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor: 5 | return torch.softmax(x, 0) 6 | 7 | 8 | def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor: 9 | return torch.softmax(x, 1) 10 | 11 | 12 | def empty_cache(device: torch.device): 13 | if device.type == 'cuda': 14 | torch.cuda.empty_cache() 15 | elif device.type == 'mps': 16 | from torch import mps 17 | mps.empty_cache() 18 | else: 19 | pass 20 | 21 | 22 | class dummy_context(object): 23 | def __enter__(self): 24 | pass 25 | 26 | def __exit__(self, exc_type, exc_val, exc_tb): 27 | pass 28 | -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/json_export.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def recursive_fix_for_json_export(my_dict: dict): 8 | # json is stupid. 'cannot serialize object of type bool_/int64/float64'. Come on bro. 9 | keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys.... 10 | for k in keys: 11 | if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)): 12 | tmp = my_dict[k] 13 | del my_dict[k] 14 | my_dict[int(k)] = tmp 15 | del tmp 16 | k = int(k) 17 | 18 | if isinstance(my_dict[k], dict): 19 | recursive_fix_for_json_export(my_dict[k]) 20 | elif isinstance(my_dict[k], np.ndarray): 21 | assert my_dict[k].ndim == 1, 'only 1d arrays are supported' 22 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=list) 23 | elif isinstance(my_dict[k], (np.bool_,)): 24 | my_dict[k] = bool(my_dict[k]) 25 | elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)): 26 | my_dict[k] = int(my_dict[k]) 27 | elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)): 28 | my_dict[k] = float(my_dict[k]) 29 | elif isinstance(my_dict[k], list): 30 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k])) 31 | elif isinstance(my_dict[k], tuple): 32 | my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple) 33 | elif isinstance(my_dict[k], torch.device): 34 | my_dict[k] = str(my_dict[k]) 35 | else: 36 | pass # pray it can be serialized 37 | 38 | 39 | def fix_types_iterable(iterable, output_type): 40 | # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep you hands off of this. 41 | out = [] 42 | for i in iterable: 43 | if type(i) in (np.int64, np.int32, np.int8, np.uint8): 44 | out.append(int(i)) 45 | elif isinstance(i, dict): 46 | recursive_fix_for_json_export(i) 47 | out.append(i) 48 | elif type(i) in (np.float32, np.float64, np.float16): 49 | out.append(float(i)) 50 | elif type(i) in (np.bool_,): 51 | out.append(bool(i)) 52 | elif isinstance(i, str): 53 | out.append(i) 54 | elif isinstance(i, Iterable): 55 | # print('recursive call on', i, type(i)) 56 | out.append(fix_types_iterable(i, type(i))) 57 | else: 58 | out.append(i) 59 | return output_type(out) 60 | -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/label_handling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/utilities/label_handling/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/network_initialization.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class InitWeights_He(object): 5 | def __init__(self, neg_slope=1e-2): 6 | self.neg_slope = neg_slope 7 | 8 | def __call__(self, module): 9 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 10 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) 11 | if module.bias is not None: 12 | module.bias = nn.init.constant_(module.bias, 0) 13 | -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/plans_handling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/U-Mamba/28459e33ca03769800dd35e23c6e62491d1925b5/umamba/nnunetv2/utilities/plans_handling/__init__.py -------------------------------------------------------------------------------- /umamba/nnunetv2/utilities/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center 2 | # (DKFZ), Heidelberg, Germany 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os.path 16 | from functools import lru_cache 17 | from typing import Union 18 | 19 | from batchgenerators.utilities.file_and_folder_operations import * 20 | import numpy as np 21 | import re 22 | 23 | from nnunetv2.paths import nnUNet_raw 24 | 25 | 26 | def get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str): 27 | files = subfiles(folder, suffix=file_ending, join=False) 28 | # all files have a 4 digit channel index (_XXXX) 29 | crop = len(file_ending) + 5 30 | files = [i[:-crop] for i in files] 31 | # only unique image ids 32 | files = np.unique(files) 33 | return files 34 | 35 | 36 | def create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None) -> List[ 37 | List[str]]: 38 | """ 39 | does not rely on dataset.json 40 | """ 41 | if identifiers is None: 42 | identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending) 43 | files = subfiles(folder, suffix=file_ending, join=False, sort=True) 44 | list_of_lists = [] 45 | for f in identifiers: 46 | p = re.compile(re.escape(f) + r"_\d\d\d\d" + re.escape(file_ending)) 47 | list_of_lists.append([join(folder, i) for i in files if p.fullmatch(i)]) 48 | return list_of_lists 49 | 50 | 51 | def get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None): 52 | if dataset_json is None: 53 | dataset_json = load_json(join(raw_dataset_folder, 'dataset.json')) 54 | 55 | if 'dataset' in dataset_json.keys(): 56 | dataset = dataset_json['dataset'] 57 | for k in dataset.keys(): 58 | dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, dataset[k]['label'])) if not os.path.isabs(dataset[k]['label']) else dataset[k]['label'] 59 | dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, i)) if not os.path.isabs(i) else i for i in dataset[k]['images']] 60 | else: 61 | identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending']) 62 | images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers) 63 | segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers] 64 | dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)} 65 | return dataset 66 | 67 | 68 | if __name__ == '__main__': 69 | print(get_filenames_of_train_images_and_targets(join(nnUNet_raw, 'Dataset002_Heart'))) 70 | --------------------------------------------------------------------------------