├── FAMNet.png ├── README.md ├── config.py ├── data ├── ABD │ ├── ABDOMEN_CT │ │ ├── class_slice_index_gen.py │ │ ├── intensity_normalization.py │ │ ├── niftiio.py │ │ └── resampling_and_roi.py │ └── ABDOMEN_MR │ │ ├── class_slice_index_gen.ipynb │ │ ├── dcm_img_to_nii.sh │ │ ├── image_normalize.ipynb │ │ └── png_gth_to_nii.ipynb ├── Cardiac │ ├── LGE │ │ ├── class_slice_index_gen.ipynb │ │ └── image_normalize.ipynb │ └── bSSFP │ │ ├── class_slice_index_gen.ipynb │ │ └── image_normalize.ipynb ├── Prostate │ ├── NCI │ │ ├── class_slice_index_gen.ipynb │ │ ├── dcm_img_to_nii.sh │ │ ├── image_normalize.ipynb │ │ └── png_gth_to_nii.ipynb │ └── UCLH │ │ ├── class_slice_index_gen.ipynb │ │ ├── dcm_img_to_nii.sh │ │ ├── image_normalize.ipynb │ │ └── png_gth_to_nii.ipynb └── supervoxels │ ├── _ccomp.pxd │ ├── _ccomp.pyx │ ├── felzenszwalb_3d.py │ ├── felzenszwalb_3d_cy.pyx │ ├── generate_supervoxels.py │ └── setup.py ├── dataloaders ├── dataset_specifics.py ├── datasets.py └── image_transforms.py ├── models ├── CDFSMIS.py ├── __pycache__ │ ├── CDFSMIS.cpython-311.pyc │ ├── FAM.cpython-311.pyc │ ├── MSFM.cpython-311.pyc │ ├── cdfs_TS.cpython-311.pyc │ └── encoder.cpython-311.pyc └── encoder.py ├── scripts ├── test_LGE2bssFP.sh ├── test_NCI2UCLH.sh ├── test_UCLH2NCI.sh ├── test_bssFP2LGE.sh ├── test_ct2mr.sh ├── test_mr2ct.sh ├── train_LGE.sh ├── train_NCI.sh ├── train_UCLH.sh ├── train_bssFP.sh ├── train_ct2mr.sh └── train_mr2ct.sh ├── test.py ├── train.py └── utils.py /FAMNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primebo1/FAMNet/3b14839129d2ad7362c458845e56af01238231ec/FAMNet.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FAMNet 2 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-b31b1b.svg?logo=arxiv)](https://arxiv.org/abs/2412.09319) 3 | [![AAAI](https://img.shields.io/badge/AAAI'25-Paper-blue)](https://ojs.aaai.org/index.php/AAAI/article/view/32184) 4 | 5 | Official code for AAAI 2025 paper: FAMNet: Frequency-aware Matching Network for Cross-domain Few-shot Medical Image Segmentation 6 | 7 | - [**News!**] 24-12-10: Our work is accepted by AAAI25. [Arxiv Paper](https://arxiv.org/abs/2412.09319) can be found here. 🎉 8 | 9 | 10 | ![](./FAMNet.png) 11 | 12 | 13 | ## 📋 Abstract 14 | Existing few-shot medical image segmentation (FSMIS) models fail to address a practical issue in medical imaging: the domain shift caused by different imaging techniques, which limits the applicability to current FSMIS tasks. To overcome this limitation, we focus on the cross-domain few-shot medical image segmentation (CD-FSMIS) task, aiming to develop a generalized model capable of adapting to a broader range of medical image segmentation scenarios with limited labeled data from the novel target domain. 15 | Inspired by the characteristics of frequency domain similarity across different domains, we propose a Frequency-aware Matching Network (FAMNet), which includes two key components: a Frequency-aware Matching (FAM) module and a Multi-Spectral Fusion (MSF) module. The FAM module tackles two problems during the meta-learning phase: 1) intra-domain variance caused by the inherent support-query bias, due to the different appearances of organs and lesions, and 2) inter-domain variance caused by different medical imaging techniques. Additionally, we design an MSF module to integrate the different frequency features decoupled by the FAM module, and further mitigate the impact of inter-domain variance on the model's segmentation performance. 16 | Combining these two modules, our FAMNet surpasses existing FSMIS models and Cross-domain Few-shot Semantic Segmentation models on three cross-domain datasets, achieving state-of-the-art performance in the CD-FSMIS task. 17 | 18 | 19 | ## ⏳ Quick start 20 | 21 | ### 🛠 Dependencies 22 | Please install the following essential dependencies: 23 | ``` 24 | dcm2nii 25 | json5==0.8.5 26 | jupyter==1.0.0 27 | nibabel==2.5.1 28 | numpy==1.22.0 29 | opencv-python==4.5.5.62 30 | Pillow>=8.1.1 31 | sacred==0.8.2 32 | scikit-image==0.18.3 33 | SimpleITK==1.2.3 34 | torch==1.10.2 35 | torchvision=0.11.2 36 | tqdm==4.62.3 37 | ``` 38 | 39 | 40 | ### 📚 Datasets and Preprocessing 41 | Please download: 42 | 1) **Abdominal MRI**: [Combined Healthy Abdominal Organ Segmentation dataset](https://chaos.grand-challenge.org/) 43 | 2) **Abdominal CT**: [Multi-Atlas Abdomen Labeling Challenge](https://www.synapse.org/#!Synapse:syn3193805/wiki/218292) 44 | 3) **Cardiac LGE and b-SSFP**: [Multi-sequence Cardiac MRI Segmentation dataset](https://zmiclab.github.io/zxh/0/mscmrseg19/index.html) 45 | 4) **Prostate UCLH and NCI**: [Cross-institution Male Pelvic Structures](https://zenodo.org/records/7013610) 46 | 47 | Pre-processing is performed according to [Ouyang et al.](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation/tree/2f2a22b74890cb9ad5e56ac234ea02b9f1c7a535) and we follow the procedure on their GitHub repository. 48 | 49 | 50 | ### 🔥 Training 51 | 1. Compile `./data/supervoxels/felzenszwalb_3d_cy.pyx` with cython (`python ./data/supervoxels/setup.py build_ext --inplace`) and run `./data/supervoxels/generate_supervoxels.py` 52 | 2. Download pre-trained ResNet-50 weights [vanilla version](https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth) or [deeplabv3 version](https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth) and put your checkpoints folder, then replace the absolute path in the code `./models/encoder.py`. 53 | 3. Run `./script/train_.sh`, for example: `./script/train_ct2mr.sh` 54 | 55 | 56 | ### 🙏 Inference 57 | Run `./script/test_.sh` 58 | 59 | 60 | ## 🥰 Acknowledgements 61 | Our code is built upon the works of [SSL-ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation), [ADNet](https://github.com/sha168/ADNet) and [QNet](https://github.com/ZJLAB-AMMI/Q-Net), we appreciate the authors for their excellent contributions! 62 | 63 | 64 | ## 📝 Citation 65 | If you use this code for your research or project, please consider citing our paper. Thanks!🥂: 66 | ``` 67 | @inproceedings{bo2025famnet, 68 | title={FAMNet: Frequency-aware Matching Network for Cross-domain Few-shot Medical Image Segmentation}, 69 | author={Bo, Yuntian and Zhu, Yazhou and Li, Lunbo and Zhang, Haofeng}, 70 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 71 | volume={39}, 72 | number={2}, 73 | pages={1889-1897}, 74 | year={2025}, 75 | DOI={10.1609/aaai.v39i2.32184} 76 | } 77 | ``` 78 | 79 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment configuration file 3 | Extended from config file from original PANet Repository 4 | """ 5 | import glob 6 | import itertools 7 | import os 8 | import sacred 9 | from sacred import Experiment 10 | from sacred.observers import FileStorageObserver 11 | from sacred.utils import apply_backspaces_and_linefeeds 12 | from utils import * 13 | from yacs.config import CfgNode as CN 14 | 15 | sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False 16 | sacred.SETTINGS.CAPTURE_MODE = 'no' 17 | 18 | ex = Experiment("CDFS") 19 | ex.captured_out_filter = apply_backspaces_and_linefeeds 20 | 21 | ###### Set up source folder ###### 22 | source_folders = ['.', './dataloaders', './models', './utils'] 23 | sources_to_save = list(itertools.chain.from_iterable( 24 | [glob.glob(f'{folder}/*.py') for folder in source_folders])) 25 | for source_file in sources_to_save: 26 | ex.add_source_file(source_file) 27 | 28 | 29 | @ex.config 30 | def cfg(): 31 | """Default configurations""" 32 | seed = 2021 33 | gpu_id = 0 34 | num_workers = 0 # 0 for debugging. 35 | mode = 'train' 36 | 37 | ## dataset 38 | dataset = 'MR' # i.e. abdominal MRI - 'CHAOST2'; cardiac MRI - CMR 39 | exclude_label = [1,2,3,4] # None, for not excluding test labels; Setting 1: None, Setting 2: True 40 | # 1 for Liver, 2 for RK, 3 for LK, 4 for Spleen in 'CHAOST2' 41 | if dataset == 'Cardiac': 42 | n_sv = 1000 43 | else: 44 | n_sv = 5000 45 | min_size = 200 46 | max_slices = 3 47 | use_gt = False # True - use ground truth as training label, False - use supervoxel as training label 48 | eval_fold = 0 # (0-4) for 5-fold cross-validation 49 | test_label = [1, 4] # for evaluation 50 | supp_idx = 0 # choose which case as the support set for evaluation, (0-4) for 'CHAOST2', (0-7) for 'CMR' 51 | n_part = 3 # for evaluation, i.e. 3 chunks 52 | 53 | ## training 54 | n_steps = 1000 55 | batch_size = 1 56 | n_shot = 1 57 | n_way = 1 58 | n_query = 1 59 | lr_step_gamma = 0.95 60 | bg_wt = 0.1 61 | t_loss_scaler = 0.0 62 | ignore_label = 255 63 | print_interval = 100 # raw=100 64 | save_snapshot_every = 1000 65 | max_iters_per_load = 1000 # epoch size, interval for reloading the dataset 66 | 67 | # Network 68 | # reload_model_path = '.../ADNet/runs/ADNet_train_CHAOST2_cv0/1/snapshots/1000.pth' 69 | reload_model_path = None 70 | 71 | optim_type = 'sgd' 72 | optim = { 73 | 'lr': 1e-3, 74 | 'momentum': 0.9, 75 | 'weight_decay': 0.0005, # 0.0005 76 | } 77 | 78 | exp_str = '_'.join( 79 | [mode] 80 | + [dataset, ] 81 | + [f'cv{eval_fold}']) 82 | 83 | path = { 84 | 'log_dir': './runs', 85 | 'ABDOMEN_MR': {'data_dir': './data/ABD/ABDOMEN_MR'}, 86 | 'ABDOMEN_CT': {'data_dir': './data/ABD/ABDOMEN_CT'}, 87 | 'CARDIAC_bssFP': {'data_dir': './data/Cardiac/bSSFP'}, 88 | 'CARDIAC_LGE': {'data_dir': './data/Cardiac/LGE'}, 89 | 'Prostate_UCLH': {'data_dir': './data/Prostate/UCLH'}, 90 | 'Prostate_NCI': {'data_dir': './data/Prostate/NCI'}, 91 | } 92 | 93 | # Settings of clip 94 | 95 | train_organ = [1, 6] # 1: Spleen 6: Liver 96 | test_organ = [2, 3] # 2: RK 3: LK 97 | # train_classname = {'SPLEEN', 'LIVER'} 98 | # test_classname = {'RIGHT_KIDNEY', 'LEFT_KIDNEY'} 99 | 100 | # backbone of clip model 101 | BACKBONE_NAME = 'RN50' # RN101, RN50x4, RN50x16, ViT-B/32, ViT-B/16 102 | N_CTX = 16 # number of context vectors 103 | CTX_INIT = "" # initialization words 104 | PREC = "fp16" # fp16, fp32, amp 105 | CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 106 | INPUT_SIZE = (224, 224) 107 | CSC = False # class-specific context 108 | INIT_WEIGHTS = "" 109 | OPTIM = CN() 110 | PROMPT_INIT = 'VISION' # RANDOM 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | @ex.config_hook 119 | def add_observer(config, command_name, logger): 120 | """A hook fucntion to add observer""" 121 | exp_name = f'{ex.path}_{config["exp_str"]}' 122 | observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name)) 123 | ex.observers.append(observer) 124 | return config 125 | -------------------------------------------------------------------------------- /data/ABD/ABDOMEN_CT/class_slice_index_gen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import SimpleITK as sitk 5 | import sys 6 | import json 7 | import niftiio as nio 8 | 9 | IMG_BNAME="./sabs_CT_normalized/image_*.nii.gz" 10 | SEG_BNAME="./sabs_CT_normalized/label_*.nii.gz" 11 | 12 | imgs = glob.glob(IMG_BNAME) 13 | segs = glob.glob(SEG_BNAME) 14 | imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) ) ] 15 | segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) ) ] 16 | 17 | classmap = {} 18 | LABEL_NAME = ["BG", "LIVER", "RK", "LK", "SPLEEN"] 19 | 20 | 21 | MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training 22 | 23 | fid = f'./sabs_CT_normalized/classmap_{MIN_TP}.json' # name of the output file. 24 | for _lb in LABEL_NAME: 25 | classmap[_lb] = {} 26 | for _sid in segs: 27 | pid = _sid.split("_")[-1].split(".nii.gz")[0] 28 | classmap[_lb][pid] = [] 29 | 30 | for seg in segs: 31 | pid = seg.split("_")[-1].split(".nii.gz")[0] 32 | lb_vol = nio.read_nii_bysitk(seg) 33 | n_slice = lb_vol.shape[0] 34 | for slc in range(n_slice): 35 | for cls in range(len(LABEL_NAME)): 36 | if cls in lb_vol[slc, ...]: 37 | if np.sum( lb_vol[slc, ...]) >= MIN_TP: 38 | classmap[LABEL_NAME[cls]][str(pid)].append(slc) 39 | print(f'pid {str(pid)} finished!') 40 | 41 | with open(fid, 'w') as fopen: 42 | json.dump(classmap, fopen) 43 | fopen.close() 44 | 45 | 46 | with open(fid, 'w') as fopen: 47 | json.dump(classmap, fopen) 48 | fopen.close() 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /data/ABD/ABDOMEN_CT/intensity_normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import SimpleITK as sitk 5 | 6 | import sys 7 | import niftiio as nio 8 | 9 | 10 | IMG_FOLDER="./data/SABS/img/" 11 | SEG_FOLDER="./data/SABS/label/" 12 | OUT_FOLDER="./tmp_normalized/" 13 | 14 | imgs = glob.glob(IMG_FOLDER + "/*.nii.gz") 15 | imgs = [ fid for fid in sorted(imgs) ] 16 | segs = [ fid for fid in sorted(glob.glob(SEG_FOLDER + "/*.nii.gz")) ] 17 | 18 | pids = [ pid.split("img0")[-1].split(".")[0] for pid in imgs] 19 | 20 | 21 | # helper function 22 | def copy_spacing_ori(src, dst): 23 | dst.SetSpacing(src.GetSpacing()) 24 | dst.SetOrigin(src.GetOrigin()) 25 | dst.SetDirection(src.GetDirection()) 26 | return dst 27 | 28 | import copy 29 | scan_dir = OUT_FOLDER 30 | LIR = -125 31 | HIR = 275 32 | os.makedirs(scan_dir, exist_ok = True) 33 | 34 | reindex = 0 35 | for img_fid, seg_fid, pid in zip(imgs, segs, pids): 36 | 37 | img_obj = sitk.ReadImage( img_fid ) 38 | seg_obj = sitk.ReadImage( seg_fid ) 39 | 40 | array = sitk.GetArrayFromImage(img_obj) 41 | 42 | array[array > HIR] = HIR 43 | array[array < LIR] = LIR 44 | 45 | array = (array - array.min()) / (array.max() - array.min()) * 255.0 46 | 47 | # then normalize this 48 | 49 | wined_img = sitk.GetImageFromArray(array) 50 | wined_img = copy_spacing_ori(img_obj, wined_img) 51 | 52 | out_img_fid = os.path.join( scan_dir, f'image_{str(reindex)}.nii.gz' ) 53 | out_lb_fid = os.path.join( scan_dir, f'label_{str(reindex)}.nii.gz' ) 54 | 55 | # then save 56 | sitk.WriteImage(wined_img, out_img_fid, True) 57 | sitk.WriteImage(seg_obj, out_lb_fid, True) 58 | print("{} has been save".format(out_img_fid)) 59 | print("{} has been save".format(out_lb_fid)) 60 | reindex += 1 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /data/ABD/ABDOMEN_CT/niftiio.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for datasets 3 | """ 4 | import numpy as np 5 | 6 | import numpy as np 7 | import SimpleITK as sitk 8 | 9 | 10 | def read_nii_bysitk(input_fid, peel_info = False): 11 | """ read nii to numpy through simpleitk 12 | peelinfo: taking direction, origin, spacing and metadata out 13 | """ 14 | img_obj = sitk.ReadImage(input_fid) 15 | img_np = sitk.GetArrayFromImage(img_obj) 16 | if peel_info: 17 | info_obj = { 18 | "spacing": img_obj.GetSpacing(), 19 | "origin": img_obj.GetOrigin(), 20 | "direction": img_obj.GetDirection(), 21 | "array_size": img_np.shape 22 | } 23 | return img_np, info_obj 24 | else: 25 | return img_np 26 | 27 | def convert_to_sitk(input_mat, peeled_info): 28 | """ 29 | write a numpy array to sitk image object with essential meta-data 30 | """ 31 | nii_obj = sitk.GetImageFromArray(input_mat) 32 | if peeled_info: 33 | nii_obj.SetSpacing( peeled_info["spacing"] ) 34 | nii_obj.SetOrigin( peeled_info["origin"] ) 35 | nii_obj.SetDirection(peeled_info["direction"] ) 36 | return nii_obj 37 | 38 | def np2itk(img, ref_obj): 39 | """ 40 | img: numpy array 41 | ref_obj: reference sitk object for copying information from 42 | """ 43 | itk_obj = sitk.GetImageFromArray(img) 44 | itk_obj.SetSpacing( ref_obj.GetSpacing() ) 45 | itk_obj.SetOrigin( ref_obj.GetOrigin() ) 46 | itk_obj.SetDirection( ref_obj.GetDirection() ) 47 | return itk_obj 48 | 49 | -------------------------------------------------------------------------------- /data/ABD/ABDOMEN_CT/resampling_and_roi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import SimpleITK as sitk 5 | import sys 6 | import niftiio as nio 7 | 8 | IMG_FOLDER = "./tmp_normalized/" 9 | SEG_FOLDER = IMG_FOLDER 10 | imgs = glob.glob(IMG_FOLDER + "/image_*.nii.gz") 11 | imgs = [ fid for fid in sorted(imgs) ] 12 | segs = [ fid for fid in sorted(glob.glob(SEG_FOLDER + "/label_*.nii.gz")) ] 13 | 14 | pids = [pid.split("_")[-1].split(".")[0] for pid in imgs] 15 | 16 | # helper functions copy pasted 17 | def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True): 18 | resample = sitk.ResampleImageFilter() 19 | resample.SetInterpolator(interpolator) 20 | resample.SetOutputDirection(mov_img_obj.GetDirection()) 21 | resample.SetOutputOrigin(mov_img_obj.GetOrigin()) 22 | mov_spacing = mov_img_obj.GetSpacing() 23 | 24 | resample.SetOutputSpacing(new_spacing) 25 | RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing) 26 | new_size = np.array(mov_img_obj.GetSize()) * RES_COE 27 | 28 | resample.SetSize( [int(sz+1) for sz in new_size] ) 29 | if logging: 30 | print("Spacing: {} -> {}".format(mov_spacing, new_spacing)) 31 | print("Size {} -> {}".format( mov_img_obj.GetSize(), new_size )) 32 | 33 | return resample.Execute(mov_img_obj) 34 | 35 | def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True): 36 | src_mat = sitk.GetArrayFromImage(mov_lb_obj) 37 | lbvs = np.unique(src_mat) 38 | if logging: 39 | print("Label values: {}".format(lbvs)) 40 | for idx, lbv in enumerate(lbvs): 41 | _src_curr_mat = np.float32(src_mat == lbv) 42 | _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat) 43 | _src_curr_obj.CopyInformation(mov_lb_obj) 44 | _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging ) 45 | _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv 46 | if idx == 0: 47 | out_vol = _tar_curr_mat 48 | else: 49 | out_vol[_tar_curr_mat == lbv] = lbv 50 | out_obj = sitk.GetImageFromArray(out_vol) 51 | out_obj.SetSpacing( _tar_curr_obj.GetSpacing() ) 52 | if ref_img != None: 53 | out_obj.CopyInformation(ref_img) 54 | return out_obj 55 | 56 | ## Then crop ROI 57 | def get_label_center(label): 58 | nnz = np.sum(label > 1e-5) 59 | return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz)) 60 | 61 | def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True): 62 | """ crop a 3d matrix given the index of the new volume on the original volume 63 | Args: 64 | refernce_ctr_idx: the center of the new volume on the original volume (in indices) 65 | only_2d: only do cropping on first two dimensions 66 | """ 67 | _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case 68 | if only_2d: 69 | assert len(crop_size) == 2, "Actual len {}".format(len(crop_size)) 70 | assert len(referece_ctr_idx) == 2, "Actual len {}".format(len(referece_ctr_idx)) 71 | _expand_cropsize.append(ori_vol.shape[-1]) 72 | 73 | image_patch = np.ones(tuple(_expand_cropsize)) * padval 74 | 75 | half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] ) 76 | _min_idx = [0,0,0] 77 | _max_idx = list(ori_vol.shape) 78 | 79 | # bias of actual cropped size to the beginning and the end of this volume 80 | _bias_start = [0,0,0] 81 | _bias_end = [0,0,0] 82 | 83 | for dim,hsize in enumerate(half_size): 84 | if dim == 2 and only_2d: 85 | break 86 | 87 | _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]]) 88 | _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]]) 89 | 90 | _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim] 91 | _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim] 92 | 93 | if only_2d: 94 | image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \ 95 | half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \ 96 | ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \ 97 | referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ] 98 | 99 | image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ] 100 | # then goes back to original volume 101 | else: 102 | image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \ 103 | half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \ 104 | half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \ 105 | ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \ 106 | referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \ 107 | referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ] 108 | 109 | image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ] 110 | return image_patch 111 | 112 | 113 | 114 | def copy_spacing_ori(src, dst): 115 | dst.SetSpacing(src.GetSpacing()) 116 | dst.SetOrigin(src.GetOrigin()) 117 | dst.SetDirection(src.GetDirection()) 118 | return dst 119 | 120 | import copy 121 | OUT_FOLDER = "./sabs_CT_normalized" 122 | scan_dir = OUT_FOLDER 123 | os.makedirs(scan_dir, exist_ok = True) 124 | BD_BIAS = 32 # cut irrelavent empty boundary to make roi stands out 125 | 126 | SPA_FAC = (512 - 2 * BD_BIAS) / 256 # spacing factor 127 | 128 | for img_fid, seg_fid, pid in zip(imgs, segs, pids): 129 | 130 | lb_n = nio.read_nii_bysitk(seg_fid) 131 | 132 | img_obj = sitk.ReadImage( img_fid ) 133 | seg_obj = sitk.ReadImage( seg_fid ) 134 | 135 | ## image 136 | array = sitk.GetArrayFromImage(img_obj) 137 | # cropping 138 | array = array[:, BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS] 139 | cropped_img_o = sitk.GetImageFromArray(array) 140 | cropped_img_o = copy_spacing_ori(img_obj, cropped_img_o) 141 | 142 | # resampling 143 | img_spa_ori = img_obj.GetSpacing() 144 | res_img_o = resample_by_res(cropped_img_o, [img_spa_ori[0] * SPA_FAC, img_spa_ori[1] * SPA_FAC, img_spa_ori[-1]], interpolator = sitk.sitkLinear, 145 | logging = True) 146 | 147 | ## label 148 | lb_arr = sitk.GetArrayFromImage(seg_obj) 149 | 150 | # cropping 151 | lb_arr = lb_arr[:,BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS] 152 | cropped_lb_o = sitk.GetImageFromArray(lb_arr) 153 | cropped_lb_o = copy_spacing_ori(seg_obj, cropped_lb_o) 154 | 155 | lb_spa_ori = seg_obj.GetSpacing() 156 | 157 | # resampling 158 | res_lb_o = resample_lb_by_res(cropped_lb_o, [lb_spa_ori[0] * SPA_FAC, lb_spa_ori[1] * SPA_FAC, lb_spa_ori[-1] ], interpolator = sitk.sitkLinear, 159 | ref_img = res_img_o, logging = True) 160 | 161 | 162 | out_img_fid = os.path.join( scan_dir, f'image_{pid}.nii.gz' ) 163 | out_lb_fid = os.path.join( scan_dir, f'label_{pid}.nii.gz' ) 164 | 165 | # then save 166 | sitk.WriteImage(res_img_o, out_img_fid, True) 167 | sitk.WriteImage(res_lb_o, out_lb_fid, True) 168 | print("{} has been saved".format(out_img_fid)) 169 | print("{} has been saved".format(out_lb_fid)) 170 | -------------------------------------------------------------------------------- /data/ABD/ABDOMEN_MR/class_slice_index_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate class-slice indexing table for experiments\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is for experiment setting up for simulating few-shot image segmentation scenarios\n", 13 | "\n", 14 | "Input: pre-processed images and their ground-truth labels\n", 15 | "\n", 16 | "Output: a `json` file for class-slice indexing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import numpy as np\n", 39 | "import os\n", 40 | "import glob\n", 41 | "import SimpleITK as sitk\n", 42 | "import sys\n", 43 | "import json\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 9, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "IMG_BNAME=\"./chaos_MR_T2_normalized/image_*.nii.gz\"\n", 55 | "SEG_BNAME=\"./chaos_MR_T2_normalized/label_*.nii.gz\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "imgs = glob.glob(IMG_BNAME)\n", 65 | "segs = glob.glob(SEG_BNAME)\n", 66 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 67 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "['./chaos_MR_T2_normalized/image_1.nii.gz',\n", 79 | " './chaos_MR_T2_normalized/image_2.nii.gz',\n", 80 | " './chaos_MR_T2_normalized/image_3.nii.gz',\n", 81 | " './chaos_MR_T2_normalized/image_5.nii.gz',\n", 82 | " './chaos_MR_T2_normalized/image_8.nii.gz',\n", 83 | " './chaos_MR_T2_normalized/image_10.nii.gz',\n", 84 | " './chaos_MR_T2_normalized/image_13.nii.gz',\n", 85 | " './chaos_MR_T2_normalized/image_15.nii.gz',\n", 86 | " './chaos_MR_T2_normalized/image_19.nii.gz',\n", 87 | " './chaos_MR_T2_normalized/image_20.nii.gz',\n", 88 | " './chaos_MR_T2_normalized/image_21.nii.gz',\n", 89 | " './chaos_MR_T2_normalized/image_22.nii.gz',\n", 90 | " './chaos_MR_T2_normalized/image_31.nii.gz',\n", 91 | " './chaos_MR_T2_normalized/image_32.nii.gz',\n", 92 | " './chaos_MR_T2_normalized/image_33.nii.gz',\n", 93 | " './chaos_MR_T2_normalized/image_34.nii.gz',\n", 94 | " './chaos_MR_T2_normalized/image_36.nii.gz',\n", 95 | " './chaos_MR_T2_normalized/image_37.nii.gz',\n", 96 | " './chaos_MR_T2_normalized/image_38.nii.gz',\n", 97 | " './chaos_MR_T2_normalized/image_39.nii.gz']" 98 | ] 99 | }, 100 | "execution_count": 11, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "imgs" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 12, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "['./chaos_MR_T2_normalized/label_1.nii.gz',\n", 118 | " './chaos_MR_T2_normalized/label_2.nii.gz',\n", 119 | " './chaos_MR_T2_normalized/label_3.nii.gz',\n", 120 | " './chaos_MR_T2_normalized/label_5.nii.gz',\n", 121 | " './chaos_MR_T2_normalized/label_8.nii.gz',\n", 122 | " './chaos_MR_T2_normalized/label_10.nii.gz',\n", 123 | " './chaos_MR_T2_normalized/label_13.nii.gz',\n", 124 | " './chaos_MR_T2_normalized/label_15.nii.gz',\n", 125 | " './chaos_MR_T2_normalized/label_19.nii.gz',\n", 126 | " './chaos_MR_T2_normalized/label_20.nii.gz',\n", 127 | " './chaos_MR_T2_normalized/label_21.nii.gz',\n", 128 | " './chaos_MR_T2_normalized/label_22.nii.gz',\n", 129 | " './chaos_MR_T2_normalized/label_31.nii.gz',\n", 130 | " './chaos_MR_T2_normalized/label_32.nii.gz',\n", 131 | " './chaos_MR_T2_normalized/label_33.nii.gz',\n", 132 | " './chaos_MR_T2_normalized/label_34.nii.gz',\n", 133 | " './chaos_MR_T2_normalized/label_36.nii.gz',\n", 134 | " './chaos_MR_T2_normalized/label_37.nii.gz',\n", 135 | " './chaos_MR_T2_normalized/label_38.nii.gz',\n", 136 | " './chaos_MR_T2_normalized/label_39.nii.gz']" 137 | ] 138 | }, 139 | "execution_count": 12, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "segs" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 13, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "pid 1 finished!\n", 158 | "pid 2 finished!\n", 159 | "pid 3 finished!\n", 160 | "pid 5 finished!\n", 161 | "pid 8 finished!\n", 162 | "pid 10 finished!\n", 163 | "pid 13 finished!\n", 164 | "pid 15 finished!\n", 165 | "pid 19 finished!\n", 166 | "pid 20 finished!\n", 167 | "pid 21 finished!\n", 168 | "pid 22 finished!\n", 169 | "pid 31 finished!\n", 170 | "pid 32 finished!\n", 171 | "pid 33 finished!\n", 172 | "pid 34 finished!\n", 173 | "pid 36 finished!\n", 174 | "pid 37 finished!\n", 175 | "pid 38 finished!\n", 176 | "pid 39 finished!\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "classmap = {}\n", 182 | "LABEL_NAME = [\"BG\", \"LIVER\", \"RK\", \"LK\", \"SPLEEN\"] \n", 183 | "\n", 184 | "\n", 185 | "MIN_TP = 100 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", 186 | "\n", 187 | "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n", 188 | "for _lb in LABEL_NAME:\n", 189 | " classmap[_lb] = {}\n", 190 | " for _sid in segs:\n", 191 | " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 192 | " classmap[_lb][pid] = []\n", 193 | "\n", 194 | "for seg in segs:\n", 195 | " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 196 | " lb_vol = nio.read_nii_bysitk(seg)\n", 197 | " n_slice = lb_vol.shape[0]\n", 198 | " for slc in range(n_slice):\n", 199 | " for cls in range(len(LABEL_NAME)):\n", 200 | " if cls in lb_vol[slc, ...]:\n", 201 | " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", 202 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 203 | " print(f'pid {str(pid)} finished!')\n", 204 | " \n", 205 | "with open(fid, 'w') as fopen:\n", 206 | " json.dump(classmap, fopen)\n", 207 | " fopen.close() \n", 208 | " " 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 14, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "with open(fid, 'w') as fopen:\n", 218 | " json.dump(classmap, fopen)\n", 219 | " fopen.close()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3 (ipykernel)", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.8.12" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /data/ABD/ABDOMEN_MR/dcm_img_to_nii.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | # Convert dicom-like images to nii files in 3D 3 | # This is the first step for image pre-processing 4 | 5 | # Feed path to the downloaded data here 6 | DATAPATH=./MR # please put chaos dataset training fold here which contains ground truth 7 | 8 | # Feed path to the output folder here 9 | OUTPATH=./niis 10 | 11 | if [ ! -d $OUTPATH/T2SPIR ] 12 | then 13 | mkdir $OUTPATH/T2SPIR 14 | fi 15 | 16 | for sid in $(ls "$DATAPATH") 17 | do 18 | dcm2nii -o "$DATAPATH/$sid/T2SPIR" "$DATAPATH/$sid/T2SPIR/DICOM_anon"; 19 | find "$DATAPATH/$sid/T2SPIR" -name "*.nii.gz" -exec mv {} "$OUTPATH/T2SPIR/image_$sid.nii.gz" \; 20 | done; 21 | 22 | 23 | -------------------------------------------------------------------------------- /data/ABD/ABDOMEN_MR/image_normalize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Image Pre-processing\n", 8 | "\n", 9 | "### Overview\n", 10 | "\n", 11 | "This is the second step for data preparation\n", 12 | "\n", 13 | "Input: `.nii`-like images and labels converted from `dicom`s/ `png` files\n", 14 | "\n", 15 | "Output: image-labels with unified size (axial), voxel-spacing, and alleviated off-resonance effects" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "%reset\n", 33 | "%load_ext autoreload\n", 34 | "%autoreload 2\n", 35 | "import numpy as np\n", 36 | "import os\n", 37 | "import glob\n", 38 | "import SimpleITK as sitk\n", 39 | "\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import copy\n", 42 | "import sys\n", 43 | "sys.path.insert(0, '../../dataloaders/')\n", 44 | "import niftiio as nio" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "IMG_FOLDER = \"./niis/T2SPIR\" #, path of nii-like images from step 1\n", 54 | "OUT_FOLDER=\"./chaos_MR_T2_normalized/\" # output directory" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "**0. Find images and their ground-truth segmentations**" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "imgs = glob.glob(IMG_FOLDER + f'/image_*.nii.gz')\n", 71 | "imgs = [ fid for fid in sorted(imgs) ]\n", 72 | "segs = [ fid for fid in sorted(glob.glob(IMG_FOLDER + f'/label_*.nii.gz')) ]\n", 73 | "\n", 74 | "pids = [pid.split(\"_\")[-1].split(\".\")[0] for pid in imgs]" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['./niis/T2SPIR/image_1.nii.gz',\n", 86 | " './niis/T2SPIR/image_10.nii.gz',\n", 87 | " './niis/T2SPIR/image_13.nii.gz',\n", 88 | " './niis/T2SPIR/image_15.nii.gz',\n", 89 | " './niis/T2SPIR/image_19.nii.gz',\n", 90 | " './niis/T2SPIR/image_2.nii.gz',\n", 91 | " './niis/T2SPIR/image_20.nii.gz',\n", 92 | " './niis/T2SPIR/image_21.nii.gz',\n", 93 | " './niis/T2SPIR/image_22.nii.gz',\n", 94 | " './niis/T2SPIR/image_3.nii.gz',\n", 95 | " './niis/T2SPIR/image_31.nii.gz',\n", 96 | " './niis/T2SPIR/image_32.nii.gz',\n", 97 | " './niis/T2SPIR/image_33.nii.gz',\n", 98 | " './niis/T2SPIR/image_34.nii.gz',\n", 99 | " './niis/T2SPIR/image_36.nii.gz',\n", 100 | " './niis/T2SPIR/image_37.nii.gz',\n", 101 | " './niis/T2SPIR/image_38.nii.gz',\n", 102 | " './niis/T2SPIR/image_39.nii.gz',\n", 103 | " './niis/T2SPIR/image_5.nii.gz',\n", 104 | " './niis/T2SPIR/image_8.nii.gz']" 105 | ] 106 | }, 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "imgs" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "['./niis/T2SPIR/label_1.nii.gz',\n", 125 | " './niis/T2SPIR/label_10.nii.gz',\n", 126 | " './niis/T2SPIR/label_13.nii.gz',\n", 127 | " './niis/T2SPIR/label_15.nii.gz',\n", 128 | " './niis/T2SPIR/label_19.nii.gz',\n", 129 | " './niis/T2SPIR/label_2.nii.gz',\n", 130 | " './niis/T2SPIR/label_20.nii.gz',\n", 131 | " './niis/T2SPIR/label_21.nii.gz',\n", 132 | " './niis/T2SPIR/label_22.nii.gz',\n", 133 | " './niis/T2SPIR/label_3.nii.gz',\n", 134 | " './niis/T2SPIR/label_31.nii.gz',\n", 135 | " './niis/T2SPIR/label_32.nii.gz',\n", 136 | " './niis/T2SPIR/label_33.nii.gz',\n", 137 | " './niis/T2SPIR/label_34.nii.gz',\n", 138 | " './niis/T2SPIR/label_36.nii.gz',\n", 139 | " './niis/T2SPIR/label_37.nii.gz',\n", 140 | " './niis/T2SPIR/label_38.nii.gz',\n", 141 | " './niis/T2SPIR/label_39.nii.gz',\n", 142 | " './niis/T2SPIR/label_5.nii.gz',\n", 143 | " './niis/T2SPIR/label_8.nii.gz']" 144 | ] 145 | }, 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "segs" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "**1. Unify image sizes and roi**\n", 160 | "\n", 161 | "a. Cut bright end of histogram to alleviate off-resonance issue\n", 162 | "\n", 163 | "b. Resample images to unified spacing\n", 164 | "\n", 165 | "c. Crop ROIs out to unify image sizes" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 6, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# some helper functions\n", 175 | "def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True):\n", 176 | " resample = sitk.ResampleImageFilter()\n", 177 | " resample.SetInterpolator(interpolator)\n", 178 | " resample.SetOutputDirection(mov_img_obj.GetDirection())\n", 179 | " resample.SetOutputOrigin(mov_img_obj.GetOrigin())\n", 180 | " mov_spacing = mov_img_obj.GetSpacing()\n", 181 | "\n", 182 | " resample.SetOutputSpacing(new_spacing)\n", 183 | " RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing)\n", 184 | " new_size = np.array(mov_img_obj.GetSize()) * RES_COE \n", 185 | "\n", 186 | " resample.SetSize( [int(sz+1) for sz in new_size] )\n", 187 | " if logging:\n", 188 | " print(\"Spacing: {} -> {}\".format(mov_spacing, new_spacing))\n", 189 | " print(\"Size {} -> {}\".format( mov_img_obj.GetSize(), new_size ))\n", 190 | "\n", 191 | " return resample.Execute(mov_img_obj)\n", 192 | "\n", 193 | "def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True):\n", 194 | " src_mat = sitk.GetArrayFromImage(mov_lb_obj)\n", 195 | " lbvs = np.unique(src_mat)\n", 196 | " if logging:\n", 197 | " print(\"Label values: {}\".format(lbvs))\n", 198 | " for idx, lbv in enumerate(lbvs):\n", 199 | " _src_curr_mat = np.float32(src_mat == lbv) \n", 200 | " _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat)\n", 201 | " _src_curr_obj.CopyInformation(mov_lb_obj)\n", 202 | " _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging )\n", 203 | " _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv\n", 204 | " if idx == 0:\n", 205 | " out_vol = _tar_curr_mat\n", 206 | " else:\n", 207 | " out_vol[_tar_curr_mat == lbv] = lbv\n", 208 | " out_obj = sitk.GetImageFromArray(out_vol)\n", 209 | " out_obj.SetSpacing( _tar_curr_obj.GetSpacing() )\n", 210 | " if ref_img != None:\n", 211 | " out_obj.CopyInformation(ref_img)\n", 212 | " return out_obj\n", 213 | " \n", 214 | "def get_label_center(label):\n", 215 | " nnz = np.sum(label > 1e-5)\n", 216 | " return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz))\n", 217 | "\n", 218 | "def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True):\n", 219 | " \"\"\" crop a 3d matrix given the index of the new volume on the original volume\n", 220 | " Args:\n", 221 | " refernce_ctr_idx: the center of the new volume on the original volume (in indices)\n", 222 | " only_2d: only do cropping on first two dimensions\n", 223 | " \"\"\"\n", 224 | " _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case\n", 225 | " if only_2d:\n", 226 | " assert len(crop_size) == 2, \"Actual len {}\".format(len(crop_size))\n", 227 | " assert len(referece_ctr_idx) == 2, \"Actual len {}\".format(len(referece_ctr_idx))\n", 228 | " _expand_cropsize.append(ori_vol.shape[-1])\n", 229 | " \n", 230 | " image_patch = np.ones(tuple(_expand_cropsize)) * padval\n", 231 | "\n", 232 | " half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] )\n", 233 | " _min_idx = [0,0,0]\n", 234 | " _max_idx = list(ori_vol.shape)\n", 235 | "\n", 236 | " # bias of actual cropped size to the beginning and the end of this volume\n", 237 | " _bias_start = [0,0,0]\n", 238 | " _bias_end = [0,0,0]\n", 239 | "\n", 240 | " for dim,hsize in enumerate(half_size):\n", 241 | " if dim == 2 and only_2d:\n", 242 | " break\n", 243 | "\n", 244 | " _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]])\n", 245 | " _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]])\n", 246 | "\n", 247 | " _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim]\n", 248 | " _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim]\n", 249 | " \n", 250 | " if only_2d:\n", 251 | " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", 252 | " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \\\n", 253 | " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", 254 | " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ]\n", 255 | "\n", 256 | " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ]\n", 257 | " # then goes back to original volume\n", 258 | " else:\n", 259 | " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", 260 | " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \\\n", 261 | " half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \\\n", 262 | " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", 263 | " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \\\n", 264 | " referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ]\n", 265 | "\n", 266 | " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ]\n", 267 | " return image_patch\n", 268 | "\n", 269 | "def copy_spacing_ori(src, dst):\n", 270 | " dst.SetSpacing(src.GetSpacing())\n", 271 | " dst.SetOrigin(src.GetOrigin())\n", 272 | " dst.SetDirection(src.GetDirection())\n", 273 | " return dst\n", 274 | "\n", 275 | "s2n = sitk.GetArrayFromImage" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 7, 281 | "metadata": { 282 | "scrolled": false 283 | }, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Failed to create the output folder.\n" 290 | ] 291 | }, 292 | { 293 | "ename": "NameError", 294 | "evalue": "name 'copy_spacing_ori' is not defined", 295 | "output_type": "error", 296 | "traceback": [ 297 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 298 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 299 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mhis_img_o\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msitk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGetImageFromArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mhis_img_o\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopy_spacing_ori\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhis_img_o\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;31m# resampling\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 300 | "\u001b[0;31mNameError\u001b[0m: name 'copy_spacing_ori' is not defined" 301 | ] 302 | } 303 | ], 304 | "source": [ 305 | "import copy\n", 306 | "try:\n", 307 | " os.mkdir(OUT_FOLDER)\n", 308 | "except:\n", 309 | " print(\"Failed to create the output folder.\")\n", 310 | " \n", 311 | "HIST_CUT_TOP = 0.5 # cut top 0.5% of intensity historgam to alleviate off-resonance effect\n", 312 | "\n", 313 | "NEW_SPA = [1.25, 1.25, 7.70] # unified voxel spacing\n", 314 | "\n", 315 | "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n", 316 | "\n", 317 | " lb_n = nio.read_nii_bysitk(seg_fid)\n", 318 | " resample_flg = True\n", 319 | "\n", 320 | " img_obj = sitk.ReadImage( img_fid )\n", 321 | " seg_obj = sitk.ReadImage( seg_fid )\n", 322 | "\n", 323 | " array = sitk.GetArrayFromImage(img_obj)\n", 324 | "\n", 325 | " # cut histogram\n", 326 | " hir = float(np.percentile(array, 100.0 - HIST_CUT_TOP))\n", 327 | " array[array > hir] = hir\n", 328 | "\n", 329 | " his_img_o = sitk.GetImageFromArray(array)\n", 330 | " his_img_o = copy_spacing_ori(img_obj, his_img_o)\n", 331 | "\n", 332 | " # resampling\n", 333 | " img_spa_ori = img_obj.GetSpacing()\n", 334 | " res_img_o = resample_by_res(his_img_o, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2]],\n", 335 | " interpolator = sitk.sitkLinear, logging = True)\n", 336 | "\n", 337 | "\n", 338 | "\n", 339 | " ## label\n", 340 | " lb_arr = sitk.GetArrayFromImage(seg_obj)\n", 341 | "\n", 342 | " # resampling\n", 343 | " res_lb_o = resample_lb_by_res(seg_obj, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2] ], interpolator = sitk.sitkLinear,\n", 344 | " ref_img = None, logging = True)\n", 345 | "\n", 346 | "\n", 347 | " # crop out rois\n", 348 | " res_img_a = s2n(res_img_o)\n", 349 | "\n", 350 | " crop_img_a = image_crop(res_img_a.transpose(1,2,0), [256, 256],\n", 351 | " referece_ctr_idx = [res_img_a.shape[1] // 2, res_img_a.shape[2] //2],\n", 352 | " padval = res_img_a.min(), only_2d = True).transpose(2,0,1)\n", 353 | "\n", 354 | " out_img_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_img_a))\n", 355 | "\n", 356 | " res_lb_a = s2n(res_lb_o)\n", 357 | "\n", 358 | " crop_lb_a = image_crop(res_lb_a.transpose(1,2,0), [256, 256],\n", 359 | " referece_ctr_idx = [res_lb_a.shape[1] // 2, res_lb_a.shape[2] //2],\n", 360 | " padval = 0, only_2d = True).transpose(2,0,1)\n", 361 | "\n", 362 | " out_lb_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_lb_a))\n", 363 | "\n", 364 | "\n", 365 | " out_img_fid = os.path.join( OUT_FOLDER, f'image_{pid}.nii.gz' )\n", 366 | " out_lb_fid = os.path.join( OUT_FOLDER, f'label_{pid}.nii.gz' ) \n", 367 | "\n", 368 | " # then save pre-processed images\n", 369 | " sitk.WriteImage(out_img_obj, out_img_fid, True) \n", 370 | " sitk.WriteImage(out_lb_obj, out_lb_fid, True) \n", 371 | " print(\"{} has been saved\".format(out_img_fid))" 372 | ] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "Python 3", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.6.0" 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 2 396 | } 397 | -------------------------------------------------------------------------------- /data/ABD/ABDOMEN_MR/png_gth_to_nii.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Converting labels from png to nii file\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is the first step for data preparation\n", 13 | "\n", 14 | "Input: ground truth labels in `.png` format\n", 15 | "\n", 16 | "Output: labels in `.nii` format, indexed by patient id" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 13, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import os\n", 39 | "import glob\n", 40 | "\n", 41 | "import numpy as np\n", 42 | "import PIL\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import SimpleITK as sitk\n", 45 | "import sys\n", 46 | "sys.path.insert(0, '../../dataloaders/')\n", 47 | "import niftiio as nio" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 14, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "example = \"./MR/1/T2SPIR/Ground/IMG-0002-00001.png\" # example of ground-truth file name. " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 15, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "### search for scan ids\n", 73 | "ids = os.listdir(\"./MR/\")\n", 74 | "OUT_DIR = './niis/T2SPIR/'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 16, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['37',\n", 86 | " '3',\n", 87 | " '15',\n", 88 | " '34',\n", 89 | " '33',\n", 90 | " '39',\n", 91 | " '20',\n", 92 | " '10',\n", 93 | " '22',\n", 94 | " '8',\n", 95 | " '31',\n", 96 | " '2',\n", 97 | " '36',\n", 98 | " '5',\n", 99 | " '13',\n", 100 | " '19',\n", 101 | " '21',\n", 102 | " '1',\n", 103 | " '38',\n", 104 | " '32']" 105 | ] 106 | }, 107 | "execution_count": 16, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "ids" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 17, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "image with id 37 has been saved!\n", 126 | "image with id 3 has been saved!\n", 127 | "image with id 15 has been saved!\n", 128 | "image with id 34 has been saved!\n", 129 | "image with id 33 has been saved!\n", 130 | "image with id 39 has been saved!\n", 131 | "image with id 20 has been saved!\n", 132 | "image with id 10 has been saved!\n", 133 | "image with id 22 has been saved!\n", 134 | "image with id 8 has been saved!\n", 135 | "image with id 31 has been saved!\n", 136 | "image with id 2 has been saved!\n", 137 | "image with id 36 has been saved!\n", 138 | "image with id 5 has been saved!\n", 139 | "image with id 13 has been saved!\n", 140 | "image with id 19 has been saved!\n", 141 | "image with id 21 has been saved!\n", 142 | "image with id 1 has been saved!\n", 143 | "image with id 38 has been saved!\n", 144 | "image with id 32 has been saved!\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "#### Write them to nii files for the ease of loading in future\n", 150 | "for curr_id in ids:\n", 151 | " pngs = glob.glob(f'./MR/{curr_id}/T2SPIR/Ground/*.png')\n", 152 | " pngs = sorted(pngs, key = lambda x: int(os.path.basename(x).split(\"-\")[-1].split(\".png\")[0]))\n", 153 | " buffer = []\n", 154 | "\n", 155 | " for fid in pngs:\n", 156 | " buffer.append(PIL.Image.open(fid))\n", 157 | "\n", 158 | " vol = np.stack(buffer, axis = 0)\n", 159 | " # flip correction\n", 160 | " vol = np.flip(vol, axis = 1).copy()\n", 161 | " # remap values\n", 162 | " for new_val, old_val in enumerate(sorted(np.unique(vol))):\n", 163 | " vol[vol == old_val] = new_val\n", 164 | "\n", 165 | " # get reference \n", 166 | " ref_img = f'./niis/T2SPIR/image_{curr_id}.nii.gz'\n", 167 | " img_o = sitk.ReadImage(ref_img)\n", 168 | " vol_o = nio.np2itk(img=vol, ref_obj=img_o)\n", 169 | " sitk.WriteImage(vol_o, f'{OUT_DIR}/label_{curr_id}.nii.gz')\n", 170 | " print(f'image with id {curr_id} has been saved!')\n", 171 | "\n", 172 | " " 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.0" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /data/Cardiac/LGE/class_slice_index_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate class-slice indexing table for experiments\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is for experiment setting up for simulating few-shot image segmentation scenarios\n", 13 | "\n", 14 | "Input: pre-processed images and their ground-truth labels\n", 15 | "\n", 16 | "Output: a `json` file for class-slice indexing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import numpy as np\n", 39 | "import os\n", 40 | "import glob\n", 41 | "import SimpleITK as sitk\n", 42 | "import sys\n", 43 | "import json\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 9, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "IMG_BNAME=\"./cmr_MR_normalized/image_*.nii.gz\"\n", 55 | "SEG_BNAME=\"./cmr_MR_normalized/label_*.nii.gz\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "imgs = glob.glob(IMG_BNAME)\n", 65 | "segs = glob.glob(SEG_BNAME)\n", 66 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 67 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "['./cmr_MR_normalized/image_1.nii.gz',\n", 79 | " './cmr_MR_normalized/image_2.nii.gz',\n", 80 | " './cmr_MR_normalized/image_3.nii.gz',\n", 81 | " './cmr_MR_normalized/image_4.nii.gz',\n", 82 | " './cmr_MR_normalized/image_5.nii.gz',\n", 83 | " './cmr_MR_normalized/image_6.nii.gz',\n", 84 | " './cmr_MR_normalized/image_7.nii.gz',\n", 85 | " './cmr_MR_normalized/image_8.nii.gz',\n", 86 | " './cmr_MR_normalized/image_9.nii.gz',\n", 87 | " './cmr_MR_normalized/image_10.nii.gz',\n", 88 | " './cmr_MR_normalized/image_11.nii.gz',\n", 89 | " './cmr_MR_normalized/image_12.nii.gz',\n", 90 | " './cmr_MR_normalized/image_13.nii.gz',\n", 91 | " './cmr_MR_normalized/image_14.nii.gz',\n", 92 | " './cmr_MR_normalized/image_15.nii.gz',\n", 93 | " './cmr_MR_normalized/image_16.nii.gz',\n", 94 | " './cmr_MR_normalized/image_17.nii.gz',\n", 95 | " './cmr_MR_normalized/image_18.nii.gz',\n", 96 | " './cmr_MR_normalized/image_19.nii.gz',\n", 97 | " './cmr_MR_normalized/image_20.nii.gz',\n", 98 | " './cmr_MR_normalized/image_21.nii.gz',\n", 99 | " './cmr_MR_normalized/image_22.nii.gz',\n", 100 | " './cmr_MR_normalized/image_23.nii.gz',\n", 101 | " './cmr_MR_normalized/image_24.nii.gz',\n", 102 | " './cmr_MR_normalized/image_25.nii.gz',\n", 103 | " './cmr_MR_normalized/image_26.nii.gz',\n", 104 | " './cmr_MR_normalized/image_27.nii.gz',\n", 105 | " './cmr_MR_normalized/image_28.nii.gz',\n", 106 | " './cmr_MR_normalized/image_29.nii.gz',\n", 107 | " './cmr_MR_normalized/image_30.nii.gz',\n", 108 | " './cmr_MR_normalized/image_31.nii.gz',\n", 109 | " './cmr_MR_normalized/image_32.nii.gz',\n", 110 | " './cmr_MR_normalized/image_33.nii.gz',\n", 111 | " './cmr_MR_normalized/image_34.nii.gz',\n", 112 | " './cmr_MR_normalized/image_35.nii.gz']" 113 | ] 114 | }, 115 | "execution_count": 11, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "imgs" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 12, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "['./cmr_MR_normalized/label_1.nii.gz',\n", 133 | " './cmr_MR_normalized/label_2.nii.gz',\n", 134 | " './cmr_MR_normalized/label_3.nii.gz',\n", 135 | " './cmr_MR_normalized/label_4.nii.gz',\n", 136 | " './cmr_MR_normalized/label_5.nii.gz',\n", 137 | " './cmr_MR_normalized/label_6.nii.gz',\n", 138 | " './cmr_MR_normalized/label_7.nii.gz',\n", 139 | " './cmr_MR_normalized/label_8.nii.gz',\n", 140 | " './cmr_MR_normalized/label_9.nii.gz',\n", 141 | " './cmr_MR_normalized/label_10.nii.gz',\n", 142 | " './cmr_MR_normalized/label_11.nii.gz',\n", 143 | " './cmr_MR_normalized/label_12.nii.gz',\n", 144 | " './cmr_MR_normalized/label_13.nii.gz',\n", 145 | " './cmr_MR_normalized/label_14.nii.gz',\n", 146 | " './cmr_MR_normalized/label_15.nii.gz',\n", 147 | " './cmr_MR_normalized/label_16.nii.gz',\n", 148 | " './cmr_MR_normalized/label_17.nii.gz',\n", 149 | " './cmr_MR_normalized/label_18.nii.gz',\n", 150 | " './cmr_MR_normalized/label_19.nii.gz',\n", 151 | " './cmr_MR_normalized/label_20.nii.gz',\n", 152 | " './cmr_MR_normalized/label_21.nii.gz',\n", 153 | " './cmr_MR_normalized/label_22.nii.gz',\n", 154 | " './cmr_MR_normalized/label_23.nii.gz',\n", 155 | " './cmr_MR_normalized/label_24.nii.gz',\n", 156 | " './cmr_MR_normalized/label_25.nii.gz',\n", 157 | " './cmr_MR_normalized/label_26.nii.gz',\n", 158 | " './cmr_MR_normalized/label_27.nii.gz',\n", 159 | " './cmr_MR_normalized/label_28.nii.gz',\n", 160 | " './cmr_MR_normalized/label_29.nii.gz',\n", 161 | " './cmr_MR_normalized/label_30.nii.gz',\n", 162 | " './cmr_MR_normalized/label_31.nii.gz',\n", 163 | " './cmr_MR_normalized/label_32.nii.gz',\n", 164 | " './cmr_MR_normalized/label_33.nii.gz',\n", 165 | " './cmr_MR_normalized/label_34.nii.gz',\n", 166 | " './cmr_MR_normalized/label_35.nii.gz']" 167 | ] 168 | }, 169 | "execution_count": 12, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "segs" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 13, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "pid 1 finished!\n", 188 | "pid 2 finished!\n", 189 | "pid 3 finished!\n", 190 | "pid 4 finished!\n", 191 | "pid 5 finished!\n", 192 | "pid 6 finished!\n", 193 | "pid 7 finished!\n", 194 | "pid 8 finished!\n", 195 | "pid 9 finished!\n", 196 | "pid 10 finished!\n", 197 | "pid 11 finished!\n", 198 | "pid 12 finished!\n", 199 | "pid 13 finished!\n", 200 | "pid 14 finished!\n", 201 | "pid 15 finished!\n", 202 | "pid 16 finished!\n", 203 | "pid 17 finished!\n", 204 | "pid 18 finished!\n", 205 | "pid 19 finished!\n", 206 | "pid 20 finished!\n", 207 | "pid 21 finished!\n", 208 | "pid 22 finished!\n", 209 | "pid 23 finished!\n", 210 | "pid 24 finished!\n", 211 | "pid 25 finished!\n", 212 | "pid 26 finished!\n", 213 | "pid 27 finished!\n", 214 | "pid 28 finished!\n", 215 | "pid 29 finished!\n", 216 | "pid 30 finished!\n", 217 | "pid 31 finished!\n", 218 | "pid 32 finished!\n", 219 | "pid 33 finished!\n", 220 | "pid 34 finished!\n", 221 | "pid 35 finished!\n" 222 | ] 223 | }, 224 | { 225 | "ename": "FileNotFoundError", 226 | "evalue": "[Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'", 227 | "output_type": "error", 228 | "traceback": [ 229 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 230 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 231 | "\u001b[0;32m/tmp/ipykernel_1065506/1189938079.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'pid {str(pid)} finished!'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'w'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassmap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 232 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "classmap = {}\n", 238 | "LABEL_NAME = [\"BG\", \"LV-MYO\", \"LV-BP\", \"RV\"] \n", 239 | "\n", 240 | "\n", 241 | "MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", 242 | "\n", 243 | "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n", 244 | "for _lb in LABEL_NAME:\n", 245 | " classmap[_lb] = {}\n", 246 | " for _sid in segs:\n", 247 | " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 248 | " classmap[_lb][pid] = []\n", 249 | "\n", 250 | "for seg in segs:\n", 251 | " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 252 | " lb_vol = nio.read_nii_bysitk(seg)\n", 253 | " n_slice = lb_vol.shape[0]\n", 254 | " lb_vol[lb_vol == 200] = 1\n", 255 | " lb_vol[lb_vol == 500] = 2\n", 256 | " lb_vol[lb_vol == 600] = 3\n", 257 | " for slc in range(n_slice):\n", 258 | " for cls in range(len(LABEL_NAME)):\n", 259 | " if cls in lb_vol[slc, ...]:\n", 260 | " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", 261 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 262 | " print(f'pid {str(pid)} finished!')\n", 263 | " \n", 264 | "with open(fid, 'w') as fopen:\n", 265 | " json.dump(classmap, fopen)\n", 266 | " fopen.close() \n", 267 | " " 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 9, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "ename": "FileNotFoundError", 277 | "evalue": "[Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'", 278 | "output_type": "error", 279 | "traceback": [ 280 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 281 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 282 | "\u001b[0;32m/tmp/ipykernel_1045184/825143362.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'w'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassmap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 283 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "with open(fid, 'w') as fopen:\n", 289 | " json.dump(classmap, fopen)\n", 290 | " fopen.close()" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": "Python 3 (ipykernel)", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.8.12" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 2 322 | } 323 | -------------------------------------------------------------------------------- /data/Cardiac/bSSFP/class_slice_index_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate class-slice indexing table for experiments\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is for experiment setting up for simulating few-shot image segmentation scenarios\n", 13 | "\n", 14 | "Input: pre-processed images and their ground-truth labels\n", 15 | "\n", 16 | "Output: a `json` file for class-slice indexing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import numpy as np\n", 39 | "import os\n", 40 | "import glob\n", 41 | "import SimpleITK as sitk\n", 42 | "import sys\n", 43 | "import json\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 9, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "IMG_BNAME=\"./cmr_MR_normalized/image_*.nii.gz\"\n", 55 | "SEG_BNAME=\"./cmr_MR_normalized/label_*.nii.gz\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "imgs = glob.glob(IMG_BNAME)\n", 65 | "segs = glob.glob(SEG_BNAME)\n", 66 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 67 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "['./cmr_MR_normalized/image_1.nii.gz',\n", 79 | " './cmr_MR_normalized/image_2.nii.gz',\n", 80 | " './cmr_MR_normalized/image_3.nii.gz',\n", 81 | " './cmr_MR_normalized/image_4.nii.gz',\n", 82 | " './cmr_MR_normalized/image_5.nii.gz',\n", 83 | " './cmr_MR_normalized/image_6.nii.gz',\n", 84 | " './cmr_MR_normalized/image_7.nii.gz',\n", 85 | " './cmr_MR_normalized/image_8.nii.gz',\n", 86 | " './cmr_MR_normalized/image_9.nii.gz',\n", 87 | " './cmr_MR_normalized/image_10.nii.gz',\n", 88 | " './cmr_MR_normalized/image_11.nii.gz',\n", 89 | " './cmr_MR_normalized/image_12.nii.gz',\n", 90 | " './cmr_MR_normalized/image_13.nii.gz',\n", 91 | " './cmr_MR_normalized/image_14.nii.gz',\n", 92 | " './cmr_MR_normalized/image_15.nii.gz',\n", 93 | " './cmr_MR_normalized/image_16.nii.gz',\n", 94 | " './cmr_MR_normalized/image_17.nii.gz',\n", 95 | " './cmr_MR_normalized/image_18.nii.gz',\n", 96 | " './cmr_MR_normalized/image_19.nii.gz',\n", 97 | " './cmr_MR_normalized/image_20.nii.gz',\n", 98 | " './cmr_MR_normalized/image_21.nii.gz',\n", 99 | " './cmr_MR_normalized/image_22.nii.gz',\n", 100 | " './cmr_MR_normalized/image_23.nii.gz',\n", 101 | " './cmr_MR_normalized/image_24.nii.gz',\n", 102 | " './cmr_MR_normalized/image_25.nii.gz',\n", 103 | " './cmr_MR_normalized/image_26.nii.gz',\n", 104 | " './cmr_MR_normalized/image_27.nii.gz',\n", 105 | " './cmr_MR_normalized/image_28.nii.gz',\n", 106 | " './cmr_MR_normalized/image_29.nii.gz',\n", 107 | " './cmr_MR_normalized/image_30.nii.gz',\n", 108 | " './cmr_MR_normalized/image_31.nii.gz',\n", 109 | " './cmr_MR_normalized/image_32.nii.gz',\n", 110 | " './cmr_MR_normalized/image_33.nii.gz',\n", 111 | " './cmr_MR_normalized/image_34.nii.gz',\n", 112 | " './cmr_MR_normalized/image_35.nii.gz']" 113 | ] 114 | }, 115 | "execution_count": 11, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "imgs" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 12, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "['./cmr_MR_normalized/label_1.nii.gz',\n", 133 | " './cmr_MR_normalized/label_2.nii.gz',\n", 134 | " './cmr_MR_normalized/label_3.nii.gz',\n", 135 | " './cmr_MR_normalized/label_4.nii.gz',\n", 136 | " './cmr_MR_normalized/label_5.nii.gz',\n", 137 | " './cmr_MR_normalized/label_6.nii.gz',\n", 138 | " './cmr_MR_normalized/label_7.nii.gz',\n", 139 | " './cmr_MR_normalized/label_8.nii.gz',\n", 140 | " './cmr_MR_normalized/label_9.nii.gz',\n", 141 | " './cmr_MR_normalized/label_10.nii.gz',\n", 142 | " './cmr_MR_normalized/label_11.nii.gz',\n", 143 | " './cmr_MR_normalized/label_12.nii.gz',\n", 144 | " './cmr_MR_normalized/label_13.nii.gz',\n", 145 | " './cmr_MR_normalized/label_14.nii.gz',\n", 146 | " './cmr_MR_normalized/label_15.nii.gz',\n", 147 | " './cmr_MR_normalized/label_16.nii.gz',\n", 148 | " './cmr_MR_normalized/label_17.nii.gz',\n", 149 | " './cmr_MR_normalized/label_18.nii.gz',\n", 150 | " './cmr_MR_normalized/label_19.nii.gz',\n", 151 | " './cmr_MR_normalized/label_20.nii.gz',\n", 152 | " './cmr_MR_normalized/label_21.nii.gz',\n", 153 | " './cmr_MR_normalized/label_22.nii.gz',\n", 154 | " './cmr_MR_normalized/label_23.nii.gz',\n", 155 | " './cmr_MR_normalized/label_24.nii.gz',\n", 156 | " './cmr_MR_normalized/label_25.nii.gz',\n", 157 | " './cmr_MR_normalized/label_26.nii.gz',\n", 158 | " './cmr_MR_normalized/label_27.nii.gz',\n", 159 | " './cmr_MR_normalized/label_28.nii.gz',\n", 160 | " './cmr_MR_normalized/label_29.nii.gz',\n", 161 | " './cmr_MR_normalized/label_30.nii.gz',\n", 162 | " './cmr_MR_normalized/label_31.nii.gz',\n", 163 | " './cmr_MR_normalized/label_32.nii.gz',\n", 164 | " './cmr_MR_normalized/label_33.nii.gz',\n", 165 | " './cmr_MR_normalized/label_34.nii.gz',\n", 166 | " './cmr_MR_normalized/label_35.nii.gz']" 167 | ] 168 | }, 169 | "execution_count": 12, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "segs" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 13, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "pid 1 finished!\n", 188 | "pid 2 finished!\n", 189 | "pid 3 finished!\n", 190 | "pid 4 finished!\n", 191 | "pid 5 finished!\n", 192 | "pid 6 finished!\n", 193 | "pid 7 finished!\n", 194 | "pid 8 finished!\n", 195 | "pid 9 finished!\n", 196 | "pid 10 finished!\n", 197 | "pid 11 finished!\n", 198 | "pid 12 finished!\n", 199 | "pid 13 finished!\n", 200 | "pid 14 finished!\n", 201 | "pid 15 finished!\n", 202 | "pid 16 finished!\n", 203 | "pid 17 finished!\n", 204 | "pid 18 finished!\n", 205 | "pid 19 finished!\n", 206 | "pid 20 finished!\n", 207 | "pid 21 finished!\n", 208 | "pid 22 finished!\n", 209 | "pid 23 finished!\n", 210 | "pid 24 finished!\n", 211 | "pid 25 finished!\n", 212 | "pid 26 finished!\n", 213 | "pid 27 finished!\n", 214 | "pid 28 finished!\n", 215 | "pid 29 finished!\n", 216 | "pid 30 finished!\n", 217 | "pid 31 finished!\n", 218 | "pid 32 finished!\n", 219 | "pid 33 finished!\n", 220 | "pid 34 finished!\n", 221 | "pid 35 finished!\n" 222 | ] 223 | }, 224 | { 225 | "ename": "FileNotFoundError", 226 | "evalue": "[Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'", 227 | "output_type": "error", 228 | "traceback": [ 229 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 230 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 231 | "\u001b[0;32m/tmp/ipykernel_1065506/1189938079.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'pid {str(pid)} finished!'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'w'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassmap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 232 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "classmap = {}\n", 238 | "LABEL_NAME = [\"BG\", \"LV-MYO\", \"LV-BP\", \"RV\"] \n", 239 | "\n", 240 | "\n", 241 | "MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", 242 | "\n", 243 | "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n", 244 | "for _lb in LABEL_NAME:\n", 245 | " classmap[_lb] = {}\n", 246 | " for _sid in segs:\n", 247 | " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 248 | " classmap[_lb][pid] = []\n", 249 | "\n", 250 | "for seg in segs:\n", 251 | " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 252 | " lb_vol = nio.read_nii_bysitk(seg)\n", 253 | " n_slice = lb_vol.shape[0]\n", 254 | " lb_vol[lb_vol == 200] = 1\n", 255 | " lb_vol[lb_vol == 500] = 2\n", 256 | " lb_vol[lb_vol == 600] = 3\n", 257 | " for slc in range(n_slice):\n", 258 | " for cls in range(len(LABEL_NAME)):\n", 259 | " if cls in lb_vol[slc, ...]:\n", 260 | " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", 261 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 262 | " print(f'pid {str(pid)} finished!')\n", 263 | " \n", 264 | "with open(fid, 'w') as fopen:\n", 265 | " json.dump(classmap, fopen)\n", 266 | " fopen.close() \n", 267 | " " 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 9, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "ename": "FileNotFoundError", 277 | "evalue": "[Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'", 278 | "output_type": "error", 279 | "traceback": [ 280 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 281 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 282 | "\u001b[0;32m/tmp/ipykernel_1045184/825143362.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'w'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassmap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 283 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "with open(fid, 'w') as fopen:\n", 289 | " json.dump(classmap, fopen)\n", 290 | " fopen.close()" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": "Python 3 (ipykernel)", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.8.12" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 2 322 | } 323 | -------------------------------------------------------------------------------- /data/Prostate/NCI/class_slice_index_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate class-slice indexing table for experiments\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is for experiment setting up for simulating few-shot image segmentation scenarios\n", 13 | "\n", 14 | "Input: pre-processed images and their ground-truth labels\n", 15 | "\n", 16 | "Output: a `json` file for class-slice indexing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import numpy as np\n", 39 | "import os\n", 40 | "import glob\n", 41 | "import SimpleITK as sitk\n", 42 | "import sys\n", 43 | "import json\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 9, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "IMG_BNAME=\"./chaos_MR_T2_normalized/image_*.nii.gz\"\n", 55 | "SEG_BNAME=\"./chaos_MR_T2_normalized/label_*.nii.gz\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "imgs = glob.glob(IMG_BNAME)\n", 65 | "segs = glob.glob(SEG_BNAME)\n", 66 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 67 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "['./chaos_MR_T2_normalized/image_1.nii.gz',\n", 79 | " './chaos_MR_T2_normalized/image_2.nii.gz',\n", 80 | " './chaos_MR_T2_normalized/image_3.nii.gz',\n", 81 | " './chaos_MR_T2_normalized/image_5.nii.gz',\n", 82 | " './chaos_MR_T2_normalized/image_8.nii.gz',\n", 83 | " './chaos_MR_T2_normalized/image_10.nii.gz',\n", 84 | " './chaos_MR_T2_normalized/image_13.nii.gz',\n", 85 | " './chaos_MR_T2_normalized/image_15.nii.gz',\n", 86 | " './chaos_MR_T2_normalized/image_19.nii.gz',\n", 87 | " './chaos_MR_T2_normalized/image_20.nii.gz',\n", 88 | " './chaos_MR_T2_normalized/image_21.nii.gz',\n", 89 | " './chaos_MR_T2_normalized/image_22.nii.gz',\n", 90 | " './chaos_MR_T2_normalized/image_31.nii.gz',\n", 91 | " './chaos_MR_T2_normalized/image_32.nii.gz',\n", 92 | " './chaos_MR_T2_normalized/image_33.nii.gz',\n", 93 | " './chaos_MR_T2_normalized/image_34.nii.gz',\n", 94 | " './chaos_MR_T2_normalized/image_36.nii.gz',\n", 95 | " './chaos_MR_T2_normalized/image_37.nii.gz',\n", 96 | " './chaos_MR_T2_normalized/image_38.nii.gz',\n", 97 | " './chaos_MR_T2_normalized/image_39.nii.gz']" 98 | ] 99 | }, 100 | "execution_count": 11, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "imgs" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 12, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "['./chaos_MR_T2_normalized/label_1.nii.gz',\n", 118 | " './chaos_MR_T2_normalized/label_2.nii.gz',\n", 119 | " './chaos_MR_T2_normalized/label_3.nii.gz',\n", 120 | " './chaos_MR_T2_normalized/label_5.nii.gz',\n", 121 | " './chaos_MR_T2_normalized/label_8.nii.gz',\n", 122 | " './chaos_MR_T2_normalized/label_10.nii.gz',\n", 123 | " './chaos_MR_T2_normalized/label_13.nii.gz',\n", 124 | " './chaos_MR_T2_normalized/label_15.nii.gz',\n", 125 | " './chaos_MR_T2_normalized/label_19.nii.gz',\n", 126 | " './chaos_MR_T2_normalized/label_20.nii.gz',\n", 127 | " './chaos_MR_T2_normalized/label_21.nii.gz',\n", 128 | " './chaos_MR_T2_normalized/label_22.nii.gz',\n", 129 | " './chaos_MR_T2_normalized/label_31.nii.gz',\n", 130 | " './chaos_MR_T2_normalized/label_32.nii.gz',\n", 131 | " './chaos_MR_T2_normalized/label_33.nii.gz',\n", 132 | " './chaos_MR_T2_normalized/label_34.nii.gz',\n", 133 | " './chaos_MR_T2_normalized/label_36.nii.gz',\n", 134 | " './chaos_MR_T2_normalized/label_37.nii.gz',\n", 135 | " './chaos_MR_T2_normalized/label_38.nii.gz',\n", 136 | " './chaos_MR_T2_normalized/label_39.nii.gz']" 137 | ] 138 | }, 139 | "execution_count": 12, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "segs" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 13, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "pid 1 finished!\n", 158 | "pid 2 finished!\n", 159 | "pid 3 finished!\n", 160 | "pid 5 finished!\n", 161 | "pid 8 finished!\n", 162 | "pid 10 finished!\n", 163 | "pid 13 finished!\n", 164 | "pid 15 finished!\n", 165 | "pid 19 finished!\n", 166 | "pid 20 finished!\n", 167 | "pid 21 finished!\n", 168 | "pid 22 finished!\n", 169 | "pid 31 finished!\n", 170 | "pid 32 finished!\n", 171 | "pid 33 finished!\n", 172 | "pid 34 finished!\n", 173 | "pid 36 finished!\n", 174 | "pid 37 finished!\n", 175 | "pid 38 finished!\n", 176 | "pid 39 finished!\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "classmap = {}\n", 182 | "LABEL_NAME = [\"BG\", \"LIVER\", \"RK\", \"LK\", \"SPLEEN\"] \n", 183 | "\n", 184 | "\n", 185 | "MIN_TP = 100 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", 186 | "\n", 187 | "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n", 188 | "for _lb in LABEL_NAME:\n", 189 | " classmap[_lb] = {}\n", 190 | " for _sid in segs:\n", 191 | " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 192 | " classmap[_lb][pid] = []\n", 193 | "\n", 194 | "for seg in segs:\n", 195 | " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 196 | " lb_vol = nio.read_nii_bysitk(seg)\n", 197 | " n_slice = lb_vol.shape[0]\n", 198 | " for slc in range(n_slice):\n", 199 | " for cls in range(len(LABEL_NAME)):\n", 200 | " if cls in lb_vol[slc, ...]:\n", 201 | " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", 202 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 203 | " print(f'pid {str(pid)} finished!')\n", 204 | " \n", 205 | "with open(fid, 'w') as fopen:\n", 206 | " json.dump(classmap, fopen)\n", 207 | " fopen.close() \n", 208 | " " 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 14, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "with open(fid, 'w') as fopen:\n", 218 | " json.dump(classmap, fopen)\n", 219 | " fopen.close()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3 (ipykernel)", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.8.12" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /data/Prostate/NCI/dcm_img_to_nii.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | # Convert dicom-like images to nii files in 3D 3 | # This is the first step for image pre-processing 4 | 5 | # Feed path to the downloaded data here 6 | DATAPATH=./MR # please put chaos dataset training fold here which contains ground truth 7 | 8 | # Feed path to the output folder here 9 | OUTPATH=./niis 10 | 11 | if [ ! -d $OUTPATH/T2SPIR ] 12 | then 13 | mkdir $OUTPATH/T2SPIR 14 | fi 15 | 16 | for sid in $(ls "$DATAPATH") 17 | do 18 | dcm2nii -o "$DATAPATH/$sid/T2SPIR" "$DATAPATH/$sid/T2SPIR/DICOM_anon"; 19 | find "$DATAPATH/$sid/T2SPIR" -name "*.nii.gz" -exec mv {} "$OUTPATH/T2SPIR/image_$sid.nii.gz" \; 20 | done; 21 | 22 | 23 | -------------------------------------------------------------------------------- /data/Prostate/NCI/image_normalize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Image Pre-processing\n", 8 | "\n", 9 | "### Overview\n", 10 | "\n", 11 | "This is the second step for data preparation\n", 12 | "\n", 13 | "Input: `.nii`-like images and labels converted from `dicom`s/ `png` files\n", 14 | "\n", 15 | "Output: image-labels with unified size (axial), voxel-spacing, and alleviated off-resonance effects" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "%reset\n", 33 | "%load_ext autoreload\n", 34 | "%autoreload 2\n", 35 | "import numpy as np\n", 36 | "import os\n", 37 | "import glob\n", 38 | "import SimpleITK as sitk\n", 39 | "\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import copy\n", 42 | "import sys\n", 43 | "sys.path.insert(0, '../../dataloaders/')\n", 44 | "import niftiio as nio" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "IMG_FOLDER = \"./niis/T2SPIR\" #, path of nii-like images from step 1\n", 54 | "OUT_FOLDER=\"./chaos_MR_T2_normalized/\" # output directory" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "**0. Find images and their ground-truth segmentations**" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "imgs = glob.glob(IMG_FOLDER + f'/image_*.nii.gz')\n", 71 | "imgs = [ fid for fid in sorted(imgs) ]\n", 72 | "segs = [ fid for fid in sorted(glob.glob(IMG_FOLDER + f'/label_*.nii.gz')) ]\n", 73 | "\n", 74 | "pids = [pid.split(\"_\")[-1].split(\".\")[0] for pid in imgs]" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['./niis/T2SPIR/image_1.nii.gz',\n", 86 | " './niis/T2SPIR/image_10.nii.gz',\n", 87 | " './niis/T2SPIR/image_13.nii.gz',\n", 88 | " './niis/T2SPIR/image_15.nii.gz',\n", 89 | " './niis/T2SPIR/image_19.nii.gz',\n", 90 | " './niis/T2SPIR/image_2.nii.gz',\n", 91 | " './niis/T2SPIR/image_20.nii.gz',\n", 92 | " './niis/T2SPIR/image_21.nii.gz',\n", 93 | " './niis/T2SPIR/image_22.nii.gz',\n", 94 | " './niis/T2SPIR/image_3.nii.gz',\n", 95 | " './niis/T2SPIR/image_31.nii.gz',\n", 96 | " './niis/T2SPIR/image_32.nii.gz',\n", 97 | " './niis/T2SPIR/image_33.nii.gz',\n", 98 | " './niis/T2SPIR/image_34.nii.gz',\n", 99 | " './niis/T2SPIR/image_36.nii.gz',\n", 100 | " './niis/T2SPIR/image_37.nii.gz',\n", 101 | " './niis/T2SPIR/image_38.nii.gz',\n", 102 | " './niis/T2SPIR/image_39.nii.gz',\n", 103 | " './niis/T2SPIR/image_5.nii.gz',\n", 104 | " './niis/T2SPIR/image_8.nii.gz']" 105 | ] 106 | }, 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "imgs" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "['./niis/T2SPIR/label_1.nii.gz',\n", 125 | " './niis/T2SPIR/label_10.nii.gz',\n", 126 | " './niis/T2SPIR/label_13.nii.gz',\n", 127 | " './niis/T2SPIR/label_15.nii.gz',\n", 128 | " './niis/T2SPIR/label_19.nii.gz',\n", 129 | " './niis/T2SPIR/label_2.nii.gz',\n", 130 | " './niis/T2SPIR/label_20.nii.gz',\n", 131 | " './niis/T2SPIR/label_21.nii.gz',\n", 132 | " './niis/T2SPIR/label_22.nii.gz',\n", 133 | " './niis/T2SPIR/label_3.nii.gz',\n", 134 | " './niis/T2SPIR/label_31.nii.gz',\n", 135 | " './niis/T2SPIR/label_32.nii.gz',\n", 136 | " './niis/T2SPIR/label_33.nii.gz',\n", 137 | " './niis/T2SPIR/label_34.nii.gz',\n", 138 | " './niis/T2SPIR/label_36.nii.gz',\n", 139 | " './niis/T2SPIR/label_37.nii.gz',\n", 140 | " './niis/T2SPIR/label_38.nii.gz',\n", 141 | " './niis/T2SPIR/label_39.nii.gz',\n", 142 | " './niis/T2SPIR/label_5.nii.gz',\n", 143 | " './niis/T2SPIR/label_8.nii.gz']" 144 | ] 145 | }, 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "segs" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "**1. Unify image sizes and roi**\n", 160 | "\n", 161 | "a. Cut bright end of histogram to alleviate off-resonance issue\n", 162 | "\n", 163 | "b. Resample images to unified spacing\n", 164 | "\n", 165 | "c. Crop ROIs out to unify image sizes" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 6, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# some helper functions\n", 175 | "def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True):\n", 176 | " resample = sitk.ResampleImageFilter()\n", 177 | " resample.SetInterpolator(interpolator)\n", 178 | " resample.SetOutputDirection(mov_img_obj.GetDirection())\n", 179 | " resample.SetOutputOrigin(mov_img_obj.GetOrigin())\n", 180 | " mov_spacing = mov_img_obj.GetSpacing()\n", 181 | "\n", 182 | " resample.SetOutputSpacing(new_spacing)\n", 183 | " RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing)\n", 184 | " new_size = np.array(mov_img_obj.GetSize()) * RES_COE \n", 185 | "\n", 186 | " resample.SetSize( [int(sz+1) for sz in new_size] )\n", 187 | " if logging:\n", 188 | " print(\"Spacing: {} -> {}\".format(mov_spacing, new_spacing))\n", 189 | " print(\"Size {} -> {}\".format( mov_img_obj.GetSize(), new_size ))\n", 190 | "\n", 191 | " return resample.Execute(mov_img_obj)\n", 192 | "\n", 193 | "def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True):\n", 194 | " src_mat = sitk.GetArrayFromImage(mov_lb_obj)\n", 195 | " lbvs = np.unique(src_mat)\n", 196 | " if logging:\n", 197 | " print(\"Label values: {}\".format(lbvs))\n", 198 | " for idx, lbv in enumerate(lbvs):\n", 199 | " _src_curr_mat = np.float32(src_mat == lbv) \n", 200 | " _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat)\n", 201 | " _src_curr_obj.CopyInformation(mov_lb_obj)\n", 202 | " _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging )\n", 203 | " _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv\n", 204 | " if idx == 0:\n", 205 | " out_vol = _tar_curr_mat\n", 206 | " else:\n", 207 | " out_vol[_tar_curr_mat == lbv] = lbv\n", 208 | " out_obj = sitk.GetImageFromArray(out_vol)\n", 209 | " out_obj.SetSpacing( _tar_curr_obj.GetSpacing() )\n", 210 | " if ref_img != None:\n", 211 | " out_obj.CopyInformation(ref_img)\n", 212 | " return out_obj\n", 213 | " \n", 214 | "def get_label_center(label):\n", 215 | " nnz = np.sum(label > 1e-5)\n", 216 | " return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz))\n", 217 | "\n", 218 | "def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True):\n", 219 | " \"\"\" crop a 3d matrix given the index of the new volume on the original volume\n", 220 | " Args:\n", 221 | " refernce_ctr_idx: the center of the new volume on the original volume (in indices)\n", 222 | " only_2d: only do cropping on first two dimensions\n", 223 | " \"\"\"\n", 224 | " _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case\n", 225 | " if only_2d:\n", 226 | " assert len(crop_size) == 2, \"Actual len {}\".format(len(crop_size))\n", 227 | " assert len(referece_ctr_idx) == 2, \"Actual len {}\".format(len(referece_ctr_idx))\n", 228 | " _expand_cropsize.append(ori_vol.shape[-1])\n", 229 | " \n", 230 | " image_patch = np.ones(tuple(_expand_cropsize)) * padval\n", 231 | "\n", 232 | " half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] )\n", 233 | " _min_idx = [0,0,0]\n", 234 | " _max_idx = list(ori_vol.shape)\n", 235 | "\n", 236 | " # bias of actual cropped size to the beginning and the end of this volume\n", 237 | " _bias_start = [0,0,0]\n", 238 | " _bias_end = [0,0,0]\n", 239 | "\n", 240 | " for dim,hsize in enumerate(half_size):\n", 241 | " if dim == 2 and only_2d:\n", 242 | " break\n", 243 | "\n", 244 | " _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]])\n", 245 | " _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]])\n", 246 | "\n", 247 | " _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim]\n", 248 | " _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim]\n", 249 | " \n", 250 | " if only_2d:\n", 251 | " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", 252 | " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \\\n", 253 | " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", 254 | " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ]\n", 255 | "\n", 256 | " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ]\n", 257 | " # then goes back to original volume\n", 258 | " else:\n", 259 | " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", 260 | " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \\\n", 261 | " half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \\\n", 262 | " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", 263 | " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \\\n", 264 | " referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ]\n", 265 | "\n", 266 | " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ]\n", 267 | " return image_patch\n", 268 | "\n", 269 | "def copy_spacing_ori(src, dst):\n", 270 | " dst.SetSpacing(src.GetSpacing())\n", 271 | " dst.SetOrigin(src.GetOrigin())\n", 272 | " dst.SetDirection(src.GetDirection())\n", 273 | " return dst\n", 274 | "\n", 275 | "s2n = sitk.GetArrayFromImage" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 7, 281 | "metadata": { 282 | "scrolled": false 283 | }, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Failed to create the output folder.\n" 290 | ] 291 | }, 292 | { 293 | "ename": "NameError", 294 | "evalue": "name 'copy_spacing_ori' is not defined", 295 | "output_type": "error", 296 | "traceback": [ 297 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 298 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 299 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mhis_img_o\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msitk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGetImageFromArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mhis_img_o\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopy_spacing_ori\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhis_img_o\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;31m# resampling\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 300 | "\u001b[0;31mNameError\u001b[0m: name 'copy_spacing_ori' is not defined" 301 | ] 302 | } 303 | ], 304 | "source": [ 305 | "import copy\n", 306 | "try:\n", 307 | " os.mkdir(OUT_FOLDER)\n", 308 | "except:\n", 309 | " print(\"Failed to create the output folder.\")\n", 310 | " \n", 311 | "HIST_CUT_TOP = 0.5 # cut top 0.5% of intensity historgam to alleviate off-resonance effect\n", 312 | "\n", 313 | "NEW_SPA = [1.25, 1.25, 7.70] # unified voxel spacing\n", 314 | "\n", 315 | "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n", 316 | "\n", 317 | " lb_n = nio.read_nii_bysitk(seg_fid)\n", 318 | " resample_flg = True\n", 319 | "\n", 320 | " img_obj = sitk.ReadImage( img_fid )\n", 321 | " seg_obj = sitk.ReadImage( seg_fid )\n", 322 | "\n", 323 | " array = sitk.GetArrayFromImage(img_obj)\n", 324 | "\n", 325 | " # cut histogram\n", 326 | " hir = float(np.percentile(array, 100.0 - HIST_CUT_TOP))\n", 327 | " array[array > hir] = hir\n", 328 | "\n", 329 | " his_img_o = sitk.GetImageFromArray(array)\n", 330 | " his_img_o = copy_spacing_ori(img_obj, his_img_o)\n", 331 | "\n", 332 | " # resampling\n", 333 | " img_spa_ori = img_obj.GetSpacing()\n", 334 | " res_img_o = resample_by_res(his_img_o, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2]],\n", 335 | " interpolator = sitk.sitkLinear, logging = True)\n", 336 | "\n", 337 | "\n", 338 | "\n", 339 | " ## label\n", 340 | " lb_arr = sitk.GetArrayFromImage(seg_obj)\n", 341 | "\n", 342 | " # resampling\n", 343 | " res_lb_o = resample_lb_by_res(seg_obj, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2] ], interpolator = sitk.sitkLinear,\n", 344 | " ref_img = None, logging = True)\n", 345 | "\n", 346 | "\n", 347 | " # crop out rois\n", 348 | " res_img_a = s2n(res_img_o)\n", 349 | "\n", 350 | " crop_img_a = image_crop(res_img_a.transpose(1,2,0), [256, 256],\n", 351 | " referece_ctr_idx = [res_img_a.shape[1] // 2, res_img_a.shape[2] //2],\n", 352 | " padval = res_img_a.min(), only_2d = True).transpose(2,0,1)\n", 353 | "\n", 354 | " out_img_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_img_a))\n", 355 | "\n", 356 | " res_lb_a = s2n(res_lb_o)\n", 357 | "\n", 358 | " crop_lb_a = image_crop(res_lb_a.transpose(1,2,0), [256, 256],\n", 359 | " referece_ctr_idx = [res_lb_a.shape[1] // 2, res_lb_a.shape[2] //2],\n", 360 | " padval = 0, only_2d = True).transpose(2,0,1)\n", 361 | "\n", 362 | " out_lb_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_lb_a))\n", 363 | "\n", 364 | "\n", 365 | " out_img_fid = os.path.join( OUT_FOLDER, f'image_{pid}.nii.gz' )\n", 366 | " out_lb_fid = os.path.join( OUT_FOLDER, f'label_{pid}.nii.gz' ) \n", 367 | "\n", 368 | " # then save pre-processed images\n", 369 | " sitk.WriteImage(out_img_obj, out_img_fid, True) \n", 370 | " sitk.WriteImage(out_lb_obj, out_lb_fid, True) \n", 371 | " print(\"{} has been saved\".format(out_img_fid))" 372 | ] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "Python 3", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.6.0" 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 2 396 | } 397 | -------------------------------------------------------------------------------- /data/Prostate/NCI/png_gth_to_nii.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Converting labels from png to nii file\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is the first step for data preparation\n", 13 | "\n", 14 | "Input: ground truth labels in `.png` format\n", 15 | "\n", 16 | "Output: labels in `.nii` format, indexed by patient id" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 13, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import os\n", 39 | "import glob\n", 40 | "\n", 41 | "import numpy as np\n", 42 | "import PIL\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import SimpleITK as sitk\n", 45 | "import sys\n", 46 | "sys.path.insert(0, '../../dataloaders/')\n", 47 | "import niftiio as nio" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 14, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "example = \"./MR/1/T2SPIR/Ground/IMG-0002-00001.png\" # example of ground-truth file name. " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 15, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "### search for scan ids\n", 73 | "ids = os.listdir(\"./MR/\")\n", 74 | "OUT_DIR = './niis/T2SPIR/'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 16, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['37',\n", 86 | " '3',\n", 87 | " '15',\n", 88 | " '34',\n", 89 | " '33',\n", 90 | " '39',\n", 91 | " '20',\n", 92 | " '10',\n", 93 | " '22',\n", 94 | " '8',\n", 95 | " '31',\n", 96 | " '2',\n", 97 | " '36',\n", 98 | " '5',\n", 99 | " '13',\n", 100 | " '19',\n", 101 | " '21',\n", 102 | " '1',\n", 103 | " '38',\n", 104 | " '32']" 105 | ] 106 | }, 107 | "execution_count": 16, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "ids" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 17, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "image with id 37 has been saved!\n", 126 | "image with id 3 has been saved!\n", 127 | "image with id 15 has been saved!\n", 128 | "image with id 34 has been saved!\n", 129 | "image with id 33 has been saved!\n", 130 | "image with id 39 has been saved!\n", 131 | "image with id 20 has been saved!\n", 132 | "image with id 10 has been saved!\n", 133 | "image with id 22 has been saved!\n", 134 | "image with id 8 has been saved!\n", 135 | "image with id 31 has been saved!\n", 136 | "image with id 2 has been saved!\n", 137 | "image with id 36 has been saved!\n", 138 | "image with id 5 has been saved!\n", 139 | "image with id 13 has been saved!\n", 140 | "image with id 19 has been saved!\n", 141 | "image with id 21 has been saved!\n", 142 | "image with id 1 has been saved!\n", 143 | "image with id 38 has been saved!\n", 144 | "image with id 32 has been saved!\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "#### Write them to nii files for the ease of loading in future\n", 150 | "for curr_id in ids:\n", 151 | " pngs = glob.glob(f'./MR/{curr_id}/T2SPIR/Ground/*.png')\n", 152 | " pngs = sorted(pngs, key = lambda x: int(os.path.basename(x).split(\"-\")[-1].split(\".png\")[0]))\n", 153 | " buffer = []\n", 154 | "\n", 155 | " for fid in pngs:\n", 156 | " buffer.append(PIL.Image.open(fid))\n", 157 | "\n", 158 | " vol = np.stack(buffer, axis = 0)\n", 159 | " # flip correction\n", 160 | " vol = np.flip(vol, axis = 1).copy()\n", 161 | " # remap values\n", 162 | " for new_val, old_val in enumerate(sorted(np.unique(vol))):\n", 163 | " vol[vol == old_val] = new_val\n", 164 | "\n", 165 | " # get reference \n", 166 | " ref_img = f'./niis/T2SPIR/image_{curr_id}.nii.gz'\n", 167 | " img_o = sitk.ReadImage(ref_img)\n", 168 | " vol_o = nio.np2itk(img=vol, ref_obj=img_o)\n", 169 | " sitk.WriteImage(vol_o, f'{OUT_DIR}/label_{curr_id}.nii.gz')\n", 170 | " print(f'image with id {curr_id} has been saved!')\n", 171 | "\n", 172 | " " 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.0" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /data/Prostate/UCLH/class_slice_index_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate class-slice indexing table for experiments\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is for experiment setting up for simulating few-shot image segmentation scenarios\n", 13 | "\n", 14 | "Input: pre-processed images and their ground-truth labels\n", 15 | "\n", 16 | "Output: a `json` file for class-slice indexing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import numpy as np\n", 39 | "import os\n", 40 | "import glob\n", 41 | "import SimpleITK as sitk\n", 42 | "import sys\n", 43 | "import json\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 9, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "IMG_BNAME=\"./chaos_MR_T2_normalized/image_*.nii.gz\"\n", 55 | "SEG_BNAME=\"./chaos_MR_T2_normalized/label_*.nii.gz\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "imgs = glob.glob(IMG_BNAME)\n", 65 | "segs = glob.glob(SEG_BNAME)\n", 66 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 67 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "['./chaos_MR_T2_normalized/image_1.nii.gz',\n", 79 | " './chaos_MR_T2_normalized/image_2.nii.gz',\n", 80 | " './chaos_MR_T2_normalized/image_3.nii.gz',\n", 81 | " './chaos_MR_T2_normalized/image_5.nii.gz',\n", 82 | " './chaos_MR_T2_normalized/image_8.nii.gz',\n", 83 | " './chaos_MR_T2_normalized/image_10.nii.gz',\n", 84 | " './chaos_MR_T2_normalized/image_13.nii.gz',\n", 85 | " './chaos_MR_T2_normalized/image_15.nii.gz',\n", 86 | " './chaos_MR_T2_normalized/image_19.nii.gz',\n", 87 | " './chaos_MR_T2_normalized/image_20.nii.gz',\n", 88 | " './chaos_MR_T2_normalized/image_21.nii.gz',\n", 89 | " './chaos_MR_T2_normalized/image_22.nii.gz',\n", 90 | " './chaos_MR_T2_normalized/image_31.nii.gz',\n", 91 | " './chaos_MR_T2_normalized/image_32.nii.gz',\n", 92 | " './chaos_MR_T2_normalized/image_33.nii.gz',\n", 93 | " './chaos_MR_T2_normalized/image_34.nii.gz',\n", 94 | " './chaos_MR_T2_normalized/image_36.nii.gz',\n", 95 | " './chaos_MR_T2_normalized/image_37.nii.gz',\n", 96 | " './chaos_MR_T2_normalized/image_38.nii.gz',\n", 97 | " './chaos_MR_T2_normalized/image_39.nii.gz']" 98 | ] 99 | }, 100 | "execution_count": 11, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "imgs" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 12, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "['./chaos_MR_T2_normalized/label_1.nii.gz',\n", 118 | " './chaos_MR_T2_normalized/label_2.nii.gz',\n", 119 | " './chaos_MR_T2_normalized/label_3.nii.gz',\n", 120 | " './chaos_MR_T2_normalized/label_5.nii.gz',\n", 121 | " './chaos_MR_T2_normalized/label_8.nii.gz',\n", 122 | " './chaos_MR_T2_normalized/label_10.nii.gz',\n", 123 | " './chaos_MR_T2_normalized/label_13.nii.gz',\n", 124 | " './chaos_MR_T2_normalized/label_15.nii.gz',\n", 125 | " './chaos_MR_T2_normalized/label_19.nii.gz',\n", 126 | " './chaos_MR_T2_normalized/label_20.nii.gz',\n", 127 | " './chaos_MR_T2_normalized/label_21.nii.gz',\n", 128 | " './chaos_MR_T2_normalized/label_22.nii.gz',\n", 129 | " './chaos_MR_T2_normalized/label_31.nii.gz',\n", 130 | " './chaos_MR_T2_normalized/label_32.nii.gz',\n", 131 | " './chaos_MR_T2_normalized/label_33.nii.gz',\n", 132 | " './chaos_MR_T2_normalized/label_34.nii.gz',\n", 133 | " './chaos_MR_T2_normalized/label_36.nii.gz',\n", 134 | " './chaos_MR_T2_normalized/label_37.nii.gz',\n", 135 | " './chaos_MR_T2_normalized/label_38.nii.gz',\n", 136 | " './chaos_MR_T2_normalized/label_39.nii.gz']" 137 | ] 138 | }, 139 | "execution_count": 12, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "segs" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 13, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "pid 1 finished!\n", 158 | "pid 2 finished!\n", 159 | "pid 3 finished!\n", 160 | "pid 5 finished!\n", 161 | "pid 8 finished!\n", 162 | "pid 10 finished!\n", 163 | "pid 13 finished!\n", 164 | "pid 15 finished!\n", 165 | "pid 19 finished!\n", 166 | "pid 20 finished!\n", 167 | "pid 21 finished!\n", 168 | "pid 22 finished!\n", 169 | "pid 31 finished!\n", 170 | "pid 32 finished!\n", 171 | "pid 33 finished!\n", 172 | "pid 34 finished!\n", 173 | "pid 36 finished!\n", 174 | "pid 37 finished!\n", 175 | "pid 38 finished!\n", 176 | "pid 39 finished!\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "classmap = {}\n", 182 | "LABEL_NAME = [\"BG\", \"LIVER\", \"RK\", \"LK\", \"SPLEEN\"] \n", 183 | "\n", 184 | "\n", 185 | "MIN_TP = 100 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", 186 | "\n", 187 | "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n", 188 | "for _lb in LABEL_NAME:\n", 189 | " classmap[_lb] = {}\n", 190 | " for _sid in segs:\n", 191 | " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 192 | " classmap[_lb][pid] = []\n", 193 | "\n", 194 | "for seg in segs:\n", 195 | " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 196 | " lb_vol = nio.read_nii_bysitk(seg)\n", 197 | " n_slice = lb_vol.shape[0]\n", 198 | " for slc in range(n_slice):\n", 199 | " for cls in range(len(LABEL_NAME)):\n", 200 | " if cls in lb_vol[slc, ...]:\n", 201 | " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", 202 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 203 | " print(f'pid {str(pid)} finished!')\n", 204 | " \n", 205 | "with open(fid, 'w') as fopen:\n", 206 | " json.dump(classmap, fopen)\n", 207 | " fopen.close() \n", 208 | " " 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 14, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "with open(fid, 'w') as fopen:\n", 218 | " json.dump(classmap, fopen)\n", 219 | " fopen.close()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3 (ipykernel)", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.8.12" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /data/Prostate/UCLH/dcm_img_to_nii.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | # Convert dicom-like images to nii files in 3D 3 | # This is the first step for image pre-processing 4 | 5 | # Feed path to the downloaded data here 6 | DATAPATH=./MR # please put chaos dataset training fold here which contains ground truth 7 | 8 | # Feed path to the output folder here 9 | OUTPATH=./niis 10 | 11 | if [ ! -d $OUTPATH/T2SPIR ] 12 | then 13 | mkdir $OUTPATH/T2SPIR 14 | fi 15 | 16 | for sid in $(ls "$DATAPATH") 17 | do 18 | dcm2nii -o "$DATAPATH/$sid/T2SPIR" "$DATAPATH/$sid/T2SPIR/DICOM_anon"; 19 | find "$DATAPATH/$sid/T2SPIR" -name "*.nii.gz" -exec mv {} "$OUTPATH/T2SPIR/image_$sid.nii.gz" \; 20 | done; 21 | 22 | 23 | -------------------------------------------------------------------------------- /data/Prostate/UCLH/png_gth_to_nii.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Converting labels from png to nii file\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is the first step for data preparation\n", 13 | "\n", 14 | "Input: ground truth labels in `.png` format\n", 15 | "\n", 16 | "Output: labels in `.nii` format, indexed by patient id" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 13, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import os\n", 39 | "import glob\n", 40 | "\n", 41 | "import numpy as np\n", 42 | "import PIL\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import SimpleITK as sitk\n", 45 | "import sys\n", 46 | "sys.path.insert(0, '../../dataloaders/')\n", 47 | "import niftiio as nio" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 14, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "example = \"./MR/1/T2SPIR/Ground/IMG-0002-00001.png\" # example of ground-truth file name. " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 15, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "### search for scan ids\n", 73 | "ids = os.listdir(\"./MR/\")\n", 74 | "OUT_DIR = './niis/T2SPIR/'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 16, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['37',\n", 86 | " '3',\n", 87 | " '15',\n", 88 | " '34',\n", 89 | " '33',\n", 90 | " '39',\n", 91 | " '20',\n", 92 | " '10',\n", 93 | " '22',\n", 94 | " '8',\n", 95 | " '31',\n", 96 | " '2',\n", 97 | " '36',\n", 98 | " '5',\n", 99 | " '13',\n", 100 | " '19',\n", 101 | " '21',\n", 102 | " '1',\n", 103 | " '38',\n", 104 | " '32']" 105 | ] 106 | }, 107 | "execution_count": 16, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "ids" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 17, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "image with id 37 has been saved!\n", 126 | "image with id 3 has been saved!\n", 127 | "image with id 15 has been saved!\n", 128 | "image with id 34 has been saved!\n", 129 | "image with id 33 has been saved!\n", 130 | "image with id 39 has been saved!\n", 131 | "image with id 20 has been saved!\n", 132 | "image with id 10 has been saved!\n", 133 | "image with id 22 has been saved!\n", 134 | "image with id 8 has been saved!\n", 135 | "image with id 31 has been saved!\n", 136 | "image with id 2 has been saved!\n", 137 | "image with id 36 has been saved!\n", 138 | "image with id 5 has been saved!\n", 139 | "image with id 13 has been saved!\n", 140 | "image with id 19 has been saved!\n", 141 | "image with id 21 has been saved!\n", 142 | "image with id 1 has been saved!\n", 143 | "image with id 38 has been saved!\n", 144 | "image with id 32 has been saved!\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "#### Write them to nii files for the ease of loading in future\n", 150 | "for curr_id in ids:\n", 151 | " pngs = glob.glob(f'./MR/{curr_id}/T2SPIR/Ground/*.png')\n", 152 | " pngs = sorted(pngs, key = lambda x: int(os.path.basename(x).split(\"-\")[-1].split(\".png\")[0]))\n", 153 | " buffer = []\n", 154 | "\n", 155 | " for fid in pngs:\n", 156 | " buffer.append(PIL.Image.open(fid))\n", 157 | "\n", 158 | " vol = np.stack(buffer, axis = 0)\n", 159 | " # flip correction\n", 160 | " vol = np.flip(vol, axis = 1).copy()\n", 161 | " # remap values\n", 162 | " for new_val, old_val in enumerate(sorted(np.unique(vol))):\n", 163 | " vol[vol == old_val] = new_val\n", 164 | "\n", 165 | " # get reference \n", 166 | " ref_img = f'./niis/T2SPIR/image_{curr_id}.nii.gz'\n", 167 | " img_o = sitk.ReadImage(ref_img)\n", 168 | " vol_o = nio.np2itk(img=vol, ref_obj=img_o)\n", 169 | " sitk.WriteImage(vol_o, f'{OUT_DIR}/label_{curr_id}.nii.gz')\n", 170 | " print(f'image with id {curr_id} has been saved!')\n", 171 | "\n", 172 | " " 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.0" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /data/supervoxels/_ccomp.pxd: -------------------------------------------------------------------------------- 1 | """Export fast union find in Cython""" 2 | cimport numpy as cnp 3 | 4 | ctypedef cnp.intp_t DTYPE_t 5 | 6 | cdef DTYPE_t find_root(DTYPE_t *forest, DTYPE_t n) nogil 7 | cdef void set_root(DTYPE_t *forest, DTYPE_t n, DTYPE_t root) nogil 8 | cdef void join_trees(DTYPE_t *forest, DTYPE_t n, DTYPE_t m) nogil -------------------------------------------------------------------------------- /data/supervoxels/felzenszwalb_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from felzenszwalb_3d_cy import felzenszwalb_cython_3d 4 | 5 | 6 | def felzenszwalb_3d(image, scale=1, sigma=0.8, min_size=20, multichannel=True, spacing=(1,1,1)): 7 | """ 8 | Code modified from: https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.felzenszwalb 9 | 10 | 11 | Computes Felsenszwalb's efficient graph based image segmentation. 12 | 13 | Produces an oversegmentation of a multichannel (i.e. RGB) image 14 | using a fast, minimum spanning tree based clustering on the image grid. 15 | The parameter ``scale`` sets an observation level. Higher scale means 16 | less and larger segments. ``sigma`` is the diameter of a Gaussian kernel, 17 | used for smoothing the image prior to segmentation. 18 | 19 | The number of produced segments as well as their size can only be 20 | controlled indirectly through ``scale``. Segment size within an image can 21 | vary greatly depending on local contrast. 22 | 23 | For RGB images, the algorithm uses the euclidean distance between pixels in 24 | color space. 25 | 26 | Parameters 27 | ---------- 28 | image : (width, height, 3) or (width, height) ndarray 29 | Input image. 30 | scale : float 31 | Free parameter. Higher means larger clusters. 32 | sigma : float 33 | Width (standard deviation) of Gaussian kernel used in preprocessing. 34 | min_size : int 35 | Minimum component size. Enforced using postprocessing. 36 | multichannel : bool, optional (default: True) 37 | Whether the last axis of the image is to be interpreted as multiple 38 | channels. A value of False, for a 3D image, is not currently supported. 39 | 40 | Returns 41 | ------- 42 | segment_mask : (width, height) ndarray 43 | Integer mask indicating segment labels. 44 | 45 | References 46 | ---------- 47 | .. [1] Efficient graph-based image segmentation, Felzenszwalb, P.F. and 48 | Huttenlocher, D.P. International Journal of Computer Vision, 2004 49 | 50 | Notes 51 | ----- 52 | The `k` parameter used in the original paper renamed to `scale` here. 53 | 54 | Examples 55 | -------- 56 | >>> from skimage.segmentation import felzenszwalb 57 | >>> from skimage.data import coffee 58 | >>> img = coffee() 59 | >>> segments = felzenszwalb(img, scale=3.0, sigma=0.95, min_size=5) 60 | """ 61 | 62 | # if not multichannel and image.ndim > 2: 63 | # raise ValueError("This algorithm works only on single or " 64 | # "multi-channel 2d images. ") 65 | 66 | image = np.atleast_3d(image) 67 | return felzenszwalb_cython_3d(image, scale=scale, sigma=sigma, min_size=min_size, spacing=spacing) 68 | 69 | -------------------------------------------------------------------------------- /data/supervoxels/felzenszwalb_3d_cy.pyx: -------------------------------------------------------------------------------- 1 | #cython: cdivision=True 2 | #cython: boundscheck=False 3 | #cython: nonecheck=False 4 | #cython: wraparound=False 5 | import numpy as np 6 | from scipy import ndimage as ndi 7 | 8 | cimport numpy as cnp 9 | from _ccomp cimport find_root, join_trees 10 | 11 | from skimage.util import img_as_float64 12 | from skimage._shared.utils import warn 13 | 14 | cnp.import_array() 15 | 16 | 17 | def felzenszwalb_cython_3d(image, double scale=1, sigma=0.8, Py_ssize_t min_size=20, spacing=(1,1,1)): 18 | """ 19 | Code modified from: https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.felzenszwalb 20 | 21 | Felzenszwalb's efficient graph based segmentation for 22 | single or multiple channels. 23 | 24 | Produces an oversegmentation of a single or multi-channel image 25 | using a fast, minimum spanning tree based clustering on the image grid. 26 | The number of produced segments as well as their size can only be 27 | controlled indirectly through ``scale``. Segment size within an image can 28 | vary greatly depending on local contrast. 29 | 30 | Parameters 31 | ---------- 32 | image : (N, M, C) ndarray 33 | Input image. 34 | scale : float, optional (default 1) 35 | Sets the obervation level. Higher means larger clusters. 36 | sigma : float, optional (default 0.8) 37 | Width of Gaussian smoothing kernel used in preprocessing. 38 | Larger sigma gives smother segment boundaries. 39 | min_size : int, optional (default 20) 40 | Minimum component size. Enforced using postprocessing. 41 | 42 | Returns 43 | ------- 44 | segment_mask : (N, M) ndarray 45 | Integer mask indicating segment labels. 46 | """ 47 | 48 | 49 | image = img_as_float64(image) 50 | dtype = image.dtype 51 | 52 | # rescale scale to behave like in reference implementation 53 | scale = float(scale) / 255. 54 | 55 | spacing = np.ascontiguousarray(spacing, dtype=dtype) 56 | sigma = np.array([sigma, sigma, sigma], dtype=dtype) 57 | sigma /= spacing.astype(dtype) 58 | 59 | image = ndi.gaussian_filter(image, sigma=sigma) 60 | height, width, depth = image.shape # depth, height, width! 61 | image = image[..., None] 62 | 63 | # assuming spacing is equal in xy dir. 64 | s = spacing[0]/spacing[1] 65 | w1 = 1.0 # x, y, xy 66 | w2 = s**2 # z 67 | w3 = (np.sqrt(1+s**2)/np.sqrt(2))**2 # zx, zy 68 | w4 = (np.sqrt(2 + s**2)/np.sqrt(3))**2 # zxy 69 | 70 | 71 | cost1 = np.sqrt(w1 * np.sum((image[:, 1:, :] - image[:, :width-1, :])**2, axis=-1)) # x 72 | cost2 = np.sqrt(w1 * np.sum((image[:, 1:, 1:] - image[:, :width-1, :depth-1])**2, axis=-1)) # xy 73 | cost3 = np.sqrt(w1 * np.sum((image[:, :, 1:] - image[:, :, :depth-1])**2, axis=-1)) # y 74 | cost7 = np.sqrt(w1 * np.sum((image[:, 1:, :depth-1] - image[:, :width-1, 1:])**2, axis=-1)) # xy 75 | cost9 = np.sqrt(w3 * np.sum((image[1:, 1:, :] - image[:height-1, :width-1, :])**2, axis=-1)) # zx 76 | cost10 = np.sqrt(w4 * np.sum((image[1:, 1:, 1:] - image[:height-1, :width-1, :depth-1])**2, axis=-1)) # zxy 77 | cost11 = np.sqrt(w3 * np.sum((image[1:, :, 1:] - image[:height-1, :, :depth-1])**2, axis=-1)) # zy 78 | cost12 = np.sqrt(w3 * np.sum((image[1:, :width-1, :] - image[:height-1, 1:, :])**2, axis=-1)) # zx 79 | cost13 = np.sqrt(w4 * np.sum((image[1:, :width-1, :depth-1] - image[:height-1, 1:, 1:])**2, axis=-1)) # zxy 80 | cost14 = np.sqrt(w3 * np.sum((image[1:, :, :depth-1] - image[:height-1, :, 1:])**2, axis=-1)) # zy 81 | cost15 = np.sqrt(w4 * np.sum((image[1:, 1:, :depth-1] - image[:height-1, :width-1, 1:])**2, axis=-1)) # zxy 82 | cost16 = np.sqrt(w4 * np.sum((image[1:, :width-1, 1:] - image[:height-1, 1:, :depth-1])**2, axis=-1)) # zxy 83 | cost25 = np.sqrt(w2 * np.sum((image[1:, :, :] - image[:height-1, :, :])**2, axis=-1)) # z 84 | 85 | 86 | cdef cnp.ndarray[cnp.float_t, ndim=1] costs = np.hstack([cost1.ravel(), cost2.ravel(), cost3.ravel(), cost7.ravel(), cost9.ravel(), cost10.ravel(), cost11.ravel(), cost12.ravel(), cost13.ravel(), cost14.ravel(), cost15.ravel(), cost16.ravel(), cost25.ravel()]).astype(float) 87 | 88 | # compute edges between pixels: 89 | cdef cnp.ndarray[cnp.intp_t, ndim=3] segments \ 90 | = np.arange(width * height * depth, dtype=np.intp).reshape(height, width, depth) 91 | 92 | 93 | edges1 = np.c_[segments[:, 1:, :].ravel(), segments[:, :width-1, :].ravel()] 94 | edges2 = np.c_[segments[:, 1:, 1:].ravel(), segments[:, :width-1, :depth-1].ravel()] 95 | edges3 = np.c_[segments[:, :, 1:].ravel(), segments[:, :, :depth-1].ravel()] 96 | edges7 = np.c_[segments[:, 1:, :depth-1].ravel(), segments[:, :width-1, 1:].ravel()] 97 | edges9 = np.c_[segments[1:, 1:, :].ravel(), segments[:height-1, :width-1, :].ravel()] 98 | edges10 = np.c_[segments[1:, 1:, 1:].ravel(), segments[:height-1, :width-1, :depth-1].ravel()] 99 | edges11 = np.c_[segments[1:, :, 1:].ravel(), segments[:height-1, :, :depth-1].ravel()] 100 | edges12 = np.c_[segments[1:, :width-1, :].ravel(), segments[:height-1, 1:, :].ravel()] 101 | edges13 = np.c_[segments[1:, :width-1, :depth-1].ravel(), segments[:height-1, 1:, 1:].ravel()] 102 | edges14 = np.c_[segments[1:, :, :depth-1].ravel(), segments[:height-1, :, 1:].ravel()] 103 | edges15 = np.c_[segments[1:, 1:, :depth-1].ravel(), segments[:height-1, :width-1, 1:].ravel()] 104 | edges16 = np.c_[segments[1:, :width-1, 1:].ravel(), segments[:height-1, 1:, :depth-1].ravel()] 105 | edges25 = np.c_[segments[1:, :, :].ravel(), segments[:height-1, :, :].ravel()] 106 | 107 | cdef cnp.ndarray[cnp.intp_t, ndim=2] edges \ 108 | = np.vstack([edges1, edges2, edges3, edges7, edges9, edges10, edges11, edges12, edges13, edges14, edges15, edges16, edges25]) 109 | 110 | # initialize data structures for segment size 111 | # and inner cost, then start greedy iteration over edges. 112 | edge_queue = np.argsort(costs) 113 | edges = np.ascontiguousarray(edges[edge_queue]) 114 | costs = np.ascontiguousarray(costs[edge_queue]) 115 | cdef cnp.intp_t *segments_p = segments.data 116 | cdef cnp.intp_t *edges_p = edges.data 117 | cdef cnp.float_t *costs_p = costs.data 118 | cdef cnp.ndarray[cnp.intp_t, ndim=1] segment_size \ 119 | = np.ones(width * height * depth, dtype=np.intp) 120 | 121 | # inner cost of segments 122 | cdef cnp.ndarray[cnp.float_t, ndim=1] cint = np.zeros(width * height * depth) 123 | cdef cnp.intp_t seg0, seg1, seg_new, e 124 | cdef float cost, inner_cost0, inner_cost1 125 | cdef Py_ssize_t num_costs = costs.size 126 | 127 | with nogil: 128 | # set costs_p back one. we increase it before we use it 129 | # since we might continue before that. 130 | costs_p -= 1 131 | for e in range(num_costs): 132 | seg0 = find_root(segments_p, edges_p[0]) 133 | seg1 = find_root(segments_p, edges_p[1]) 134 | 135 | edges_p += 2 136 | costs_p += 1 137 | if seg0 == seg1: 138 | continue 139 | 140 | 141 | inner_cost0 = cint[seg0] + scale / segment_size[seg0] 142 | inner_cost1 = cint[seg1] + scale / segment_size[seg1] 143 | 144 | # return 0 # ok 145 | 146 | if costs_p[0] < min(inner_cost0, inner_cost1): 147 | # update size and cost 148 | 149 | join_trees(segments_p, seg0, seg1) # TODO: not ok! 150 | #return 0 # not ok!! 151 | seg_new = find_root(segments_p, seg0) 152 | segment_size[seg_new] = segment_size[seg0] + segment_size[seg1] 153 | cint[seg_new] = costs_p[0] 154 | 155 | 156 | # postprocessing to remove small segments 157 | edges_p = edges.data 158 | for e in range(num_costs): 159 | seg0 = find_root(segments_p, edges_p[0]) 160 | seg1 = find_root(segments_p, edges_p[1]) 161 | edges_p += 2 162 | if seg0 == seg1: 163 | continue 164 | if segment_size[seg0] < min_size or segment_size[seg1] < min_size: 165 | join_trees(segments_p, seg0, seg1) 166 | seg_new = find_root(segments_p, seg0) 167 | segment_size[seg_new] = segment_size[seg0] + segment_size[seg1] 168 | 169 | 170 | 171 | # unravel the union find tree 172 | flat = segments.ravel() 173 | old = np.zeros_like(flat) 174 | while (old != flat).any(): 175 | old = flat 176 | flat = flat[flat] 177 | flat = np.unique(flat, return_inverse=True)[1] 178 | return flat.reshape((height, width, depth)) -------------------------------------------------------------------------------- /data/supervoxels/generate_supervoxels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from Ouyang et al. 3 | https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation 4 | """ 5 | 6 | import os 7 | import SimpleITK as sitk 8 | import glob 9 | from skimage.measure import label 10 | import scipy.ndimage.morphology as snm 11 | from felzenszwalb_3d import * 12 | 13 | base_dir = '/CHAOST2/chaos_MR_T2_normalized' 14 | # base_dir = '/CMR/cmr_MR_normalized' 15 | 16 | imgs = glob.glob(os.path.join(base_dir, 'image*')) 17 | labels = glob.glob(os.path.join(base_dir, 'label*')) 18 | 19 | imgs = sorted(imgs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 20 | labels = sorted(labels, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 21 | 22 | fg_thresh = 10 23 | 24 | MODE = 'MIDDLE' 25 | n_sv = 5000 26 | # n_sv = 1000 27 | 28 | def read_nii_bysitk(input_fid): 29 | """ read nii to numpy through simpleitk 30 | peelinfo: taking direction, origin, spacing and metadata out 31 | """ 32 | img_obj = sitk.ReadImage(input_fid) 33 | img_np = sitk.GetArrayFromImage(img_obj) 34 | return img_np 35 | 36 | # thresholding the intensity values to get a binary mask of the patient 37 | def fg_mask2d(img_2d, thresh): 38 | mask_map = np.float32(img_2d > thresh) 39 | def getLargestCC(segmentation): # largest connected components 40 | labels = label(segmentation) 41 | assert (labels.max() != 0) # assume at least 1 CC 42 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 43 | return largestCC 44 | 45 | if mask_map.max() < 0.999: 46 | return mask_map 47 | else: 48 | post_mask = getLargestCC(mask_map) 49 | fill_mask = snm.binary_fill_holes(post_mask) 50 | return fill_mask 51 | 52 | 53 | # remove supervoxels within the empty regions 54 | def supervox_masking(seg, mask): 55 | 56 | seg[seg == 0] = seg.max() + 1 57 | seg = np.int32(seg) 58 | seg[mask == 0] = 0 59 | 60 | return seg 61 | 62 | # make supervoxels 63 | for img_path in imgs: 64 | img = read_nii_bysitk(img_path) 65 | img = 255 * (img - img.min()) / img.ptp() 66 | 67 | reader = sitk.ImageFileReader() 68 | reader.SetFileName(img_path) 69 | reader.LoadPrivateTagsOn() 70 | reader.ReadImageInformation() 71 | 72 | x = float(reader.GetMetaData('pixdim[1]')) 73 | y = float(reader.GetMetaData('pixdim[2]')) 74 | z = float(reader.GetMetaData('pixdim[3]')) 75 | 76 | segments_felzenszwalb = felzenszwalb_3d(img, min_size=n_sv, sigma=0, spacing=(z, x, y)) 77 | 78 | # post processing: remove bg (low intensity regions) 79 | fg_mask_vol = np.zeros(segments_felzenszwalb.shape) 80 | for ii in range(segments_felzenszwalb.shape[0]): 81 | _fgm = fg_mask2d(img[ii, ...], fg_thresh) 82 | fg_mask_vol[ii] = _fgm 83 | processed_seg_vol = supervox_masking(segments_felzenszwalb, fg_mask_vol) 84 | 85 | # write to nii.gz 86 | out_seg = sitk.GetImageFromArray(processed_seg_vol) 87 | 88 | idx = os.path.basename(img_path).split("_")[-1].split(".nii.gz")[0] 89 | 90 | # seg_fid = os.path.join(f'/supervoxels_{n_sv}/', f'superpix-{MODE}_{idx}.nii.gz') 91 | # sitk.WriteImage(out_seg, seg_fid) 92 | print(f'image with id {idx} has finished') 93 | 94 | 95 | -------------------------------------------------------------------------------- /data/supervoxels/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | from distutils.core import Extension 5 | 6 | 7 | extensions = [ 8 | Extension("felzenszwalb_3d_cy", ["felzenszwalb_3d_cy.pyx"], include_dirs=[numpy.get_include()]) 9 | ] 10 | setup( 11 | name='felzenszwalb_3d_cy', 12 | ext_modules=cythonize(extensions) 13 | ) 14 | 15 | extensions = [ 16 | Extension("_ccomp", ["_ccomp.pyx"], include_dirs=[numpy.get_include()]) 17 | ] 18 | setup( 19 | name='_ccomp', 20 | ext_modules=cythonize(extensions) 21 | ) -------------------------------------------------------------------------------- /dataloaders/dataset_specifics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset Specifics 3 | Extended from ADNet code by Hansen et al. 4 | """ 5 | 6 | import torch 7 | import random 8 | 9 | 10 | def get_label_names(dataset): 11 | label_names = {} 12 | if dataset == 'CARDIAC_bssFP': 13 | label_names[0] = 'BG' 14 | label_names[1] = 'LV-MYO' 15 | label_names[2] = 'LV-BP' 16 | label_names[3] = 'RV' 17 | elif dataset == 'CARDIAC_LGE': 18 | label_names[0] = 'BG' 19 | label_names[1] = 'LV-MYO' 20 | label_names[2] = 'LV-BP' 21 | label_names[3] = 'RV' 22 | elif dataset == 'ABDOMEN_MR': 23 | label_names[0] = 'BG' 24 | label_names[1] = 'LIVER' 25 | label_names[2] = 'RIGHT_KIDNEY' 26 | label_names[3] = 'LEFT_KIDNEY' 27 | label_names[4] = 'SPLEEN' 28 | elif dataset == 'ABDOMEN_CT': 29 | label_names[0] = 'BG' 30 | label_names[1] = 'SPLEEN' 31 | label_names[2] = 'RIGHT_KIDNEY' 32 | label_names[3] = 'LEFT_KIDNEY' 33 | label_names[4] = 'GALLBLADDER' 34 | label_names[5] = 'ESOPHAGUS' 35 | label_names[6] = 'LIVER' 36 | label_names[7] = 'STOMACH' 37 | label_names[8] = 'AORTA' 38 | label_names[9] = 'INFERIOR_VENA_CAVA' # Inferior vena cava 39 | label_names[10] = 'PORTAL_VEIN_AND_SPLENIC_VEIN' # portal vein and splenic vein 40 | label_names[11] = 'PANCREAS' 41 | label_names[12] = 'RIGHT_ADRENAL_GLAND' # right adrenal gland 42 | label_names[13] = 'LEFT_ADRENAL_GLAND' # left adrenal gland 43 | elif dataset == 'Prostate_UCLH': 44 | label_names[0] = 'BG' 45 | label_names[1] = 'Bladder' 46 | label_names[2] = 'Bone' 47 | label_names[3] = 'Obturator_Internus' 48 | label_names[4] = 'Transition_Zone' 49 | label_names[5] = 'Central_Gland' 50 | label_names[6] = 'Rectum' 51 | label_names[7] = 'Seminal_Vesicle' 52 | label_names[8] = 'Neurovascular_Bundle' 53 | elif dataset == 'Prostate_TCIA_PD': 54 | label_names[0] = 'BG' 55 | label_names[1] = 'Bladder' 56 | label_names[2] = 'Bone' 57 | label_names[3] = 'Obturator_Internus' 58 | label_names[4] = 'Transition_Zone' 59 | label_names[5] = 'Central_Gland' 60 | label_names[6] = 'Rectum' 61 | label_names[7] = 'Seminal_Vesicle' 62 | label_names[8] = 'Neurovascular_Bundle' 63 | elif dataset == 'Prostate_NCI': 64 | label_names[0] = 'BG' 65 | label_names[1] = 'Bladder' 66 | label_names[2] = 'Bone' 67 | label_names[3] = 'Obturator_Internus' 68 | label_names[4] = 'Transition_Zone' 69 | label_names[5] = 'Central_Gland' 70 | label_names[6] = 'Rectum' 71 | label_names[7] = 'Seminal_Vesicle' 72 | label_names[8] = 'Neurovascular_Bundle' 73 | 74 | return label_names 75 | 76 | 77 | def get_folds(dataset): 78 | FOLD = {} 79 | if dataset == 'CARDIAC_bssFP': 80 | FOLD[0] = set(range(0, 8)) 81 | FOLD[1] = set(range(9, 17)) 82 | FOLD[2] = set(range(18, 26)) 83 | FOLD[3] = set(range(27, 35)) 84 | FOLD[4] = set(range(36, 44)) 85 | FOLD[4].update([0]) 86 | return FOLD 87 | 88 | elif dataset == 'CARDIAC_LGE': 89 | FOLD[0] = set(range(0, 8)) 90 | FOLD[1] = set(range(9, 17)) 91 | FOLD[2] = set(range(18, 26)) 92 | FOLD[3] = set(range(27, 35)) 93 | FOLD[4] = set(range(36, 44)) 94 | FOLD[4].update([0]) 95 | return FOLD 96 | 97 | elif dataset == 'ABDOMEN_MR': 98 | FOLD[0] = set(range(0, 5)) 99 | FOLD[1] = set(range(4, 9)) 100 | FOLD[2] = set(range(8, 13)) 101 | FOLD[3] = set(range(12, 17)) 102 | FOLD[4] = set(range(16, 20)) 103 | FOLD[4].update([0]) 104 | return FOLD 105 | elif dataset == 'ABDOMEN_CT': 106 | FOLD[0] = set(range(0, 5)) 107 | FOLD[1] = set(range(4, 9)) 108 | FOLD[2] = set(range(8, 13)) 109 | FOLD[3] = set(range(12, 17)) 110 | FOLD[4] = set(range(16, 20)) 111 | FOLD[4].update([0]) 112 | return FOLD 113 | elif dataset == 'Prostate_UCLH': 114 | FOLD[0] = set(range(0, 5)) 115 | FOLD[1] = set(range(4, 9)) 116 | FOLD[2] = set(range(8, 13)) 117 | FOLD[3] = set(range(12, 17)) 118 | FOLD[4] = set(range(16, 20)) 119 | FOLD[4].update([0]) 120 | return FOLD 121 | elif dataset == 'Prostate_Picture': 122 | FOLD[0] = set(range(0, 5)) 123 | FOLD[1] = set(range(4, 9)) 124 | FOLD[2] = set(range(8, 13)) 125 | FOLD[3] = set(range(12, 17)) 126 | FOLD[4] = set(range(16, 20)) 127 | FOLD[4].update([0]) 128 | return FOLD 129 | elif dataset == 'Prostate_NCI': 130 | FOLD[0] = set(range(0, 5)) 131 | FOLD[1] = set(range(4, 9)) 132 | FOLD[2] = set(range(8, 13)) 133 | FOLD[3] = set(range(12, 17)) 134 | FOLD[4] = set(range(16, 20)) 135 | FOLD[4].update([0]) 136 | return FOLD 137 | elif dataset == 'Prostate_TCIA_PD': 138 | FOLD[0] = set(range(0, 5)) 139 | FOLD[1] = set(range(4, 9)) 140 | FOLD[2] = set(range(8, 13)) 141 | FOLD[3] = set(range(12, 17)) 142 | FOLD[4] = set(range(16, 20)) 143 | FOLD[4].update([0]) 144 | return FOLD 145 | 146 | else: 147 | raise ValueError(f'Dataset: {dataset} not found') 148 | 149 | 150 | def sample_xy(spr, k=0, b=215): 151 | _, h, v = torch.where(spr) 152 | 153 | if len(h) == 0 or len(v) == 0: 154 | horizontal = 0 155 | vertical = 0 156 | else: 157 | 158 | h_min = min(h) 159 | h_max = max(h) 160 | if b > (h_max - h_min): 161 | kk = min(k, int((h_max - h_min) / 2)) 162 | horizontal = random.randint(max(h_max - b - kk, 0), min(h_min + kk, 256 - b - 1)) 163 | else: 164 | kk = min(k, int(b / 2)) 165 | horizontal = random.randint(max(h_min - kk, 0), min(h_max - b + kk, 256 - b - 1)) 166 | 167 | v_min = min(v) 168 | v_max = max(v) 169 | if b > (v_max - v_min): 170 | kk = min(k, int((v_max - v_min) / 2)) 171 | vertical = random.randint(max(v_max - b - kk, 0), min(v_min + kk, 256 - b - 1)) 172 | else: 173 | kk = min(k, int(b / 2)) 174 | vertical = random.randint(max(v_min - kk, 0), min(v_max - b + kk, 256 - b - 1)) 175 | 176 | return horizontal, vertical 177 | 178 | 179 | -------------------------------------------------------------------------------- /dataloaders/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for Training and Test 3 | Extended from ADNet code by Hansen et al. 4 | """ 5 | import torch 6 | import cv2 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms as deftfx 9 | import glob 10 | import os 11 | import SimpleITK as sitk 12 | import random 13 | import numpy as np 14 | from . import image_transforms as myit 15 | from .dataset_specifics import * 16 | 17 | 18 | class TestDataset(Dataset): 19 | 20 | def __init__(self, args): 21 | 22 | # reading the paths 23 | if args['dataset'] == 'CARDIAC_bssFP': 24 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_bssFP_normalized/image*')) 25 | elif args['dataset'] == 'CARDIAC_LGE': 26 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_LGE_normalized/image*')) 27 | elif args['dataset'] == 'ABDOMEN_MR': 28 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/image*')) 29 | elif args['dataset'] == 'ABDOMEN_CT': 30 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/image*')) 31 | elif args['dataset'] == 'Prostate_UCLH': 32 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'UCLH_normalized/image*')) 33 | elif args['dataset'] == 'Prostate_TCIA_PD': 34 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'tcia_pd_normalized/image*')) 35 | elif args['dataset'] == 'Prostate_NCI': 36 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'NCI_normalized/image*')) 37 | 38 | self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 39 | 40 | # remove test fold! 41 | self.FOLD = get_folds(args['dataset']) 42 | self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx in self.FOLD[args['eval_fold']]] 43 | 44 | # split into support/query 45 | idx = np.arange(len(self.image_dirs)) 46 | self.support_dir = self.image_dirs[idx[args['supp_idx']]] 47 | self.image_dirs.pop(idx[args['supp_idx']]) # remove support 48 | self.label = None 49 | 50 | def __len__(self): 51 | return len(self.image_dirs) 52 | 53 | def __getitem__(self, idx): 54 | 55 | img_path = self.image_dirs[idx] 56 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) 57 | img = (img - img.mean()) / img.std() 58 | img = np.stack(3 * [img], axis=1) 59 | 60 | lbl = sitk.GetArrayFromImage( 61 | sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1])) 62 | 63 | lbl[lbl == 200] = 1 64 | lbl[lbl == 500] = 2 65 | lbl[lbl == 600] = 3 66 | lbl = 1 * (lbl == self.label) 67 | 68 | sample = {'id': img_path} 69 | 70 | # Evaluation protocol. 71 | idx = lbl.sum(axis=(1, 2)) > 0 72 | sample['image'] = torch.from_numpy(img[idx]) 73 | sample['label'] = torch.from_numpy(lbl[idx]) 74 | 75 | return sample 76 | 77 | def get_support_index(self, n_shot, C): 78 | """ 79 | Selecting intervals according to Ouyang et al. 80 | """ 81 | if n_shot == 1: 82 | pcts = [0.5] 83 | else: 84 | half_part = 1 / (n_shot * 2) 85 | part_interval = (1.0 - 1.0 / n_shot) / (n_shot - 1) 86 | pcts = [half_part + part_interval * ii for ii in range(n_shot)] 87 | 88 | return (np.array(pcts) * C).astype('int') 89 | 90 | def getSupport(self, label=None, all_slices=True, N=None): 91 | if label is None: 92 | raise ValueError('Need to specify label class!') 93 | 94 | img_path = self.support_dir 95 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) 96 | img = (img - img.mean()) / img.std() 97 | img = np.stack(3 * [img], axis=1) 98 | 99 | lbl = sitk.GetArrayFromImage( 100 | sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1])) 101 | lbl[lbl == 200] = 1 102 | lbl[lbl == 500] = 2 103 | lbl[lbl == 600] = 3 104 | lbl = 1 * (lbl == label) 105 | 106 | sample = {} 107 | if all_slices: 108 | sample['image'] = torch.from_numpy(img) 109 | sample['label'] = torch.from_numpy(lbl) 110 | else: 111 | # select N labeled slices 112 | if N is None: 113 | raise ValueError('Need to specify number of labeled slices!') 114 | idx = lbl.sum(axis=(1, 2)) > 0 115 | idx_ = self.get_support_index(N, idx.sum()) 116 | 117 | sample['image'] = torch.from_numpy(img[idx][idx_]) 118 | sample['label'] = torch.from_numpy(lbl[idx][idx_]) 119 | 120 | return sample 121 | 122 | 123 | class TrainDataset(Dataset): 124 | 125 | def __init__(self, args): 126 | self.n_shot = args['n_shot'] 127 | self.n_way = args['n_way'] 128 | self.n_query = args['n_query'] 129 | self.n_sv = args['n_sv'] 130 | self.max_iter = args['max_iter'] 131 | self.read = True # read images before get_item 132 | self.train_sampling = 'neighbors' 133 | 134 | self.min_size = args['min_size'] 135 | self.test_label = args['test_label'] 136 | self.exclude_label = args['exclude_label'] 137 | self.use_gt = args['use_gt'] 138 | 139 | # reading the paths (leaving the reading of images into memory to __getitem__) 140 | 141 | if args['dataset'] == 'CARDIAC_bssFP': 142 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_bssFP_normalized/image*')) 143 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_bssFP_normalized/label*')) 144 | elif args['dataset'] == 'CARDIAC_LGE': 145 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_LGE_normalized/image*')) 146 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_LGE_normalized/label*')) 147 | elif args['dataset'] == 'ABDOMEN_MR': 148 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/image*')) 149 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/label*')) 150 | elif args['dataset'] == 'ABDOMEN_CT': 151 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/image*')) 152 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/label*')) 153 | elif args['dataset'] == 'Prostate_UCLH': 154 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'UCLH_normalized/image*')) 155 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'UCLH_normalized/label*')) 156 | elif args['dataset'] == 'Prostate_NCI': 157 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'NCI_normalized/image*')) 158 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'NCI_normalized/label*')) 159 | elif args['dataset'] == 'Prostate_TCIA_PD': 160 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'tcia_pd_normalized/image*')) 161 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'tcia_pd_normalized/label*')) 162 | 163 | 164 | self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 165 | self.label_dirs = sorted(self.label_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 166 | self.sprvxl_dirs = glob.glob(os.path.join(args['data_dir'], 'supervoxels_' + str(args['n_sv']), 'super*')) 167 | self.sprvxl_dirs = sorted(self.sprvxl_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 168 | 169 | # remove test fold! 170 | self.FOLD = get_folds(args['dataset']) 171 | self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx not in self.FOLD[args['eval_fold']]] 172 | self.label_dirs = [elem for idx, elem in enumerate(self.label_dirs) if idx not in self.FOLD[args['eval_fold']]] 173 | self.sprvxl_dirs = [elem for idx, elem in enumerate(self.sprvxl_dirs) if 174 | idx not in self.FOLD[args['eval_fold']]] 175 | 176 | # read images 177 | if self.read: 178 | self.images = {} 179 | self.labels = {} 180 | self.sprvxls = {} 181 | for image_dir, label_dir, sprvxl_dir in zip(self.image_dirs, self.label_dirs, self.sprvxl_dirs): 182 | self.images[image_dir] = sitk.GetArrayFromImage(sitk.ReadImage(image_dir)) 183 | self.labels[label_dir] = sitk.GetArrayFromImage(sitk.ReadImage(label_dir)) 184 | self.sprvxls[sprvxl_dir] = sitk.GetArrayFromImage(sitk.ReadImage(sprvxl_dir)) 185 | 186 | def __len__(self): 187 | return self.max_iter 188 | 189 | def gamma_tansform(self, img): 190 | gamma_range = (0.5, 1.5) 191 | gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0] 192 | cmin = img.min() 193 | irange = (img.max() - cmin + 1e-5) 194 | 195 | img = img - cmin + 1e-5 196 | img = irange * np.power(img * 1.0 / irange, gamma) 197 | img = img + cmin 198 | 199 | return img 200 | 201 | def geom_transform(self, img, mask): 202 | 203 | affine = {'rotate': 5, 'shift': (5, 5), 'shear': 5, 'scale': (0.9, 1.2)} 204 | alpha = 10 205 | sigma = 5 206 | order = 3 207 | 208 | tfx = [] 209 | tfx.append(myit.RandomAffine(affine.get('rotate'), 210 | affine.get('shift'), 211 | affine.get('shear'), 212 | affine.get('scale'), 213 | affine.get('scale_iso', True), 214 | order=order)) 215 | tfx.append(myit.ElasticTransform(alpha, sigma)) 216 | transform = deftfx.Compose(tfx) 217 | 218 | if len(img.shape) > 4: 219 | n_shot = img.shape[1] 220 | for shot in range(n_shot): 221 | cat = np.concatenate((img[0, shot], mask[:, shot])).transpose(1, 2, 0) 222 | cat = transform(cat).transpose(2, 0, 1) 223 | img[0, shot] = cat[:3, :, :] 224 | mask[:, shot] = np.rint(cat[3:, :, :]) 225 | 226 | else: 227 | for q in range(img.shape[0]): 228 | cat = np.concatenate((img[q], mask[q][None])).transpose(1, 2, 0) 229 | cat = transform(cat).transpose(2, 0, 1) 230 | img[q] = cat[:3, :, :] 231 | mask[q] = np.rint(cat[3:, :, :].squeeze()) 232 | 233 | return img, mask 234 | 235 | def __getitem__(self, idx): 236 | 237 | # sample patient idx 238 | pat_idx = random.choice(range(len(self.image_dirs))) 239 | 240 | if self.read: 241 | # get image/supervoxel volume from dictionary 242 | img = self.images[self.image_dirs[pat_idx]] 243 | gt = self.labels[self.label_dirs[pat_idx]] 244 | sprvxl = self.sprvxls[self.sprvxl_dirs[pat_idx]] 245 | else: 246 | # read image/supervoxel volume into memory 247 | img = sitk.GetArrayFromImage(sitk.ReadImage(self.image_dirs[pat_idx])) 248 | gt = sitk.GetArrayFromImage(sitk.ReadImage(self.label_dirs[pat_idx])) 249 | sprvxl = sitk.GetArrayFromImage(sitk.ReadImage(self.sprvxl_dirs[pat_idx])) 250 | 251 | if self.exclude_label is not None: # identify the slices containing test labels 252 | idx = np.arange(gt.shape[0]) 253 | exclude_idx = np.full(gt.shape[0], True, dtype=bool) 254 | for i in range(len(self.exclude_label)): 255 | exclude_idx = exclude_idx & (np.sum(gt == self.exclude_label[i], axis=(1, 2)) > 0) 256 | exclude_idx = idx[exclude_idx] 257 | else: 258 | exclude_idx = [] 259 | 260 | # normalize 261 | img = (img - img.mean()) / img.std() 262 | 263 | # chose training label 264 | if self.use_gt: 265 | lbl = gt.copy() 266 | else: 267 | lbl = sprvxl.copy() 268 | # lbl is label numpy 269 | 270 | # sample class(es) (gt/supervoxel) 271 | unique = list(np.unique(lbl)) 272 | unique.remove(0) 273 | if self.use_gt: 274 | unique = list(set(unique) - set(self.test_label)) 275 | 276 | size = 0 277 | while size < self.min_size: 278 | n_slices = (self.n_shot * self.n_way) + self.n_query - 1 279 | while n_slices < ((self.n_shot * self.n_way) + self.n_query): 280 | cls_idx = random.choice(unique) # cls_idx is sampled class id 281 | 282 | # extract slices containing the sampled class 283 | sli_idx = np.sum(lbl == cls_idx, axis=(1, 2)) > 0 284 | idx = np.arange(lbl.shape[0]) 285 | sli_idx = idx[sli_idx] 286 | sli_idx = list( 287 | set(sli_idx) - set(np.intersect1d(sli_idx, exclude_idx))) # remove slices containing test labels 288 | n_slices = len(sli_idx) 289 | 290 | # generate possible subsets with successive slices (size = self.n_shot * self.n_way + self.n_query) 291 | subsets = [] 292 | for i in range(len(sli_idx)): 293 | if not subsets: 294 | subsets.append([sli_idx[i]]) 295 | elif sli_idx[i - 1] + 1 == sli_idx[i]: 296 | subsets[-1].append(sli_idx[i]) 297 | else: 298 | subsets.append([sli_idx[i]]) 299 | i = 0 300 | while i < len(subsets): 301 | if len(subsets[i]) < (self.n_shot * self.n_way + self.n_query): 302 | del subsets[i] 303 | else: 304 | i += 1 305 | if not len(subsets): 306 | return self.__getitem__(idx + np.random.randint(low=0, high=self.max_iter - 1, size=(1,))) 307 | 308 | # sample support and query slices 309 | i = random.choice(np.arange(len(subsets))) # subset index 310 | i = random.choice(subsets[i][:-(self.n_shot * self.n_way + self.n_query - 1)]) 311 | sample = np.arange(i, i + (self.n_shot * self.n_way) + self.n_query) 312 | 313 | lbl_cls = 1 * (lbl == cls_idx) 314 | 315 | size = max(np.sum(lbl_cls[sample[0]]), np.sum(lbl_cls[sample[1]])) 316 | 317 | # invert order 318 | if np.random.random(1) > 0.5: 319 | sample = sample[::-1] # successive slices (inverted) 320 | 321 | sup_lbl = lbl_cls[sample[:self.n_shot * self.n_way]][None,] # n_way * (n_shot * C) * H * W 322 | qry_lbl = lbl_cls[sample[self.n_shot * self.n_way:]] # n_qry * C * H * W 323 | 324 | sup_img = img[sample[:self.n_shot * self.n_way]][None,] # n_way * (n_shot * C) * H * W 325 | sup_img = np.stack((sup_img, sup_img, sup_img), axis=2) 326 | qry_img = img[sample[self.n_shot * self.n_way:]] # n_qry * C * H * W 327 | qry_img = np.stack((qry_img, qry_img, qry_img), axis=1) 328 | 329 | # gamma transform 330 | if np.random.random(1) > 0.5: 331 | qry_img = self.gamma_tansform(qry_img) 332 | else: 333 | sup_img = self.gamma_tansform(sup_img) 334 | 335 | # geom transform 336 | if np.random.random(1) > 0.5: 337 | qry_img, qry_lbl = self.geom_transform(qry_img, qry_lbl) 338 | else: 339 | sup_img, sup_lbl = self.geom_transform(sup_img, sup_lbl) 340 | 341 | sample = {'support_images': sup_img, 342 | 'support_fg_labels': sup_lbl, 343 | 'query_images': qry_img, 344 | 'query_labels': qry_lbl, 345 | 'selected_class': cls_idx} 346 | 347 | return sample 348 | 349 | -------------------------------------------------------------------------------- /dataloaders/image_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image Transformation 3 | Code originally from Ouyang et al. (used in the 2D setting) 4 | """ 5 | 6 | from collections.abc import Sequence 7 | import cv2 8 | import numpy as np 9 | 10 | from scipy.ndimage.filters import gaussian_filter 11 | from scipy.ndimage.interpolation import map_coordinates 12 | from numpy.lib.stride_tricks import as_strided 13 | 14 | 15 | ###### UTILITIES ###### 16 | def random_num_generator(config, random_state=np.random): 17 | if config[0] == 'uniform': 18 | ret = random_state.uniform(config[1], config[2], 1)[0] 19 | elif config[0] == 'lognormal': 20 | ret = random_state.lognormal(config[1], config[2], 1)[0] 21 | else: 22 | # print(config) 23 | raise Exception('unsupported format') 24 | return ret 25 | 26 | 27 | def get_translation_matrix(translation): 28 | """ translation: [tx, ty] """ 29 | tx, ty = translation 30 | translation_matrix = np.array([[1, 0, tx], 31 | [0, 1, ty], 32 | [0, 0, 1]]) 33 | return translation_matrix 34 | 35 | 36 | def get_rotation_matrix(rotation, input_shape, centred=True): 37 | theta = np.pi / 180 * np.array(rotation) 38 | if centred: 39 | rotation_matrix = cv2.getRotationMatrix2D((input_shape[0] / 2, input_shape[1] // 2), rotation, 1) 40 | rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]]) 41 | else: 42 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 43 | [np.sin(theta), np.cos(theta), 0], 44 | [0, 0, 1]]) 45 | return rotation_matrix 46 | 47 | 48 | def get_zoom_matrix(zoom, input_shape, centred=True): 49 | zx, zy = zoom 50 | if centred: 51 | zoom_matrix = cv2.getRotationMatrix2D((input_shape[0] / 2, input_shape[1] // 2), 0, zoom[0]) 52 | zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]]) 53 | else: 54 | zoom_matrix = np.array([[zx, 0, 0], 55 | [0, zy, 0], 56 | [0, 0, 1]]) 57 | return zoom_matrix 58 | 59 | 60 | def get_shear_matrix(shear_angle): 61 | theta = (np.pi * shear_angle) / 180 62 | shear_matrix = np.array([[1, -np.sin(theta), 0], 63 | [0, np.cos(theta), 0], 64 | [0, 0, 1]]) 65 | return shear_matrix 66 | 67 | 68 | ###### AFFINE TRANSFORM ###### 69 | class RandomAffine(object): 70 | """Apply random affine transformation on a numpy.ndarray (H x W x C) 71 | Comment by co1818: this is still doing affine on 2d (H x W plane). 72 | A same transform is applied to all C channels 73 | 74 | Parameter: 75 | ---------- 76 | 77 | alpha: Range [0, 4] seems good for small images 78 | 79 | order: interpolation method (c.f. opencv) 80 | """ 81 | 82 | def __init__(self, 83 | rotation_range=None, 84 | translation_range=None, 85 | shear_range=None, 86 | zoom_range=None, 87 | zoom_keep_aspect=False, 88 | interp='bilinear', 89 | order=3): 90 | """ 91 | Perform an affine transforms. 92 | 93 | Arguments 94 | --------- 95 | rotation_range : one integer or float 96 | image will be rotated randomly between (-degrees, degrees) 97 | 98 | translation_range : (x_shift, y_shift) 99 | shifts in pixels 100 | 101 | *NOT TESTED* shear_range : float 102 | image will be sheared randomly between (-degrees, degrees) 103 | 104 | zoom_range : (zoom_min, zoom_max) 105 | list/tuple with two floats between [0, infinity). 106 | first float should be less than the second 107 | lower and upper bounds on percent zoom. 108 | Anything less than 1.0 will zoom in on the image, 109 | anything greater than 1.0 will zoom out on the image. 110 | e.g. (0.7, 1.0) will only zoom in, 111 | (1.0, 1.4) will only zoom out, 112 | (0.7, 1.4) will randomly zoom in or out 113 | """ 114 | 115 | self.rotation_range = rotation_range 116 | self.translation_range = translation_range 117 | self.shear_range = shear_range 118 | self.zoom_range = zoom_range 119 | self.zoom_keep_aspect = zoom_keep_aspect 120 | self.interp = interp 121 | self.order = order 122 | 123 | def build_M(self, input_shape): 124 | tfx = [] 125 | final_tfx = np.eye(3) 126 | if self.rotation_range: 127 | rot = np.random.uniform(-self.rotation_range, self.rotation_range) 128 | tfx.append(get_rotation_matrix(rot, input_shape)) 129 | if self.translation_range: 130 | tx = np.random.uniform(-self.translation_range[0], self.translation_range[0]) 131 | ty = np.random.uniform(-self.translation_range[1], self.translation_range[1]) 132 | tfx.append(get_translation_matrix((tx, ty))) 133 | if self.shear_range: 134 | rot = np.random.uniform(-self.shear_range, self.shear_range) 135 | tfx.append(get_shear_matrix(rot)) 136 | if self.zoom_range: 137 | sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) 138 | if self.zoom_keep_aspect: 139 | sy = sx 140 | else: 141 | sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) 142 | 143 | tfx.append(get_zoom_matrix((sx, sy), input_shape)) 144 | 145 | for tfx_mat in tfx: 146 | final_tfx = np.dot(tfx_mat, final_tfx) 147 | 148 | return final_tfx.astype(np.float32) 149 | 150 | def __call__(self, image): 151 | # build matrix 152 | input_shape = image.shape[:2] 153 | M = self.build_M(input_shape) 154 | 155 | res = np.zeros_like(image) 156 | # if isinstance(self.interp, Sequence): 157 | if type(self.order) is list or type(self.order) is tuple: 158 | for i, intp in enumerate(self.order): 159 | res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp) 160 | else: 161 | # squeeze if needed 162 | orig_shape = image.shape 163 | image_s = np.squeeze(image) 164 | res = affine_transform_via_M(image_s, M[:2], interp=self.order) 165 | res = res.reshape(orig_shape) 166 | 167 | # res = affine_transform_via_M(image, M[:2], interp=self.order) 168 | 169 | return res 170 | 171 | 172 | def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST): 173 | imshape = image.shape 174 | shape_size = imshape[:2] 175 | 176 | # Random affine 177 | warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1], 178 | flags=interp, borderMode=borderMode) 179 | 180 | # print(imshape, warped.shape) 181 | 182 | warped = warped[..., np.newaxis].reshape(imshape) 183 | 184 | return warped 185 | 186 | 187 | ###### ELASTIC TRANSFORM ###### 188 | def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): 189 | """Elastic deformation of image as described in [Simard2003]_. 190 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 191 | Convolutional Neural Networks applied to Visual Document Analysis", in 192 | Proc. of the International Conference on Document Analysis and 193 | Recognition, 2003. 194 | """ 195 | assert image.ndim == 3 196 | shape = image.shape[:2] 197 | 198 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), 199 | sigma, mode="constant", cval=0) * alpha 200 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), 201 | sigma, mode="constant", cval=0) * alpha 202 | 203 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 204 | indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] 205 | result = np.empty_like(image) 206 | for i in range(image.shape[2]): 207 | result[:, :, i] = map_coordinates( 208 | image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) 209 | return result 210 | 211 | 212 | def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False): 213 | """Expects data to be (nx, ny, n1 ,..., nm) 214 | params: 215 | ------ 216 | 217 | alpha: 218 | the scaling parameter. 219 | E.g.: alpha=2 => distorts images up to 2x scaling 220 | 221 | sigma: 222 | standard deviation of gaussian filter. 223 | E.g. 224 | low (sig~=1e-3) => no smoothing, pixelated. 225 | high (1/5 * imsize) => smooth, more like affine. 226 | very high (1/2*im_size) => translation 227 | """ 228 | 229 | if random_state is None: 230 | random_state = np.random.RandomState(None) 231 | 232 | shape = image.shape 233 | imsize = shape[:2] 234 | dim = shape[2:] 235 | 236 | # Random affine 237 | blur_size = int(4 * sigma) | 1 238 | dx = cv2.GaussianBlur(random_state.rand(*imsize) * 2 - 1, 239 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 240 | dy = cv2.GaussianBlur(random_state.rand(*imsize) * 2 - 1, 241 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 242 | 243 | # use as_strided to copy things over across n1...nn channels 244 | dx = as_strided(dx.astype(np.float32), 245 | strides=(0,) * len(dim) + (4 * shape[1], 4), 246 | shape=dim + (shape[0], shape[1])) 247 | dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim)))) 248 | 249 | dy = as_strided(dy.astype(np.float32), 250 | strides=(0,) * len(dim) + (4 * shape[1], 4), 251 | shape=dim + (shape[0], shape[1])) 252 | dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim)))) 253 | 254 | coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim]) 255 | indices = [np.reshape(e + de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:], 256 | [dy, dx] + [0] * len(dim))] 257 | 258 | if lazy: 259 | return indices 260 | 261 | return map_coordinates(image, indices, order=order, mode='reflect').reshape(shape) 262 | 263 | 264 | class ElasticTransform(object): 265 | """Apply elastic transformation on a numpy.ndarray (H x W x C) 266 | """ 267 | 268 | def __init__(self, alpha, sigma, order=1): 269 | self.alpha = alpha 270 | self.sigma = sigma 271 | self.order = order 272 | 273 | def __call__(self, image): 274 | if isinstance(self.alpha, Sequence): 275 | alpha = random_num_generator(self.alpha) 276 | else: 277 | alpha = self.alpha 278 | if isinstance(self.sigma, Sequence): 279 | sigma = random_num_generator(self.sigma) 280 | else: 281 | sigma = self.sigma 282 | return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order) 283 | 284 | 285 | class RandomFlip3D(object): 286 | 287 | def __init__(self, h=True, v=True, t=True, p=0.5): 288 | """ 289 | Randomly flip an image horizontally and/or vertically with 290 | some probability. 291 | 292 | Arguments 293 | --------- 294 | h : boolean 295 | whether to horizontally flip w/ probability p 296 | 297 | v : boolean 298 | whether to vertically flip w/ probability p 299 | 300 | p : float between [0,1] 301 | probability with which to apply allowed flipping operations 302 | """ 303 | self.horizontal = h 304 | self.vertical = v 305 | self.depth = t 306 | self.p = p 307 | 308 | def __call__(self, x, y=None): 309 | # horizontal flip with p = self.p 310 | if self.horizontal: 311 | if np.random.random() < self.p: 312 | x = x[::-1, ...] 313 | 314 | # vertical flip with p = self.p 315 | if self.vertical: 316 | if np.random.random() < self.p: 317 | x = x[:, ::-1, ...] 318 | 319 | if self.depth: 320 | if np.random.random() < self.p: 321 | x = x[..., ::-1] 322 | 323 | return x 324 | -------------------------------------------------------------------------------- /models/__pycache__/CDFSMIS.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primebo1/FAMNet/3b14839129d2ad7362c458845e56af01238231ec/models/__pycache__/CDFSMIS.cpython-311.pyc -------------------------------------------------------------------------------- /models/__pycache__/FAM.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primebo1/FAMNet/3b14839129d2ad7362c458845e56af01238231ec/models/__pycache__/FAM.cpython-311.pyc -------------------------------------------------------------------------------- /models/__pycache__/MSFM.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primebo1/FAMNet/3b14839129d2ad7362c458845e56af01238231ec/models/__pycache__/MSFM.cpython-311.pyc -------------------------------------------------------------------------------- /models/__pycache__/cdfs_TS.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primebo1/FAMNet/3b14839129d2ad7362c458845e56af01238231ec/models/__pycache__/cdfs_TS.cpython-311.pyc -------------------------------------------------------------------------------- /models/__pycache__/encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primebo1/FAMNet/3b14839129d2ad7362c458845e56af01238231ec/models/__pycache__/encoder.cpython-311.pyc -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | """ 6 | Pretrained Model: ResNet50 7 | Pretrained on Dataset: COCO and ResNet 8 | """ 9 | 10 | 11 | class Res50Encoder(nn.Module): 12 | """ 13 | Resnet50 backbone from deeplabv3 14 | modify the 'downsample' component in layer2 and/or layer3 and/or layer4 as the vanilla Resnet 15 | """ 16 | 17 | def __init__(self, replace_stride_with_dilation=None, pretrained_weights='resnet50'): 18 | super().__init__() 19 | # using pretrained model's weights 20 | if pretrained_weights == 'COCO': 21 | self.pretrained_weights = torch.load( 22 | "deeplabv3_resnet50_coco-cd0a2569.pth", map_location='cpu') # pretrained on COCO 23 | elif pretrained_weights == 'ImageNet': 24 | self.pretrained_weights = torch.load( 25 | "resnet50-19c8e357.pth", map_location='cpu') # pretrained on ImageNet 26 | else: 27 | self.pretrained_weights = pretrained_weights 28 | 29 | _model = torchvision.models.resnet.resnet50(pretrained=False, 30 | replace_stride_with_dilation=replace_stride_with_dilation) 31 | self.backbone = nn.ModuleDict() 32 | for dic, m in _model.named_children(): 33 | self.backbone[dic] = m 34 | 35 | self.reduce1 = nn.Conv2d(1024, 512, kernel_size=1, bias=False) 36 | self.reduce2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) 37 | self.reduce1d = nn.Linear(in_features=1000, out_features=1, bias=True) 38 | 39 | self.IN_layer = nn.Sequential( 40 | nn.InstanceNorm2d(512), 41 | nn.ReLU(inplace=True) 42 | ) 43 | 44 | self._init_weights() 45 | 46 | def forward(self, x): 47 | 48 | """ 49 | :param x: (2, 3, 256, 256) 50 | :return: 51 | """ 52 | x = self.backbone["conv1"](x) 53 | x = self.backbone["bn1"](x) 54 | x = self.backbone["relu"](x) 55 | 56 | x = self.backbone["maxpool"](x) 57 | x = self.backbone["layer1"](x) 58 | x = self.backbone["layer2"](x) 59 | x = self.backbone["layer3"](x) 60 | feature = self.reduce1(x) # (2, 512, 64, 64) 61 | feature = self.IN_layer(feature) 62 | x = self.backbone["layer4"](x) 63 | # feature map -> avgpool -> fc -> single value 64 | t = self.backbone["avgpool"](x) 65 | t = torch.flatten(t, 1) 66 | t = self.backbone["fc"](t) 67 | t = self.reduce1d(t) 68 | return (feature, t) 69 | 70 | def _init_weights(self): 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 74 | elif isinstance(m, nn.BatchNorm2d): 75 | nn.init.constant_(m.weight, 1) 76 | nn.init.constant_(m.bias, 0) 77 | 78 | if self.pretrained_weights is not None: 79 | keys = list(self.pretrained_weights.keys()) 80 | new_dic = self.state_dict() 81 | new_keys = list(new_dic.keys()) 82 | 83 | for i in range(len(keys)): 84 | if keys[i] in new_keys: 85 | new_dic[keys[i]] = self.pretrained_weights[keys[i]] 86 | 87 | self.load_state_dict(new_dic) 88 | 89 | -------------------------------------------------------------------------------- /scripts/test_LGE2bssFP.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # test a model to segment abdominal/cardiac MRI 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='CARDIAC_bssFP' 8 | #DATASET='CMR' 9 | NWORKER=16 10 | RUNS=1 11 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 12 | TEST_LABEL=[1,2,3] 13 | ###### Training configs ###### 14 | NSTEP=39001 15 | DECAY=0.98 16 | 17 | MAX_ITER=3000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=3000 # interval for saving snapshot 19 | SEED=2025 20 | 21 | N_PART=3 # defines the number of chunks for evaluation 22 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 23 | echo ======================================================================== 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./results" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | for SUPP_IDX in "${ALL_SUPP[@]}" 35 | do 36 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 37 | RELOAD_MODEL_PATH=".../exps_train_on_CARDIAC_LGE_cdfs_FAM_resnet50/CDFS_train_CARDIAC_LGE_cv${EVAL_FOLD}/1/snapshots/39000.pth" 38 | python3 test.py with \ 39 | mode="test" \ 40 | dataset=$DATASET \ 41 | num_workers=$NWORKER \ 42 | n_steps=$NSTEP \ 43 | eval_fold=$EVAL_FOLD \ 44 | max_iters_per_load=$MAX_ITER \ 45 | supp_idx=$SUPP_IDX \ 46 | test_label=$TEST_LABEL \ 47 | seed=$SEED \ 48 | n_part=$N_PART \ 49 | reload_model_path=$RELOAD_MODEL_PATH \ 50 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 51 | lr_step_gamma=$DECAY \ 52 | path.log_dir=$LOGDIR 53 | done 54 | done 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /scripts/test_NCI2UCLH.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # test a model to segment abdominal/cardiac MRI 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='Prostate_UCLH' 8 | #DATASET='CMR' 9 | NWORKER=16 10 | RUNS=1 11 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 12 | TEST_LABEL=[1,5,6] 13 | ###### Training configs ###### 14 | NSTEP=39001 15 | DECAY=0.98 16 | 17 | MAX_ITER=3000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=3000 # interval for saving snapshot 19 | SEED=2025 20 | 21 | N_PART=3 # defines the number of chunks for evaluation 22 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7, 23 | echo ======================================================================== 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./results" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | for SUPP_IDX in "${ALL_SUPP[@]}" 35 | do 36 | RELOAD_MODEL_PATH=".../exps_train_on_Prostate_NCI_cdfs_FAM/CDFS_train_Prostate_NCI_cv${EVAL_FOLD}/1/snapshots/39000.pth" 37 | python3 test.py with \ 38 | mode="test" \ 39 | dataset=$DATASET \ 40 | num_workers=$NWORKER \ 41 | n_steps=$NSTEP \ 42 | eval_fold=$EVAL_FOLD \ 43 | max_iters_per_load=$MAX_ITER \ 44 | supp_idx=$SUPP_IDX \ 45 | test_label=$TEST_LABEL \ 46 | seed=$SEED \ 47 | n_part=$N_PART \ 48 | reload_model_path=$RELOAD_MODEL_PATH \ 49 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 50 | lr_step_gamma=$DECAY \ 51 | path.log_dir=$LOGDIR 52 | done 53 | done 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /scripts/test_UCLH2NCI.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # test a model to segment abdominal/cardiac MRI 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='Prostate_NCI' 8 | #DATASET='CMR' 9 | NWORKER=16 10 | RUNS=1 11 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 12 | TEST_LABEL=[1,5,6] 13 | ###### Training configs ###### 14 | NSTEP=40000 15 | DECAY=0.98 16 | 17 | MAX_ITER=3000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=3000 # interval for saving snapshot 19 | SEED=2025 20 | 21 | N_PART=3 # defines the number of chunks for evaluation 22 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7, 23 | echo ======================================================================== 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./results" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | for SUPP_IDX in "${ALL_SUPP[@]}" 35 | do 36 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 37 | RELOAD_MODEL_PATH=".../exps_train_on_Prostate_UCLH_cdfs_FAM/CDFS_train_Prostate_UCLH_cv${EVAL_FOLD}/1/snapshots/39000.pth" 38 | 39 | python3 test.py with \ 40 | mode="test" \ 41 | dataset=$DATASET \ 42 | num_workers=$NWORKER \ 43 | n_steps=$NSTEP \ 44 | eval_fold=$EVAL_FOLD \ 45 | max_iters_per_load=$MAX_ITER \ 46 | supp_idx=$SUPP_IDX \ 47 | test_label=$TEST_LABEL \ 48 | seed=$SEED \ 49 | n_part=$N_PART \ 50 | reload_model_path=$RELOAD_MODEL_PATH \ 51 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 52 | lr_step_gamma=$DECAY \ 53 | path.log_dir=$LOGDIR 54 | done 55 | done 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /scripts/test_bssFP2LGE.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # test a model to segment abdominal/cardiac MRI 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='CARDIAC_LGE' 8 | #DATASET='CMR' 9 | NWORKER=16 10 | RUNS=1 11 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 12 | TEST_LABEL=[1,2,3] 13 | ###### Training configs ###### 14 | NSTEP=39001 15 | DECAY=0.98 16 | 17 | MAX_ITER=3000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=3000 # interval for saving snapshot 19 | SEED=2025 20 | 21 | N_PART=3 # defines the number of chunks for evaluation 22 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 23 | 24 | echo ======================================================================== 25 | for EVAL_FOLD in "${ALL_EV[@]}" 26 | do 27 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 28 | echo $PREFIX 29 | LOGDIR="./results" 30 | 31 | if [ ! -d $LOGDIR ] 32 | then 33 | mkdir -p $LOGDIR 34 | fi 35 | for SUPP_IDX in "${ALL_SUPP[@]}" 36 | do 37 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 38 | RELOAD_MODEL_PATH=".../exps_train_on_CARDIAC_bssFP_cdfs_FAM_resnet50/CDFS_train_CARDIAC_bssFP_cv${EVAL_FOLD}/1/snapshots/39000.pth" 39 | python3 test.py with \ 40 | mode="test" \ 41 | dataset=$DATASET \ 42 | num_workers=$NWORKER \ 43 | n_steps=$NSTEP \ 44 | eval_fold=$EVAL_FOLD \ 45 | max_iters_per_load=$MAX_ITER \ 46 | supp_idx=$SUPP_IDX \ 47 | test_label=$TEST_LABEL \ 48 | seed=$SEED \ 49 | n_part=$N_PART \ 50 | reload_model_path=$RELOAD_MODEL_PATH \ 51 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 52 | lr_step_gamma=$DECAY \ 53 | path.log_dir=$LOGDIR 54 | done 55 | done 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /scripts/test_ct2mr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # test a model to segment abdominal/cardiac MRI 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='ABDOMEN_MR' 8 | #DATASET='CMR' 9 | NWORKER=16 10 | RUNS=1 11 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 12 | TEST_LABEL=[1,2,3,4] 13 | ###### Training configs ###### 14 | NSTEP=39001 15 | DECAY=0.98 16 | 17 | MAX_ITER=3000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=3000 # interval for saving snapshot 19 | SEED=2025 20 | 21 | N_PART=3 # defines the number of chunks for evaluation 22 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 23 | 24 | echo ======================================================================== 25 | 26 | for EVAL_FOLD in "${ALL_EV[@]}" 27 | do 28 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 29 | echo $PREFIX 30 | LOGDIR="./results" 31 | 32 | if [ ! -d $LOGDIR ] 33 | then 34 | mkdir -p $LOGDIR 35 | fi 36 | for SUPP_IDX in "${ALL_SUPP[@]}" 37 | do 38 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 39 | RELOAD_MODEL_PATH=".../exps_train_on_ABDOMEN_CT_cdfs_FAM/CDFS_train_ABDOMEN_CT_cv${EVAL_FOLD}/1/snapshots/39000.pth" 40 | 41 | python3 test.py with \ 42 | mode="test" \ 43 | dataset=$DATASET \ 44 | num_workers=$NWORKER \ 45 | n_steps=$NSTEP \ 46 | eval_fold=$EVAL_FOLD \ 47 | max_iters_per_load=$MAX_ITER \ 48 | supp_idx=$SUPP_IDX \ 49 | test_label=$TEST_LABEL \ 50 | seed=$SEED \ 51 | n_part=$N_PART \ 52 | reload_model_path=$RELOAD_MODEL_PATH \ 53 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 54 | lr_step_gamma=$DECAY \ 55 | path.log_dir=$LOGDIR 56 | done 57 | done 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /scripts/test_mr2ct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # test a model to segment abdominal/cardiac MRI 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='ABDOMEN_CT' 8 | #DATASET='CMR' 9 | NWORKER=16 10 | RUNS=1 11 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 12 | TEST_LABEL=[1,2,3,6] 13 | ###### Training configs ###### 14 | NSTEP=39001 15 | DECAY=0.98 16 | 17 | MAX_ITER=3000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=3000 # interval for saving snapshot 19 | SEED=2025 20 | 21 | N_PART=3 # defines the number of chunks for evaluation 22 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 23 | echo ======================================================================== 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./results" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | for SUPP_IDX in "${ALL_SUPP[@]}" 35 | do 36 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 37 | RELOAD_MODEL_PATH=".../exps_train_on_ABDOMEN_MR_cdfs_FAM/CDFS_train_ABDOMEN_MR_cv${EVAL_FOLD}/1/snapshots/39000.pth" 38 | python3 test.py with \ 39 | mode="test" \ 40 | dataset=$DATASET \ 41 | num_workers=$NWORKER \ 42 | n_steps=$NSTEP \ 43 | eval_fold=$EVAL_FOLD \ 44 | max_iters_per_load=$MAX_ITER \ 45 | supp_idx=$SUPP_IDX \ 46 | test_label=$TEST_LABEL \ 47 | seed=$SEED \ 48 | n_part=$N_PART \ 49 | reload_model_path=$RELOAD_MODEL_PATH \ 50 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 51 | lr_step_gamma=$DECAY \ 52 | path.log_dir=$LOGDIR 53 | done 54 | done 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /scripts/train_LGE.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal MRI (T2 fold of CHAOS challenge) 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='CARDIAC_LGE' 8 | NWORKER=16 9 | RUNS=1 10 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 11 | TEST_LABEL=[1,2,3] 12 | EXCLUDE_LABEL=None 13 | USE_GT=False 14 | ###### Training configs ###### 15 | NSTEP=40001 16 | DECAY=0.98 17 | 18 | MAX_ITER=3000 # defines the size of an epoch 19 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 20 | SEED=2025 21 | 22 | echo ======================================================================== 23 | 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./exps_train_on_${DATASET}_cdfs_FAM_resnet50" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | 35 | python3 train.py with \ 36 | mode='train' \ 37 | dataset=$DATASET \ 38 | num_workers=$NWORKER \ 39 | n_steps=$NSTEP \ 40 | eval_fold=$EVAL_FOLD \ 41 | test_label=$TEST_LABEL \ 42 | exclude_label=$EXCLUDE_LABEL \ 43 | use_gt=$USE_GT \ 44 | max_iters_per_load=$MAX_ITER \ 45 | seed=$SEED \ 46 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 47 | lr_step_gamma=$DECAY \ 48 | path.log_dir=$LOGDIR 49 | done 50 | -------------------------------------------------------------------------------- /scripts/train_NCI.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal MRI (T2 fold of CHAOS challenge) 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='Prostate_NCI' 8 | NWORKER=16 9 | RUNS=1 10 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 11 | TEST_LABEL=[1,5,6] 12 | EXCLUDE_LABEL=None 13 | USE_GT=False 14 | ###### Training configs ###### 15 | NSTEP=39001 16 | DECAY=0.98 17 | 18 | MAX_ITER=3000 # defines the size of an epoch 19 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 20 | SEED=2025 21 | 22 | echo ======================================================================== 23 | 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./exps_train_on_${DATASET}_cdfs_FAM" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | 35 | python3 train.py with \ 36 | mode='train' \ 37 | dataset=$DATASET \ 38 | num_workers=$NWORKER \ 39 | n_steps=$NSTEP \ 40 | eval_fold=$EVAL_FOLD \ 41 | test_label=$TEST_LABEL \ 42 | exclude_label=$EXCLUDE_LABEL \ 43 | use_gt=$USE_GT \ 44 | max_iters_per_load=$MAX_ITER \ 45 | seed=$SEED \ 46 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 47 | lr_step_gamma=$DECAY \ 48 | path.log_dir=$LOGDIR 49 | done 50 | -------------------------------------------------------------------------------- /scripts/train_UCLH.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal MRI (T2 fold of CHAOS challenge) 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='Prostate_UCLH' 8 | NWORKER=16 9 | RUNS=1 10 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 11 | TEST_LABEL=[1,5,6] 12 | EXCLUDE_LABEL=None 13 | USE_GT=False 14 | ###### Training configs ###### 15 | NSTEP=39001 16 | DECAY=0.98 17 | 18 | MAX_ITER=3000 # defines the size of an epoch 19 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 20 | SEED=2025 21 | 22 | echo ======================================================================== 23 | 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./exps_train_on_${DATASET}_cdfs_FAM" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | 35 | python3 train.py with \ 36 | mode='train' \ 37 | dataset=$DATASET \ 38 | num_workers=$NWORKER \ 39 | n_steps=$NSTEP \ 40 | eval_fold=$EVAL_FOLD \ 41 | test_label=$TEST_LABEL \ 42 | exclude_label=$EXCLUDE_LABEL \ 43 | use_gt=$USE_GT \ 44 | max_iters_per_load=$MAX_ITER \ 45 | seed=$SEED \ 46 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 47 | lr_step_gamma=$DECAY \ 48 | path.log_dir=$LOGDIR 49 | done 50 | -------------------------------------------------------------------------------- /scripts/train_bssFP.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal MRI (T2 fold of CHAOS challenge) 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='CARDIAC_bssFP' 8 | NWORKER=16 9 | RUNS=1 10 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 11 | TEST_LABEL=[1,2,3] 12 | EXCLUDE_LABEL=None 13 | USE_GT=False 14 | ###### Training configs ###### 15 | NSTEP=40001 16 | DECAY=0.98 17 | 18 | MAX_ITER=3000 # defines the size of an epoch 19 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 20 | SEED=2025 21 | 22 | echo ======================================================================== 23 | 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./exps_train_on_${DATASET}_cdfs_FAM_resnet50" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | 35 | python3 train.py with \ 36 | mode='train' \ 37 | dataset=$DATASET \ 38 | num_workers=$NWORKER \ 39 | n_steps=$NSTEP \ 40 | eval_fold=$EVAL_FOLD \ 41 | test_label=$TEST_LABEL \ 42 | exclude_label=$EXCLUDE_LABEL \ 43 | use_gt=$USE_GT \ 44 | max_iters_per_load=$MAX_ITER \ 45 | seed=$SEED \ 46 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 47 | lr_step_gamma=$DECAY \ 48 | path.log_dir=$LOGDIR 49 | done 50 | -------------------------------------------------------------------------------- /scripts/train_ct2mr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal MRI (T2 fold of CHAOS challenge) 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='ABDOMEN_CT' 8 | NWORKER=16 9 | RUNS=1 10 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 11 | TEST_LABEL=[1,2,3,4] 12 | EXCLUDE_LABEL=None 13 | USE_GT=False 14 | ###### Training configs ###### 15 | NSTEP=40001 16 | DECAY=0.98 17 | 18 | MAX_ITER=3000 # defines the size of an epoch 19 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 20 | SEED=2025 21 | 22 | echo ======================================================================== 23 | 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./exps_train_on_${DATASET}_cdfs_FAM" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | 35 | python3 train.py with \ 36 | mode='train' \ 37 | dataset=$DATASET \ 38 | num_workers=$NWORKER \ 39 | n_steps=$NSTEP \ 40 | eval_fold=$EVAL_FOLD \ 41 | test_label=$TEST_LABEL \ 42 | exclude_label=$EXCLUDE_LABEL \ 43 | use_gt=$USE_GT \ 44 | max_iters_per_load=$MAX_ITER \ 45 | seed=$SEED \ 46 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 47 | lr_step_gamma=$DECAY \ 48 | path.log_dir=$LOGDIR 49 | done 50 | -------------------------------------------------------------------------------- /scripts/train_mr2ct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal MRI (T2 fold of CHAOS challenge) 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ###### Shared configs ###### 7 | DATASET='ABDOMEN_MR' 8 | NWORKER=16 9 | RUNS=1 10 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 11 | TEST_LABEL=[1,2,3,6] 12 | EXCLUDE_LABEL=None # use class 1, 6 as training classes 13 | USE_GT=False 14 | ###### Training configs ###### 15 | NSTEP=40001 16 | DECAY=0.98 17 | 18 | MAX_ITER=3000 # defines the size of an epoch 19 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 20 | SEED=2025 21 | 22 | echo ======================================================================== 23 | 24 | for EVAL_FOLD in "${ALL_EV[@]}" 25 | do 26 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 27 | echo $PREFIX 28 | LOGDIR="./exps_train_on_${DATASET}_cdfs_FAM" 29 | 30 | if [ ! -d $LOGDIR ] 31 | then 32 | mkdir -p $LOGDIR 33 | fi 34 | 35 | python3 train.py with \ 36 | mode='train' \ 37 | dataset=$DATASET \ 38 | num_workers=$NWORKER \ 39 | n_steps=$NSTEP \ 40 | eval_fold=$EVAL_FOLD \ 41 | test_label=$TEST_LABEL \ 42 | exclude_label=$EXCLUDE_LABEL \ 43 | use_gt=$USE_GT \ 44 | max_iters_per_load=$MAX_ITER \ 45 | seed=$SEED \ 46 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 47 | lr_step_gamma=$DECAY \ 48 | path.log_dir=$LOGDIR 49 | done 50 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | For evaluation 4 | Extended from ADNet code by Hansen et al. 5 | """ 6 | import shutil 7 | import SimpleITK as sitk 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torch.utils.data import DataLoader 11 | from models.CDFSMIS import FewShotSeg 12 | from dataloaders.datasets import TestDataset 13 | from dataloaders.dataset_specifics import * 14 | from utils import * 15 | from config import ex 16 | 17 | 18 | @ex.automain 19 | def main(_run, _config, _log): 20 | if _run.observers: 21 | os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) 22 | for source_file, _ in _run.experiment_info['sources']: 23 | os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), 24 | exist_ok=True) 25 | _run.observers[0].save_file(source_file, f'source/{source_file}') 26 | shutil.rmtree(f'{_run.observers[0].basedir}/_sources') 27 | 28 | # Set up logger -> log to .txt 29 | file_handler = logging.FileHandler(os.path.join(f'{_run.observers[0].dir}', f'logger.log')) 30 | file_handler.setLevel('INFO') 31 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s') 32 | file_handler.setFormatter(formatter) 33 | _log.handlers.append(file_handler) 34 | _log.info(f'Run "{_config["exp_str"]}" with ID "{_run.observers[0].dir[-1]}"') 35 | 36 | # Deterministic setting for reproduciablity. 37 | if _config['seed'] is not None: 38 | random.seed(_config['seed']) 39 | torch.manual_seed(_config['seed']) 40 | torch.cuda.manual_seed_all(_config['seed']) 41 | cudnn.deterministic = True 42 | 43 | # Enable cuDNN benchmark mode to select the fastest convolution algorithm. 44 | cudnn.enabled = True 45 | cudnn.benchmark = True 46 | torch.cuda.set_device(device=_config['gpu_id']) 47 | torch.set_num_threads(1) 48 | 49 | _log.info(f'Create model...') 50 | model_config = { 51 | # 'data_dir': _config['path'][_config['dataset']]['data_dir'], 52 | 'dataset': _config['dataset'], 53 | 'PREC': _config['PREC'], 54 | # 'train_classname': _config['train_classname'], 55 | # 'test_classname': _config['test_classname'], 56 | 'BACKBONE_NAME': _config['BACKBONE_NAME'], 57 | 'N_CTX': _config['N_CTX'], 58 | 'CTX_INIT': _config['CTX_INIT'], 59 | 'CLASS_TOKEN_POSITION': _config['CLASS_TOKEN_POSITION'], 60 | 'INPUT_SIZE': _config['INPUT_SIZE'], 61 | 'CSC': _config['CSC'], 62 | 'INIT_WEIGHTS': _config['INIT_WEIGHTS'], 63 | 'OPTIM': _config['OPTIM'], 64 | 'PROMPT_INIT': _config['PROMPT_INIT'], 65 | # 'classnames': _config['classnames'], 66 | } 67 | model = FewShotSeg(model_config) 68 | model.cuda() 69 | model.load_state_dict(torch.load(_config['reload_model_path'], map_location='cpu'), strict=False) 70 | 71 | _log.info(f'Load data...') 72 | data_config = { 73 | 'data_dir': _config['path'][_config['dataset']]['data_dir'], 74 | 'dataset': _config['dataset'], 75 | 'n_shot': _config['n_shot'], 76 | 'n_way': _config['n_way'], 77 | 'n_query': _config['n_query'], 78 | 'n_sv': _config['n_sv'], 79 | 'max_iter': _config['max_iters_per_load'], 80 | 'eval_fold': _config['eval_fold'], 81 | 'min_size': _config['min_size'], 82 | 'max_slices': _config['max_slices'], 83 | 'supp_idx': _config['supp_idx'], 84 | } 85 | test_dataset = TestDataset(data_config) 86 | test_loader = DataLoader(test_dataset, 87 | batch_size=_config['batch_size'], 88 | shuffle=True, 89 | num_workers=_config['num_workers'], 90 | pin_memory=True, 91 | drop_last=True) 92 | 93 | # Get unique labels (classes). 94 | labels = get_label_names(_config['dataset']) 95 | 96 | # Loop over classes. 97 | class_dice = {} 98 | class_iou = {} 99 | 100 | _log.info(f'Starting validation...') 101 | for label_val, label_name in labels.items(): 102 | 103 | # Skip BG class. 104 | if label_name == 'BG': 105 | continue 106 | elif np.intersect1d([label_val], _config['test_label']).size == 0: 107 | continue 108 | 109 | _log.info(f'Test Class: {label_name}') 110 | 111 | # Get support sample + mask for current class. 112 | support_sample = test_dataset.getSupport(label=label_val, all_slices=False, N=_config['n_part']) 113 | # support_sample['image']: (3, 3, 256, 256) 114 | # support_sample['label']: (3, 256, 256) 115 | 116 | test_dataset.label = label_val 117 | 118 | # Test. 119 | with torch.no_grad(): 120 | model.eval() 121 | 122 | # Unpack support data. 123 | support_image = [support_sample['image'][[i]].float().cuda() for i in 124 | range(support_sample['image'].shape[0])] # n_shot x 3 x H x W, support_image is a list {3X(1, 3, 256, 256)} 125 | support_fg_mask = [support_sample['label'][[i]].float().cuda() for i in 126 | range(support_sample['image'].shape[0])] # n_shot x H x W 127 | 128 | # Loop through query volumes. 129 | scores = Scores() 130 | for i, sample in enumerate(test_loader): # this "for" loops 4 times 131 | 132 | # Unpack query data. 133 | query_image = [sample['image'][i].float().cuda() for i in 134 | range(sample['image'].shape[0])] # [C x 3 x H x W] query_image is list {(C x 3 x H x W)} 135 | query_label = sample['label'].long() # C x H x W 136 | query_id = sample['id'][0].split('image_')[1][:-len('.nii.gz')] 137 | 138 | # prompt = _config['dataset'] 139 | 140 | # Compute output. 141 | # Match support slice and query sub-chunck. 142 | query_pred = torch.zeros(query_label.shape[-3:]) 143 | C_q = sample['image'].shape[1] # slice number of query img 144 | 145 | idx_ = np.linspace(0, C_q, _config['n_part'] + 1).astype('int') 146 | for sub_chunck in range(_config['n_part']): # n_part = 3 147 | support_image_s = [support_image[sub_chunck]] # 1 x 3 x H x W 148 | support_fg_mask_s = [support_fg_mask[sub_chunck]] # 1 x H x W 149 | query_image_s = query_image[0][idx_[sub_chunck]:idx_[sub_chunck + 1]] # C' x 3 x H x W 150 | query_pred_s = [] 151 | for i in range(query_image_s.shape[0]): 152 | _pred_s, _= model([support_image_s], [support_fg_mask_s], [query_image_s[[i]]], 153 | _, _, train=False) # 1 x 2 x H x W 154 | 155 | query_pred_s.append(_pred_s) 156 | query_pred_s = torch.cat(query_pred_s, dim=0) 157 | query_pred_s = query_pred_s.argmax(dim=1).cpu() # C x H x W 158 | query_pred[idx_[sub_chunck]:idx_[sub_chunck + 1]] = query_pred_s 159 | 160 | # Record scores. 161 | scores.record(query_pred, query_label) 162 | 163 | # Log. 164 | _log.info( 165 | f'Tested query volume: {sample["id"][0][len(_config["path"][_config["dataset"]]["data_dir"]):]}.') 166 | _log.info(f'Dice score: {scores.patient_dice[-1].item()}') 167 | 168 | # Save predictions. 169 | file_name = os.path.join(f'{_run.observers[0].dir}/interm_preds', 170 | f'prediction_{query_id}_{label_name}.nii.gz') 171 | itk_pred = sitk.GetImageFromArray(query_pred) 172 | sitk.WriteImage(itk_pred, file_name, True) 173 | _log.info(f'{query_id} has been saved. ') 174 | 175 | # Log class-wise results 176 | class_dice[label_name] = torch.tensor(scores.patient_dice).mean().item() 177 | class_iou[label_name] = torch.tensor(scores.patient_iou).mean().item() 178 | _log.info(f'Test Class: {label_name}') 179 | _log.info(f'Mean class IoU: {class_iou[label_name]}') 180 | _log.info(f'Mean class Dice: {class_dice[label_name]}') 181 | 182 | _log.info(f'Final results...') 183 | _log.info(f'Mean IoU: {class_iou}') 184 | _log.info(f'Mean Dice: {class_dice}') 185 | 186 | 187 | def dict_Avg(Dict): 188 | L = len(Dict) # 取字典中键值对的个数 189 | S = sum(Dict.values()) # 取字典中键对应值的总和 190 | A = S / L 191 | return A 192 | 193 | value = dict_Avg(class_dice) 194 | with open('results.txt', 'w') as file: 195 | file.write(str(value)) 196 | 197 | _log.info(f'Whole mean Dice: {dict_Avg(class_dice)}') 198 | 199 | _log.info(f'End of validation.') 200 | return 1 201 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | For Evaluation 4 | Extended from ADNet code by Hansen et al. 5 | """ 6 | import shutil 7 | 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | import torch.optim 11 | from sklearn.metrics import accuracy_score 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | from torch.utils.data import DataLoader 14 | 15 | from config import ex 16 | from dataloaders.datasets import TrainDataset as TrainDataset 17 | from models.CDFSMIS import FewShotSeg 18 | from utils import * 19 | from thop import profile 20 | 21 | def pixel_accuracy(pred, label): 22 | pred_flatten = pred.flatten() 23 | label_flatten = label.flatten() 24 | accuracy = accuracy_score(label_flatten, pred_flatten) 25 | return accuracy 26 | 27 | 28 | @ex.automain 29 | def main(_run, _config, _log): 30 | if _run.observers: 31 | # Set up source folder 32 | os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) 33 | for source_file, _ in _run.experiment_info['sources']: 34 | os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), 35 | exist_ok=True) 36 | _run.observers[0].save_file(source_file, f'source/{source_file}') 37 | shutil.rmtree(f'{_run.observers[0].basedir}/_sources') 38 | 39 | # Set up logger -> log to .txt 40 | file_handler = logging.FileHandler(os.path.join(f'{_run.observers[0].dir}', f'logger.log')) 41 | file_handler.setLevel('INFO') 42 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s') 43 | file_handler.setFormatter(formatter) 44 | _log.handlers.append(file_handler) 45 | _log.info(f'Run "{_config["exp_str"]}" with ID "{_run.observers[0].dir[-1]}"') 46 | 47 | # Deterministic setting for reproduciablity. 48 | if _config['seed'] is not None: 49 | random.seed(_config['seed']) 50 | torch.manual_seed(_config['seed']) 51 | torch.cuda.manual_seed_all(_config['seed']) 52 | cudnn.deterministic = True 53 | 54 | # Enable cuDNN benchmark mode to select the fastest convolution algorithm. 55 | cudnn.enabled = True 56 | cudnn.benchmark = True 57 | torch.cuda.set_device(device=_config['gpu_id']) 58 | torch.set_num_threads(1) 59 | 60 | _log.info(f'Create model...') 61 | model_config = { 62 | # 'data_dir': _config['path'][_config['dataset']]['data_dir'], 63 | 'dataset': _config['dataset'], 64 | 'PREC': _config['PREC'], 65 | # 'train_classname': _config['train_classname'], 66 | # 'test_classname': _config['test_classname'], 67 | 'BACKBONE_NAME': _config['BACKBONE_NAME'], 68 | 'N_CTX': _config['N_CTX'], 69 | 'CTX_INIT': _config['CTX_INIT'], 70 | 'CLASS_TOKEN_POSITION': _config['CLASS_TOKEN_POSITION'], 71 | 'INPUT_SIZE': _config['INPUT_SIZE'], 72 | 'CSC': _config['CSC'], 73 | 'INIT_WEIGHTS': _config['INIT_WEIGHTS'], 74 | 'OPTIM': _config['OPTIM'], 75 | 'PROMPT_INIT': _config['PROMPT_INIT'], 76 | # 'classnames': _config['classnames'], 77 | } 78 | model = FewShotSeg(model_config) 79 | model = model.cuda() 80 | model.train() 81 | 82 | _log.info(f'Set optimizer...') 83 | optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) 84 | lr_milestones = [(ii + 1) * _config['max_iters_per_load'] for ii in 85 | range(_config['n_steps'] // _config['max_iters_per_load'] - 1)] 86 | scheduler = MultiStepLR(optimizer, milestones=lr_milestones, gamma=_config['lr_step_gamma']) 87 | 88 | my_weight = torch.FloatTensor([0.1, 1.0]).cuda() 89 | criterion = nn.NLLLoss(ignore_index=255, weight=my_weight) 90 | # criterion = nn.CrossEntropyLoss() 91 | 92 | _log.info(f'Load data...') 93 | data_config = { 94 | 'data_dir': _config['path'][_config['dataset']]['data_dir'], 95 | 'dataset': _config['dataset'], 96 | 'n_shot': _config['n_shot'], 97 | 'n_way': _config['n_way'], 98 | 'n_query': _config['n_query'], 99 | 'n_sv': _config['n_sv'], 100 | 'max_iter': _config['max_iters_per_load'], 101 | 'eval_fold': _config['eval_fold'], 102 | 'min_size': _config['min_size'], 103 | 'max_slices': _config['max_slices'], 104 | 'test_label': _config['test_label'], 105 | 'exclude_label': _config['exclude_label'], 106 | 'use_gt': _config['use_gt'], 107 | 'train_organ': _config['train_organ'], 108 | 109 | } 110 | train_dataset = TrainDataset(data_config) 111 | train_loader = DataLoader(train_dataset, 112 | batch_size=_config['batch_size'], 113 | shuffle=True, 114 | num_workers=_config['num_workers'], 115 | pin_memory=True, 116 | drop_last=True) 117 | 118 | n_sub_epochs = _config['n_steps'] // _config['max_iters_per_load'] # number of times for reloading 119 | log_loss = {'total_loss': 0, 'query_loss': 0, 'align_loss': 0, 'thresh_loss': 0} 120 | 121 | loss_values = [] 122 | i_iter = 0 123 | _log.info(f'Start training...') 124 | for sub_epoch in range(n_sub_epochs): 125 | _log.info(f'This is epoch "{sub_epoch}" of "{n_sub_epochs}" epochs.') 126 | for _, sample in enumerate(train_loader): 127 | 128 | # Prepare episode data. 129 | support_images = [[shot.float().cuda() for shot in way] 130 | for way in sample['support_images']] 131 | support_fg_mask = [[shot.float().cuda() for shot in way] 132 | for way in sample['support_fg_labels']] 133 | 134 | # prompt = _config['dataset'] 135 | 136 | query_images = [query_image.float().cuda() for query_image in sample['query_images']] 137 | query_labels = torch.cat([query_label.long().cuda() for query_label in sample['query_labels']], dim=0) 138 | 139 | 140 | # Compute outputs and losses. 141 | query_pred, proto_loss = model(support_images, support_fg_mask, query_images, query_labels, opt=optimizer, train=True) 142 | 143 | query_loss = criterion(torch.log(torch.clamp(query_pred, torch.finfo(torch.float32).eps, 144 | 1 - torch.finfo(torch.float32).eps)), query_labels) 145 | 146 | 147 | 148 | loss = query_loss + proto_loss 149 | 150 | # Compute gradient and do SGD step. 151 | for param in model.parameters(): 152 | param.grad = None 153 | 154 | loss.backward() 155 | optimizer.step() 156 | scheduler.step() 157 | 158 | # Log loss 159 | query_loss = query_loss.detach().data.cpu().numpy() 160 | 161 | loss_values.append(query_loss) 162 | 163 | _run.log_scalar('total_loss', loss.item()) 164 | _run.log_scalar('query_loss', query_loss) 165 | 166 | log_loss['total_loss'] += loss.item() 167 | log_loss['query_loss'] += query_loss 168 | 169 | # Print loss and take snapshots. 170 | if (i_iter + 1) % _config['print_interval'] == 0: 171 | total_loss = log_loss['total_loss'] / _config['print_interval'] 172 | query_loss = log_loss['query_loss'] / _config['print_interval'] 173 | 174 | log_loss['total_loss'] = 0 175 | log_loss['query_loss'] = 0 176 | 177 | _log.info(f'step {i_iter + 1}: total_loss: {total_loss}, query_loss: {query_loss},') 178 | # f' align_loss: {align_loss}') 179 | 180 | if (i_iter + 1) % _config['save_snapshot_every'] == 0: 181 | _log.info('###### Taking snapshot ######') 182 | torch.save(model.state_dict(), 183 | os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) 184 | 185 | i_iter += 1 186 | 187 | loss_values = np.array(loss_values) 188 | # loss_values = loss_values.detach().cpu().numpy() 189 | np.savetxt('loss_values.txt', loss_values) 190 | 191 | _log.info('End of training.') 192 | return 1 193 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for Dataset 3 | Extended from ADNet code by Hansen et al. 4 | """ 5 | import random 6 | import torch 7 | import numpy as np 8 | import operator 9 | import os 10 | import logging 11 | 12 | 13 | def set_seed(seed): 14 | """ 15 | Set the random seed 16 | """ 17 | random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | 21 | 22 | CLASS_LABELS = { 23 | 'CHAOST2': { 24 | 'pa_all': set(range(1, 5)), 25 | 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes 26 | 1: set([2, 3]), # lower_abdomen 27 | }, 28 | } 29 | 30 | 31 | def get_bbox(fg_mask, inst_mask): 32 | """ 33 | Get the ground truth bounding boxes 34 | """ 35 | 36 | fg_bbox = torch.zeros_like(fg_mask, device=fg_mask.device) 37 | bg_bbox = torch.ones_like(fg_mask, device=fg_mask.device) 38 | 39 | inst_mask[fg_mask == 0] = 0 40 | area = torch.bincount(inst_mask.view(-1)) 41 | cls_id = area[1:].argmax() + 1 42 | cls_ids = np.unique(inst_mask)[1:] 43 | 44 | mask_idx = np.where(inst_mask[0] == cls_id) 45 | y_min = mask_idx[0].min() 46 | y_max = mask_idx[0].max() 47 | x_min = mask_idx[1].min() 48 | x_max = mask_idx[1].max() 49 | fg_bbox[0, y_min:y_max + 1, x_min:x_max + 1] = 1 50 | 51 | for i in cls_ids: 52 | mask_idx = np.where(inst_mask[0] == i) 53 | y_min = max(mask_idx[0].min(), 0) 54 | y_max = min(mask_idx[0].max(), fg_mask.shape[1] - 1) 55 | x_min = max(mask_idx[1].min(), 0) 56 | x_max = min(mask_idx[1].max(), fg_mask.shape[2] - 1) 57 | bg_bbox[0, y_min:y_max + 1, x_min:x_max + 1] = 0 58 | return fg_bbox, bg_bbox 59 | 60 | 61 | def t2n(img_t): 62 | """ 63 | torch to numpy regardless of whether tensor is on gpu or memory 64 | """ 65 | if img_t.is_cuda: 66 | return img_t.data.cpu().numpy() 67 | else: 68 | return img_t.data.numpy() 69 | 70 | 71 | def to01(x_np): 72 | """ 73 | normalize a numpy to 0-1 for visualize 74 | """ 75 | return (x_np - x_np.min()) / (x_np.max() - x_np.min() + 1e-5) 76 | 77 | 78 | class Scores(): 79 | 80 | def __init__(self): 81 | self.TP = 0 82 | self.TN = 0 83 | self.FP = 0 84 | self.FN = 0 85 | 86 | self.patient_dice = [] 87 | self.patient_iou = [] 88 | 89 | def record(self, preds, label): 90 | assert len(torch.unique(preds)) < 3 91 | 92 | tp = torch.sum((label == 1) * (preds == 1)) 93 | tn = torch.sum((label == 0) * (preds == 0)) 94 | fp = torch.sum((label == 0) * (preds == 1)) 95 | fn = torch.sum((label == 1) * (preds == 0)) 96 | 97 | self.patient_dice.append(2 * tp / (2 * tp + fp + fn)) 98 | self.patient_iou.append(tp / (tp + fp + fn)) 99 | 100 | self.TP += tp 101 | self.TN += tn 102 | self.FP += fp 103 | self.FN += fn 104 | 105 | def compute_dice(self): 106 | return 2 * self.TP / (2 * self.TP + self.FP + self.FN) 107 | 108 | def compute_iou(self): 109 | return self.TP / (self.TP + self.FP + self.FN) 110 | 111 | 112 | def set_logger(path): 113 | logger = logging.getLogger() 114 | logger.handlers = [] 115 | formatter = logging.Formatter('[%(levelname)] - %(name)s - %(message)s') 116 | logger.setLevel("INFO") 117 | 118 | # log to .txt 119 | file_handler = logging.FileHandler(path) 120 | file_handler.setFormatter(formatter) 121 | logger.addHandler(file_handler) 122 | 123 | # log to console 124 | stream_handler = logging.StreamHandler() 125 | stream_handler.setFormatter(formatter) 126 | logger.addHandler(stream_handler) 127 | return logger 128 | --------------------------------------------------------------------------------