├── checkpoints └── setting.txt ├── overview.png ├── visualization ├── README.md ├── main.m ├── drawlabel2image.m ├── mask_generalization_cmr.py ├── mask_generalization_ct.py └── mask_generalization_mri.py ├── data ├── supervoxels │ ├── _ccomp.pxd │ ├── setup.py │ ├── felzenszwalb_3d.py │ ├── generate_supervoxels.py │ └── felzenszwalb_3d_cy.pyx ├── CHAOST2 │ ├── dcm_img_to_nii.sh │ ├── png_gth_to_nii.ipynb │ ├── class_slice_index_gen.ipynb │ └── image_normalize.ipynb ├── SABS │ ├── niftiio.py │ ├── class_slice_index_gen.py │ ├── intensity_normalization.py │ └── resampling_and_roi.py └── CMR │ └── class_slice_index_gen.ipynb ├── scripts ├── train_CMR.sh ├── train_SABS_setting1.sh ├── train_CHAOST2_setting1.sh ├── train_SABS_setting2.sh ├── train_CHAOST2_setting2.sh ├── test_CMR.sh ├── test_SABS_setting1.sh ├── test_SABS_setting2.sh ├── test_CHAOST2_setting1.sh └── test_CHAOST2_setting2.sh ├── config.py ├── losses.py ├── dataloaders ├── dataset_specifics.py ├── image_transforms.py ├── datasets.py └── datasets_outside.py ├── utils.py ├── models ├── attention.py ├── encoder.py └── fewshot.py ├── README.md ├── train.py └── test.py /checkpoints/setting.txt: -------------------------------------------------------------------------------- 1 | Put the checkpoint of backbone in this fold 2 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YazhouZhu19/RPT/HEAD/overview.png -------------------------------------------------------------------------------- /visualization/README.md: -------------------------------------------------------------------------------- 1 | # The introduction of visualizing the predictions 2 | 3 | ### We take the CMR dataset as example 4 | #### Step 1: run `mask_generalization_cmr.py` 5 | #### Step 2: rum `main.m` 6 | 7 | -------------------------------------------------------------------------------- /data/supervoxels/_ccomp.pxd: -------------------------------------------------------------------------------- 1 | """Export fast union find in Cython""" 2 | cimport numpy as cnp 3 | 4 | ctypedef cnp.intp_t DTYPE_t 5 | 6 | cdef DTYPE_t find_root(DTYPE_t *forest, DTYPE_t n) nogil 7 | cdef void set_root(DTYPE_t *forest, DTYPE_t n, DTYPE_t root) nogil 8 | cdef void join_trees(DTYPE_t *forest, DTYPE_t n, DTYPE_t m) nogil -------------------------------------------------------------------------------- /data/supervoxels/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | from distutils.core import Extension 5 | 6 | 7 | extensions = [ 8 | Extension("felzenszwalb_3d_cy", ["felzenszwalb_3d_cy.pyx"], include_dirs=[numpy.get_include()]) 9 | ] 10 | setup( 11 | name='felzenszwalb_3d_cy', 12 | ext_modules=cythonize(extensions) 13 | ) 14 | 15 | extensions = [ 16 | Extension("_ccomp", ["_ccomp.pyx"], include_dirs=[numpy.get_include()]) 17 | ] 18 | setup( 19 | name='_ccomp', 20 | ext_modules=cythonize(extensions) 21 | ) -------------------------------------------------------------------------------- /data/CHAOST2/dcm_img_to_nii.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | # Convert dicom-like images to nii files in 3D 3 | # This is the first step for image pre-processing 4 | 5 | # Feed path to the downloaded data here 6 | DATAPATH=./MR # please put chaos dataset training fold here which contains ground truth 7 | 8 | # Feed path to the output folder here 9 | OUTPATH=./niis 10 | 11 | if [ ! -d $OUTPATH/T2SPIR ] 12 | then 13 | mkdir $OUTPATH/T2SPIR 14 | fi 15 | 16 | for sid in $(ls "$DATAPATH") 17 | do 18 | dcm2nii -o "$DATAPATH/$sid/T2SPIR" "$DATAPATH/$sid/T2SPIR/DICOM_anon"; 19 | find "$DATAPATH/$sid/T2SPIR" -name "*.nii.gz" -exec mv {} "$OUTPATH/T2SPIR/image_$sid.nii.gz" \; 20 | done; 21 | 22 | 23 | -------------------------------------------------------------------------------- /visualization/main.m: -------------------------------------------------------------------------------- 1 | % the pth of mask 2 | label = imread('.../cmr_lvbp_gt.png'); 3 | % the pth of raw image 4 | im = imread('.../cmr_lvbp_img.png'); 5 | 6 | color1 = [1,0,0; 0,1,0; 0,0,1; 1,1,0; 0,1,1]; 7 | alpha1 = 0.1; 8 | colorimg1 = drawlabel2image(im,label,color1,alpha1); 9 | 10 | color2 = [1,0,0; 0,1,0; 0,0,1; 1,1,0; 0,1,1]; 11 | alpha2 = 0.9; 12 | colorimg2 = drawlabel2image(im,label,color2,alpha2); 13 | 14 | color3 = [1,1,1; 1,1,1; 1,1,1; 1,1,1; 0,0,1]; 15 | alpha3 = 0.8; 16 | colorimg3 = drawlabel2image(im,label,color3,alpha3); 17 | 18 | % MATLAB自带的colormap 19 | color4 = jet; % matlab自带 20 | alpha4 = 0.5; 21 | colorimg4 = drawlabel2image(im,label,color4,alpha4); 22 | 23 | % the pth of masked image 24 | imwrite(colorimg3, '.../cmr_masked_gt.png'); 25 | 26 | % 显示 27 | figure 28 | subplot(2,2,1), imshow(colorimg1) 29 | subplot(2,2,2), imshow(colorimg2) 30 | subplot(2,2,3), imshow(colorimg3) 31 | subplot(2,2,4), imshow(colorimg4) 32 | 33 | 34 | % cmr 天蓝 深红 黄 35 | % Abd 红 蓝 绿 橘黄 36 | -------------------------------------------------------------------------------- /scripts/train_CMR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='CMR' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3] 11 | EXCLUDE_LABEL=None 12 | USE_GT=False 13 | ###### Training configs ###### 14 | NSTEP=40000 15 | DECAY=0.98 16 | 17 | MAX_ITER=5000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=5000 # interval for saving snapshot 19 | SEED=2023 20 | 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./exps_on_${DATASET}_fewshot" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | 34 | python3 train.py with \ 35 | mode='train' \ 36 | dataset=$DATASET \ 37 | num_workers=$NWORKER \ 38 | n_steps=$NSTEP \ 39 | eval_fold=$EVAL_FOLD \ 40 | test_label=$TEST_LABEL \ 41 | exclude_label=$EXCLUDE_LABEL \ 42 | use_gt=$USE_GT \ 43 | max_iters_per_load=$MAX_ITER \ 44 | seed=$SEED \ 45 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 46 | lr_step_gamma=$DECAY \ 47 | path.log_dir=$LOGDIR 48 | done 49 | -------------------------------------------------------------------------------- /scripts/train_SABS_setting1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='SABS' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3,6] 11 | EXCLUDE_LABEL=None 12 | USE_GT=False 13 | ###### Training configs ###### 14 | NSTEP=40000 15 | DECAY=0.98 16 | 17 | MAX_ITER=5000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=5000 # interval for saving snapshot 19 | SEED=2023 20 | 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./exps_on_${DATASET}_fewshot_setting1" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | 34 | python3 train.py with \ 35 | mode='train' \ 36 | dataset=$DATASET \ 37 | num_workers=$NWORKER \ 38 | n_steps=$NSTEP \ 39 | eval_fold=$EVAL_FOLD \ 40 | test_label=$TEST_LABEL \ 41 | exclude_label=$EXCLUDE_LABEL \ 42 | use_gt=$USE_GT \ 43 | max_iters_per_load=$MAX_ITER \ 44 | seed=$SEED \ 45 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 46 | lr_step_gamma=$DECAY \ 47 | path.log_dir=$LOGDIR 48 | done 49 | -------------------------------------------------------------------------------- /scripts/train_CHAOST2_setting1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='CHAOST2' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3,4] 11 | EXCLUDE_LABEL=None 12 | USE_GT=False 13 | ###### Training configs ###### 14 | NSTEP=40000 15 | DECAY=0.98 16 | 17 | MAX_ITER=5000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=5000 # interval for saving snapshot 19 | SEED=2023 20 | 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./exps_on_${DATASET}_fewshot_setting1" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | 34 | python3 train.py with \ 35 | mode='train' \ 36 | dataset=$DATASET \ 37 | num_workers=$NWORKER \ 38 | n_steps=$NSTEP \ 39 | eval_fold=$EVAL_FOLD \ 40 | test_label=$TEST_LABEL \ 41 | exclude_label=$EXCLUDE_LABEL \ 42 | use_gt=$USE_GT \ 43 | max_iters_per_load=$MAX_ITER \ 44 | seed=$SEED \ 45 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 46 | lr_step_gamma=$DECAY \ 47 | path.log_dir=$LOGDIR 48 | done 49 | -------------------------------------------------------------------------------- /scripts/train_SABS_setting2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='SABS' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3,6] 11 | EXCLUDE_LABEL=[1,2,3,6] 12 | USE_GT=False 13 | ###### Training configs ###### 14 | NSTEP=40000 15 | DECAY=0.98 16 | 17 | MAX_ITER=5000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=5000 # interval for saving snapshot 19 | SEED=2023 20 | 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./exps_on_${DATASET}_fewshot_setting2" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | 34 | python3 train.py with \ 35 | mode='train' \ 36 | dataset=$DATASET \ 37 | num_workers=$NWORKER \ 38 | n_steps=$NSTEP \ 39 | eval_fold=$EVAL_FOLD \ 40 | test_label=$TEST_LABEL \ 41 | exclude_label=$EXCLUDE_LABEL \ 42 | use_gt=$USE_GT \ 43 | max_iters_per_load=$MAX_ITER \ 44 | seed=$SEED \ 45 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 46 | lr_step_gamma=$DECAY \ 47 | path.log_dir=$LOGDIR 48 | done 49 | -------------------------------------------------------------------------------- /scripts/train_CHAOST2_setting2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='CHAOST2' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3,4] 11 | EXCLUDE_LABEL=[1,2,3,4] 12 | USE_GT=False 13 | ###### Training configs ###### 14 | NSTEP=40000 15 | DECAY=0.98 16 | 17 | MAX_ITER=5000 # defines the size of an epoch 18 | SNAPSHOT_INTERVAL=5000 # interval for saving snapshot 19 | SEED=2023 20 | 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="train_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./exps_on_${DATASET}_fewshot_setting2" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | 34 | python3 train.py with \ 35 | mode='train' \ 36 | dataset=$DATASET \ 37 | num_workers=$NWORKER \ 38 | n_steps=$NSTEP \ 39 | eval_fold=$EVAL_FOLD \ 40 | test_label=$TEST_LABEL \ 41 | exclude_label=$EXCLUDE_LABEL \ 42 | use_gt=$USE_GT \ 43 | max_iters_per_load=$MAX_ITER \ 44 | seed=$SEED \ 45 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 46 | lr_step_gamma=$DECAY \ 47 | path.log_dir=$LOGDIR 48 | done 49 | -------------------------------------------------------------------------------- /visualization/drawlabel2image.m: -------------------------------------------------------------------------------- 1 | % gt = imread('abd_ct_gt.png') 2 | % im = imread('abd_ct_img.png') 3 | % figure 4 | % subplot(1,3,1), imshow(gt) 5 | % subplot(1,3,2), imshow(im) 6 | 7 | % img 原始图像 8 | % label 目标标签 9 | % color 每个类的颜色 10 | % alpha 每个类颜色的不透明度 11 | function colorimg = drawlabel2image(img, label, color, alpha) 12 | [row, col, dim] = size(img); 13 | if dim==1 14 | img = cat(3, img, img, img); 15 | elseif dim==3 16 | else 17 | error('请输入灰度图或RGB图') 18 | end 19 | 20 | % 预处理 21 | img = im2double(img); 22 | label = uint16(label); 23 | nlabel = max(label(:)); 24 | 25 | % color修正 26 | while size(color,1)2*nlabel 31 | gap = floor(size(color,1)/nlabel); 32 | color = color(1:gap:end,:); 33 | end 34 | 35 | alpha = alpha(:); 36 | while length(alpha)0); 42 | bg = img.*double(~mask); 43 | 44 | 45 | obj = zeros(row,col,3,nlabel); 46 | for idx = 1:nlabel 47 | objmask = double(label==idx); 48 | R = img(:,:,1).*objmask*(1-alpha(idx))+objmask*color(idx,1)*alpha(idx); 49 | G = img(:,:,2).*objmask*(1-alpha(idx))+objmask*color(idx,2)*alpha(idx); 50 | B = img(:,:,3).*objmask*(1-alpha(idx))+objmask*color(idx,3)*alpha(idx); 51 | obj(:,:,:,idx) = cat(3,R,G,B); 52 | end 53 | 54 | colorimg = sum(obj,4)+bg; 55 | 56 | end 57 | 58 | 59 | -------------------------------------------------------------------------------- /data/SABS/niftiio.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for datasets 3 | """ 4 | import numpy as np 5 | 6 | import numpy as np 7 | import SimpleITK as sitk 8 | 9 | 10 | def read_nii_bysitk(input_fid, peel_info = False): 11 | """ read nii to numpy through simpleitk 12 | peelinfo: taking direction, origin, spacing and metadata out 13 | """ 14 | img_obj = sitk.ReadImage(input_fid) 15 | img_np = sitk.GetArrayFromImage(img_obj) 16 | if peel_info: 17 | info_obj = { 18 | "spacing": img_obj.GetSpacing(), 19 | "origin": img_obj.GetOrigin(), 20 | "direction": img_obj.GetDirection(), 21 | "array_size": img_np.shape 22 | } 23 | return img_np, info_obj 24 | else: 25 | return img_np 26 | 27 | def convert_to_sitk(input_mat, peeled_info): 28 | """ 29 | write a numpy array to sitk image object with essential meta-data 30 | """ 31 | nii_obj = sitk.GetImageFromArray(input_mat) 32 | if peeled_info: 33 | nii_obj.SetSpacing( peeled_info["spacing"] ) 34 | nii_obj.SetOrigin( peeled_info["origin"] ) 35 | nii_obj.SetDirection(peeled_info["direction"] ) 36 | return nii_obj 37 | 38 | def np2itk(img, ref_obj): 39 | """ 40 | img: numpy array 41 | ref_obj: reference sitk object for copying information from 42 | """ 43 | itk_obj = sitk.GetImageFromArray(img) 44 | itk_obj.SetSpacing( ref_obj.GetSpacing() ) 45 | itk_obj.SetOrigin( ref_obj.GetOrigin() ) 46 | itk_obj.SetDirection( ref_obj.GetDirection() ) 47 | return itk_obj 48 | 49 | -------------------------------------------------------------------------------- /scripts/test_CMR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='CMR' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3] 11 | ###### Training configs ###### 12 | NSTEP=45000 13 | DECAY=0.98 14 | 15 | MAX_ITER=1000 # defines the size of an epoch 16 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 17 | SEED=2021 18 | 19 | N_PART=3 # defines the number of chunks for evaluation 20 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./results" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | for SUPP_IDX in "${ALL_SUPP[@]}" 34 | do 35 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 36 | RELOAD_MODEL_PATH=".../exps_on_CMR_fewshot/RPT_train_CMR_cv${EVAL_FOLD}/1/snapshots/30000.pth" 37 | python3 test.py with \ 38 | mode="test" \ 39 | dataset=$DATASET \ 40 | num_workers=$NWORKER \ 41 | n_steps=$NSTEP \ 42 | eval_fold=$EVAL_FOLD \ 43 | max_iters_per_load=$MAX_ITER \ 44 | supp_idx=$SUPP_IDX \ 45 | test_label=$TEST_LABEL \ 46 | seed=$SEED \ 47 | n_part=$N_PART \ 48 | reload_model_path=$RELOAD_MODEL_PATH \ 49 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 50 | lr_step_gamma=$DECAY \ 51 | path.log_dir=$LOGDIR 52 | done 53 | done 54 | 55 | -------------------------------------------------------------------------------- /scripts/test_SABS_setting1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='SABS' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3,6] 11 | ###### Training configs ###### 12 | NSTEP=45000 13 | DECAY=0.98 14 | 15 | MAX_ITER=1000 # defines the size of an epoch 16 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 17 | SEED=2021 18 | 19 | N_PART=3 # defines the number of chunks for evaluation 20 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./results" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | for SUPP_IDX in "${ALL_SUPP[@]}" 34 | do 35 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 36 | RELOAD_MODEL_PATH=".../exps_on_SABS_fewshot_setting1/RPT_train_SABS_cv${EVAL_FOLD}/1/snapshots/30000.pth" 37 | python3 test.py with \ 38 | mode="test" \ 39 | dataset=$DATASET \ 40 | num_workers=$NWORKER \ 41 | n_steps=$NSTEP \ 42 | eval_fold=$EVAL_FOLD \ 43 | max_iters_per_load=$MAX_ITER \ 44 | supp_idx=$SUPP_IDX \ 45 | test_label=$TEST_LABEL \ 46 | seed=$SEED \ 47 | n_part=$N_PART \ 48 | reload_model_path=$RELOAD_MODEL_PATH \ 49 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 50 | lr_step_gamma=$DECAY \ 51 | path.log_dir=$LOGDIR 52 | done 53 | done 54 | 55 | -------------------------------------------------------------------------------- /scripts/test_SABS_setting2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='SABS' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3,6] 11 | ###### Training configs ###### 12 | NSTEP=45000 13 | DECAY=0.98 14 | 15 | MAX_ITER=1000 # defines the size of an epoch 16 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 17 | SEED=2021 18 | 19 | N_PART=3 # defines the number of chunks for evaluation 20 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./results" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | for SUPP_IDX in "${ALL_SUPP[@]}" 34 | do 35 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 36 | RELOAD_MODEL_PATH=".../exps_on_SABS_fewshot_setting2/RPT_train_SABS_cv${EVAL_FOLD}/1/snapshots/30000.pth" 37 | python3 test.py with \ 38 | mode="test" \ 39 | dataset=$DATASET \ 40 | num_workers=$NWORKER \ 41 | n_steps=$NSTEP \ 42 | eval_fold=$EVAL_FOLD \ 43 | max_iters_per_load=$MAX_ITER \ 44 | supp_idx=$SUPP_IDX \ 45 | test_label=$TEST_LABEL \ 46 | seed=$SEED \ 47 | n_part=$N_PART \ 48 | reload_model_path=$RELOAD_MODEL_PATH \ 49 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 50 | lr_step_gamma=$DECAY \ 51 | path.log_dir=$LOGDIR 52 | done 53 | done 54 | 55 | -------------------------------------------------------------------------------- /scripts/test_CHAOST2_setting1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='CHAOST2' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3,4] 11 | ###### Training configs ###### 12 | NSTEP=45000 13 | DECAY=0.98 14 | 15 | MAX_ITER=1000 # defines the size of an epoch 16 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 17 | SEED=2021 18 | 19 | N_PART=3 # defines the number of chunks for evaluation 20 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./results" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | for SUPP_IDX in "${ALL_SUPP[@]}" 34 | do 35 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 36 | RELOAD_MODEL_PATH=".../exps_on_CHAOST2_fewshot_setting1/RPT_train_CHAOST2_cv${EVAL_FOLD}/1/snapshots/30000.pth" 37 | python3 test.py with \ 38 | mode="test" \ 39 | dataset=$DATASET \ 40 | num_workers=$NWORKER \ 41 | n_steps=$NSTEP \ 42 | eval_fold=$EVAL_FOLD \ 43 | max_iters_per_load=$MAX_ITER \ 44 | supp_idx=$SUPP_IDX \ 45 | test_label=$TEST_LABEL \ 46 | seed=$SEED \ 47 | n_part=$N_PART \ 48 | reload_model_path=$RELOAD_MODEL_PATH \ 49 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 50 | lr_step_gamma=$DECAY \ 51 | path.log_dir=$LOGDIR 52 | done 53 | done 54 | 55 | -------------------------------------------------------------------------------- /scripts/test_CHAOST2_setting2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUID1=0 3 | export CUDA_VISIBLE_DEVICES=$GPUID1 4 | 5 | ###### Shared configs ###### 6 | DATASET='CHAOST2' 7 | NWORKER=16 8 | RUNS=1 9 | ALL_EV=(0 1 2 3 4) # 5-fold cross validation (0, 1, 2, 3, 4) 10 | TEST_LABEL=[1,2,3,4] 11 | ###### Training configs ###### 12 | NSTEP=45000 13 | DECAY=0.98 14 | 15 | MAX_ITER=1000 # defines the size of an epoch 16 | SNAPSHOT_INTERVAL=1000 # interval for saving snapshot 17 | SEED=2021 18 | 19 | N_PART=3 # defines the number of chunks for evaluation 20 | ALL_SUPP=(2) # CHAOST2: 0-4, CMR: 0-7 21 | echo ======================================================================== 22 | 23 | for EVAL_FOLD in "${ALL_EV[@]}" 24 | do 25 | PREFIX="test_${DATASET}_cv${EVAL_FOLD}" 26 | echo $PREFIX 27 | LOGDIR="./results" 28 | 29 | if [ ! -d $LOGDIR ] 30 | then 31 | mkdir -p $LOGDIR 32 | fi 33 | for SUPP_IDX in "${ALL_SUPP[@]}" 34 | do 35 | # RELOAD_PATH='please feed the absolute path to the trained weights here' # path to the reloaded model 36 | RELOAD_MODEL_PATH=".../exps_on_CHAOST2_fewshot_setting2/RPT_train_CHAOST2_cv${EVAL_FOLD}/1/snapshots/30000.pth" 37 | python3 test.py with \ 38 | mode="test" \ 39 | dataset=$DATASET \ 40 | num_workers=$NWORKER \ 41 | n_steps=$NSTEP \ 42 | eval_fold=$EVAL_FOLD \ 43 | max_iters_per_load=$MAX_ITER \ 44 | supp_idx=$SUPP_IDX \ 45 | test_label=$TEST_LABEL \ 46 | seed=$SEED \ 47 | n_part=$N_PART \ 48 | reload_model_path=$RELOAD_MODEL_PATH \ 49 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 50 | lr_step_gamma=$DECAY \ 51 | path.log_dir=$LOGDIR 52 | done 53 | done 54 | 55 | -------------------------------------------------------------------------------- /data/SABS/class_slice_index_gen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import SimpleITK as sitk 5 | import sys 6 | import json 7 | import niftiio as nio 8 | 9 | IMG_BNAME="./sabs_CT_normalized/image_*.nii.gz" 10 | SEG_BNAME="./sabs_CT_normalized/label_*.nii.gz" 11 | 12 | imgs = glob.glob(IMG_BNAME) 13 | segs = glob.glob(SEG_BNAME) 14 | imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) ) ] 15 | segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) ) ] 16 | 17 | classmap = {} 18 | LABEL_NAME = ["BG", "LIVER", "RK", "LK", "SPLEEN"] 19 | 20 | 21 | MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training 22 | 23 | fid = f'./sabs_CT_normalized/classmap_{MIN_TP}.json' # name of the output file. 24 | for _lb in LABEL_NAME: 25 | classmap[_lb] = {} 26 | for _sid in segs: 27 | pid = _sid.split("_")[-1].split(".nii.gz")[0] 28 | classmap[_lb][pid] = [] 29 | 30 | for seg in segs: 31 | pid = seg.split("_")[-1].split(".nii.gz")[0] 32 | lb_vol = nio.read_nii_bysitk(seg) 33 | n_slice = lb_vol.shape[0] 34 | for slc in range(n_slice): 35 | for cls in range(len(LABEL_NAME)): 36 | if cls in lb_vol[slc, ...]: 37 | if np.sum( lb_vol[slc, ...]) >= MIN_TP: 38 | classmap[LABEL_NAME[cls]][str(pid)].append(slc) 39 | print(f'pid {str(pid)} finished!') 40 | 41 | with open(fid, 'w') as fopen: 42 | json.dump(classmap, fopen) 43 | fopen.close() 44 | 45 | 46 | with open(fid, 'w') as fopen: 47 | json.dump(classmap, fopen) 48 | fopen.close() 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /data/SABS/intensity_normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import SimpleITK as sitk 5 | 6 | import sys 7 | import niftiio as nio 8 | 9 | 10 | IMG_FOLDER="./data/SABS/img/" 11 | SEG_FOLDER="./data/SABS/label/" 12 | OUT_FOLDER="./tmp_normalized/" 13 | 14 | imgs = glob.glob(IMG_FOLDER + "/*.nii.gz") 15 | imgs = [ fid for fid in sorted(imgs) ] 16 | segs = [ fid for fid in sorted(glob.glob(SEG_FOLDER + "/*.nii.gz")) ] 17 | 18 | pids = [ pid.split("img0")[-1].split(".")[0] for pid in imgs] 19 | 20 | 21 | # helper function 22 | def copy_spacing_ori(src, dst): 23 | dst.SetSpacing(src.GetSpacing()) 24 | dst.SetOrigin(src.GetOrigin()) 25 | dst.SetDirection(src.GetDirection()) 26 | return dst 27 | 28 | import copy 29 | scan_dir = OUT_FOLDER 30 | LIR = -125 31 | HIR = 275 32 | os.makedirs(scan_dir, exist_ok = True) 33 | 34 | reindex = 0 35 | for img_fid, seg_fid, pid in zip(imgs, segs, pids): 36 | 37 | img_obj = sitk.ReadImage( img_fid ) 38 | seg_obj = sitk.ReadImage( seg_fid ) 39 | 40 | array = sitk.GetArrayFromImage(img_obj) 41 | 42 | array[array > HIR] = HIR 43 | array[array < LIR] = LIR 44 | 45 | array = (array - array.min()) / (array.max() - array.min()) * 255.0 46 | 47 | # then normalize this 48 | 49 | wined_img = sitk.GetImageFromArray(array) 50 | wined_img = copy_spacing_ori(img_obj, wined_img) 51 | 52 | out_img_fid = os.path.join( scan_dir, f'image_{str(reindex)}.nii.gz' ) 53 | out_lb_fid = os.path.join( scan_dir, f'label_{str(reindex)}.nii.gz' ) 54 | 55 | # then save 56 | sitk.WriteImage(wined_img, out_img_fid, True) 57 | sitk.WriteImage(seg_obj, out_lb_fid, True) 58 | print("{} has been save".format(out_img_fid)) 59 | print("{} has been save".format(out_lb_fid)) 60 | reindex += 1 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /data/supervoxels/felzenszwalb_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from felzenszwalb_3d_cy import felzenszwalb_cython_3d 4 | 5 | 6 | def felzenszwalb_3d(image, scale=1, sigma=0.8, min_size=20, multichannel=True, spacing=(1,1,1)): 7 | """ 8 | Code modified from: https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.felzenszwalb 9 | 10 | 11 | Computes Felsenszwalb's efficient graph based image segmentation. 12 | 13 | Produces an oversegmentation of a multichannel (i.e. RGB) image 14 | using a fast, minimum spanning tree based clustering on the image grid. 15 | The parameter ``scale`` sets an observation level. Higher scale means 16 | less and larger segments. ``sigma`` is the diameter of a Gaussian kernel, 17 | used for smoothing the image prior to segmentation. 18 | 19 | The number of produced segments as well as their size can only be 20 | controlled indirectly through ``scale``. Segment size within an image can 21 | vary greatly depending on local contrast. 22 | 23 | For RGB images, the algorithm uses the euclidean distance between pixels in 24 | color space. 25 | 26 | Parameters 27 | ---------- 28 | image : (width, height, 3) or (width, height) ndarray 29 | Input image. 30 | scale : float 31 | Free parameter. Higher means larger clusters. 32 | sigma : float 33 | Width (standard deviation) of Gaussian kernel used in preprocessing. 34 | min_size : int 35 | Minimum component size. Enforced using postprocessing. 36 | multichannel : bool, optional (default: True) 37 | Whether the last axis of the image is to be interpreted as multiple 38 | channels. A value of False, for a 3D image, is not currently supported. 39 | 40 | Returns 41 | ------- 42 | segment_mask : (width, height) ndarray 43 | Integer mask indicating segment labels. 44 | 45 | References 46 | ---------- 47 | .. [1] Efficient graph-based image segmentation, Felzenszwalb, P.F. and 48 | Huttenlocher, D.P. International Journal of Computer Vision, 2004 49 | 50 | Notes 51 | ----- 52 | The `k` parameter used in the original paper renamed to `scale` here. 53 | 54 | Examples 55 | -------- 56 | >>> from skimage.segmentation import felzenszwalb 57 | >>> from skimage.data import coffee 58 | >>> img = coffee() 59 | >>> segments = felzenszwalb(img, scale=3.0, sigma=0.95, min_size=5) 60 | """ 61 | 62 | # if not multichannel and image.ndim > 2: 63 | # raise ValueError("This algorithm works only on single or " 64 | # "multi-channel 2d images. ") 65 | 66 | image = np.atleast_3d(image) 67 | return felzenszwalb_cython_3d(image, scale=scale, sigma=sigma, min_size=min_size, spacing=spacing) 68 | 69 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment configuration file 3 | """ 4 | import glob 5 | import itertools 6 | import os 7 | import sacred 8 | from sacred import Experiment 9 | from sacred.observers import FileStorageObserver 10 | from sacred.utils import apply_backspaces_and_linefeeds 11 | from utils import * 12 | 13 | sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False 14 | sacred.SETTINGS.CAPTURE_MODE = 'no' 15 | 16 | ex = Experiment("RPT") 17 | ex.captured_out_filter = apply_backspaces_and_linefeeds 18 | 19 | ###### Set up source folder ###### 20 | source_folders = ['.', './dataloaders', './models', './utils'] 21 | sources_to_save = list(itertools.chain.from_iterable( 22 | [glob.glob(f'{folder}/*.py') for folder in source_folders])) 23 | for source_file in sources_to_save: 24 | ex.add_source_file(source_file) 25 | 26 | 27 | @ex.config 28 | def cfg(): 29 | """Default configurations""" 30 | seed = 2021 31 | gpu_id = 0 32 | num_workers = 0 # 0 for debugging. 33 | mode = 'train' 34 | 35 | ## dataset 36 | dataset = 'CHAOST2' # i.e. abdominal MRI - 'CHAOST2'; cardiac MRI - CMR 37 | exclude_label = [1,2,3,4] # None, for not excluding test labels; Setting 1: None, Setting 2: True 38 | # 1 for Liver, 2 for RK, 3 for LK, 4 for Spleen in 'CHAOST2' 39 | if dataset == 'CMR': 40 | n_sv = 1000 41 | else: 42 | n_sv = 5000 43 | min_size = 200 44 | max_slices = 3 45 | use_gt = False # True - use ground truth as training label, False - use supervoxel as training label 46 | eval_fold = 0 # (0-4) for 5-fold cross-validation 47 | test_label = [1, 4] # for evaluation 48 | supp_idx = 0 # choose which case as the support set for evaluation, (0-4) for 'CHAOST2', (0-7) for 'CMR' 49 | n_part = 3 # for evaluation, i.e. 3 chunks 50 | 51 | ## training 52 | n_steps = 1000 53 | batch_size = 1 54 | n_shot = 1 55 | n_way = 1 56 | n_query = 1 57 | lr_step_gamma = 0.95 58 | bg_wt = 0.1 59 | t_loss_scaler = 0.0 60 | ignore_label = 255 61 | print_interval = 100 62 | save_snapshot_every = 1000 63 | max_iters_per_load = 1000 # epoch size, interval for reloading the dataset 64 | 65 | # Network 66 | reload_model_path = None 67 | 68 | optim_type = 'sgd' 69 | optim = { 70 | 'lr': 1e-3, 71 | 'momentum': 0.9, 72 | 'weight_decay': 0.0005, 73 | } 74 | 75 | exp_str = '_'.join( 76 | [mode] 77 | + [dataset, ] 78 | + [f'cv{eval_fold}']) 79 | 80 | path = { 81 | 'log_dir': './runs', 82 | 'CHAOST2': {'data_dir': './data/CHAOST2'}, 83 | 'SABS': {'data_dir': './data/SABS'}, 84 | 'CMR': {'data_dir': './data/CMR'}, 85 | } 86 | 87 | 88 | @ex.config_hook 89 | def add_observer(config, command_name, logger): 90 | """A hook fucntion to add observer""" 91 | exp_name = f'{ex.path}_{config["exp_str"]}' 92 | observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name)) 93 | ex.observers.append(observer) 94 | return config 95 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # boundary loss code from https://github.com/yiskw713/boundary_loss_for_remote_sensing 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def one_hot(label, n_classes, requires_grad=True): 9 | """Return One Hot Label""" 10 | device = label.device 11 | one_hot_label = torch.eye( 12 | n_classes, device=device, requires_grad=requires_grad)[label] 13 | one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3) 14 | 15 | return one_hot_label 16 | 17 | 18 | class BoundaryLoss(nn.Module): 19 | """Boundary Loss proposed in: 20 | Alexey Bokhovkin et al., Boundary Loss for Remote Sensing Imagery Semantic Segmentation 21 | https://arxiv.org/abs/1905.07852 22 | """ 23 | 24 | def __init__(self, theta0=3, theta=5): 25 | super().__init__() 26 | 27 | self.theta0 = theta0 28 | self.theta = theta 29 | 30 | def forward(self, pred, gt): 31 | """ 32 | Input: 33 | - pred: the output from model (before softmax) 34 | shape (N, C, H, W) 35 | - gt: ground truth map 36 | shape (N, H, w) 37 | Return: 38 | - boundary loss, averaged over mini-bathc 39 | """ 40 | 41 | n, c, _, _ = pred.shape 42 | 43 | # softmax so that predicted map can be distributed in [0, 1] 44 | pred = torch.softmax(pred, dim=1) 45 | 46 | # one-hot vector of ground truth 47 | one_hot_gt = one_hot(gt, c) 48 | 49 | # boundary map 50 | gt_b = F.max_pool2d( 51 | 1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) 52 | gt_b -= 1 - one_hot_gt 53 | 54 | pred_b = F.max_pool2d( 55 | 1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) 56 | pred_b -= 1 - pred 57 | 58 | # extended boundary map 59 | gt_b_ext = F.max_pool2d( 60 | gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) 61 | 62 | pred_b_ext = F.max_pool2d( 63 | pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) 64 | 65 | # reshape 66 | gt_b = gt_b.view(n, c, -1) 67 | pred_b = pred_b.view(n, c, -1) 68 | gt_b_ext = gt_b_ext.view(n, c, -1) 69 | pred_b_ext = pred_b_ext.view(n, c, -1) 70 | 71 | # Precision, Recall 72 | P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7) 73 | R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7) 74 | 75 | # Boundary F1 Score 76 | BF1 = 2 * P * R / (P + R + 1e-7) 77 | 78 | # summing BF1 Score for each class and average over mini-batch 79 | loss = torch.mean(1 - BF1) 80 | 81 | return loss 82 | 83 | 84 | class DiceLoss(nn.Module): 85 | def __init__(self): 86 | super(DiceLoss, self).__init__() 87 | 88 | def forward(self, y_pred, y_true): 89 | smooth = 1. 90 | y_pred = torch.sigmoid(y_pred)[:, 1, :, :] 91 | y_true = (y_true > 0.5).float() 92 | 93 | intersection = torch.sum(y_pred * y_true) 94 | union = torch.sum(y_pred) + torch.sum(y_true) 95 | dice = (2.0 * intersection + smooth) / (union + smooth) 96 | loss = 1 - dice 97 | return loss 98 | 99 | -------------------------------------------------------------------------------- /dataloaders/dataset_specifics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset Specifics 3 | Extended from ADNet code by Hansen et al. 4 | """ 5 | 6 | import torch 7 | import random 8 | 9 | 10 | def get_label_names(dataset): 11 | label_names = {} 12 | if dataset == 'CMR': 13 | label_names[0] = 'BG' 14 | label_names[1] = 'LV-MYO' 15 | label_names[2] = 'LV-BP' 16 | label_names[3] = 'RV' 17 | 18 | elif dataset == 'CHAOST2': 19 | label_names[0] = 'BG' 20 | label_names[1] = 'LIVER' 21 | label_names[2] = 'RK' 22 | label_names[3] = 'LK' 23 | label_names[4] = 'SPLEEN' 24 | elif dataset == 'SABS': 25 | label_names[0] = 'BG' 26 | label_names[1] = 'SPLEEN' 27 | label_names[2] = 'RK' 28 | label_names[3] = 'LK' 29 | label_names[4] = 'GALLBLADDER' 30 | label_names[5] = 'ESOPHAGUS' 31 | label_names[6] = 'LIVER' 32 | label_names[7] = 'STOMACH' 33 | label_names[8] = 'AORTA' 34 | label_names[9] = 'IVC' # Inferior vena cava 35 | label_names[10] = 'PS_VEIN' # portal vein and splenic vein 36 | label_names[11] = 'PANCREAS' 37 | label_names[12] = 'AG_R' # right adrenal gland 38 | label_names[13] = 'AG_L' # left adrenal gland 39 | 40 | return label_names 41 | 42 | 43 | def get_folds(dataset): 44 | FOLD = {} 45 | if dataset == 'CMR': 46 | FOLD[0] = set(range(0, 8)) 47 | FOLD[1] = set(range(7, 15)) 48 | FOLD[2] = set(range(14, 22)) 49 | FOLD[3] = set(range(21, 29)) 50 | FOLD[4] = set(range(28, 35)) 51 | FOLD[4].update([0]) 52 | return FOLD 53 | 54 | elif dataset == 'CHAOST2': 55 | FOLD[0] = set(range(0, 5)) 56 | FOLD[1] = set(range(4, 9)) 57 | FOLD[2] = set(range(8, 13)) 58 | FOLD[3] = set(range(12, 17)) 59 | FOLD[4] = set(range(16, 20)) 60 | FOLD[4].update([0]) 61 | return FOLD 62 | elif dataset == 'SABS': 63 | FOLD[0] = set(range(0, 7)) 64 | FOLD[1] = set(range(6, 13)) 65 | FOLD[2] = set(range(12, 19)) 66 | FOLD[3] = set(range(18, 25)) 67 | FOLD[4] = set(range(24, 30)) 68 | FOLD[4].update([0]) 69 | return FOLD 70 | else: 71 | raise ValueError(f'Dataset: {dataset} not found') 72 | 73 | 74 | def sample_xy(spr, k=0, b=215): 75 | _, h, v = torch.where(spr) 76 | 77 | if len(h) == 0 or len(v) == 0: 78 | horizontal = 0 79 | vertical = 0 80 | else: 81 | 82 | h_min = min(h) 83 | h_max = max(h) 84 | if b > (h_max - h_min): 85 | kk = min(k, int((h_max - h_min) / 2)) 86 | horizontal = random.randint(max(h_max - b - kk, 0), min(h_min + kk, 256 - b - 1)) 87 | else: 88 | kk = min(k, int(b / 2)) 89 | horizontal = random.randint(max(h_min - kk, 0), min(h_max - b + kk, 256 - b - 1)) 90 | 91 | v_min = min(v) 92 | v_max = max(v) 93 | if b > (v_max - v_min): 94 | kk = min(k, int((v_max - v_min) / 2)) 95 | vertical = random.randint(max(v_max - b - kk, 0), min(v_min + kk, 256 - b - 1)) 96 | else: 97 | kk = min(k, int(b / 2)) 98 | vertical = random.randint(max(v_min - kk, 0), min(v_max - b + kk, 256 - b - 1)) 99 | 100 | return horizontal, vertical 101 | -------------------------------------------------------------------------------- /data/supervoxels/generate_supervoxels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from Ouyang et al. 3 | https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation 4 | """ 5 | 6 | import os 7 | import SimpleITK as sitk 8 | import glob 9 | from skimage.measure import label 10 | import scipy.ndimage.morphology as snm 11 | from felzenszwalb_3d import * 12 | 13 | base_dir = '../../data/SABS/sabs_CT_normalized' 14 | # base_dir = '../../data/CHAOST2/chaos_MR_T2_normalized' 15 | # base_dir = '/CMR/cmr_MR_normalized' 16 | 17 | imgs = glob.glob(os.path.join(base_dir, 'image*')) 18 | labels = glob.glob(os.path.join(base_dir, 'label*')) 19 | 20 | imgs = sorted(imgs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 21 | labels = sorted(labels, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 22 | 23 | fg_thresh = 10 24 | 25 | MODE = 'MIDDLE' 26 | n_sv = 5000 27 | # n_sv = 1000 28 | # if ~os.path.exists(f'../../data/CHAOST2/supervoxels_{n_sv}/'): 29 | # os.mkdir(f'../../data/CHAOST2/supervoxels_{n_sv}/') 30 | if ~os.path.exists(f'../../data/SABS/supervoxels_{n_sv}/'): 31 | os.mkdir(f'../../data/SABS/supervoxels_{n_sv}/') 32 | 33 | 34 | def read_nii_bysitk(input_fid): 35 | """ read nii to numpy through simpleitk 36 | peelinfo: taking direction, origin, spacing and metadata out 37 | """ 38 | img_obj = sitk.ReadImage(input_fid) 39 | img_np = sitk.GetArrayFromImage(img_obj) 40 | return img_np 41 | 42 | 43 | # thresholding the intensity values to get a binary mask of the patient 44 | def fg_mask2d(img_2d, thresh): 45 | mask_map = np.float32(img_2d > thresh) 46 | 47 | def getLargestCC(segmentation): # largest connected components 48 | labels = label(segmentation) 49 | assert (labels.max() != 0) # assume at least 1 CC 50 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 51 | return largestCC 52 | 53 | if mask_map.max() < 0.999: 54 | return mask_map 55 | else: 56 | post_mask = getLargestCC(mask_map) 57 | fill_mask = snm.binary_fill_holes(post_mask) 58 | return fill_mask 59 | 60 | 61 | # remove supervoxels within the empty regions 62 | def supervox_masking(seg, mask): 63 | seg[seg == 0] = seg.max() + 1 64 | seg = np.int32(seg) 65 | seg[mask == 0] = 0 66 | 67 | return seg 68 | 69 | 70 | # make supervoxels 71 | for img_path in imgs: 72 | img = read_nii_bysitk(img_path) 73 | img = 255 * (img - img.min()) / img.ptp() 74 | 75 | reader = sitk.ImageFileReader() 76 | reader.SetFileName(img_path) 77 | reader.LoadPrivateTagsOn() 78 | reader.ReadImageInformation() 79 | 80 | x = float(reader.GetMetaData('pixdim[1]')) 81 | y = float(reader.GetMetaData('pixdim[2]')) 82 | z = float(reader.GetMetaData('pixdim[3]')) 83 | 84 | segments_felzenszwalb = felzenszwalb_3d(img, min_size=n_sv, sigma=0, spacing=(z, x, y)) 85 | 86 | # post processing: remove bg (low intensity regions) 87 | fg_mask_vol = np.zeros(segments_felzenszwalb.shape) 88 | for ii in range(segments_felzenszwalb.shape[0]): 89 | _fgm = fg_mask2d(img[ii, ...], fg_thresh) 90 | fg_mask_vol[ii] = _fgm 91 | processed_seg_vol = supervox_masking(segments_felzenszwalb, fg_mask_vol) 92 | 93 | # write to nii.gz 94 | out_seg = sitk.GetImageFromArray(processed_seg_vol) 95 | 96 | idx = os.path.basename(img_path).split("_")[-1].split(".nii.gz")[0] 97 | 98 | seg_fid = os.path.join(f'../../data/SABS/supervoxels_{n_sv}/', f'superpix-{MODE}_{idx}.nii.gz') 99 | # seg_fid = os.path.join(f'../../data/CHAOST2/supervoxels_{n_sv}/', f'superpix-{MODE}_{idx}.nii.gz') 100 | sitk.WriteImage(out_seg, seg_fid) 101 | print(f'image with id {idx} has finished') 102 | -------------------------------------------------------------------------------- /visualization/mask_generalization_cmr.py: -------------------------------------------------------------------------------- 1 | # the code for generating masks for CMR 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import SimpleITK as itk 7 | import matplotlib.pyplot as plt 8 | import matplotlib 9 | import cv2 10 | import SimpleITK as itk 11 | 12 | # CMR dataset 13 | 14 | abd_ct_pth = './data/Abd_CT/' 15 | abd_mri_pth = './data/Abd_MRI/' 16 | cmr_pth = './data/CMR/' 17 | 18 | cmr_gt_pth = os.path.join(cmr_pth, 'CMR_x_GT.nii.gz') # the ground truth mask, x is case number 19 | cmr_img_pth = os.path.join(cmr_pth, 'CMR_x.nii.gz') # the image 20 | cmr_lvbp_pth = os.path.join(cmr_pth, 'prediction_x_LV-BP.nii.gz') # the LV-BP prediction of case x 21 | cmr_lvmyo_pth = os.path.join(cmr_pth, 'prediction_x_LV-MYO.nii.gz') # the LV-MYO prediction of case x 22 | cmr_rv_pth = os.path.join(cmr_pth, 'prediction_x_RV.nii.gz') # the RV prediction of case x 23 | 24 | 25 | cmr_gt = itk.GetArrayFromImage(itk.ReadImage(cmr_gt_pth)) 26 | cmr_img = itk.GetArrayFromImage(itk.ReadImage(cmr_img_pth)) 27 | 28 | cmr_gt[cmr_gt == 200] = 1 29 | cmr_gt[cmr_gt == 500] = 2 30 | cmr_gt[cmr_gt == 600] = 3 31 | 32 | # ********************************lvbp********************************** 33 | cmr_lvbp = itk.GetArrayFromImage(itk.ReadImage(cmr_lvbp_pth)) 34 | cmr_lvbp_gt = 1 * (cmr_gt == 2) 35 | idx = cmr_lvbp_gt.sum(axis=(1, 2)) > 0 36 | cmr_lvbp_gt = cmr_lvbp_gt[idx] 37 | cmr_img_lvbp = cmr_img[idx] 38 | 39 | cmr_lvbp_gt_show = cmr_lvbp_gt[4]*200 # choose the 4th slice of case x to illustrate, you also can choose other slices 40 | cmr_img_show_1 = cmr_img_lvbp[4] / 4.5 41 | 42 | cmr_lvbp_spt = cmr_lvbp_gt[2]*200 # choose the 2th slice of case x as support image 43 | cmr_img_spt = cmr_img_lvbp[2] / 4.5 44 | 45 | 46 | cv2.imwrite("./data/CMR/cmr_lvbp_gt.png", cmr_lvbp_gt_show) # ground truth 47 | cv2.imwrite("./data/CMR/cmr_lvbp_img.png", cmr_img_show_1) # the image 48 | cv2.imwrite("./data/CMR/cmr_lvbp.png", cmr_img_lvbp_show) # the lvbp prediction 49 | 50 | cv2.imwrite("./data/CMR/cmr_lvbp_spt.png", cmr_lvbp_spt) # the lvbp mask of support image 51 | cv2.imwrite("./data/CMR/cmr_lvbp_img_spt.png", cmr_img_spt) # the support image 52 | 53 | # **********************************lvmyo********************************* 54 | cmr_lvmyo = itk.GetArrayFromImage(itk.ReadImage(cmr_lvmyo_pth)) 55 | cmr_lvmyo_gt = 1 * (cmr_gt == 1) 56 | idx = cmr_lvmyo_gt.sum(axis=(1, 2)) > 0 57 | cmr_lvmyo_gt = cmr_lvmyo_gt[idx] 58 | cmr_img_lvmyo = cmr_img[idx] 59 | 60 | cmr_lvmyo_gt_show = cmr_lvmyo_gt[6]*200 61 | cmr_img_show_2 = cmr_img_lvmyo[6] / 4.5 62 | 63 | cmr_lvmyo_spt = cmr_lvmyo_gt[3]*200 64 | cmr_img_spt = cmr_img_lvmyo[3] / 4.5 65 | 66 | cv2.imwrite("./data/CMR/cmr_lvmyo_gt.png", cmr_lvmyo_gt_show) 67 | cv2.imwrite("./data/CMR/cmr_lvmyo_img.png", cmr_img_show_2) 68 | cv2.imwrite("./data/CMR/cmr_lvmyo.png", cmr_img_lvmyo_show) 69 | 70 | cv2.imwrite("./data/CMR/cmr_lvmyo_spt.png", cmr_lvmyo_spt) 71 | cv2.imwrite("./data/CMR/cmr_lvmyo_img_spt.png", cmr_img_spt) 72 | 73 | # **********************************rv************************************ 74 | cmr_rv = itk.GetArrayFromImage(itk.ReadImage(cmr_rv_pth)) 75 | cmr_rv_gt = 1 * (cmr_gt == 3) 76 | idx = cmr_rv_gt.sum(axis=(1, 2)) > 0 77 | cmr_rv_gt = cmr_rv_gt[idx] 78 | cmr_img_rv = cmr_img[idx] 79 | 80 | cmr_rv_gt_show = cmr_rv_gt[2]*200 81 | cmr_img_show_3 = cmr_img_rv[2] / 4.5 82 | 83 | cmr_rv_spt = cmr_rv_gt[5]*200 84 | cmr_img_spt = cmr_img_rv[5] / 4.5 85 | 86 | cv2.imwrite("./data/CMR/cmr_rv_gt.png", cmr_rv_gt_show) 87 | cv2.imwrite("./data/CMR/cmr_rv_img.png", cmr_img_show_3) 88 | cv2.imwrite("./data/CMR/cmr_rv.png", cmr_img_rv_show) 89 | 90 | cv2.imwrite("./data/CMR/cmr_rv_spt.png", cmr_rv_spt) 91 | cv2.imwrite("./data/CMR/cmr_rv_img_spt.png", cmr_img_spt) 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for Dataset 3 | Extended from ADNet code by Hansen et al. 4 | """ 5 | import random 6 | import torch 7 | import numpy as np 8 | import operator 9 | import os 10 | import logging 11 | 12 | 13 | def set_seed(seed): 14 | """ 15 | Set the random seed 16 | """ 17 | random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | 21 | 22 | CLASS_LABELS = { 23 | 'CHAOST2': { 24 | 'pa_all': set(range(1, 5)), 25 | 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes 26 | 1: set([2, 3]), # lower_abdomen 27 | }, 28 | } 29 | 30 | 31 | def get_bbox(fg_mask, inst_mask): 32 | """ 33 | Get the ground truth bounding boxes 34 | """ 35 | 36 | fg_bbox = torch.zeros_like(fg_mask, device=fg_mask.device) 37 | bg_bbox = torch.ones_like(fg_mask, device=fg_mask.device) 38 | 39 | inst_mask[fg_mask == 0] = 0 40 | area = torch.bincount(inst_mask.view(-1)) 41 | cls_id = area[1:].argmax() + 1 42 | cls_ids = np.unique(inst_mask)[1:] 43 | 44 | mask_idx = np.where(inst_mask[0] == cls_id) 45 | y_min = mask_idx[0].min() 46 | y_max = mask_idx[0].max() 47 | x_min = mask_idx[1].min() 48 | x_max = mask_idx[1].max() 49 | fg_bbox[0, y_min:y_max + 1, x_min:x_max + 1] = 1 50 | 51 | for i in cls_ids: 52 | mask_idx = np.where(inst_mask[0] == i) 53 | y_min = max(mask_idx[0].min(), 0) 54 | y_max = min(mask_idx[0].max(), fg_mask.shape[1] - 1) 55 | x_min = max(mask_idx[1].min(), 0) 56 | x_max = min(mask_idx[1].max(), fg_mask.shape[2] - 1) 57 | bg_bbox[0, y_min:y_max + 1, x_min:x_max + 1] = 0 58 | return fg_bbox, bg_bbox 59 | 60 | 61 | def t2n(img_t): 62 | """ 63 | torch to numpy regardless of whether tensor is on gpu or memory 64 | """ 65 | if img_t.is_cuda: 66 | return img_t.data.cpu().numpy() 67 | else: 68 | return img_t.data.numpy() 69 | 70 | 71 | def to01(x_np): 72 | """ 73 | normalize a numpy to 0-1 for visualize 74 | """ 75 | return (x_np - x_np.min()) / (x_np.max() - x_np.min() + 1e-5) 76 | 77 | 78 | class Scores(): 79 | 80 | def __init__(self): 81 | self.TP = 0 82 | self.TN = 0 83 | self.FP = 0 84 | self.FN = 0 85 | 86 | self.patient_dice = [] 87 | self.patient_iou = [] 88 | 89 | def record(self, preds, label): 90 | assert len(torch.unique(preds)) < 3 91 | 92 | tp = torch.sum((label == 1) * (preds == 1)) 93 | tn = torch.sum((label == 0) * (preds == 0)) 94 | fp = torch.sum((label == 0) * (preds == 1)) 95 | fn = torch.sum((label == 1) * (preds == 0)) 96 | 97 | self.patient_dice.append(2 * tp / (2 * tp + fp + fn)) 98 | self.patient_iou.append(tp / (tp + fp + fn)) 99 | 100 | self.TP += tp 101 | self.TN += tn 102 | self.FP += fp 103 | self.FN += fn 104 | 105 | def compute_dice(self): 106 | return 2 * self.TP / (2 * self.TP + self.FP + self.FN) 107 | 108 | def compute_iou(self): 109 | return self.TP / (self.TP + self.FP + self.FN) 110 | 111 | 112 | def set_logger(path): 113 | logger = logging.getLogger() 114 | logger.handlers = [] 115 | formatter = logging.Formatter('[%(levelname)] - %(name)s - %(message)s') 116 | logger.setLevel("INFO") 117 | 118 | # log to .txt 119 | file_handler = logging.FileHandler(path) 120 | file_handler.setFormatter(formatter) 121 | logger.addHandler(file_handler) 122 | 123 | # log to console 124 | stream_handler = logging.StreamHandler() 125 | stream_handler.setFormatter(formatter) 126 | logger.addHandler(stream_handler) 127 | return logger 128 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | """ Scaled Dot-Product Attention """ 9 | 10 | def __init__(self, temperature, attn_dropout=0.1): 11 | super().__init__() 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | self.softmax = nn.Softmax(dim=-1) 15 | 16 | def forward(self, q, k, v): 17 | attn = torch.bmm(q, k.transpose(1, 2)) 18 | attn = attn / self.temperature 19 | log_attn = F.log_softmax(attn, 2) 20 | attn = self.softmax(attn) 21 | attn = self.dropout(attn) 22 | 23 | output = torch.bmm(attn, v) 24 | return output, attn, log_attn 25 | 26 | 27 | class MultiHeadAttention(nn.Module): # for 64 channel 28 | """ Multi-Head Attention module """ 29 | 30 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, do_activation=True): 31 | super().__init__() 32 | self.n_head = n_head 33 | self.d_k = d_k 34 | self.d_v = d_v 35 | self.do_activation = do_activation 36 | 37 | self.w_qs = nn.Linear(d_model, n_head * d_k) 38 | self.w_ks = nn.Linear(d_model, n_head * d_k) 39 | self.w_vs = nn.Linear(d_model, n_head * d_v) 40 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 41 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 42 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 43 | 44 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 45 | self.layer_norm = nn.LayerNorm(d_model) 46 | 47 | self.fc = nn.Linear(n_head * d_v, d_model) 48 | nn.init.xavier_normal_(self.fc.weight) 49 | self.dropout = nn.Dropout(dropout) 50 | 51 | self.activation = F.relu 52 | 53 | def forward(self, q, k, v): 54 | 55 | """ 56 | q shape: 1 x N'x C 57 | k shape: 1 x N'x C 58 | v shape: 1 x N'x C 59 | """ 60 | 61 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 62 | 63 | sz_b, len_q, _ = q.size() 64 | sz_b, len_k, _ = k.size() 65 | sz_b, len_v, _ = v.size() 66 | 67 | residual = q 68 | 69 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 70 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 71 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 72 | 73 | # activation here 74 | if self.do_activation: 75 | q = self.activation(q) 76 | k = self.activation(k) 77 | v = self.activation(v) 78 | 79 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 80 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 81 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 82 | 83 | output, attn, log_attn = self.attention(q, k, v) 84 | 85 | output = output.view(n_head, sz_b, len_q, d_v) 86 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 87 | 88 | output = self.dropout(self.fc(output)) 89 | 90 | # activation here 91 | if self.do_activation: 92 | output = self.activation(output) 93 | 94 | output = self.layer_norm(output + residual) 95 | 96 | return output 97 | 98 | 99 | class MultiLayerPerceptron(nn.Module): 100 | """ 101 | Multi-layer Perceptron module 102 | """ 103 | def __init__(self, dim, mlp_dim): 104 | super(MultiLayerPerceptron, self).__init__() 105 | self.norm = nn.LayerNorm(dim) 106 | self.mlp = nn.Sequential( 107 | nn.Linear(dim, mlp_dim), 108 | nn.GELU(), 109 | nn.Linear(mlp_dim, dim) 110 | ) 111 | 112 | def forward(self, x): 113 | x = self.mlp(x) + x 114 | x = self.norm(x) 115 | return x 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 |

Few-Shot Medical Image Segmentation via a Region-enhanced Prototypical Transformer

5 | 6 |
7 | 8 |

9 | 10 | 11 | ## Abstaract 12 | Automated segmentation of large volumes of medical images is often plagued by the limited availability of fully annotated data and the diversity of organ surface properties resulting from the use of different acquisition protocols for different patients. In this paper, we introduce a more promising few-shot learning-based method named Region-enhanced Prototypical Transformer (RPT) to mitigate the effects of large intra-class diversity/bias. First, a subdivision strategy is introduced to produce a collection of regional prototypes from the foreground of the support prototype. Second, a self-selection mechanism is proposed to incorporate into the Bias-alleviated Transformer (BaT) block to suppress or remove interferences present in the query prototype and regional support prototypes. By stacking BaT blocks, the proposed RPT can iteratively optimize the generated regional prototypes and finally produce rectified and more accurate global prototypes for Few-Shot Medical Image Segmentation (FSMS). Extensive experiments are conducted on three publicly available medical image datasets, and the obtained results show consistent improvements compared to state-of-the-art FSMS methods. 13 | 14 | 15 | # Getting started 16 | 17 | ### Dependencies 18 | Please install following essential dependencies: 19 | ``` 20 | dcm2nii 21 | json5==0.8.5 22 | jupyter==1.0.0 23 | nibabel==2.5.1 24 | numpy==1.22.0 25 | opencv-python==4.5.5.62 26 | Pillow>=8.1.1 27 | sacred==0.8.2 28 | scikit-image==0.18.3 29 | SimpleITK==1.2.3 30 | torch==1.10.2 31 | torchvision=0.11.2 32 | tqdm==4.62.3 33 | ``` 34 | 35 | Pre-processing is performed according to [Ouyang et al.](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation/tree/2f2a22b74890cb9ad5e56ac234ea02b9f1c7a535) and we follow the procedure on their github repository. 36 | 37 | 38 | The trained models can be downloaded by: 39 | 1) [trained models for CHAOS under Setting 1](https://drive.google.com/drive/folders/1gp2Hp4EPBOKwIbVN4l3QyfWeAN8l4jJj?usp=drive_link) 40 | 2) [trained models for CHAOS under Setting 2](https://drive.google.com/drive/folders/1RQ0B0XQfOIwoO-7R2sUG7h7dea6v4TJa?usp=drive_link) 41 | 3) [trained models for SABS under Setting 1](https://drive.google.com/drive/folders/1xXK8_1fQVQyRoL1N49RN7ZW3H10E-Y5-?usp=drive_link) 42 | 4) [trained models for SABS under Setting 2](https://drive.google.com/drive/folders/1EZamwmnh8DkJ51J3VJbC0Mn2vhdE-cGt?usp=drive_link) 43 | 5) [trained models for CMR](https://drive.google.com/drive/folders/1czW-1mMOdaouI5PBPBNXI8cLbt9jJ2xq?usp=drive_link) 44 | 45 | 46 | 47 | The pre-processed data and supervoxels can be downloaded by: 48 | 1) [Pre-processed CHAOS-T2 data and supervoxels](https://drive.google.com/drive/folders/1elxzn67Hhe0m1PvjjwLGls6QbkIQr1m1?usp=share_link) 49 | 2) [Pre-processed SABS data and supervoxels](https://drive.google.com/drive/folders/1pgm9sPE6ihqa2OuaiSz7X8QhXKkoybv5?usp=share_link) 50 | 3) [Pre-processed CMR data and supervoxels](https://drive.google.com/drive/folders/1aaU5KQiKOZelfVOpQxxfZNXKNkhrcvY2?usp=share_link) 51 | ### Training 52 | 1. Compile `./supervoxels/felzenszwalb_3d_cy.pyx` with cython (`python ./supervoxels/setup.py build_ext --inplace`) and run `./supervoxels/generate_supervoxels.py` 53 | 2. Download pre-trained ResNet-101 weights [vanilla version](https://download.pytorch.org/models/resnet101-63fe2227.pth) or [deeplabv3 version](https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth) and put your checkpoints folder, then replace the absolute path in the code `./models/encoder.py`. 54 | 3. Run `./script/train.sh` 55 | 56 | ### Inference 57 | Run `./script/test.sh` 58 | 59 | ### Acknowledgement 60 | Our code is based the works: [SSL-ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation), [ADNet](https://github.com/sha168/ADNet) and [QNet](https://github.com/ZJLAB-AMMI/Q-Net) 61 | 62 | ## Citation 63 | ```bibtex 64 | @inproceedings{zhu2023few, 65 | title={Few-Shot Medical Image Segmentation via a Region-Enhanced Prototypical Transformer}, 66 | author={Zhu, Yazhou and Wang, Shidong and Xin, Tong and Zhang, Haofeng}, 67 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 68 | pages={271--280}, 69 | year={2023}, 70 | organization={Springer} 71 | } 72 | ``` 73 | 74 | -------------------------------------------------------------------------------- /visualization/mask_generalization_ct.py: -------------------------------------------------------------------------------- 1 | # the code for generating masks for ABD-CT 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import SimpleITK as itk 7 | import matplotlib.pyplot as plt 8 | import matplotlib 9 | import cv2 10 | import SimpleITK as itk 11 | 12 | # Abd_CT dataset 13 | 14 | abd_ct_pth = './data/Abd_CT/' 15 | abd_mri_pth = './data/Abd_MRI/' 16 | cmr_pth = './data/CMR/' 17 | 18 | abd_ct_gt_pth = os.path.join(abd_ct_pth, 'Abd_CT_x_GT.nii.gz') # the ground truth mask, x is case number 19 | abd_ct_img_pth = os.path.join(abd_ct_pth, 'Abd_CT_x.nii.gz') # the image 20 | abd_ct_liver_pth = os.path.join(abd_ct_pth, 'prediction_x_LIVER.nii.gz') # the liver prediction of case x 21 | abd_ct_spleen_pth = os.path.join(abd_ct_pth, 'prediction_x_SPLEEN.nii.gz') # the spleen prediction of case x 22 | abd_ct_rk_pth = os.path.join(abd_ct_pth, 'prediction_x_RK.nii.gz') # the right kidney prediction of case x 23 | abd_ct_lk_pth = os.path.join(abd_ct_pth, 'prediction_x_LK.nii.gz') # the left kidney prediction of case x 24 | 25 | 26 | abd_ct_gt = itk.GetArrayFromImage(itk.ReadImage(abd_ct_gt_pth)) 27 | abd_ct_img = itk.GetArrayFromImage(itk.ReadImage(abd_ct_img_pth)) 28 | 29 | abd_ct_gt[abd_ct_gt == 200] = 1 30 | abd_ct_gt[abd_ct_gt == 500] = 2 31 | abd_ct_gt[abd_ct_gt == 600] = 3 32 | 33 | # ********************************liver********************************** 34 | abd_ct_liver = itk.GetArrayFromImage(itk.ReadImage(abd_ct_liver_pth)) 35 | abd_ct_liver_gt = 1 * (abd_ct_gt == 6) 36 | idx = abd_ct_liver_gt.sum(axis=(1, 2)) > 0 37 | abd_ct_liver_gt = abd_ct_liver_gt[idx] 38 | abd_ct_img_liver = abd_ct_img[idx] 39 | 40 | abd_ct_liver_gt_show = abd_ct_liver_gt[16] * 200 # choose the 16th slice of case x to illustrate, you also can choose other slices 41 | abd_ct_img_show_1 = abd_ct_img_liver[16] 42 | abd_ct_liver_show = abd_ct_liver[16] * 200 43 | 44 | abd_ct_liver_spt = abd_ct_liver_gt[13] * 200 # choose the 13th slice of case x as support image. 45 | abd_ct_img_spt = abd_ct_img_liver[13] 46 | 47 | cv2.imwrite("./data/Abd_CT/abd_ct_liver_gt.png", abd_ct_liver_gt_show) # ground truth 48 | cv2.imwrite("./data/Abd_CT/abd_ct_liver_img.png", abd_ct_img_show_1) # the image 49 | cv2.imwrite("./data/Abd_CT/abd_ct_liver.png", abd_ct_liver_show) # the liver prediction 50 | 51 | cv2.imwrite("./data/Abd_CT/abd_ct_liver_spt.png", abd_ct_liver_spt) # the liver mask of support image 52 | cv2.imwrite("./data/Abd_CT/abd_ct_liver_img_spt.png", abd_ct_img_spt) # the support image 53 | 54 | # **********************************spleen********************************* 55 | abd_ct_spleen = itk.GetArrayFromImage(itk.ReadImage(abd_ct_spleen_pth)) 56 | abd_ct_spleen_gt = 1 * (abd_ct_gt == 1) 57 | idx = abd_ct_spleen_gt.sum(axis=(1, 2)) > 0 58 | abd_ct_spleen_gt = abd_ct_spleen_gt[idx] 59 | abd_ct_img_spleen = abd_ct_img[idx] 60 | 61 | abd_ct_spleen_gt_show = abd_ct_spleen_gt[14]*200 62 | abd_ct_img_show_2 = abd_ct_img_spleen[14] 63 | 64 | abd_ct_spleen_spt = abd_ct_spleen_gt[8]*200 65 | abd_ct_img_spt = abd_ct_img_spleen[8] 66 | 67 | cv2.imwrite("./data/Abd_CT/abd_ct_spleen_gt.png", abd_ct_spleen_gt_show) 68 | cv2.imwrite("./data/Abd_CT/abd_ct_spleen_img.png", abd_ct_img_show_2) 69 | cv2.imwrite("./data/Abd_CT/abd_ct_spleen.png", abd_ct_spleen_show) 70 | 71 | cv2.imwrite("./data/Abd_CT/abd_ct_spleen_spt.png", abd_ct_spleen_spt) 72 | cv2.imwrite("./data/Abd_CT/abd_ct_spleen_img_spt.png", abd_ct_img_spt) 73 | 74 | # **********************************RK************************************ 75 | abd_ct_rk = itk.GetArrayFromImage(itk.ReadImage(abd_ct_rk_pth)) 76 | abd_ct_rk_gt = 1 * (abd_ct_gt == 2) 77 | idx = abd_ct_rk_gt.sum(axis=(1, 2)) > 0 78 | abd_ct_rk_gt = abd_ct_rk_gt[idx] 79 | abd_ct_img_rk = abd_ct_img[idx] 80 | 81 | abd_ct_rk_gt_show = abd_ct_rk_gt[18]*200 82 | abd_ct_img_show_3 = abd_ct_img_rk[18] 83 | 84 | abd_ct_rk_spt = abd_ct_rk_gt[10]*200 85 | abd_ct_img_spt = abd_ct_img_rk[10] 86 | 87 | cv2.imwrite("./data/Abd_CT/abd_ct_rk_gt.png", abd_ct_rk_gt_show) 88 | cv2.imwrite("./data/Abd_CT/abd_ct_rk_img.png", abd_ct_img_show_3) 89 | cv2.imwrite("./data/Abd_CT/abd_ct_rk.png", abd_ct_rk_show) 90 | 91 | cv2.imwrite("./data/Abd_CT/abd_ct_rk_spt.png", abd_ct_rk_spt) 92 | cv2.imwrite("./data/Abd_CT/abd_ct_rk_img_spt.png", abd_ct_img_spt) 93 | 94 | # *********************************LK************************************** 95 | abd_ct_lk = itk.GetArrayFromImage(itk.ReadImage(abd_ct_lk_pth)) 96 | abd_ct_lk_gt = 1 * (abd_ct_gt == 3) 97 | idx = abd_ct_lk_gt.sum(axis=(1, 2)) > 0 98 | abd_ct_lk_gt = abd_ct_lk_gt[idx] 99 | abd_ct_img_lk = abd_ct_img[idx] 100 | 101 | abd_ct_lk_gt_show = abd_ct_lk_gt[17]*200 102 | abd_ct_img_show_4 = abd_ct_img_lk[17] 103 | 104 | abd_ct_lk_spt = abd_ct_lk_gt[8]*200 105 | abd_ct_img_spt = abd_ct_img_lk[8] 106 | 107 | cv2.imwrite("./data/Abd_CT/abd_ct_lk_gt.png", abd_ct_lk_gt_show) 108 | cv2.imwrite("./data/Abd_CT/abd_ct_lk_img.png", abd_ct_img_show_4) 109 | cv2.imwrite("./data/Abd_CT/abd_ct_lk.png", abd_ct_lk_show) 110 | 111 | cv2.imwrite("./data/Abd_CT/abd_ct_lk_spt.png", abd_ct_lk_spt) 112 | cv2.imwrite("./data/Abd_CT/abd_ct_lk_img_spt.png", abd_ct_img_spt) 113 | 114 | 115 | -------------------------------------------------------------------------------- /visualization/mask_generalization_mri.py: -------------------------------------------------------------------------------- 1 | # the code for generating masks for ABD-MRI 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import SimpleITK as itk 7 | import matplotlib.pyplot as plt 8 | import matplotlib 9 | import cv2 10 | import SimpleITK as itk 11 | 12 | # Abd_MRI dataset 13 | 14 | 15 | abd_ct_pth = './data/Abd_CT/' 16 | abd_mri_pth = './data/Abd_MRI/' 17 | cmr_pth = './data/CMR/' 18 | 19 | abd_mri_gt_pth = os.path.join(abd_mri_pth, 'Abd_MRI_x_GT.nii.gz') # the ground truth mask, x is case number 20 | abd_mri_img_pth = os.path.join(abd_mri_pth, 'Abd_MRI_x.nii.gz') # the image 21 | 22 | abd_mri_liver_pth = os.path.join(abd_mri_pth, 'prediction_x_LIVER.nii.gz') # the liver prediction of case x 23 | abd_mri_spleen_pth = os.path.join(abd_mri_pth, 'prediction_x_SPLEEN.nii.gz') # the spleen prediction of case x 24 | abd_mri_rk_pth = os.path.join(abd_mri_pth, 'prediction_x_RK.nii.gz') # the right kidney prediction of case x 25 | abd_mri_lk_pth = os.path.join(abd_mri_pth, 'prediction_x_LK.nii.gz') # the left kidney prediction of case x 26 | 27 | abd_mri_gt = itk.GetArrayFromImage(itk.ReadImage(abd_mri_gt_pth)) 28 | abd_mri_img = itk.GetArrayFromImage(itk.ReadImage(abd_mri_img_pth)) 29 | 30 | 31 | abd_mri_gt[abd_mri_gt == 200] = 1 32 | abd_mri_gt[abd_mri_gt == 500] = 2 33 | abd_mri_gt[abd_mri_gt == 600] = 3 34 | 35 | # ********************************liver********************************** 36 | abd_mri_liver = itk.GetArrayFromImage(itk.ReadImage(abd_mri_liver_pth)) 37 | abd_mri_liver_gt = 1 * (abd_mri_gt == 1) 38 | idx = abd_mri_liver_gt.sum(axis=(1, 2)) > 0 39 | abd_mri_liver_gt = abd_mri_liver_gt[idx] 40 | abd_mri_img_liver = abd_mri_img[idx] 41 | 42 | abd_mri_liver_gt_show = abd_mri_liver_gt[11] * 200 # choose the 11th slice of case x to illustrate, you also can choose other slices 43 | abd_mri_img_show_1 = abd_mri_img_liver[11] / 4.5 44 | abd_mri_liver_show = abd_mri_liver[11] * 200 45 | 46 | abd_mri_liver_spt = abd_mri_liver_gt[5] * 200 # choose the 5th slice of case x as support image. 47 | abd_mri_img_spt = abd_mri_img_liver[5] / 4.5 48 | 49 | cv2.imwrite("./data/Abd_MRI/abd_mri_liver_gt.png", abd_mri_liver_gt_show) # the ground truth 50 | cv2.imwrite("./data/Abd_MRI/abd_mri_liver_img.png", abd_mri_img_show_1) # the image 51 | cv2.imwrite("./data/Abd_MRI/abd_mri_liver.png", abd_mri_liver_show) # the liver prediction 52 | 53 | cv2.imwrite("./data/Abd_MRI/abd_mri_liver_spt.png", abd_mri_liver_spt) # the liver mask of support image 54 | cv2.imwrite("./data/Abd_MRI/abd_mri_liver_img_spt.png", abd_mri_img_spt) # the support image 55 | 56 | # **********************************spleen********************************* 57 | abd_mri_spleen = itk.GetArrayFromImage(itk.ReadImage(abd_mri_spleen_pth)) 58 | abd_mri_spleen_gt = 1 * (abd_mri_gt == 4) 59 | idx = abd_mri_spleen_gt.sum(axis=(1, 2)) > 0 60 | abd_mri_spleen_gt = abd_mri_spleen_gt[idx] 61 | abd_mri_img_spleen = abd_mri_img[idx] 62 | 63 | abd_mri_spleen_gt_show = abd_mri_spleen_gt[10]*200 64 | abd_mri_img_show_2 = abd_mri_img_spleen[10] / 4.5 65 | 66 | abd_mri_spleen_spt = abd_mri_spleen_gt[4]*200 67 | abd_mri_img_spt = abd_mri_img_spleen[4] / 4.5 68 | 69 | cv2.imwrite("./data/Abd_MRI/abd_mri_spleen_gt.png", abd_mri_spleen_gt_show) 70 | cv2.imwrite("./data/Abd_MRI/abd_mri_spleen_img.png", abd_mri_img_show_2) 71 | cv2.imwrite("./data/Abd_MRI/abd_mri_spleen.png", abd_mri_spleen_show) 72 | 73 | cv2.imwrite("./data/Abd_MRI/abd_mri_spleen_spt.png", abd_mri_spleen_spt) 74 | cv2.imwrite("./data/Abd_MRI/abd_mri_spleen_img_spt.png", abd_mri_img_spt) 75 | 76 | # **********************************RK************************************ 77 | abd_mri_rk = itk.GetArrayFromImage(itk.ReadImage(abd_mri_rk_pth)) 78 | abd_mri_rk_gt = 1 * (abd_mri_gt == 2) 79 | idx = abd_mri_rk_gt.sum(axis=(1, 2)) > 0 80 | abd_mri_rk_gt = abd_mri_rk_gt[idx] 81 | abd_mri_img_rk = abd_mri_img[idx] 82 | 83 | abd_mri_rk_gt_show = abd_mri_rk_gt[10]*200 84 | abd_mri_img_show_3 = abd_mri_img_rk[10] / 4.5 85 | 86 | abd_mri_rk_spt = abd_mri_rk_gt[11]*200 87 | abd_mri_img_spt = abd_mri_img_rk[11] / 4.5 88 | 89 | cv2.imwrite("./data/Abd_MRI/abd_mri_rk_gt.png", abd_mri_rk_gt_show) 90 | cv2.imwrite("./data/Abd_MRI/abd_mri_rk_img.png", abd_mri_img_show_3) 91 | cv2.imwrite("./data/Abd_MRI/abd_mri_rk.png", abd_mri_rk_show) 92 | 93 | cv2.imwrite("./data/Abd_MRI/abd_mri_rk_spt.png", abd_mri_rk_spt) 94 | cv2.imwrite("./data/Abd_MRI/abd_mri_rk_img_spt.png", abd_mri_img_spt) 95 | 96 | # *********************************LK************************************** 97 | abd_mri_lk = itk.GetArrayFromImage(itk.ReadImage(abd_mri_lk_pth)) 98 | abd_mri_lk_gt = 1 * (abd_mri_gt == 3) 99 | idx = abd_mri_lk_gt.sum(axis=(1, 2)) > 0 100 | abd_mri_lk_gt = abd_mri_lk_gt[idx] 101 | abd_mri_img_lk = abd_mri_img[idx] 102 | 103 | abd_mri_lk_gt_show = abd_mri_lk_gt[11]*200 104 | abd_mri_img_show_4 = abd_mri_img_lk[11] / 4.5 105 | 106 | abd_mri_lk_spt = abd_mri_lk_gt[5]*200 107 | abd_mri_img_spt = abd_mri_img_lk[5] / 4.5 108 | 109 | cv2.imwrite("./data/Abd_MRI/abd_mri_lk_gt.png", abd_mri_lk_gt_show) 110 | cv2.imwrite("./data/Abd_MRI/abd_mri_lk_img.png", abd_mri_img_show_4) 111 | cv2.imwrite("./data/Abd_MRI/abd_mri_lk.png", abd_mri_lk_show) 112 | 113 | cv2.imwrite("./data/Abd_MRI/abd_mri_lk_spt.png", abd_mri_lk_spt) 114 | cv2.imwrite("./data/Abd_MRI/abd_mri_lk_img_spt.png", abd_mri_img_spt) 115 | 116 | -------------------------------------------------------------------------------- /data/CHAOST2/png_gth_to_nii.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Converting labels from png to nii file\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is the first step for data preparation\n", 13 | "\n", 14 | "Input: ground truth labels in `.png` format\n", 15 | "\n", 16 | "Output: labels in `.nii` format, indexed by patient id" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 13, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import os\n", 39 | "import glob\n", 40 | "\n", 41 | "import numpy as np\n", 42 | "import PIL\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import SimpleITK as sitk\n", 45 | "import sys\n", 46 | "sys.path.insert(0, '../../dataloaders/')\n", 47 | "import niftiio as nio" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 14, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "example = \"./MR/1/T2SPIR/Ground/IMG-0002-00001.png\" # example of ground-truth file name. " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 15, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "### search for scan ids\n", 73 | "ids = os.listdir(\"./MR/\")\n", 74 | "OUT_DIR = './niis/T2SPIR/'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 16, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['37',\n", 86 | " '3',\n", 87 | " '15',\n", 88 | " '34',\n", 89 | " '33',\n", 90 | " '39',\n", 91 | " '20',\n", 92 | " '10',\n", 93 | " '22',\n", 94 | " '8',\n", 95 | " '31',\n", 96 | " '2',\n", 97 | " '36',\n", 98 | " '5',\n", 99 | " '13',\n", 100 | " '19',\n", 101 | " '21',\n", 102 | " '1',\n", 103 | " '38',\n", 104 | " '32']" 105 | ] 106 | }, 107 | "execution_count": 16, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "ids" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 17, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "image with id 37 has been saved!\n", 126 | "image with id 3 has been saved!\n", 127 | "image with id 15 has been saved!\n", 128 | "image with id 34 has been saved!\n", 129 | "image with id 33 has been saved!\n", 130 | "image with id 39 has been saved!\n", 131 | "image with id 20 has been saved!\n", 132 | "image with id 10 has been saved!\n", 133 | "image with id 22 has been saved!\n", 134 | "image with id 8 has been saved!\n", 135 | "image with id 31 has been saved!\n", 136 | "image with id 2 has been saved!\n", 137 | "image with id 36 has been saved!\n", 138 | "image with id 5 has been saved!\n", 139 | "image with id 13 has been saved!\n", 140 | "image with id 19 has been saved!\n", 141 | "image with id 21 has been saved!\n", 142 | "image with id 1 has been saved!\n", 143 | "image with id 38 has been saved!\n", 144 | "image with id 32 has been saved!\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "#### Write them to nii files for the ease of loading in future\n", 150 | "for curr_id in ids:\n", 151 | " pngs = glob.glob(f'./MR/{curr_id}/T2SPIR/Ground/*.png')\n", 152 | " pngs = sorted(pngs, key = lambda x: int(os.path.basename(x).split(\"-\")[-1].split(\".png\")[0]))\n", 153 | " buffer = []\n", 154 | "\n", 155 | " for fid in pngs:\n", 156 | " buffer.append(PIL.Image.open(fid))\n", 157 | "\n", 158 | " vol = np.stack(buffer, axis = 0)\n", 159 | " # flip correction\n", 160 | " vol = np.flip(vol, axis = 1).copy()\n", 161 | " # remap values\n", 162 | " for new_val, old_val in enumerate(sorted(np.unique(vol))):\n", 163 | " vol[vol == old_val] = new_val\n", 164 | "\n", 165 | " # get reference \n", 166 | " ref_img = f'./niis/T2SPIR/image_{curr_id}.nii.gz'\n", 167 | " img_o = sitk.ReadImage(ref_img)\n", 168 | " vol_o = nio.np2itk(img=vol, ref_obj=img_o)\n", 169 | " sitk.WriteImage(vol_o, f'{OUT_DIR}/label_{curr_id}.nii.gz')\n", 170 | " print(f'image with id {curr_id} has been saved!')\n", 171 | "\n", 172 | " " 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.0" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class Res101Encoder(nn.Module): 7 | """ 8 | Resnet101 backbone from deeplabv3 9 | modify the 'downsample' component in layer2 and/or layer3 and/or layer4 as the vanilla Resnet 10 | """ 11 | 12 | def __init__(self, replace_stride_with_dilation=None, pretrained_weights='resnet101'): 13 | super().__init__() 14 | # using pretrained model's weights 15 | if pretrained_weights == 'deeplabv3': 16 | self.pretrained_weights = torch.load( 17 | "/home/cs4007/data/zyz/RPT-main/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth", map_location='cpu') 18 | elif pretrained_weights == 'resnet101': 19 | self.pretrained_weights = torch.load("/home/cs4007/data/zyz/RPT-main/checkpoints/resnet101-63fe2227.pth", 20 | map_location='cpu') 21 | else: 22 | self.pretrained_weights = pretrained_weights 23 | 24 | _model = torchvision.models.resnet.resnet101(pretrained=False, 25 | replace_stride_with_dilation=replace_stride_with_dilation) 26 | self.backbone = nn.ModuleDict() 27 | for dic, m in _model.named_children(): 28 | self.backbone[dic] = m 29 | 30 | self.reduce1 = nn.Conv2d(1024, 512, kernel_size=1, bias=False) 31 | self.reduce2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) 32 | self.reduce1d = nn.Linear(in_features=1000, out_features=1, bias=True) 33 | 34 | self._init_weights() 35 | 36 | def forward(self, x): 37 | features = dict() 38 | x = self.backbone["conv1"](x) 39 | x = self.backbone["bn1"](x) 40 | x = self.backbone["relu"](x) 41 | 42 | x = self.backbone["maxpool"](x) 43 | x = self.backbone["layer1"](x) 44 | x = self.backbone["layer2"](x) 45 | x = self.backbone["layer3"](x) 46 | feature = self.reduce1(x) # (2, 512, 64, 64) 47 | x = self.backbone["layer4"](x) 48 | # feature map -> avgpool -> fc -> single value 49 | t = self.backbone["avgpool"](x) 50 | t = torch.flatten(t, 1) 51 | t = self.backbone["fc"](t) 52 | t = self.reduce1d(t) 53 | return (feature, t) 54 | 55 | def _init_weights(self): 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 59 | elif isinstance(m, nn.BatchNorm2d): 60 | nn.init.constant_(m.weight, 1) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | if self.pretrained_weights is not None: 64 | keys = list(self.pretrained_weights.keys()) 65 | new_dic = self.state_dict() 66 | new_keys = list(new_dic.keys()) 67 | 68 | for i in range(len(keys)): 69 | if keys[i] in new_keys: 70 | new_dic[keys[i]] = self.pretrained_weights[keys[i]] 71 | 72 | self.load_state_dict(new_dic) 73 | 74 | 75 | class Res50Encoder(nn.Module): 76 | """ 77 | Resnet50 backbone from deeplabv3 78 | modify the 'downsample' component in layer2 and/or layer3 and/or layer4 as the vanilla Resnet 79 | """ 80 | 81 | def __init__(self, replace_stride_with_dilation=None, pretrained_weights='resnet50'): 82 | super().__init__() 83 | # using pretrained model's weights 84 | if pretrained_weights == 'deeplabv3': 85 | self.pretrained_weights = torch.load( 86 | "/home/cs4007/data/zyz/CDFSMIS/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth", map_location='cpu') # pretrained on COCO 87 | elif pretrained_weights == 'resnet50': 88 | self.pretrained_weights = torch.load("/home/cs4007/data/zyz/CDFSMIS/checkpoints/resnet50-19c8e357.pth", 89 | map_location='cpu') # pretrained on ImageNet 90 | else: 91 | self.pretrained_weights = pretrained_weights 92 | 93 | _model = torchvision.models.resnet.resnet50(pretrained=False, 94 | replace_stride_with_dilation=replace_stride_with_dilation) 95 | self.backbone = nn.ModuleDict() 96 | for dic, m in _model.named_children(): 97 | self.backbone[dic] = m 98 | 99 | self.reduce1 = nn.Conv2d(1024, 512, kernel_size=1, bias=False) 100 | self.reduce2 = nn.Conv2d(2048, 512, kernel_size=1, bias=False) 101 | self.reduce1d = nn.Linear(in_features=1000, out_features=1, bias=True) 102 | 103 | self._init_weights() 104 | 105 | def forward(self, x): 106 | x = self.backbone["conv1"](x) 107 | x = self.backbone["bn1"](x) 108 | x = self.backbone["relu"](x) 109 | 110 | x = self.backbone["maxpool"](x) 111 | x = self.backbone["layer1"](x) 112 | x = self.backbone["layer2"](x) 113 | x = self.backbone["layer3"](x) 114 | feature = self.reduce1(x) # (2, 512, 64, 64) 115 | x = self.backbone["layer4"](x) 116 | # feature map -> avgpool -> fc -> single value 117 | t = self.backbone["avgpool"](x) 118 | t = torch.flatten(t, 1) 119 | t = self.backbone["fc"](t) 120 | t = self.reduce1d(t) 121 | return (feature, t) 122 | 123 | def _init_weights(self): 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 127 | elif isinstance(m, nn.BatchNorm2d): 128 | nn.init.constant_(m.weight, 1) 129 | nn.init.constant_(m.bias, 0) 130 | 131 | if self.pretrained_weights is not None: 132 | keys = list(self.pretrained_weights.keys()) 133 | new_dic = self.state_dict() 134 | new_keys = list(new_dic.keys()) 135 | 136 | for i in range(len(keys)): 137 | if keys[i] in new_keys: 138 | new_dic[keys[i]] = self.pretrained_weights[keys[i]] 139 | 140 | self.load_state_dict(new_dic) 141 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import random 4 | import logging 5 | import shutil 6 | 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torch.utils.data import DataLoader 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from models.fewshot import FewShotSeg 13 | from dataloaders.datasets import TrainDataset as TrainDataset 14 | from utils import * 15 | from config import ex 16 | from losses import * 17 | 18 | 19 | @ex.automain 20 | def main(_run, _config, _log): 21 | if _run.observers: 22 | # Set up source folder 23 | os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) 24 | for source_file, _ in _run.experiment_info['sources']: 25 | os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), 26 | exist_ok=True) 27 | _run.observers[0].save_file(source_file, f'source/{source_file}') 28 | shutil.rmtree(f'{_run.observers[0].basedir}/_sources') 29 | 30 | # Set up logger -> log to .txt 31 | file_handler = logging.FileHandler(os.path.join(f'{_run.observers[0].dir}', f'logger.log')) 32 | file_handler.setLevel('INFO') 33 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s') 34 | file_handler.setFormatter(formatter) 35 | _log.handlers.append(file_handler) 36 | _log.info(f'Run "{_config["exp_str"]}" with ID "{_run.observers[0].dir[-1]}"') 37 | 38 | # Deterministic setting for reproduciablity. 39 | if _config['seed'] is not None: 40 | random.seed(_config['seed']) 41 | torch.manual_seed(_config['seed']) 42 | torch.cuda.manual_seed_all(_config['seed']) 43 | cudnn.deterministic = True 44 | 45 | # Enable cuDNN benchmark mode to select the fastest convolution algorithm. 46 | cudnn.enabled = True 47 | cudnn.benchmark = True 48 | torch.cuda.set_device(device=_config['gpu_id']) 49 | torch.set_num_threads(1) 50 | 51 | _log.info(f'Create model...') 52 | model = FewShotSeg() 53 | model = model.cuda() 54 | model.train() 55 | 56 | _log.info(f'Set optimizer...') 57 | optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) 58 | lr_milestones = [(ii + 1) * _config['max_iters_per_load'] for ii in 59 | range(_config['n_steps'] // _config['max_iters_per_load'] - 1)] 60 | scheduler = MultiStepLR(optimizer, milestones=lr_milestones, gamma=_config['lr_step_gamma']) 61 | 62 | my_weight = torch.FloatTensor([0.1, 1.0]).cuda() 63 | criterion = nn.NLLLoss(ignore_index=255, weight=my_weight) 64 | criterion_bd = BoundaryLoss() 65 | criterion_dice = DiceLoss() 66 | 67 | _log.info(f'Load data...') 68 | data_config = { 69 | 'data_dir': _config['path'][_config['dataset']]['data_dir'], 70 | 'dataset': _config['dataset'], 71 | 'n_shot': _config['n_shot'], 72 | 'n_way': _config['n_way'], 73 | 'n_query': _config['n_query'], 74 | 'n_sv': _config['n_sv'], 75 | 'max_iter': _config['max_iters_per_load'], 76 | 'eval_fold': _config['eval_fold'], 77 | 'min_size': _config['min_size'], 78 | 'max_slices': _config['max_slices'], 79 | 'test_label': _config['test_label'], 80 | 'exclude_label': _config['exclude_label'], 81 | 'use_gt': _config['use_gt'], 82 | } 83 | train_dataset = TrainDataset(data_config) 84 | train_loader = DataLoader(train_dataset, 85 | batch_size=_config['batch_size'], 86 | shuffle=True, 87 | num_workers=_config['num_workers'], 88 | pin_memory=True, 89 | drop_last=True) 90 | 91 | n_sub_epochs = _config['n_steps'] // _config['max_iters_per_load'] # number of times for reloading 92 | log_loss = {'total_loss': 0, 'query_loss': 0, 'align_loss': 0, 'thresh_loss': 0} 93 | 94 | i_iter = 0 95 | _log.info(f'Start training...') 96 | 97 | eta = 1. 98 | for sub_epoch in range(n_sub_epochs): 99 | _log.info(f'This is epoch "{sub_epoch}" of "{n_sub_epochs}" epochs.') 100 | 101 | for _, sample in enumerate(train_loader): 102 | # Prepare episode data. 103 | support_images = [[shot.float().cuda() for shot in way] 104 | for way in sample['support_images']] 105 | support_fg_mask = [[shot.float().cuda() for shot in way] 106 | for way in sample['support_fg_labels']] 107 | 108 | query_images = [query_image.float().cuda() for query_image in sample['query_images']] 109 | query_labels = torch.cat([query_label.long().cuda() for query_label in sample['query_labels']], dim=0) 110 | 111 | # Compute outputs and losses. 112 | query_pred, periphery_loss, align_loss, mse_loss, qry_loss = model(support_images, support_fg_mask, 113 | query_images, query_labels, train=True) 114 | 115 | query_loss = criterion(torch.log(torch.clamp(query_pred, torch.finfo(torch.float32).eps, 116 | 1 - torch.finfo(torch.float32).eps)), query_labels) 117 | 118 | # bd_loss = criterion_bd(query_pred, query_labels) 119 | # dice_loss = criterion_dice(query_pred, query_labels) 120 | 121 | loss = query_loss + 0.1 * periphery_loss + align_loss + 0.1 * mse_loss + qry_loss 122 | 123 | # Compute gradient and do SGD step. 124 | for param in model.parameters(): 125 | param.grad = None 126 | 127 | loss.backward() 128 | optimizer.step() 129 | scheduler.step() 130 | 131 | # Log loss 132 | query_loss = query_loss.detach().data.cpu().numpy() 133 | # align_loss = align_loss.detach().data.cpu().numpy() 134 | 135 | _run.log_scalar('total_loss', loss.item()) 136 | _run.log_scalar('query_loss', query_loss) 137 | 138 | log_loss['total_loss'] += loss.item() 139 | log_loss['query_loss'] += query_loss 140 | 141 | # Print loss and take snapshots. 142 | if (i_iter + 1) % _config['print_interval'] == 0: 143 | total_loss = log_loss['total_loss'] / _config['print_interval'] 144 | query_loss = log_loss['query_loss'] / _config['print_interval'] 145 | 146 | log_loss['total_loss'] = 0 147 | log_loss['query_loss'] = 0 148 | 149 | _log.info(f'step {i_iter + 1}: total_loss: {total_loss}, query_loss: {query_loss},' 150 | ) 151 | 152 | if (i_iter + 1) % _config['save_snapshot_every'] == 0: 153 | _log.info('###### Taking snapshot ######') 154 | torch.save(model.state_dict(), 155 | os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) 156 | 157 | i_iter += 1 158 | 159 | eta = eta - 0.01 160 | 161 | _log.info('End of training.') 162 | return 1 163 | -------------------------------------------------------------------------------- /data/SABS/resampling_and_roi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import SimpleITK as sitk 5 | import sys 6 | import niftiio as nio 7 | 8 | IMG_FOLDER = "./tmp_normalized/" 9 | SEG_FOLDER = IMG_FOLDER 10 | imgs = glob.glob(IMG_FOLDER + "/image_*.nii.gz") 11 | imgs = [ fid for fid in sorted(imgs) ] 12 | segs = [ fid for fid in sorted(glob.glob(SEG_FOLDER + "/label_*.nii.gz")) ] 13 | 14 | pids = [pid.split("_")[-1].split(".")[0] for pid in imgs] 15 | 16 | # helper functions copy pasted 17 | def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True): 18 | resample = sitk.ResampleImageFilter() 19 | resample.SetInterpolator(interpolator) 20 | resample.SetOutputDirection(mov_img_obj.GetDirection()) 21 | resample.SetOutputOrigin(mov_img_obj.GetOrigin()) 22 | mov_spacing = mov_img_obj.GetSpacing() 23 | 24 | resample.SetOutputSpacing(new_spacing) 25 | RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing) 26 | new_size = np.array(mov_img_obj.GetSize()) * RES_COE 27 | 28 | resample.SetSize( [int(sz+1) for sz in new_size] ) 29 | if logging: 30 | print("Spacing: {} -> {}".format(mov_spacing, new_spacing)) 31 | print("Size {} -> {}".format( mov_img_obj.GetSize(), new_size )) 32 | 33 | return resample.Execute(mov_img_obj) 34 | 35 | def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True): 36 | src_mat = sitk.GetArrayFromImage(mov_lb_obj) 37 | lbvs = np.unique(src_mat) 38 | if logging: 39 | print("Label values: {}".format(lbvs)) 40 | for idx, lbv in enumerate(lbvs): 41 | _src_curr_mat = np.float32(src_mat == lbv) 42 | _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat) 43 | _src_curr_obj.CopyInformation(mov_lb_obj) 44 | _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging ) 45 | _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv 46 | if idx == 0: 47 | out_vol = _tar_curr_mat 48 | else: 49 | out_vol[_tar_curr_mat == lbv] = lbv 50 | out_obj = sitk.GetImageFromArray(out_vol) 51 | out_obj.SetSpacing( _tar_curr_obj.GetSpacing() ) 52 | if ref_img != None: 53 | out_obj.CopyInformation(ref_img) 54 | return out_obj 55 | 56 | ## Then crop ROI 57 | def get_label_center(label): 58 | nnz = np.sum(label > 1e-5) 59 | return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz)) 60 | 61 | def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True): 62 | """ crop a 3d matrix given the index of the new volume on the original volume 63 | Args: 64 | refernce_ctr_idx: the center of the new volume on the original volume (in indices) 65 | only_2d: only do cropping on first two dimensions 66 | """ 67 | _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case 68 | if only_2d: 69 | assert len(crop_size) == 2, "Actual len {}".format(len(crop_size)) 70 | assert len(referece_ctr_idx) == 2, "Actual len {}".format(len(referece_ctr_idx)) 71 | _expand_cropsize.append(ori_vol.shape[-1]) 72 | 73 | image_patch = np.ones(tuple(_expand_cropsize)) * padval 74 | 75 | half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] ) 76 | _min_idx = [0,0,0] 77 | _max_idx = list(ori_vol.shape) 78 | 79 | # bias of actual cropped size to the beginning and the end of this volume 80 | _bias_start = [0,0,0] 81 | _bias_end = [0,0,0] 82 | 83 | for dim,hsize in enumerate(half_size): 84 | if dim == 2 and only_2d: 85 | break 86 | 87 | _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]]) 88 | _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]]) 89 | 90 | _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim] 91 | _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim] 92 | 93 | if only_2d: 94 | image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \ 95 | half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \ 96 | ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \ 97 | referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ] 98 | 99 | image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ] 100 | # then goes back to original volume 101 | else: 102 | image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \ 103 | half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \ 104 | half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \ 105 | ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \ 106 | referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \ 107 | referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ] 108 | 109 | image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ] 110 | return image_patch 111 | 112 | 113 | 114 | def copy_spacing_ori(src, dst): 115 | dst.SetSpacing(src.GetSpacing()) 116 | dst.SetOrigin(src.GetOrigin()) 117 | dst.SetDirection(src.GetDirection()) 118 | return dst 119 | 120 | import copy 121 | OUT_FOLDER = "./sabs_CT_normalized" 122 | scan_dir = OUT_FOLDER 123 | os.makedirs(scan_dir, exist_ok = True) 124 | BD_BIAS = 32 # cut irrelavent empty boundary to make roi stands out 125 | 126 | SPA_FAC = (512 - 2 * BD_BIAS) / 256 # spacing factor 127 | 128 | for img_fid, seg_fid, pid in zip(imgs, segs, pids): 129 | 130 | lb_n = nio.read_nii_bysitk(seg_fid) 131 | 132 | img_obj = sitk.ReadImage( img_fid ) 133 | seg_obj = sitk.ReadImage( seg_fid ) 134 | 135 | ## image 136 | array = sitk.GetArrayFromImage(img_obj) 137 | # cropping 138 | array = array[:, BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS] 139 | cropped_img_o = sitk.GetImageFromArray(array) 140 | cropped_img_o = copy_spacing_ori(img_obj, cropped_img_o) 141 | 142 | # resampling 143 | img_spa_ori = img_obj.GetSpacing() 144 | res_img_o = resample_by_res(cropped_img_o, [img_spa_ori[0] * SPA_FAC, img_spa_ori[1] * SPA_FAC, img_spa_ori[-1]], interpolator = sitk.sitkLinear, 145 | logging = True) 146 | 147 | ## label 148 | lb_arr = sitk.GetArrayFromImage(seg_obj) 149 | 150 | # cropping 151 | lb_arr = lb_arr[:,BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS] 152 | cropped_lb_o = sitk.GetImageFromArray(lb_arr) 153 | cropped_lb_o = copy_spacing_ori(seg_obj, cropped_lb_o) 154 | 155 | lb_spa_ori = seg_obj.GetSpacing() 156 | 157 | # resampling 158 | res_lb_o = resample_lb_by_res(cropped_lb_o, [lb_spa_ori[0] * SPA_FAC, lb_spa_ori[1] * SPA_FAC, lb_spa_ori[-1] ], interpolator = sitk.sitkLinear, 159 | ref_img = res_img_o, logging = True) 160 | 161 | 162 | out_img_fid = os.path.join( scan_dir, f'image_{pid}.nii.gz' ) 163 | out_lb_fid = os.path.join( scan_dir, f'label_{pid}.nii.gz' ) 164 | 165 | # then save 166 | sitk.WriteImage(res_img_o, out_img_fid, True) 167 | sitk.WriteImage(res_lb_o, out_lb_fid, True) 168 | print("{} has been saved".format(out_img_fid)) 169 | print("{} has been saved".format(out_lb_fid)) 170 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | For evaluation 4 | """ 5 | import shutil 6 | import SimpleITK as sitk 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | from torch.utils.data import DataLoader 10 | from models.fewshot import FewShotSeg 11 | from dataloaders.datasets import TestDataset 12 | from dataloaders.dataset_specifics import * 13 | from utils import * 14 | from config import ex 15 | 16 | 17 | @ex.automain 18 | def main(_run, _config, _log): 19 | if _run.observers: 20 | os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) 21 | for source_file, _ in _run.experiment_info['sources']: 22 | os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), 23 | exist_ok=True) 24 | _run.observers[0].save_file(source_file, f'source/{source_file}') 25 | shutil.rmtree(f'{_run.observers[0].basedir}/_sources') 26 | 27 | # Set up logger -> log to .txt 28 | file_handler = logging.FileHandler(os.path.join(f'{_run.observers[0].dir}', f'logger.log')) 29 | file_handler.setLevel('INFO') 30 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s') 31 | file_handler.setFormatter(formatter) 32 | _log.handlers.append(file_handler) 33 | _log.info(f'Run "{_config["exp_str"]}" with ID "{_run.observers[0].dir[-1]}"') 34 | 35 | # Deterministic setting for reproduciablity. 36 | if _config['seed'] is not None: 37 | random.seed(_config['seed']) 38 | torch.manual_seed(_config['seed']) 39 | torch.cuda.manual_seed_all(_config['seed']) 40 | cudnn.deterministic = True 41 | 42 | # Enable cuDNN benchmark mode to select the fastest convolution algorithm. 43 | cudnn.enabled = True 44 | cudnn.benchmark = True 45 | torch.cuda.set_device(device=_config['gpu_id']) 46 | torch.set_num_threads(1) 47 | 48 | _log.info(f'Create model...') 49 | model = FewShotSeg() 50 | model.cuda() 51 | model.load_state_dict(torch.load(_config['reload_model_path'], map_location='cpu')) 52 | 53 | _log.info(f'Load data...') 54 | data_config = { 55 | 'data_dir': _config['path'][_config['dataset']]['data_dir'], 56 | 'dataset': _config['dataset'], 57 | 'n_shot': _config['n_shot'], 58 | 'n_way': _config['n_way'], 59 | 'n_query': _config['n_query'], 60 | 'n_sv': _config['n_sv'], 61 | 'max_iter': _config['max_iters_per_load'], 62 | 'eval_fold': _config['eval_fold'], 63 | 'min_size': _config['min_size'], 64 | 'max_slices': _config['max_slices'], 65 | 'supp_idx': _config['supp_idx'], 66 | } 67 | test_dataset = TestDataset(data_config) 68 | test_loader = DataLoader(test_dataset, 69 | batch_size=_config['batch_size'], 70 | shuffle=False, 71 | num_workers=_config['num_workers'], 72 | pin_memory=True, 73 | drop_last=True) 74 | 75 | # Get unique labels (classes). 76 | labels = get_label_names(_config['dataset']) 77 | 78 | # Loop over classes. 79 | class_dice = {} 80 | class_iou = {} 81 | 82 | _log.info(f'Starting validation...') 83 | for label_val, label_name in labels.items(): 84 | 85 | # Skip BG class. 86 | if label_name == 'BG': 87 | continue 88 | elif (not np.intersect1d([label_val], _config['test_label'])): 89 | continue 90 | 91 | _log.info(f'Test Class: {label_name}') 92 | 93 | # Get support sample + mask for current class. 94 | support_sample = test_dataset.getSupport(label=label_val, all_slices=False, N=_config['n_part']) 95 | 96 | test_dataset.label = label_val 97 | 98 | # Test. 99 | with torch.no_grad(): 100 | model.eval() 101 | 102 | # Unpack support data. 103 | support_image = [support_sample['image'][[i]].float().cuda() for i in 104 | range(support_sample['image'].shape[0])] # n_shot x 3 x H x W, support_image is a list {3X(1, 3, 256, 256)} 105 | support_fg_mask = [support_sample['label'][[i]].float().cuda() for i in 106 | range(support_sample['image'].shape[0])] # n_shot x H x W 107 | 108 | # Loop through query volumes. 109 | scores = Scores() 110 | for i, sample in enumerate(test_loader): # this "for" loops 4 times 111 | 112 | # Unpack query data. 113 | query_image = [sample['image'][i].float().cuda() for i in 114 | range(sample['image'].shape[0])] # [C x 3 x H x W] query_image is list {(C x 3 x H x W)} 115 | query_label = sample['label'].long() # C x H x W 116 | query_id = sample['id'][0].split('image_')[1][:-len('.nii.gz')] 117 | 118 | # Compute output. 119 | # Match support slice and query sub-chunck. 120 | query_pred = torch.zeros(query_label.shape[-3:]) 121 | C_q = sample['image'].shape[1] # slice number of query img 122 | 123 | idx_ = np.linspace(0, C_q, _config['n_part'] + 1).astype('int') 124 | for sub_chunck in range(_config['n_part']): # n_part = 3 125 | support_image_s = [support_image[sub_chunck]] # 1 x 3 x H x W 126 | support_fg_mask_s = [support_fg_mask[sub_chunck]] # 1 x H x W 127 | query_image_s = query_image[0][idx_[sub_chunck]:idx_[sub_chunck + 1]] # C' x 3 x H x W 128 | query_pred_s = [] 129 | for i in range(query_image_s.shape[0]): 130 | _pred_s, _, _, _, _ = model([support_image_s], [support_fg_mask_s], [query_image_s[[i]]], _, train=False) # 1 x 2 x H x W 131 | query_pred_s.append(_pred_s) 132 | query_pred_s = torch.cat(query_pred_s, dim=0) 133 | query_pred_s = query_pred_s.argmax(dim=1).cpu() # C x H x W 134 | query_pred[idx_[sub_chunck]:idx_[sub_chunck + 1]] = query_pred_s 135 | 136 | # Record scores. 137 | scores.record(query_pred, query_label) 138 | 139 | # Log. 140 | _log.info( 141 | f'Tested query volume: {sample["id"][0][len(_config["path"][_config["dataset"]]["data_dir"]):]}.') 142 | _log.info(f'Dice score: {scores.patient_dice[-1].item()}') 143 | 144 | # Save predictions. 145 | file_name = os.path.join(f'{_run.observers[0].dir}/interm_preds', 146 | f'prediction_{query_id}_{label_name}.nii.gz') 147 | itk_pred = sitk.GetImageFromArray(query_pred) 148 | sitk.WriteImage(itk_pred, file_name, True) 149 | _log.info(f'{query_id} has been saved. ') 150 | 151 | # Log class-wise results 152 | class_dice[label_name] = torch.tensor(scores.patient_dice).mean().item() 153 | class_iou[label_name] = torch.tensor(scores.patient_iou).mean().item() 154 | _log.info(f'Test Class: {label_name}') 155 | _log.info(f'Mean class IoU: {class_iou[label_name]}') 156 | _log.info(f'Mean class Dice: {class_dice[label_name]}') 157 | 158 | _log.info(f'Final results...') 159 | _log.info(f'Mean IoU: {class_iou}') 160 | _log.info(f'Mean Dice: {class_dice}') 161 | 162 | def dict_Avg(Dict): 163 | L = len(Dict) # 取字典中键值对的个数 164 | S = sum(Dict.values()) # 取字典中键对应值的总和 165 | A = S / L 166 | return A 167 | 168 | value = dict_Avg(class_dice) 169 | with open('results.txt', 'w') as file: 170 | file.write(str(value)) 171 | 172 | _log.info(f'Whole mean Dice: {dict_Avg(class_dice)}') 173 | _log.info(f'End of validation.') 174 | return 1 175 | -------------------------------------------------------------------------------- /data/CHAOST2/class_slice_index_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate class-slice indexing table for experiments\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is for experiment setting up for simulating few-shot image segmentation scenarios\n", 13 | "\n", 14 | "Input: pre-processed images and their ground-truth labels\n", 15 | "\n", 16 | "Output: a `json` file for class-slice indexing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import numpy as np\n", 39 | "import os\n", 40 | "import glob\n", 41 | "import SimpleITK as sitk\n", 42 | "import sys\n", 43 | "import json\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 9, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "IMG_BNAME=\"./chaos_MR_T2_normalized/image_*.nii.gz\"\n", 55 | "SEG_BNAME=\"./chaos_MR_T2_normalized/label_*.nii.gz\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "imgs = glob.glob(IMG_BNAME)\n", 65 | "segs = glob.glob(SEG_BNAME)\n", 66 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 67 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "['./chaos_MR_T2_normalized/image_1.nii.gz',\n", 79 | " './chaos_MR_T2_normalized/image_2.nii.gz',\n", 80 | " './chaos_MR_T2_normalized/image_3.nii.gz',\n", 81 | " './chaos_MR_T2_normalized/image_5.nii.gz',\n", 82 | " './chaos_MR_T2_normalized/image_8.nii.gz',\n", 83 | " './chaos_MR_T2_normalized/image_10.nii.gz',\n", 84 | " './chaos_MR_T2_normalized/image_13.nii.gz',\n", 85 | " './chaos_MR_T2_normalized/image_15.nii.gz',\n", 86 | " './chaos_MR_T2_normalized/image_19.nii.gz',\n", 87 | " './chaos_MR_T2_normalized/image_20.nii.gz',\n", 88 | " './chaos_MR_T2_normalized/image_21.nii.gz',\n", 89 | " './chaos_MR_T2_normalized/image_22.nii.gz',\n", 90 | " './chaos_MR_T2_normalized/image_31.nii.gz',\n", 91 | " './chaos_MR_T2_normalized/image_32.nii.gz',\n", 92 | " './chaos_MR_T2_normalized/image_33.nii.gz',\n", 93 | " './chaos_MR_T2_normalized/image_34.nii.gz',\n", 94 | " './chaos_MR_T2_normalized/image_36.nii.gz',\n", 95 | " './chaos_MR_T2_normalized/image_37.nii.gz',\n", 96 | " './chaos_MR_T2_normalized/image_38.nii.gz',\n", 97 | " './chaos_MR_T2_normalized/image_39.nii.gz']" 98 | ] 99 | }, 100 | "execution_count": 11, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "imgs" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 12, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "['./chaos_MR_T2_normalized/label_1.nii.gz',\n", 118 | " './chaos_MR_T2_normalized/label_2.nii.gz',\n", 119 | " './chaos_MR_T2_normalized/label_3.nii.gz',\n", 120 | " './chaos_MR_T2_normalized/label_5.nii.gz',\n", 121 | " './chaos_MR_T2_normalized/label_8.nii.gz',\n", 122 | " './chaos_MR_T2_normalized/label_10.nii.gz',\n", 123 | " './chaos_MR_T2_normalized/label_13.nii.gz',\n", 124 | " './chaos_MR_T2_normalized/label_15.nii.gz',\n", 125 | " './chaos_MR_T2_normalized/label_19.nii.gz',\n", 126 | " './chaos_MR_T2_normalized/label_20.nii.gz',\n", 127 | " './chaos_MR_T2_normalized/label_21.nii.gz',\n", 128 | " './chaos_MR_T2_normalized/label_22.nii.gz',\n", 129 | " './chaos_MR_T2_normalized/label_31.nii.gz',\n", 130 | " './chaos_MR_T2_normalized/label_32.nii.gz',\n", 131 | " './chaos_MR_T2_normalized/label_33.nii.gz',\n", 132 | " './chaos_MR_T2_normalized/label_34.nii.gz',\n", 133 | " './chaos_MR_T2_normalized/label_36.nii.gz',\n", 134 | " './chaos_MR_T2_normalized/label_37.nii.gz',\n", 135 | " './chaos_MR_T2_normalized/label_38.nii.gz',\n", 136 | " './chaos_MR_T2_normalized/label_39.nii.gz']" 137 | ] 138 | }, 139 | "execution_count": 12, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "segs" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 13, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "pid 1 finished!\n", 158 | "pid 2 finished!\n", 159 | "pid 3 finished!\n", 160 | "pid 5 finished!\n", 161 | "pid 8 finished!\n", 162 | "pid 10 finished!\n", 163 | "pid 13 finished!\n", 164 | "pid 15 finished!\n", 165 | "pid 19 finished!\n", 166 | "pid 20 finished!\n", 167 | "pid 21 finished!\n", 168 | "pid 22 finished!\n", 169 | "pid 31 finished!\n", 170 | "pid 32 finished!\n", 171 | "pid 33 finished!\n", 172 | "pid 34 finished!\n", 173 | "pid 36 finished!\n", 174 | "pid 37 finished!\n", 175 | "pid 38 finished!\n", 176 | "pid 39 finished!\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "classmap = {}\n", 182 | "LABEL_NAME = [\"BG\", \"LIVER\", \"RK\", \"LK\", \"SPLEEN\"] \n", 183 | "\n", 184 | "\n", 185 | "MIN_TP = 100 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", 186 | "\n", 187 | "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n", 188 | "for _lb in LABEL_NAME:\n", 189 | " classmap[_lb] = {}\n", 190 | " for _sid in segs:\n", 191 | " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 192 | " classmap[_lb][pid] = []\n", 193 | "\n", 194 | "for seg in segs:\n", 195 | " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 196 | " lb_vol = nio.read_nii_bysitk(seg)\n", 197 | " n_slice = lb_vol.shape[0]\n", 198 | " for slc in range(n_slice):\n", 199 | " for cls in range(len(LABEL_NAME)):\n", 200 | " if cls in lb_vol[slc, ...]:\n", 201 | " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", 202 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 203 | " print(f'pid {str(pid)} finished!')\n", 204 | " \n", 205 | "with open(fid, 'w') as fopen:\n", 206 | " json.dump(classmap, fopen)\n", 207 | " fopen.close() \n", 208 | " " 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 14, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "with open(fid, 'w') as fopen:\n", 218 | " json.dump(classmap, fopen)\n", 219 | " fopen.close()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3 (ipykernel)", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.8.12" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /data/supervoxels/felzenszwalb_3d_cy.pyx: -------------------------------------------------------------------------------- 1 | #cython: cdivision=True 2 | #cython: boundscheck=False 3 | #cython: nonecheck=False 4 | #cython: wraparound=False 5 | import numpy as np 6 | from scipy import ndimage as ndi 7 | 8 | cimport numpy as cnp 9 | from _ccomp cimport find_root, join_trees 10 | 11 | from skimage.util import img_as_float64 12 | from skimage._shared.utils import warn 13 | 14 | cnp.import_array() 15 | 16 | 17 | def felzenszwalb_cython_3d(image, double scale=1, sigma=0.8, Py_ssize_t min_size=20, spacing=(1,1,1)): 18 | """ 19 | Code modified from: https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.felzenszwalb 20 | 21 | Felzenszwalb's efficient graph based segmentation for 22 | single or multiple channels. 23 | 24 | Produces an oversegmentation of a single or multi-channel image 25 | using a fast, minimum spanning tree based clustering on the image grid. 26 | The number of produced segments as well as their size can only be 27 | controlled indirectly through ``scale``. Segment size within an image can 28 | vary greatly depending on local contrast. 29 | 30 | Parameters 31 | ---------- 32 | image : (N, M, C) ndarray 33 | Input image. 34 | scale : float, optional (default 1) 35 | Sets the obervation level. Higher means larger clusters. 36 | sigma : float, optional (default 0.8) 37 | Width of Gaussian smoothing kernel used in preprocessing. 38 | Larger sigma gives smother segment boundaries. 39 | min_size : int, optional (default 20) 40 | Minimum component size. Enforced using postprocessing. 41 | 42 | Returns 43 | ------- 44 | segment_mask : (N, M) ndarray 45 | Integer mask indicating segment labels. 46 | """ 47 | 48 | 49 | image = img_as_float64(image) 50 | dtype = image.dtype 51 | 52 | # rescale scale to behave like in reference implementation 53 | scale = float(scale) / 255. 54 | 55 | spacing = np.ascontiguousarray(spacing, dtype=dtype) 56 | sigma = np.array([sigma, sigma, sigma], dtype=dtype) 57 | sigma /= spacing.astype(dtype) 58 | 59 | image = ndi.gaussian_filter(image, sigma=sigma) 60 | height, width, depth = image.shape # depth, height, width! 61 | image = image[..., None] 62 | 63 | # assuming spacing is equal in xy dir. 64 | s = spacing[0]/spacing[1] 65 | w1 = 1.0 # x, y, xy 66 | w2 = s**2 # z 67 | w3 = (np.sqrt(1+s**2)/np.sqrt(2))**2 # zx, zy 68 | w4 = (np.sqrt(2 + s**2)/np.sqrt(3))**2 # zxy 69 | 70 | 71 | cost1 = np.sqrt(w1 * np.sum((image[:, 1:, :] - image[:, :width-1, :])**2, axis=-1)) # x 72 | cost2 = np.sqrt(w1 * np.sum((image[:, 1:, 1:] - image[:, :width-1, :depth-1])**2, axis=-1)) # xy 73 | cost3 = np.sqrt(w1 * np.sum((image[:, :, 1:] - image[:, :, :depth-1])**2, axis=-1)) # y 74 | cost7 = np.sqrt(w1 * np.sum((image[:, 1:, :depth-1] - image[:, :width-1, 1:])**2, axis=-1)) # xy 75 | cost9 = np.sqrt(w3 * np.sum((image[1:, 1:, :] - image[:height-1, :width-1, :])**2, axis=-1)) # zx 76 | cost10 = np.sqrt(w4 * np.sum((image[1:, 1:, 1:] - image[:height-1, :width-1, :depth-1])**2, axis=-1)) # zxy 77 | cost11 = np.sqrt(w3 * np.sum((image[1:, :, 1:] - image[:height-1, :, :depth-1])**2, axis=-1)) # zy 78 | cost12 = np.sqrt(w3 * np.sum((image[1:, :width-1, :] - image[:height-1, 1:, :])**2, axis=-1)) # zx 79 | cost13 = np.sqrt(w4 * np.sum((image[1:, :width-1, :depth-1] - image[:height-1, 1:, 1:])**2, axis=-1)) # zxy 80 | cost14 = np.sqrt(w3 * np.sum((image[1:, :, :depth-1] - image[:height-1, :, 1:])**2, axis=-1)) # zy 81 | cost15 = np.sqrt(w4 * np.sum((image[1:, 1:, :depth-1] - image[:height-1, :width-1, 1:])**2, axis=-1)) # zxy 82 | cost16 = np.sqrt(w4 * np.sum((image[1:, :width-1, 1:] - image[:height-1, 1:, :depth-1])**2, axis=-1)) # zxy 83 | cost25 = np.sqrt(w2 * np.sum((image[1:, :, :] - image[:height-1, :, :])**2, axis=-1)) # z 84 | 85 | 86 | cdef cnp.ndarray[cnp.float_t, ndim=1] costs = np.hstack([cost1.ravel(), cost2.ravel(), cost3.ravel(), cost7.ravel(), cost9.ravel(), cost10.ravel(), cost11.ravel(), cost12.ravel(), cost13.ravel(), cost14.ravel(), cost15.ravel(), cost16.ravel(), cost25.ravel()]).astype(float) 87 | 88 | # compute edges between pixels: 89 | cdef cnp.ndarray[cnp.intp_t, ndim=3] segments \ 90 | = np.arange(width * height * depth, dtype=np.intp).reshape(height, width, depth) 91 | 92 | 93 | edges1 = np.c_[segments[:, 1:, :].ravel(), segments[:, :width-1, :].ravel()] 94 | edges2 = np.c_[segments[:, 1:, 1:].ravel(), segments[:, :width-1, :depth-1].ravel()] 95 | edges3 = np.c_[segments[:, :, 1:].ravel(), segments[:, :, :depth-1].ravel()] 96 | edges7 = np.c_[segments[:, 1:, :depth-1].ravel(), segments[:, :width-1, 1:].ravel()] 97 | edges9 = np.c_[segments[1:, 1:, :].ravel(), segments[:height-1, :width-1, :].ravel()] 98 | edges10 = np.c_[segments[1:, 1:, 1:].ravel(), segments[:height-1, :width-1, :depth-1].ravel()] 99 | edges11 = np.c_[segments[1:, :, 1:].ravel(), segments[:height-1, :, :depth-1].ravel()] 100 | edges12 = np.c_[segments[1:, :width-1, :].ravel(), segments[:height-1, 1:, :].ravel()] 101 | edges13 = np.c_[segments[1:, :width-1, :depth-1].ravel(), segments[:height-1, 1:, 1:].ravel()] 102 | edges14 = np.c_[segments[1:, :, :depth-1].ravel(), segments[:height-1, :, 1:].ravel()] 103 | edges15 = np.c_[segments[1:, 1:, :depth-1].ravel(), segments[:height-1, :width-1, 1:].ravel()] 104 | edges16 = np.c_[segments[1:, :width-1, 1:].ravel(), segments[:height-1, 1:, :depth-1].ravel()] 105 | edges25 = np.c_[segments[1:, :, :].ravel(), segments[:height-1, :, :].ravel()] 106 | 107 | cdef cnp.ndarray[cnp.intp_t, ndim=2] edges \ 108 | = np.vstack([edges1, edges2, edges3, edges7, edges9, edges10, edges11, edges12, edges13, edges14, edges15, edges16, edges25]) 109 | 110 | # initialize data structures for segment size 111 | # and inner cost, then start greedy iteration over edges. 112 | edge_queue = np.argsort(costs) 113 | edges = np.ascontiguousarray(edges[edge_queue]) 114 | costs = np.ascontiguousarray(costs[edge_queue]) 115 | cdef cnp.intp_t *segments_p = segments.data 116 | cdef cnp.intp_t *edges_p = edges.data 117 | cdef cnp.float_t *costs_p = costs.data 118 | cdef cnp.ndarray[cnp.intp_t, ndim=1] segment_size \ 119 | = np.ones(width * height * depth, dtype=np.intp) 120 | 121 | # inner cost of segments 122 | cdef cnp.ndarray[cnp.float_t, ndim=1] cint = np.zeros(width * height * depth) 123 | cdef cnp.intp_t seg0, seg1, seg_new, e 124 | cdef float cost, inner_cost0, inner_cost1 125 | cdef Py_ssize_t num_costs = costs.size 126 | 127 | with nogil: 128 | # set costs_p back one. we increase it before we use it 129 | # since we might continue before that. 130 | costs_p -= 1 131 | for e in range(num_costs): 132 | seg0 = find_root(segments_p, edges_p[0]) 133 | seg1 = find_root(segments_p, edges_p[1]) 134 | 135 | edges_p += 2 136 | costs_p += 1 137 | if seg0 == seg1: 138 | continue 139 | 140 | 141 | inner_cost0 = cint[seg0] + scale / segment_size[seg0] 142 | inner_cost1 = cint[seg1] + scale / segment_size[seg1] 143 | 144 | # return 0 # ok 145 | 146 | if costs_p[0] < min(inner_cost0, inner_cost1): 147 | # update size and cost 148 | 149 | join_trees(segments_p, seg0, seg1) # TODO: not ok! 150 | #return 0 # not ok!! 151 | seg_new = find_root(segments_p, seg0) 152 | segment_size[seg_new] = segment_size[seg0] + segment_size[seg1] 153 | cint[seg_new] = costs_p[0] 154 | 155 | 156 | # postprocessing to remove small segments 157 | edges_p = edges.data 158 | for e in range(num_costs): 159 | seg0 = find_root(segments_p, edges_p[0]) 160 | seg1 = find_root(segments_p, edges_p[1]) 161 | edges_p += 2 162 | if seg0 == seg1: 163 | continue 164 | if segment_size[seg0] < min_size or segment_size[seg1] < min_size: 165 | join_trees(segments_p, seg0, seg1) 166 | seg_new = find_root(segments_p, seg0) 167 | segment_size[seg_new] = segment_size[seg0] + segment_size[seg1] 168 | 169 | 170 | 171 | # unravel the union find tree 172 | flat = segments.ravel() 173 | old = np.zeros_like(flat) 174 | while (old != flat).any(): 175 | old = flat 176 | flat = flat[flat] 177 | flat = np.unique(flat, return_inverse=True)[1] 178 | return flat.reshape((height, width, depth)) -------------------------------------------------------------------------------- /dataloaders/image_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image Transformation 3 | Code originally from Ouyang et al. (used in the 2D setting) 4 | """ 5 | 6 | from collections import Sequence 7 | import cv2 8 | import numpy as np 9 | import scipy 10 | from scipy.ndimage.filters import gaussian_filter 11 | from scipy.ndimage.interpolation import map_coordinates 12 | from numpy.lib.stride_tricks import as_strided 13 | 14 | 15 | ###### UTILITIES ###### 16 | def random_num_generator(config, random_state=np.random): 17 | if config[0] == 'uniform': 18 | ret = random_state.uniform(config[1], config[2], 1)[0] 19 | elif config[0] == 'lognormal': 20 | ret = random_state.lognormal(config[1], config[2], 1)[0] 21 | else: 22 | # print(config) 23 | raise Exception('unsupported format') 24 | return ret 25 | 26 | 27 | def get_translation_matrix(translation): 28 | """ translation: [tx, ty] """ 29 | tx, ty = translation 30 | translation_matrix = np.array([[1, 0, tx], 31 | [0, 1, ty], 32 | [0, 0, 1]]) 33 | return translation_matrix 34 | 35 | 36 | def get_rotation_matrix(rotation, input_shape, centred=True): 37 | theta = np.pi / 180 * np.array(rotation) 38 | if centred: 39 | rotation_matrix = cv2.getRotationMatrix2D((input_shape[0] / 2, input_shape[1] // 2), rotation, 1) 40 | rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]]) 41 | else: 42 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 43 | [np.sin(theta), np.cos(theta), 0], 44 | [0, 0, 1]]) 45 | return rotation_matrix 46 | 47 | 48 | def get_zoom_matrix(zoom, input_shape, centred=True): 49 | zx, zy = zoom 50 | if centred: 51 | zoom_matrix = cv2.getRotationMatrix2D((input_shape[0] / 2, input_shape[1] // 2), 0, zoom[0]) 52 | zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]]) 53 | else: 54 | zoom_matrix = np.array([[zx, 0, 0], 55 | [0, zy, 0], 56 | [0, 0, 1]]) 57 | return zoom_matrix 58 | 59 | 60 | def get_shear_matrix(shear_angle): 61 | theta = (np.pi * shear_angle) / 180 62 | shear_matrix = np.array([[1, -np.sin(theta), 0], 63 | [0, np.cos(theta), 0], 64 | [0, 0, 1]]) 65 | return shear_matrix 66 | 67 | 68 | ###### AFFINE TRANSFORM ###### 69 | class RandomAffine(object): 70 | """Apply random affine transformation on a numpy.ndarray (H x W x C) 71 | Comment by co1818: this is still doing affine on 2d (H x W plane). 72 | A same transform is applied to all C channels 73 | 74 | Parameter: 75 | ---------- 76 | 77 | alpha: Range [0, 4] seems good for small images 78 | 79 | order: interpolation method (c.f. opencv) 80 | """ 81 | 82 | def __init__(self, 83 | rotation_range=None, 84 | translation_range=None, 85 | shear_range=None, 86 | zoom_range=None, 87 | zoom_keep_aspect=False, 88 | interp='bilinear', 89 | order=3): 90 | """ 91 | Perform an affine transforms. 92 | 93 | Arguments 94 | --------- 95 | rotation_range : one integer or float 96 | image will be rotated randomly between (-degrees, degrees) 97 | 98 | translation_range : (x_shift, y_shift) 99 | shifts in pixels 100 | 101 | *NOT TESTED* shear_range : float 102 | image will be sheared randomly between (-degrees, degrees) 103 | 104 | zoom_range : (zoom_min, zoom_max) 105 | list/tuple with two floats between [0, infinity). 106 | first float should be less than the second 107 | lower and upper bounds on percent zoom. 108 | Anything less than 1.0 will zoom in on the image, 109 | anything greater than 1.0 will zoom out on the image. 110 | e.g. (0.7, 1.0) will only zoom in, 111 | (1.0, 1.4) will only zoom out, 112 | (0.7, 1.4) will randomly zoom in or out 113 | """ 114 | 115 | self.rotation_range = rotation_range 116 | self.translation_range = translation_range 117 | self.shear_range = shear_range 118 | self.zoom_range = zoom_range 119 | self.zoom_keep_aspect = zoom_keep_aspect 120 | self.interp = interp 121 | self.order = order 122 | 123 | def build_M(self, input_shape): 124 | tfx = [] 125 | final_tfx = np.eye(3) 126 | if self.rotation_range: 127 | rot = np.random.uniform(-self.rotation_range, self.rotation_range) 128 | tfx.append(get_rotation_matrix(rot, input_shape)) 129 | if self.translation_range: 130 | tx = np.random.uniform(-self.translation_range[0], self.translation_range[0]) 131 | ty = np.random.uniform(-self.translation_range[1], self.translation_range[1]) 132 | tfx.append(get_translation_matrix((tx, ty))) 133 | if self.shear_range: 134 | rot = np.random.uniform(-self.shear_range, self.shear_range) 135 | tfx.append(get_shear_matrix(rot)) 136 | if self.zoom_range: 137 | sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) 138 | if self.zoom_keep_aspect: 139 | sy = sx 140 | else: 141 | sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) 142 | 143 | tfx.append(get_zoom_matrix((sx, sy), input_shape)) 144 | 145 | for tfx_mat in tfx: 146 | final_tfx = np.dot(tfx_mat, final_tfx) 147 | 148 | return final_tfx.astype(np.float32) 149 | 150 | def __call__(self, image): 151 | # build matrix 152 | input_shape = image.shape[:2] 153 | M = self.build_M(input_shape) 154 | 155 | res = np.zeros_like(image) 156 | # if isinstance(self.interp, Sequence): 157 | if type(self.order) is list or type(self.order) is tuple: 158 | for i, intp in enumerate(self.order): 159 | res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp) 160 | else: 161 | # squeeze if needed 162 | orig_shape = image.shape 163 | image_s = np.squeeze(image) 164 | res = affine_transform_via_M(image_s, M[:2], interp=self.order) 165 | res = res.reshape(orig_shape) 166 | 167 | # res = affine_transform_via_M(image, M[:2], interp=self.order) 168 | 169 | return res 170 | 171 | 172 | def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST): 173 | imshape = image.shape 174 | shape_size = imshape[:2] 175 | 176 | # Random affine 177 | warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1], 178 | flags=interp, borderMode=borderMode) 179 | 180 | # print(imshape, warped.shape) 181 | 182 | warped = warped[..., np.newaxis].reshape(imshape) 183 | 184 | return warped 185 | 186 | 187 | ###### ELASTIC TRANSFORM ###### 188 | def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): 189 | """Elastic deformation of image as described in [Simard2003]_. 190 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 191 | Convolutional Neural Networks applied to Visual Document Analysis", in 192 | Proc. of the International Conference on Document Analysis and 193 | Recognition, 2003. 194 | """ 195 | assert image.ndim == 3 196 | shape = image.shape[:2] 197 | 198 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), 199 | sigma, mode="constant", cval=0) * alpha 200 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), 201 | sigma, mode="constant", cval=0) * alpha 202 | 203 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 204 | indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] 205 | result = np.empty_like(image) 206 | for i in range(image.shape[2]): 207 | result[:, :, i] = map_coordinates( 208 | image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) 209 | return result 210 | 211 | 212 | def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False): 213 | """Expects data to be (nx, ny, n1 ,..., nm) 214 | params: 215 | ------ 216 | 217 | alpha: 218 | the scaling parameter. 219 | E.g.: alpha=2 => distorts images up to 2x scaling 220 | 221 | sigma: 222 | standard deviation of gaussian filter. 223 | E.g. 224 | low (sig~=1e-3) => no smoothing, pixelated. 225 | high (1/5 * imsize) => smooth, more like affine. 226 | very high (1/2*im_size) => translation 227 | """ 228 | 229 | if random_state is None: 230 | random_state = np.random.RandomState(None) 231 | 232 | shape = image.shape 233 | imsize = shape[:2] 234 | dim = shape[2:] 235 | 236 | # Random affine 237 | blur_size = int(4 * sigma) | 1 238 | dx = cv2.GaussianBlur(random_state.rand(*imsize) * 2 - 1, 239 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 240 | dy = cv2.GaussianBlur(random_state.rand(*imsize) * 2 - 1, 241 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 242 | 243 | # use as_strided to copy things over across n1...nn channels 244 | dx = as_strided(dx.astype(np.float32), 245 | strides=(0,) * len(dim) + (4 * shape[1], 4), 246 | shape=dim + (shape[0], shape[1])) 247 | dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim)))) 248 | 249 | dy = as_strided(dy.astype(np.float32), 250 | strides=(0,) * len(dim) + (4 * shape[1], 4), 251 | shape=dim + (shape[0], shape[1])) 252 | dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim)))) 253 | 254 | coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim]) 255 | indices = [np.reshape(e + de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:], 256 | [dy, dx] + [0] * len(dim))] 257 | 258 | if lazy: 259 | return indices 260 | 261 | return map_coordinates(image, indices, order=order, mode='reflect').reshape(shape) 262 | 263 | 264 | class ElasticTransform(object): 265 | """Apply elastic transformation on a numpy.ndarray (H x W x C) 266 | """ 267 | 268 | def __init__(self, alpha, sigma, order=1): 269 | self.alpha = alpha 270 | self.sigma = sigma 271 | self.order = order 272 | 273 | def __call__(self, image): 274 | if isinstance(self.alpha, Sequence): 275 | alpha = random_num_generator(self.alpha) 276 | else: 277 | alpha = self.alpha 278 | if isinstance(self.sigma, Sequence): 279 | sigma = random_num_generator(self.sigma) 280 | else: 281 | sigma = self.sigma 282 | return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order) 283 | 284 | 285 | class RandomFlip3D(object): 286 | 287 | def __init__(self, h=True, v=True, t=True, p=0.5): 288 | """ 289 | Randomly flip an image horizontally and/or vertically with 290 | some probability. 291 | 292 | Arguments 293 | --------- 294 | h : boolean 295 | whether to horizontally flip w/ probability p 296 | 297 | v : boolean 298 | whether to vertically flip w/ probability p 299 | 300 | p : float between [0,1] 301 | probability with which to apply allowed flipping operations 302 | """ 303 | self.horizontal = h 304 | self.vertical = v 305 | self.depth = t 306 | self.p = p 307 | 308 | def __call__(self, x, y=None): 309 | # horizontal flip with p = self.p 310 | if self.horizontal: 311 | if np.random.random() < self.p: 312 | x = x[::-1, ...] 313 | 314 | # vertical flip with p = self.p 315 | if self.vertical: 316 | if np.random.random() < self.p: 317 | x = x[:, ::-1, ...] 318 | 319 | if self.depth: 320 | if np.random.random() < self.p: 321 | x = x[..., ::-1] 322 | 323 | return x 324 | -------------------------------------------------------------------------------- /data/CMR/class_slice_index_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate class-slice indexing table for experiments\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is for experiment setting up for simulating few-shot image segmentation scenarios\n", 13 | "\n", 14 | "Input: pre-processed images and their ground-truth labels\n", 15 | "\n", 16 | "Output: a `json` file for class-slice indexing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import numpy as np\n", 39 | "import os\n", 40 | "import glob\n", 41 | "import SimpleITK as sitk\n", 42 | "import sys\n", 43 | "import json\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 9, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "IMG_BNAME=\"./cmr_MR_normalized/image_*.nii.gz\"\n", 55 | "SEG_BNAME=\"./cmr_MR_normalized/label_*.nii.gz\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "imgs = glob.glob(IMG_BNAME)\n", 65 | "segs = glob.glob(SEG_BNAME)\n", 66 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 67 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "['./cmr_MR_normalized/image_1.nii.gz',\n", 79 | " './cmr_MR_normalized/image_2.nii.gz',\n", 80 | " './cmr_MR_normalized/image_3.nii.gz',\n", 81 | " './cmr_MR_normalized/image_4.nii.gz',\n", 82 | " './cmr_MR_normalized/image_5.nii.gz',\n", 83 | " './cmr_MR_normalized/image_6.nii.gz',\n", 84 | " './cmr_MR_normalized/image_7.nii.gz',\n", 85 | " './cmr_MR_normalized/image_8.nii.gz',\n", 86 | " './cmr_MR_normalized/image_9.nii.gz',\n", 87 | " './cmr_MR_normalized/image_10.nii.gz',\n", 88 | " './cmr_MR_normalized/image_11.nii.gz',\n", 89 | " './cmr_MR_normalized/image_12.nii.gz',\n", 90 | " './cmr_MR_normalized/image_13.nii.gz',\n", 91 | " './cmr_MR_normalized/image_14.nii.gz',\n", 92 | " './cmr_MR_normalized/image_15.nii.gz',\n", 93 | " './cmr_MR_normalized/image_16.nii.gz',\n", 94 | " './cmr_MR_normalized/image_17.nii.gz',\n", 95 | " './cmr_MR_normalized/image_18.nii.gz',\n", 96 | " './cmr_MR_normalized/image_19.nii.gz',\n", 97 | " './cmr_MR_normalized/image_20.nii.gz',\n", 98 | " './cmr_MR_normalized/image_21.nii.gz',\n", 99 | " './cmr_MR_normalized/image_22.nii.gz',\n", 100 | " './cmr_MR_normalized/image_23.nii.gz',\n", 101 | " './cmr_MR_normalized/image_24.nii.gz',\n", 102 | " './cmr_MR_normalized/image_25.nii.gz',\n", 103 | " './cmr_MR_normalized/image_26.nii.gz',\n", 104 | " './cmr_MR_normalized/image_27.nii.gz',\n", 105 | " './cmr_MR_normalized/image_28.nii.gz',\n", 106 | " './cmr_MR_normalized/image_29.nii.gz',\n", 107 | " './cmr_MR_normalized/image_30.nii.gz',\n", 108 | " './cmr_MR_normalized/image_31.nii.gz',\n", 109 | " './cmr_MR_normalized/image_32.nii.gz',\n", 110 | " './cmr_MR_normalized/image_33.nii.gz',\n", 111 | " './cmr_MR_normalized/image_34.nii.gz',\n", 112 | " './cmr_MR_normalized/image_35.nii.gz']" 113 | ] 114 | }, 115 | "execution_count": 11, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "imgs" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 12, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "['./cmr_MR_normalized/label_1.nii.gz',\n", 133 | " './cmr_MR_normalized/label_2.nii.gz',\n", 134 | " './cmr_MR_normalized/label_3.nii.gz',\n", 135 | " './cmr_MR_normalized/label_4.nii.gz',\n", 136 | " './cmr_MR_normalized/label_5.nii.gz',\n", 137 | " './cmr_MR_normalized/label_6.nii.gz',\n", 138 | " './cmr_MR_normalized/label_7.nii.gz',\n", 139 | " './cmr_MR_normalized/label_8.nii.gz',\n", 140 | " './cmr_MR_normalized/label_9.nii.gz',\n", 141 | " './cmr_MR_normalized/label_10.nii.gz',\n", 142 | " './cmr_MR_normalized/label_11.nii.gz',\n", 143 | " './cmr_MR_normalized/label_12.nii.gz',\n", 144 | " './cmr_MR_normalized/label_13.nii.gz',\n", 145 | " './cmr_MR_normalized/label_14.nii.gz',\n", 146 | " './cmr_MR_normalized/label_15.nii.gz',\n", 147 | " './cmr_MR_normalized/label_16.nii.gz',\n", 148 | " './cmr_MR_normalized/label_17.nii.gz',\n", 149 | " './cmr_MR_normalized/label_18.nii.gz',\n", 150 | " './cmr_MR_normalized/label_19.nii.gz',\n", 151 | " './cmr_MR_normalized/label_20.nii.gz',\n", 152 | " './cmr_MR_normalized/label_21.nii.gz',\n", 153 | " './cmr_MR_normalized/label_22.nii.gz',\n", 154 | " './cmr_MR_normalized/label_23.nii.gz',\n", 155 | " './cmr_MR_normalized/label_24.nii.gz',\n", 156 | " './cmr_MR_normalized/label_25.nii.gz',\n", 157 | " './cmr_MR_normalized/label_26.nii.gz',\n", 158 | " './cmr_MR_normalized/label_27.nii.gz',\n", 159 | " './cmr_MR_normalized/label_28.nii.gz',\n", 160 | " './cmr_MR_normalized/label_29.nii.gz',\n", 161 | " './cmr_MR_normalized/label_30.nii.gz',\n", 162 | " './cmr_MR_normalized/label_31.nii.gz',\n", 163 | " './cmr_MR_normalized/label_32.nii.gz',\n", 164 | " './cmr_MR_normalized/label_33.nii.gz',\n", 165 | " './cmr_MR_normalized/label_34.nii.gz',\n", 166 | " './cmr_MR_normalized/label_35.nii.gz']" 167 | ] 168 | }, 169 | "execution_count": 12, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "segs" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 13, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "pid 1 finished!\n", 188 | "pid 2 finished!\n", 189 | "pid 3 finished!\n", 190 | "pid 4 finished!\n", 191 | "pid 5 finished!\n", 192 | "pid 6 finished!\n", 193 | "pid 7 finished!\n", 194 | "pid 8 finished!\n", 195 | "pid 9 finished!\n", 196 | "pid 10 finished!\n", 197 | "pid 11 finished!\n", 198 | "pid 12 finished!\n", 199 | "pid 13 finished!\n", 200 | "pid 14 finished!\n", 201 | "pid 15 finished!\n", 202 | "pid 16 finished!\n", 203 | "pid 17 finished!\n", 204 | "pid 18 finished!\n", 205 | "pid 19 finished!\n", 206 | "pid 20 finished!\n", 207 | "pid 21 finished!\n", 208 | "pid 22 finished!\n", 209 | "pid 23 finished!\n", 210 | "pid 24 finished!\n", 211 | "pid 25 finished!\n", 212 | "pid 26 finished!\n", 213 | "pid 27 finished!\n", 214 | "pid 28 finished!\n", 215 | "pid 29 finished!\n", 216 | "pid 30 finished!\n", 217 | "pid 31 finished!\n", 218 | "pid 32 finished!\n", 219 | "pid 33 finished!\n", 220 | "pid 34 finished!\n", 221 | "pid 35 finished!\n" 222 | ] 223 | }, 224 | { 225 | "ename": "FileNotFoundError", 226 | "evalue": "[Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'", 227 | "output_type": "error", 228 | "traceback": [ 229 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 230 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 231 | "\u001b[0;32m/tmp/ipykernel_1065506/1189938079.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'pid {str(pid)} finished!'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'w'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassmap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 232 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "classmap = {}\n", 238 | "LABEL_NAME = [\"BG\", \"LV-MYO\", \"LV-BP\", \"RV\"] \n", 239 | "\n", 240 | "\n", 241 | "MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", 242 | "\n", 243 | "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n", 244 | "for _lb in LABEL_NAME:\n", 245 | " classmap[_lb] = {}\n", 246 | " for _sid in segs:\n", 247 | " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 248 | " classmap[_lb][pid] = []\n", 249 | "\n", 250 | "for seg in segs:\n", 251 | " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 252 | " lb_vol = nio.read_nii_bysitk(seg)\n", 253 | " n_slice = lb_vol.shape[0]\n", 254 | " lb_vol[lb_vol == 200] = 1\n", 255 | " lb_vol[lb_vol == 500] = 2\n", 256 | " lb_vol[lb_vol == 600] = 3\n", 257 | " for slc in range(n_slice):\n", 258 | " for cls in range(len(LABEL_NAME)):\n", 259 | " if cls in lb_vol[slc, ...]:\n", 260 | " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", 261 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 262 | " print(f'pid {str(pid)} finished!')\n", 263 | " \n", 264 | "with open(fid, 'w') as fopen:\n", 265 | " json.dump(classmap, fopen)\n", 266 | " fopen.close() \n", 267 | " " 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 9, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "ename": "FileNotFoundError", 277 | "evalue": "[Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'", 278 | "output_type": "error", 279 | "traceback": [ 280 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 281 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 282 | "\u001b[0;32m/tmp/ipykernel_1045184/825143362.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'w'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassmap\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mfopen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 283 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './chaos_MR_T2_normalized/classmap_1.json'" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "with open(fid, 'w') as fopen:\n", 289 | " json.dump(classmap, fopen)\n", 290 | " fopen.close()" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": "Python 3 (ipykernel)", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.8.12" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 2 322 | } 323 | -------------------------------------------------------------------------------- /dataloaders/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for Training and Test 3 | Extended from ADNet code by Hansen et al. 4 | """ 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as deftfx 8 | import glob 9 | import os 10 | import SimpleITK as sitk 11 | import random 12 | import numpy as np 13 | from . import image_transforms as myit 14 | from .dataset_specifics import * 15 | 16 | 17 | class TestDataset(Dataset): 18 | 19 | def __init__(self, args): 20 | 21 | # reading the paths 22 | if args['dataset'] == 'CMR': 23 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_MR_normalized/image*')) 24 | elif args['dataset'] == 'CHAOST2': 25 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/image*')) 26 | elif args['dataset'] == 'SABS': 27 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/image*')) 28 | 29 | self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 30 | 31 | # remove test fold! 32 | self.FOLD = get_folds(args['dataset']) 33 | self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx in self.FOLD[args['eval_fold']]] 34 | 35 | # split into support/query 36 | idx = np.arange(len(self.image_dirs)) 37 | self.support_dir = self.image_dirs[idx[args['supp_idx']]] 38 | self.image_dirs.pop(idx[args['supp_idx']]) # remove support 39 | self.label = None 40 | 41 | def __len__(self): 42 | return len(self.image_dirs) 43 | 44 | def __getitem__(self, idx): 45 | 46 | img_path = self.image_dirs[idx] 47 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) 48 | img = (img - img.mean()) / img.std() 49 | img = np.stack(3 * [img], axis=1) 50 | 51 | lbl = sitk.GetArrayFromImage( 52 | sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1])) 53 | 54 | lbl[lbl == 200] = 1 55 | lbl[lbl == 500] = 2 56 | lbl[lbl == 600] = 3 57 | lbl = 1 * (lbl == self.label) 58 | 59 | sample = {'id': img_path} 60 | 61 | # Evaluation protocol. 62 | idx = lbl.sum(axis=(1, 2)) > 0 63 | sample['image'] = torch.from_numpy(img[idx]) 64 | sample['label'] = torch.from_numpy(lbl[idx]) 65 | 66 | return sample 67 | 68 | def get_support_index(self, n_shot, C): 69 | """ 70 | Selecting intervals according to Ouyang et al. 71 | """ 72 | if n_shot == 1: 73 | pcts = [0.5] 74 | else: 75 | half_part = 1 / (n_shot * 2) 76 | part_interval = (1.0 - 1.0 / n_shot) / (n_shot - 1) 77 | pcts = [half_part + part_interval * ii for ii in range(n_shot)] 78 | 79 | return (np.array(pcts) * C).astype('int') 80 | 81 | def getSupport(self, label=None, all_slices=True, N=None): 82 | if label is None: 83 | raise ValueError('Need to specify label class!') 84 | 85 | img_path = self.support_dir 86 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) 87 | img = (img - img.mean()) / img.std() 88 | img = np.stack(3 * [img], axis=1) 89 | 90 | lbl = sitk.GetArrayFromImage( 91 | sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1])) 92 | lbl[lbl == 200] = 1 93 | lbl[lbl == 500] = 2 94 | lbl[lbl == 600] = 3 95 | lbl = 1 * (lbl == label) 96 | 97 | sample = {} 98 | if all_slices: 99 | sample['image'] = torch.from_numpy(img) 100 | sample['label'] = torch.from_numpy(lbl) 101 | else: 102 | # select N labeled slices 103 | if N is None: 104 | raise ValueError('Need to specify number of labeled slices!') 105 | idx = lbl.sum(axis=(1, 2)) > 0 106 | idx_ = self.get_support_index(N, idx.sum()) 107 | 108 | sample['image'] = torch.from_numpy(img[idx][idx_]) 109 | sample['label'] = torch.from_numpy(lbl[idx][idx_]) 110 | 111 | return sample 112 | 113 | 114 | class TrainDataset(Dataset): 115 | 116 | def __init__(self, args): 117 | self.n_shot = args['n_shot'] 118 | self.n_way = args['n_way'] 119 | self.n_query = args['n_query'] 120 | self.n_sv = args['n_sv'] 121 | self.max_iter = args['max_iter'] 122 | self.read = True # read images before get_item 123 | self.train_sampling = 'neighbors' 124 | self.min_size = args['min_size'] 125 | self.test_label = args['test_label'] 126 | self.exclude_label = args['exclude_label'] 127 | self.use_gt = args['use_gt'] 128 | 129 | # reading the paths (leaving the reading of images into memory to __getitem__) 130 | if args['dataset'] == 'CMR': 131 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_MR_normalized/image*')) 132 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_MR_normalized/label*')) 133 | elif args['dataset'] == 'CHAOST2': 134 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/image*')) 135 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/label*')) 136 | elif args['dataset'] == 'SABS': 137 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/image*')) 138 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/label*')) 139 | 140 | self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 141 | self.label_dirs = sorted(self.label_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 142 | self.sprvxl_dirs = glob.glob(os.path.join(args['data_dir'], 'supervoxels_' + str(args['n_sv']), 'super*')) 143 | self.sprvxl_dirs = sorted(self.sprvxl_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 144 | 145 | # remove test fold! 146 | self.FOLD = get_folds(args['dataset']) 147 | self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx not in self.FOLD[args['eval_fold']]] 148 | self.label_dirs = [elem for idx, elem in enumerate(self.label_dirs) if idx not in self.FOLD[args['eval_fold']]] 149 | self.sprvxl_dirs = [elem for idx, elem in enumerate(self.sprvxl_dirs) if 150 | idx not in self.FOLD[args['eval_fold']]] 151 | 152 | # read images 153 | if self.read: 154 | self.images = {} 155 | self.labels = {} 156 | self.sprvxls = {} 157 | for image_dir, label_dir, sprvxl_dir in zip(self.image_dirs, self.label_dirs, self.sprvxl_dirs): 158 | self.images[image_dir] = sitk.GetArrayFromImage(sitk.ReadImage(image_dir)) 159 | self.labels[label_dir] = sitk.GetArrayFromImage(sitk.ReadImage(label_dir)) 160 | self.sprvxls[sprvxl_dir] = sitk.GetArrayFromImage(sitk.ReadImage(sprvxl_dir)) 161 | 162 | def __len__(self): 163 | return self.max_iter 164 | 165 | def gamma_tansform(self, img): 166 | gamma_range = (0.5, 1.5) 167 | gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0] 168 | cmin = img.min() 169 | irange = (img.max() - cmin + 1e-5) 170 | 171 | img = img - cmin + 1e-5 172 | img = irange * np.power(img * 1.0 / irange, gamma) 173 | img = img + cmin 174 | 175 | return img 176 | 177 | def geom_transform(self, img, mask): 178 | 179 | affine = {'rotate': 5, 'shift': (5, 5), 'shear': 5, 'scale': (0.9, 1.2)} 180 | alpha = 10 181 | sigma = 5 182 | order = 3 183 | 184 | tfx = [] 185 | tfx.append(myit.RandomAffine(affine.get('rotate'), 186 | affine.get('shift'), 187 | affine.get('shear'), 188 | affine.get('scale'), 189 | affine.get('scale_iso', True), 190 | order=order)) 191 | tfx.append(myit.ElasticTransform(alpha, sigma)) 192 | transform = deftfx.Compose(tfx) 193 | 194 | if len(img.shape) > 4: 195 | n_shot = img.shape[1] 196 | for shot in range(n_shot): 197 | cat = np.concatenate((img[0, shot], mask[:, shot])).transpose(1, 2, 0) 198 | cat = transform(cat).transpose(2, 0, 1) 199 | img[0, shot] = cat[:3, :, :] 200 | mask[:, shot] = np.rint(cat[3:, :, :]) 201 | 202 | else: 203 | for q in range(img.shape[0]): 204 | cat = np.concatenate((img[q], mask[q][None])).transpose(1, 2, 0) 205 | cat = transform(cat).transpose(2, 0, 1) 206 | img[q] = cat[:3, :, :] 207 | mask[q] = np.rint(cat[3:, :, :].squeeze()) 208 | 209 | return img, mask 210 | 211 | def __getitem__(self, idx): 212 | 213 | # sample patient idx 214 | pat_idx = random.choice(range(len(self.image_dirs))) 215 | 216 | if self.read: 217 | # get image/supervoxel volume from dictionary 218 | img = self.images[self.image_dirs[pat_idx]] 219 | gt = self.labels[self.label_dirs[pat_idx]] 220 | sprvxl = self.sprvxls[self.sprvxl_dirs[pat_idx]] 221 | else: 222 | # read image/supervoxel volume into memory 223 | img = sitk.GetArrayFromImage(sitk.ReadImage(self.image_dirs[pat_idx])) 224 | gt = sitk.GetArrayFromImage(sitk.ReadImage(self.label_dirs[pat_idx])) 225 | sprvxl = sitk.GetArrayFromImage(sitk.ReadImage(self.sprvxl_dirs[pat_idx])) 226 | 227 | if self.exclude_label is not None: # identify the slices containing test labels 228 | idx = np.arange(gt.shape[0]) 229 | exclude_idx = np.full(gt.shape[0], True, dtype=bool) 230 | for i in range(len(self.exclude_label)): 231 | exclude_idx = exclude_idx & (np.sum(gt == self.exclude_label[i], axis=(1, 2)) > 0) 232 | exclude_idx = idx[exclude_idx] 233 | else: 234 | exclude_idx = [] 235 | 236 | # normalize 237 | img = (img - img.mean()) / img.std() 238 | 239 | # chose training label 240 | if self.use_gt: 241 | lbl = gt.copy() 242 | else: 243 | lbl = sprvxl.copy() 244 | # lbl is label numpy 245 | 246 | # sample class(es) (gt/supervoxel) 247 | unique = list(np.unique(lbl)) 248 | unique.remove(0) 249 | if self.use_gt: 250 | unique = list(set(unique) - set(self.test_label)) 251 | 252 | size = 0 253 | while size < self.min_size: 254 | n_slices = (self.n_shot * self.n_way) + self.n_query - 1 255 | while n_slices < ((self.n_shot * self.n_way) + self.n_query): 256 | 257 | cls_idx = random.choice(unique) # cls_idx is sampled class id 258 | 259 | # extract slices containing the sampled class 260 | sli_idx = np.sum(lbl == cls_idx, axis=(1, 2)) > 0 261 | idx = np.arange(lbl.shape[0]) 262 | sli_idx = idx[sli_idx] 263 | sli_idx = list(set(sli_idx) - set(np.intersect1d(sli_idx, exclude_idx))) # remove slices containing test labels 264 | n_slices = len(sli_idx) 265 | 266 | # generate possible subsets with successive slices (size = self.n_shot * self.n_way + self.n_query) 267 | subsets = [] 268 | for i in range(len(sli_idx)): 269 | if not subsets: 270 | subsets.append([sli_idx[i]]) 271 | elif sli_idx[i - 1] + 1 == sli_idx[i]: 272 | subsets[-1].append(sli_idx[i]) 273 | else: 274 | subsets.append([sli_idx[i]]) 275 | i = 0 276 | while i < len(subsets): 277 | if len(subsets[i]) < (self.n_shot * self.n_way + self.n_query): 278 | del subsets[i] 279 | else: 280 | i += 1 281 | if not len(subsets): 282 | return self.__getitem__(idx + np.random.randint(low=0, high=self.max_iter - 1, size=(1,))) 283 | 284 | # sample support and query slices 285 | i = random.choice(np.arange(len(subsets))) # subset index 286 | i = random.choice(subsets[i][:-(self.n_shot * self.n_way + self.n_query - 1)]) 287 | sample = np.arange(i, i + (self.n_shot * self.n_way) + self.n_query) 288 | 289 | lbl_cls = 1 * (lbl == cls_idx) 290 | 291 | size = max(np.sum(lbl_cls[sample[0]]), np.sum(lbl_cls[sample[1]])) 292 | 293 | # invert order 294 | if np.random.random(1) > 0.5: 295 | sample = sample[::-1] # successive slices (inverted) 296 | 297 | sup_lbl = lbl_cls[sample[:self.n_shot * self.n_way]][None,] # n_way * (n_shot * C) * H * W 298 | qry_lbl = lbl_cls[sample[self.n_shot * self.n_way:]] # n_qry * C * H * W 299 | 300 | sup_img = img[sample[:self.n_shot * self.n_way]][None,] # n_way * (n_shot * C) * H * W 301 | sup_img = np.stack((sup_img, sup_img, sup_img), axis=2) 302 | qry_img = img[sample[self.n_shot * self.n_way:]] # n_qry * C * H * W 303 | qry_img = np.stack((qry_img, qry_img, qry_img), axis=1) 304 | 305 | # gamma transform 306 | if np.random.random(1) > 0.5: 307 | qry_img = self.gamma_tansform(qry_img) 308 | else: 309 | sup_img = self.gamma_tansform(sup_img) 310 | 311 | # geom transform 312 | if np.random.random(1) > 0.5: 313 | qry_img, qry_lbl = self.geom_transform(qry_img, qry_lbl) 314 | else: 315 | sup_img, sup_lbl = self.geom_transform(sup_img, sup_lbl) 316 | 317 | sample = {'support_images': sup_img, 318 | 'support_fg_labels': sup_lbl, 319 | 'query_images': qry_img, 320 | 'query_labels': qry_lbl, 321 | 'selected_class': cls_idx} 322 | 323 | return sample 324 | -------------------------------------------------------------------------------- /data/CHAOST2/image_normalize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Image Pre-processing\n", 8 | "\n", 9 | "### Overview\n", 10 | "\n", 11 | "This is the second step for data preparation\n", 12 | "\n", 13 | "Input: `.nii`-like images and labels converted from `dicom`s/ `png` files\n", 14 | "\n", 15 | "Output: image-labels with unified size (axial), voxel-spacing, and alleviated off-resonance effects" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "%reset\n", 33 | "%load_ext autoreload\n", 34 | "%autoreload 2\n", 35 | "import numpy as np\n", 36 | "import os\n", 37 | "import glob\n", 38 | "import SimpleITK as sitk\n", 39 | "\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import copy\n", 42 | "import sys\n", 43 | "sys.path.insert(0, '../../dataloaders/')\n", 44 | "import niftiio as nio" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "IMG_FOLDER = \"./niis/T2SPIR\" #, path of nii-like images from step 1\n", 54 | "OUT_FOLDER=\"./chaos_MR_T2_normalized/\" # output directory" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "**0. Find images and their ground-truth segmentations**" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "imgs = glob.glob(IMG_FOLDER + f'/image_*.nii.gz')\n", 71 | "imgs = [ fid for fid in sorted(imgs) ]\n", 72 | "segs = [ fid for fid in sorted(glob.glob(IMG_FOLDER + f'/label_*.nii.gz')) ]\n", 73 | "\n", 74 | "pids = [pid.split(\"_\")[-1].split(\".\")[0] for pid in imgs]" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['./niis/T2SPIR/image_1.nii.gz',\n", 86 | " './niis/T2SPIR/image_10.nii.gz',\n", 87 | " './niis/T2SPIR/image_13.nii.gz',\n", 88 | " './niis/T2SPIR/image_15.nii.gz',\n", 89 | " './niis/T2SPIR/image_19.nii.gz',\n", 90 | " './niis/T2SPIR/image_2.nii.gz',\n", 91 | " './niis/T2SPIR/image_20.nii.gz',\n", 92 | " './niis/T2SPIR/image_21.nii.gz',\n", 93 | " './niis/T2SPIR/image_22.nii.gz',\n", 94 | " './niis/T2SPIR/image_3.nii.gz',\n", 95 | " './niis/T2SPIR/image_31.nii.gz',\n", 96 | " './niis/T2SPIR/image_32.nii.gz',\n", 97 | " './niis/T2SPIR/image_33.nii.gz',\n", 98 | " './niis/T2SPIR/image_34.nii.gz',\n", 99 | " './niis/T2SPIR/image_36.nii.gz',\n", 100 | " './niis/T2SPIR/image_37.nii.gz',\n", 101 | " './niis/T2SPIR/image_38.nii.gz',\n", 102 | " './niis/T2SPIR/image_39.nii.gz',\n", 103 | " './niis/T2SPIR/image_5.nii.gz',\n", 104 | " './niis/T2SPIR/image_8.nii.gz']" 105 | ] 106 | }, 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "imgs" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "['./niis/T2SPIR/label_1.nii.gz',\n", 125 | " './niis/T2SPIR/label_10.nii.gz',\n", 126 | " './niis/T2SPIR/label_13.nii.gz',\n", 127 | " './niis/T2SPIR/label_15.nii.gz',\n", 128 | " './niis/T2SPIR/label_19.nii.gz',\n", 129 | " './niis/T2SPIR/label_2.nii.gz',\n", 130 | " './niis/T2SPIR/label_20.nii.gz',\n", 131 | " './niis/T2SPIR/label_21.nii.gz',\n", 132 | " './niis/T2SPIR/label_22.nii.gz',\n", 133 | " './niis/T2SPIR/label_3.nii.gz',\n", 134 | " './niis/T2SPIR/label_31.nii.gz',\n", 135 | " './niis/T2SPIR/label_32.nii.gz',\n", 136 | " './niis/T2SPIR/label_33.nii.gz',\n", 137 | " './niis/T2SPIR/label_34.nii.gz',\n", 138 | " './niis/T2SPIR/label_36.nii.gz',\n", 139 | " './niis/T2SPIR/label_37.nii.gz',\n", 140 | " './niis/T2SPIR/label_38.nii.gz',\n", 141 | " './niis/T2SPIR/label_39.nii.gz',\n", 142 | " './niis/T2SPIR/label_5.nii.gz',\n", 143 | " './niis/T2SPIR/label_8.nii.gz']" 144 | ] 145 | }, 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "segs" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "**1. Unify image sizes and roi**\n", 160 | "\n", 161 | "a. Cut bright end of histogram to alleviate off-resonance issue\n", 162 | "\n", 163 | "b. Resample images to unified spacing\n", 164 | "\n", 165 | "c. Crop ROIs out to unify image sizes" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 6, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# some helper functions\n", 175 | "def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True):\n", 176 | " resample = sitk.ResampleImageFilter()\n", 177 | " resample.SetInterpolator(interpolator)\n", 178 | " resample.SetOutputDirection(mov_img_obj.GetDirection())\n", 179 | " resample.SetOutputOrigin(mov_img_obj.GetOrigin())\n", 180 | " mov_spacing = mov_img_obj.GetSpacing()\n", 181 | "\n", 182 | " resample.SetOutputSpacing(new_spacing)\n", 183 | " RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing)\n", 184 | " new_size = np.array(mov_img_obj.GetSize()) * RES_COE \n", 185 | "\n", 186 | " resample.SetSize( [int(sz+1) for sz in new_size] )\n", 187 | " if logging:\n", 188 | " print(\"Spacing: {} -> {}\".format(mov_spacing, new_spacing))\n", 189 | " print(\"Size {} -> {}\".format( mov_img_obj.GetSize(), new_size ))\n", 190 | "\n", 191 | " return resample.Execute(mov_img_obj)\n", 192 | "\n", 193 | "def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True):\n", 194 | " src_mat = sitk.GetArrayFromImage(mov_lb_obj)\n", 195 | " lbvs = np.unique(src_mat)\n", 196 | " if logging:\n", 197 | " print(\"Label values: {}\".format(lbvs))\n", 198 | " for idx, lbv in enumerate(lbvs):\n", 199 | " _src_curr_mat = np.float32(src_mat == lbv) \n", 200 | " _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat)\n", 201 | " _src_curr_obj.CopyInformation(mov_lb_obj)\n", 202 | " _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging )\n", 203 | " _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv\n", 204 | " if idx == 0:\n", 205 | " out_vol = _tar_curr_mat\n", 206 | " else:\n", 207 | " out_vol[_tar_curr_mat == lbv] = lbv\n", 208 | " out_obj = sitk.GetImageFromArray(out_vol)\n", 209 | " out_obj.SetSpacing( _tar_curr_obj.GetSpacing() )\n", 210 | " if ref_img != None:\n", 211 | " out_obj.CopyInformation(ref_img)\n", 212 | " return out_obj\n", 213 | " \n", 214 | "def get_label_center(label):\n", 215 | " nnz = np.sum(label > 1e-5)\n", 216 | " return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz))\n", 217 | "\n", 218 | "def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True):\n", 219 | " \"\"\" crop a 3d matrix given the index of the new volume on the original volume\n", 220 | " Args:\n", 221 | " refernce_ctr_idx: the center of the new volume on the original volume (in indices)\n", 222 | " only_2d: only do cropping on first two dimensions\n", 223 | " \"\"\"\n", 224 | " _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case\n", 225 | " if only_2d:\n", 226 | " assert len(crop_size) == 2, \"Actual len {}\".format(len(crop_size))\n", 227 | " assert len(referece_ctr_idx) == 2, \"Actual len {}\".format(len(referece_ctr_idx))\n", 228 | " _expand_cropsize.append(ori_vol.shape[-1])\n", 229 | " \n", 230 | " image_patch = np.ones(tuple(_expand_cropsize)) * padval\n", 231 | "\n", 232 | " half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] )\n", 233 | " _min_idx = [0,0,0]\n", 234 | " _max_idx = list(ori_vol.shape)\n", 235 | "\n", 236 | " # bias of actual cropped size to the beginning and the end of this volume\n", 237 | " _bias_start = [0,0,0]\n", 238 | " _bias_end = [0,0,0]\n", 239 | "\n", 240 | " for dim,hsize in enumerate(half_size):\n", 241 | " if dim == 2 and only_2d:\n", 242 | " break\n", 243 | "\n", 244 | " _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]])\n", 245 | " _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]])\n", 246 | "\n", 247 | " _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim]\n", 248 | " _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim]\n", 249 | " \n", 250 | " if only_2d:\n", 251 | " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", 252 | " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \\\n", 253 | " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", 254 | " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ]\n", 255 | "\n", 256 | " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ]\n", 257 | " # then goes back to original volume\n", 258 | " else:\n", 259 | " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", 260 | " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \\\n", 261 | " half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \\\n", 262 | " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", 263 | " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \\\n", 264 | " referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ]\n", 265 | "\n", 266 | " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ]\n", 267 | " return image_patch\n", 268 | "\n", 269 | "def copy_spacing_ori(src, dst):\n", 270 | " dst.SetSpacing(src.GetSpacing())\n", 271 | " dst.SetOrigin(src.GetOrigin())\n", 272 | " dst.SetDirection(src.GetDirection())\n", 273 | " return dst\n", 274 | "\n", 275 | "s2n = sitk.GetArrayFromImage" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 7, 281 | "metadata": { 282 | "scrolled": false 283 | }, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Failed to create the output folder.\n" 290 | ] 291 | }, 292 | { 293 | "ename": "NameError", 294 | "evalue": "name 'copy_spacing_ori' is not defined", 295 | "output_type": "error", 296 | "traceback": [ 297 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 298 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 299 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mhis_img_o\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msitk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGetImageFromArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mhis_img_o\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopy_spacing_ori\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhis_img_o\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;31m# resampling\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 300 | "\u001b[0;31mNameError\u001b[0m: name 'copy_spacing_ori' is not defined" 301 | ] 302 | } 303 | ], 304 | "source": [ 305 | "import copy\n", 306 | "try:\n", 307 | " os.mkdir(OUT_FOLDER)\n", 308 | "except:\n", 309 | " print(\"Failed to create the output folder.\")\n", 310 | " \n", 311 | "HIST_CUT_TOP = 0.5 # cut top 0.5% of intensity historgam to alleviate off-resonance effect\n", 312 | "\n", 313 | "NEW_SPA = [1.25, 1.25, 7.70] # unified voxel spacing\n", 314 | "\n", 315 | "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n", 316 | "\n", 317 | " lb_n = nio.read_nii_bysitk(seg_fid)\n", 318 | " resample_flg = True\n", 319 | "\n", 320 | " img_obj = sitk.ReadImage( img_fid )\n", 321 | " seg_obj = sitk.ReadImage( seg_fid )\n", 322 | "\n", 323 | " array = sitk.GetArrayFromImage(img_obj)\n", 324 | "\n", 325 | " # cut histogram\n", 326 | " hir = float(np.percentile(array, 100.0 - HIST_CUT_TOP))\n", 327 | " array[array > hir] = hir\n", 328 | "\n", 329 | " his_img_o = sitk.GetImageFromArray(array)\n", 330 | " his_img_o = copy_spacing_ori(img_obj, his_img_o)\n", 331 | "\n", 332 | " # resampling\n", 333 | " img_spa_ori = img_obj.GetSpacing()\n", 334 | " res_img_o = resample_by_res(his_img_o, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2]],\n", 335 | " interpolator = sitk.sitkLinear, logging = True)\n", 336 | "\n", 337 | "\n", 338 | "\n", 339 | " ## label\n", 340 | " lb_arr = sitk.GetArrayFromImage(seg_obj)\n", 341 | "\n", 342 | " # resampling\n", 343 | " res_lb_o = resample_lb_by_res(seg_obj, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2] ], interpolator = sitk.sitkLinear,\n", 344 | " ref_img = None, logging = True)\n", 345 | "\n", 346 | "\n", 347 | " # crop out rois\n", 348 | " res_img_a = s2n(res_img_o)\n", 349 | "\n", 350 | " crop_img_a = image_crop(res_img_a.transpose(1,2,0), [256, 256],\n", 351 | " referece_ctr_idx = [res_img_a.shape[1] // 2, res_img_a.shape[2] //2],\n", 352 | " padval = res_img_a.min(), only_2d = True).transpose(2,0,1)\n", 353 | "\n", 354 | " out_img_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_img_a))\n", 355 | "\n", 356 | " res_lb_a = s2n(res_lb_o)\n", 357 | "\n", 358 | " crop_lb_a = image_crop(res_lb_a.transpose(1,2,0), [256, 256],\n", 359 | " referece_ctr_idx = [res_lb_a.shape[1] // 2, res_lb_a.shape[2] //2],\n", 360 | " padval = 0, only_2d = True).transpose(2,0,1)\n", 361 | "\n", 362 | " out_lb_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_lb_a))\n", 363 | "\n", 364 | "\n", 365 | " out_img_fid = os.path.join( OUT_FOLDER, f'image_{pid}.nii.gz' )\n", 366 | " out_lb_fid = os.path.join( OUT_FOLDER, f'label_{pid}.nii.gz' ) \n", 367 | "\n", 368 | " # then save pre-processed images\n", 369 | " sitk.WriteImage(out_img_obj, out_img_fid, True) \n", 370 | " sitk.WriteImage(out_lb_obj, out_lb_fid, True) \n", 371 | " print(\"{} has been saved\".format(out_img_fid))" 372 | ] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "Python 3", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.6.0" 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 2 396 | } 397 | -------------------------------------------------------------------------------- /dataloaders/datasets_outside.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for Training and Test 3 | Extended from ADNet code by Hansen et al. 4 | """ 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as deftfx 8 | import glob 9 | import os 10 | import SimpleITK as sitk 11 | import random 12 | import numpy as np 13 | from . import image_transforms as myit 14 | from .dataset_specifics import * 15 | 16 | 17 | class TestDataset(Dataset): 18 | 19 | def __init__(self, args): 20 | 21 | # reading the paths 22 | if args['dataset'] == 'CMR': 23 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_MR_normalized/image*')) 24 | elif args['dataset'] == 'CHAOST2': 25 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/image*')) 26 | elif args['dataset'] == 'SABS': 27 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/image*')) 28 | 29 | self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 30 | 31 | # remove test fold! 32 | self.FOLD = get_folds(args['dataset']) 33 | self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx in self.FOLD[args['eval_fold']]] 34 | 35 | # split into support/query 36 | idx = np.arange(len(self.image_dirs)) 37 | self.support_dir = self.image_dirs[idx[args['supp_idx']]] 38 | self.image_dirs.pop(idx[args['supp_idx']]) # remove support 39 | self.label = None 40 | 41 | def __len__(self): 42 | return len(self.image_dirs) 43 | 44 | def __getitem__(self, idx): 45 | 46 | img_path = self.image_dirs[idx] 47 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) 48 | img = (img - img.mean()) / img.std() 49 | img = np.stack(3 * [img], axis=1) 50 | 51 | lbl = sitk.GetArrayFromImage( 52 | sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1])) 53 | lbl[lbl == 200] = 1 54 | lbl[lbl == 500] = 2 55 | lbl[lbl == 600] = 3 56 | lbl = 1 * (lbl == self.label) 57 | 58 | sample = {'id': img_path} 59 | 60 | # Evaluation protocol. 61 | idx = lbl.sum(axis=(1, 2)) > 0 62 | sample['image'] = torch.from_numpy(img[idx]) 63 | sample['label'] = torch.from_numpy(lbl[idx]) 64 | 65 | return sample 66 | 67 | def get_support_index(self, n_shot, C): 68 | """ 69 | Selecting intervals according to Ouyang et al. 70 | """ 71 | if n_shot == 1: 72 | pcts = [0.5] 73 | else: 74 | half_part = 1 / (n_shot * 2) 75 | part_interval = (1.0 - 1.0 / n_shot) / (n_shot - 1) 76 | pcts = [half_part + part_interval * ii for ii in range(n_shot)] 77 | 78 | return (np.array(pcts) * C).astype('int') 79 | 80 | def getSupport(self, label=None, all_slices=True, N=None): 81 | if label is None: 82 | raise ValueError('Need to specify label class!') 83 | 84 | img_path = self.support_dir 85 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) 86 | img = (img - img.mean()) / img.std() 87 | img = np.stack(3 * [img], axis=1) 88 | 89 | lbl = sitk.GetArrayFromImage( 90 | sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1])) 91 | lbl[lbl == 200] = 1 92 | lbl[lbl == 500] = 2 93 | lbl[lbl == 600] = 3 94 | lbl = 1 * (lbl == label) 95 | 96 | sample = {} 97 | if all_slices: 98 | sample['image'] = torch.from_numpy(img) 99 | sample['label'] = torch.from_numpy(lbl) 100 | else: 101 | # select N labeled slices 102 | if N is None: 103 | raise ValueError('Need to specify number of labeled slices!') 104 | idx = lbl.sum(axis=(1, 2)) > 0 105 | idx_ = self.get_support_index(N, idx.sum()) 106 | 107 | sample['image'] = torch.from_numpy(img[idx][idx_]) 108 | sample['label'] = torch.from_numpy(lbl[idx][idx_]) 109 | 110 | return sample 111 | 112 | 113 | class TrainDataset(Dataset): 114 | 115 | def __init__(self, args): 116 | self.n_shot = args['n_shot'] 117 | self.n_way = args['n_way'] 118 | self.n_query = args['n_query'] 119 | self.n_sv = args['n_sv'] 120 | self.max_iter = args['max_iter'] 121 | self.read = True # read images before get_item 122 | self.train_sampling = 'neighbors' 123 | self.min_size = args['min_size'] # 200 124 | self.test_label = args['test_label'] 125 | self.exclude_label = args['exclude_label'] 126 | self.use_gt = args['use_gt'] 127 | 128 | # reading the paths (leaving the reading of images into memory to __getitem__) 129 | if args['dataset'] == 'CMR': 130 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_MR_normalized/image*')) 131 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'cmr_MR_normalized/label*')) 132 | elif args['dataset'] == 'CHAOST2': 133 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/image*')) 134 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'chaos_MR_T2_normalized/label*')) 135 | elif args['dataset'] == 'SABS': 136 | self.image_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/image*')) 137 | self.label_dirs = glob.glob(os.path.join(args['data_dir'], 'sabs_CT_normalized/label*')) 138 | 139 | self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 140 | self.label_dirs = sorted(self.label_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 141 | self.sprvxl_dirs = glob.glob(os.path.join(args['data_dir'], 'supervoxels_' + str(args['n_sv']), 'super*')) 142 | self.sprvxl_dirs = sorted(self.sprvxl_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0])) 143 | 144 | # remove test fold! 145 | self.FOLD = get_folds(args['dataset']) 146 | self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx not in self.FOLD[args['eval_fold']]] 147 | self.label_dirs = [elem for idx, elem in enumerate(self.label_dirs) if idx not in self.FOLD[args['eval_fold']]] 148 | self.sprvxl_dirs = [elem for idx, elem in enumerate(self.sprvxl_dirs) if 149 | idx not in self.FOLD[args['eval_fold']]] 150 | 151 | # read images 152 | if self.read: 153 | self.images = {} 154 | self.labels = {} 155 | self.sprvxls = {} 156 | for image_dir, label_dir, sprvxl_dir in zip(self.image_dirs, self.label_dirs, self.sprvxl_dirs): 157 | self.images[image_dir] = sitk.GetArrayFromImage(sitk.ReadImage(image_dir)) 158 | self.labels[label_dir] = sitk.GetArrayFromImage(sitk.ReadImage(label_dir)) 159 | self.sprvxls[sprvxl_dir] = sitk.GetArrayFromImage(sitk.ReadImage(sprvxl_dir)) 160 | 161 | def __len__(self): 162 | return self.max_iter 163 | 164 | def gamma_tansform(self, img): 165 | gamma_range = (0.5, 1.5) 166 | gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0] 167 | cmin = img.min() 168 | irange = (img.max() - cmin + 1e-5) 169 | 170 | img = img - cmin + 1e-5 171 | img = irange * np.power(img * 1.0 / irange, gamma) 172 | img = img + cmin 173 | 174 | return img 175 | 176 | def geom_transform(self, img, mask): 177 | 178 | affine = {'rotate': 5, 'shift': (5, 5), 'shear': 5, 'scale': (0.9, 1.2)} 179 | alpha = 10 180 | sigma = 5 181 | order = 3 182 | 183 | tfx = [] 184 | tfx.append(myit.RandomAffine(affine.get('rotate'), 185 | affine.get('shift'), 186 | affine.get('shear'), 187 | affine.get('scale'), 188 | affine.get('scale_iso', True), 189 | order=order)) 190 | tfx.append(myit.ElasticTransform(alpha, sigma)) 191 | transform = deftfx.Compose(tfx) 192 | 193 | if len(img.shape) > 4: 194 | n_shot = img.shape[1] 195 | for shot in range(n_shot): 196 | cat = np.concatenate((img[0, shot], mask[:, shot])).transpose(1, 2, 0) 197 | cat = transform(cat).transpose(2, 0, 1) 198 | img[0, shot] = cat[:3, :, :] 199 | mask[:, shot] = np.rint(cat[3:, :, :]) 200 | 201 | else: 202 | for q in range(img.shape[0]): 203 | cat = np.concatenate((img[q], mask[q][None])).transpose(1, 2, 0) 204 | cat = transform(cat).transpose(2, 0, 1) 205 | img[q] = cat[:3, :, :] 206 | mask[q] = np.rint(cat[3:, :, :].squeeze()) 207 | 208 | return img, mask 209 | 210 | def __getitem__(self, idx): 211 | 212 | # sample patient idx 213 | pat_idx = random.choice(range(len(self.image_dirs))) 214 | 215 | if self.read: 216 | # get image/supervoxel volume from dictionary 217 | img = self.images[self.image_dirs[pat_idx]] 218 | gt = self.labels[self.label_dirs[pat_idx]] 219 | sprvxl = self.sprvxls[self.sprvxl_dirs[pat_idx]] 220 | else: 221 | # read image/supervoxel volume into memory 222 | img = sitk.GetArrayFromImage(sitk.ReadImage(self.image_dirs[pat_idx])) 223 | gt = sitk.GetArrayFromImage(sitk.ReadImage(self.label_dirs[pat_idx])) 224 | sprvxl = sitk.GetArrayFromImage(sitk.ReadImage(self.sprvxl_dirs[pat_idx])) 225 | 226 | if self.exclude_label is not None: # identify the slices containing test labels 227 | idx = np.arange(gt.shape[0]) 228 | exclude_idx = np.full(gt.shape[0], True, dtype=bool) 229 | for i in range(len(self.exclude_label)): 230 | exclude_idx = exclude_idx & (np.sum(gt == self.exclude_label[i], axis=(1, 2)) > 0) 231 | exclude_idx = idx[exclude_idx] 232 | else: 233 | exclude_idx = [] 234 | 235 | # normalize 236 | img = (img - img.mean()) / img.std() 237 | 238 | # chose training label 239 | if self.use_gt: 240 | lbl = gt.copy() 241 | else: 242 | lbl = sprvxl.copy() 243 | # lbl is label numpy 244 | 245 | # sample class(es) (gt/supervoxel) 246 | unique = list(np.unique(lbl)) 247 | unique.remove(0) 248 | if self.use_gt: 249 | unique = list(set(unique) - set(self.test_label)) 250 | # unique is type of list 251 | 252 | size = 0 253 | while size < self.min_size: 254 | n_slices = (self.n_shot * self.n_way) + self.n_query - 1 255 | while n_slices < ((self.n_shot * self.n_way) + self.n_query): 256 | 257 | cls_idx = random.choice(unique) # cls_idx is sampled class id 258 | 259 | # extract slices containing the sampled class 260 | sli_idx = np.sum(lbl == cls_idx, axis=(1, 2)) > 0 261 | idx = np.arange(lbl.shape[0]) 262 | sli_idx = idx[sli_idx] 263 | sli_idx = list(set(sli_idx) - set(np.intersect1d(sli_idx, exclude_idx))) # remove slices containing test labels 264 | n_slices = len(sli_idx) 265 | 266 | # generate possible subsets with successive slices (size = self.n_shot * self.n_way + self.n_query) 267 | subsets = [] 268 | for i in range(len(sli_idx)): 269 | if not subsets: 270 | subsets.append([sli_idx[i]]) 271 | elif sli_idx[i - 1] + 1 == sli_idx[i]: 272 | subsets[-1].append(sli_idx[i]) 273 | else: 274 | subsets.append([sli_idx[i]]) 275 | i = 0 276 | while i < len(subsets): 277 | if len(subsets[i]) < (self.n_shot * self.n_way + self.n_query): 278 | del subsets[i] 279 | else: 280 | i += 1 281 | if not len(subsets): 282 | return self.__getitem__(idx + np.random.randint(low=0, high=self.max_iter - 1, size=(1,))) 283 | 284 | # sample support and query slices 285 | i = random.choice(np.arange(len(subsets))) # subset index 286 | i = random.choice(subsets[i][:-(self.n_shot * self.n_way + self.n_query - 1)]) 287 | sample = np.arange(i, i + (self.n_shot * self.n_way) + self.n_query) 288 | 289 | lbl_cls = 1 * (lbl == cls_idx) 290 | 291 | size = max(np.sum(lbl_cls[sample[0]]), np.sum(lbl_cls[sample[1]])) 292 | 293 | # invert order 294 | if np.random.random(1) > 0.5: 295 | sample = sample[::-1] # successive slices (inverted) 296 | 297 | sup_lbl = lbl_cls[sample[:self.n_shot * self.n_way]][None,] # n_way * (n_shot * C) * H * W 298 | qry_lbl = lbl_cls[sample[self.n_shot * self.n_way:]] # n_qry * C * H * W 299 | 300 | sup_img = img[sample[:self.n_shot * self.n_way]][None,] # n_way * (n_shot * C) * H * W 301 | sup_img = np.stack((sup_img, sup_img, sup_img), axis=2) 302 | qry_img = img[sample[self.n_shot * self.n_way:]] # n_qry * C * H * W 303 | qry_img = np.stack((qry_img, qry_img, qry_img), axis=1) 304 | 305 | # Select outside classes and according images&labels 306 | size_outside = 0 307 | while size_outside < self.min_size: 308 | n_slices_outside = (self.n_shot * self.n_way) + self.n_query - 1 309 | while n_slices_outside < ((self.n_shot * self.n_way) + self.n_query): 310 | 311 | unique_outside = list(set(unique) - {cls_idx} - set(exclude_idx)) # remove slected and test classes 312 | cls_outside_idx = random.choice(unique_outside) 313 | 314 | # extract slices containing the outside class 315 | sli_idx_outside = np.sum(lbl == cls_outside_idx, axis=(1, 2)) > 0 316 | idx_outside = np.arange(lbl.shape[0]) 317 | sli_idx_outside = idx_outside[sli_idx_outside] 318 | sli_idx_outside = list(set(sli_idx_outside) - set(np.intersect1d(sli_idx_outside, exclude_idx))) 319 | n_slices_outside = len(sli_idx_outside) 320 | 321 | # generate possible subsets with outside slices (size = self.n_shot * self.n_way + self.n_query) 322 | subsets_outside = [] 323 | for i in range(len(sli_idx_outside)): 324 | if not subsets_outside: 325 | subsets_outside.append([sli_idx_outside[i]]) 326 | elif sli_idx_outside[i - 1] + 1 == sli_idx_outside[i]: 327 | subsets_outside[-1].append(sli_idx_outside[i]) 328 | else: 329 | subsets_outside.append([sli_idx_outside[i]]) 330 | i = 0 331 | while i < len(subsets_outside): 332 | if len(subsets_outside[i]) < (self.n_shot * self.n_way + self.n_query): 333 | del subsets_outside[i] 334 | else: 335 | i += 1 336 | if not len(subsets_outside): 337 | return self.__getitem__(idx + np.random.randint(low=0, high=self.max_iter - 1, size=(1,))) 338 | 339 | # sample outside support and query slices 340 | i_outside = random.choice(np.arange(len(subsets_outside))) 341 | i_outside = random.choice(subsets_outside[i_outside][:-(self.n_shot * self.n_way + self.n_query - 1)]) 342 | sample_outside = np.arange(i_outside, i_outside + (self.n_shot * self.n_way) + self.n_query) 343 | 344 | lbl_cls_outside = 1 * (lbl == cls_outside_idx) 345 | 346 | size_outside = max(np.sum(lbl_cls_outside[sample_outside[0]]), np.sum(lbl_cls_outside[sample_outside[1]])) 347 | 348 | # invert order 349 | if np.random.random(1) > 0.5: 350 | sample_outside = sample_outside[::-1] # successive outside slices (inverted) 351 | 352 | sup_lbl_outside = lbl_cls_outside[sample_outside[:self.n_shot * self.n_way]][None,] 353 | qry_lbl_outside = lbl_cls_outside[sample_outside[self.n_shot * self.n_way:]] 354 | 355 | sup_img_outside = img[sample_outside[:self.n_shot * self.n_way]][None,] 356 | sup_img_outside = np.stack((sup_img_outside, sup_img_outside, sup_img_outside), axis=2) 357 | qry_img_outside = img[sample_outside[self.n_shot * self.n_way:]] 358 | qry_img_outside = np.stack((qry_img_outside, qry_img_outside, qry_img_outside), axis=1) 359 | 360 | # gamma transform 361 | if np.random.random(1) > 0.5: 362 | qry_img = self.gamma_tansform(qry_img) 363 | else: 364 | sup_img = self.gamma_tansform(sup_img) 365 | 366 | # geom transform 367 | if np.random.random(1) > 0.5: 368 | qry_img, qry_lbl = self.geom_transform(qry_img, qry_lbl) 369 | else: 370 | sup_img, sup_lbl = self.geom_transform(sup_img, sup_lbl) 371 | 372 | # gamma transform for outside classes 373 | if np.random.random(1) > 0.5: 374 | qry_img_outside = self.gamma_tansform(qry_img_outside) 375 | else: 376 | sup_img_outside = self.gamma_tansform(sup_img_outside) 377 | 378 | # geom transform for outside classes 379 | if np.random.random(1) > 0.5: 380 | qry_img_outside, qry_lbl_outside = self.geom_transform(qry_img_outside, qry_lbl_outside) 381 | else: 382 | sup_img_outside, sup_lbl_outside = self.geom_transform(sup_img_outside, sup_lbl_outside) 383 | 384 | sample = {'support_images': sup_img, 385 | 'support_fg_labels': sup_lbl, 386 | 'query_images': qry_img, 387 | 'query_labels': qry_lbl, 388 | 'support_images_outside': sup_img_outside, 389 | 'support_outside_labels': sup_lbl_outside, 390 | 'query_images_outside': qry_img_outside, 391 | 'query_outside_labels': qry_lbl_outside, 392 | 'selected_class': cls_idx} 393 | 394 | return sample 395 | -------------------------------------------------------------------------------- /models/fewshot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.parameter import Parameter 8 | from .encoder import Res101Encoder 9 | from .attention import MultiHeadAttention 10 | from .attention import MultiLayerPerceptron 11 | 12 | 13 | class FewShotSeg(nn.Module): 14 | 15 | def __init__(self, pretrained_weights="deeplabv3"): 16 | super().__init__() 17 | 18 | # Encoder 19 | self.encoder = Res101Encoder(replace_stride_with_dilation=[True, True, False], 20 | pretrained_weights=pretrained_weights) # or "resnet101" 21 | self.device = torch.device('cuda') 22 | self.scaler = 20.0 23 | self.criterion = nn.NLLLoss() 24 | self.criterion_MSE = nn.MSELoss() 25 | self.fg_sampler = np.random.RandomState(1289) 26 | self.fg_num = 100 # number of foreground partitions 27 | self.MHA = MultiHeadAttention(n_head=3, d_model=512, d_k=512, d_v=512) 28 | self.MLP = MultiLayerPerceptron(dim=512, mlp_dim=1024) 29 | self.layer_norm = nn.LayerNorm(512) 30 | 31 | def forward(self, supp_imgs, supp_mask, qry_imgs, qry_mask, train=False, t_loss_scaler=1, n_iters=20): 32 | """ 33 | Args: 34 | supp_imgs: support images 35 | way x shot x [B x 3 x H x W], list of lists of tensors 36 | fore_mask: foreground masks for support images 37 | way x shot x [B x H x W], list of lists of tensors 38 | back_mask: background masks for support images 39 | way x shot x [B x H x W], list of lists of tensors 40 | qry_imgs: query images 41 | N x [B x 3 x H x W], list of tensors 42 | """ 43 | 44 | self.n_ways = len(supp_imgs) 45 | self.n_shots = len(supp_imgs[0]) 46 | self.n_queries = len(qry_imgs) 47 | self.iter = 3 48 | assert self.n_ways == 1 # for now only one-way, because not every shot has multiple sub-images 49 | assert self.n_queries == 1 50 | 51 | qry_bs = qry_imgs[0].shape[0] 52 | supp_bs = supp_imgs[0][0].shape[0] 53 | img_size = supp_imgs[0][0].shape[-2:] 54 | supp_mask = torch.stack([torch.stack(way, dim=0) for way in supp_mask], 55 | dim=0).view(supp_bs, self.n_ways, self.n_shots, *img_size) # B x Wa x Sh x H x W 56 | 57 | # Dilate the mask 58 | kernel = np.ones((3, 3), np.uint8) 59 | supp_mask_ = supp_mask.cpu().numpy()[0][0][0] 60 | supp_dilated_mask = cv2.dilate(supp_mask_, kernel, iterations=1) # (256, 256) 61 | supp_periphery_mask = supp_dilated_mask - supp_mask_ 62 | supp_periphery_mask = np.reshape(supp_periphery_mask, (supp_bs, self.n_ways, self.n_shots, 63 | np.shape(supp_periphery_mask)[0], 64 | np.shape(supp_periphery_mask)[1])) 65 | supp_dilated_mask = np.reshape(supp_dilated_mask, (supp_bs, self.n_ways, self.n_shots, 66 | np.shape(supp_dilated_mask)[0], 67 | np.shape(supp_dilated_mask)[1])) 68 | supp_periphery_mask = torch.tensor(supp_periphery_mask).cuda() # (1, 1, 1, 256, 256) B x Wa x Sh x H x W 69 | supp_dilated_mask = torch.tensor(supp_dilated_mask).cuda() # (1, 1, 1, 256, 256) B x Wa x Sh x H x W 70 | 71 | # Extract features # 72 | imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs] 73 | + [torch.cat(qry_imgs, dim=0), ], dim=0) 74 | img_fts, tao = self.encoder(imgs_concat) 75 | 76 | supp_fts = img_fts[:self.n_ways * self.n_shots * supp_bs].view( # B x Wa x Sh x C x H' x W' 77 | supp_bs, self.n_ways, self.n_shots, -1, *img_fts.shape[-2:]) 78 | 79 | qry_fts = img_fts[self.n_ways * self.n_shots * supp_bs:].view( # B x N x C x H' x W' 80 | qry_bs, self.n_queries, -1, *img_fts.shape[-2:]) 81 | 82 | # Get threshold # 83 | self.t = tao[self.n_ways * self.n_shots * supp_bs:] # t for query features 84 | self.thresh_pred = [self.t for _ in range(self.n_ways)] 85 | 86 | self.t_ = tao[:self.n_ways * self.n_shots * supp_bs] # t for support features 87 | self.thresh_pred_ = [self.t_ for _ in range(self.n_ways)] 88 | 89 | # Compute loss # 90 | periphery_loss = torch.zeros(1).to(self.device) 91 | align_loss = torch.zeros(1).to(self.device) 92 | mse_loss = torch.zeros(1).to(self.device) 93 | qry_loss = torch.zeros(1).to(self.device) 94 | outputs = [] 95 | for epi in range(supp_bs): 96 | # Partition the foreground object into N parts, the coarse support prototypes 97 | fg_partition_prototypes = [[self.compute_multiple_prototypes( 98 | self.fg_num, supp_fts[[epi], way, shot], supp_mask[[epi], way, shot], self.fg_sampler) 99 | for shot in range(self.n_shots)] for way in range(self.n_ways)] 100 | 101 | # calculate coarse query prototype 102 | supp_fts_ = [[self.getFeatures(supp_fts[[epi], way, shot], supp_mask[[epi], way, shot]) 103 | for shot in range(self.n_shots)] for way in range(self.n_ways)] 104 | 105 | fg_prototypes = self.getPrototype(supp_fts_) # the coarse foreground 106 | 107 | # Dilated region prototypes # 108 | supp_fts_dilated = [[self.getFeatures(supp_fts[[epi], way, shot], supp_dilated_mask[[epi], way, shot]) 109 | for shot in range(self.n_shots)] for way in range(self.n_ways)] 110 | fg_prototypes_dilated = self.getPrototype(supp_fts_dilated) 111 | 112 | # Segment periphery region with support images 113 | supp_pred_object = torch.stack([self.getPred(supp_fts[epi][way], fg_prototypes[way], self.thresh_pred_[way]) 114 | for way in range(self.n_ways)], dim=1) # N x Wa x H' x W' 115 | supp_pred_object = F.interpolate(supp_pred_object, size=img_size, mode='bilinear', align_corners=True) 116 | # supp_pred_object: (1, 1, 256, 256) 117 | 118 | supp_pred_dilated = torch.stack([self.getPred(supp_fts[epi][way], fg_prototypes_dilated[way], self.thresh_pred_[way]) 119 | for way in range(self.n_ways)], dim=1) # N x Wa x H' x W' 120 | supp_pred_dilated = F.interpolate(supp_pred_dilated, size=img_size, mode='bilinear', align_corners=True) 121 | # supp_pred_dilated: (1, 1, 256, 256) 122 | 123 | # Prediction of periphery region 124 | pred_periphery = supp_pred_dilated - supp_pred_object 125 | pred_periphery = torch.cat((1.0 - pred_periphery, pred_periphery), dim=1) 126 | # pred_periphery: (1, 2, 256, 256) B x C x H x W 127 | label_periphery = torch.full_like(supp_periphery_mask[epi][0][0], 255, device=supp_periphery_mask.device) 128 | label_periphery[supp_periphery_mask[epi][0][0] == 1] = 1 129 | label_periphery[supp_periphery_mask[epi][0][0] == 0] = 0 130 | # label_periphery: (256, 256) H x W 131 | 132 | # Compute periphery loss 133 | eps_ = torch.finfo(torch.float32).eps 134 | log_prob_ = torch.log(torch.clamp(pred_periphery, eps_, 1 - eps_)) 135 | periphery_loss += self.criterion(log_prob_, label_periphery[None, ...].long()) / self.n_shots / self.n_ways 136 | 137 | qry_pred = torch.stack( 138 | [self.getPred(qry_fts[epi], fg_prototypes[way], self.thresh_pred[way]) 139 | for way in range(self.n_ways)], dim=1) # N x Wa x H' x W' 140 | 141 | qry_prototype_coarse = self.getFeatures(qry_fts[epi], qry_pred[epi]) 142 | 143 | # # The first BATE block 144 | for i in range(self.iter): 145 | fg_partition_prototypes = [[self.BATE(fg_partition_prototypes[way][shot][epi], qry_prototype_coarse) 146 | for shot in range(self.n_shots)] for way in range(self.n_ways)] 147 | 148 | supp_proto = [[torch.mean(fg_partition_prototypes[way][shot], dim=1) + fg_prototypes[way] for shot in range(self.n_shots)] 149 | for way in range(self.n_ways)] 150 | 151 | # CQPC module 152 | qry_pred_coarse = torch.stack( 153 | [self.getPred(qry_fts[epi], supp_proto[way][epi], self.thresh_pred[way]) 154 | for way in range(self.n_ways)], dim=1) 155 | 156 | qry_prototype_coarse = self.getFeatures(qry_fts[epi], qry_pred_coarse[epi]) 157 | 158 | # Get query predictions # 159 | 160 | qry_pred = torch.stack( 161 | [self.getPred(qry_fts[epi], supp_proto[way][epi], self.thresh_pred[way]) 162 | for way in range(self.n_ways)], dim=1) # N x Wa x H' x W' 163 | 164 | # Combine predictions of different feature maps # 165 | qry_pred_up = F.interpolate(qry_pred, size=img_size, mode='bilinear', align_corners=True) 166 | 167 | preds = torch.cat((1.0 - qry_pred_up, qry_pred_up), dim=1) 168 | 169 | outputs.append(preds) 170 | 171 | if train: 172 | align_loss_epi = self.alignLoss(supp_fts[epi], qry_fts[epi], preds, supp_mask[epi]) 173 | align_loss += align_loss_epi 174 | if train: 175 | proto_mse_loss_epi = self.proto_mse(qry_fts[epi], preds, supp_mask[epi], fg_prototypes) 176 | mse_loss += proto_mse_loss_epi 177 | if train: 178 | qry_fts_ = [[self.getFeatures(qry_fts[epi], qry_mask)]] 179 | qry_prototypes = self.getPrototype(qry_fts_) 180 | qry_pred = self.getPred(qry_fts[epi], qry_prototypes[epi], self.thresh_pred[epi]) 181 | 182 | qry_pred = F.interpolate(qry_pred[None, ...], size=img_size, mode='bilinear', align_corners=True) 183 | preds = torch.cat((1.0 - qry_pred, qry_pred), dim=1) 184 | 185 | qry_label = torch.full_like(qry_mask[epi], 255, device=qry_mask.device) 186 | qry_label[qry_mask[epi] == 1] = 1 187 | qry_label[qry_mask[epi] == 0] = 0 188 | 189 | # Compute Loss 190 | eps = torch.finfo(torch.float32).eps 191 | log_prob = torch.log(torch.clamp(preds, eps, 1 - eps)) 192 | qry_loss += self.criterion(log_prob, qry_label[None, ...].long()) / self.n_shots / self.n_ways 193 | 194 | output = torch.stack(outputs, dim=1) 195 | output = output.view(-1, *output.shape[2:]) 196 | 197 | return output, periphery_loss / supp_bs, align_loss / supp_bs, mse_loss / supp_bs, qry_loss / supp_bs 198 | 199 | def getPred(self, fts, prototype, thresh): 200 | """ 201 | Calculate the distance between features and prototypes 202 | 203 | Args: 204 | fts: input features 205 | expect shape: N x C x H x W 206 | prototype: prototype of one semantic class 207 | expect shape: 1 x C 208 | """ 209 | 210 | sim = -F.cosine_similarity(fts, prototype[..., None, None], dim=1) * self.scaler 211 | pred = 1.0 - torch.sigmoid(0.5 * (sim - thresh)) 212 | 213 | return pred 214 | 215 | def getFeatures(self, fts, mask): 216 | """ 217 | Extract foreground and background features via masked average pooling 218 | 219 | Args: 220 | fts: input features, expect shape: 1 x C x H' x W' 221 | mask: binary mask, expect shape: 1 x H x W 222 | """ 223 | 224 | fts = F.interpolate(fts, size=mask.shape[-2:], mode='bilinear') 225 | 226 | # masked fg features 227 | masked_fts = torch.sum(fts * mask[None, ...], dim=(-2, -1)) \ 228 | / (mask[None, ...].sum(dim=(-2, -1)) + 1e-5) # 1 x C 229 | 230 | return masked_fts 231 | 232 | def getFeatures_fg(self, fts, mask): 233 | """ 234 | Args: 235 | fts: input features, expect shape: 1 x C x H' x W' 236 | mask: binary mask, expect shape: 1 x H x W 237 | """ 238 | 239 | fts_ = fts.squeeze(0).permute(1, 2, 0) 240 | 241 | fts_ = fts_.view(fts_.size()[0] * fts_.size()[1], fts_.size()[2]) 242 | mask_ = F.interpolate(mask.unsqueeze(0), size=fts.shape[-2:], mode='bilinear') 243 | mask_ = mask_.view(-1) 244 | 245 | l = math.ceil(mask_.sum()) 246 | c = torch.argsort(mask_, descending=True, dim=0) 247 | fg = c[:l] 248 | 249 | fts_fg = fts_[fg] 250 | 251 | return fts_fg 252 | 253 | def getPrototype(self, fg_fts): 254 | """ 255 | Average the features to obtain the prototype 256 | 257 | Args: 258 | fg_fts: lists of list of foreground features for each way/shot 259 | expect shape: Wa x Sh x [1 x C] 260 | bg_fts: lists of list of background features for each way/shot 261 | expect shape: Wa x Sh x [1 x C] 262 | """ 263 | 264 | n_ways, n_shots = len(fg_fts), len(fg_fts[0]) 265 | fg_prototypes = [torch.sum(torch.cat([tr for tr in way], dim=0), dim=0, keepdim=True) / n_shots for way in 266 | fg_fts] ## concat all fg_fts 267 | 268 | return fg_prototypes 269 | 270 | def compute_multiple_prototypes(self, fg_num, sup_fts, sup_fg, sampler): 271 | """ 272 | 273 | Parameters 274 | ---------- 275 | fg_num: int 276 | Foreground partition numbers 277 | sup_fts: torch.Tensor 278 | [B, C, h, w], float32 279 | sup_fg: torch. Tensor 280 | [B, h, w], float32 (0,1) 281 | sampler: np.random.RandomState 282 | 283 | Returns 284 | ------- 285 | fg_proto: torch.Tensor 286 | [B, k, C], where k is the number of foreground proxies 287 | 288 | """ 289 | 290 | B, C, h, w = sup_fts.shape # B=1, C=512 291 | fg_mask = F.interpolate(sup_fg.unsqueeze(0), size=sup_fts.shape[-2:], mode='bilinear') 292 | fg_mask = fg_mask.squeeze(0).bool() # [B, h, w] --> bool 293 | batch_fg_protos = [] 294 | 295 | for b in range(B): 296 | fg_protos = [] 297 | 298 | fg_mask_i = fg_mask[b] # [h, w] 299 | 300 | # Check if zero 301 | with torch.no_grad(): 302 | if fg_mask_i.sum() < fg_num: 303 | fg_mask_i = fg_mask[b].clone() # don't change original mask 304 | fg_mask_i.view(-1)[:fg_num] = True 305 | 306 | # Iteratively select farthest points as centers of foreground local regions 307 | all_centers = [] 308 | first = True 309 | pts = torch.stack(torch.where(fg_mask_i), dim=1) 310 | for _ in range(fg_num): 311 | if first: 312 | i = sampler.choice(pts.shape[0]) 313 | first = False 314 | else: 315 | dist = pts.reshape(-1, 1, 2) - torch.stack(all_centers, dim=0).reshape(1, -1, 2) 316 | # choose the farthest point 317 | i = torch.argmax((dist ** 2).sum(-1).min(1)[0]) 318 | pt = pts[i] # center y, x 319 | all_centers.append(pt) 320 | 321 | # Assign fg labels for fg pixels 322 | dist = pts.reshape(-1, 1, 2) - torch.stack(all_centers, dim=0).reshape(1, -1, 2) 323 | fg_labels = torch.argmin((dist ** 2).sum(-1), dim=1) 324 | 325 | # Compute fg prototypes 326 | fg_feats = sup_fts[b].permute(1, 2, 0)[fg_mask_i] # [N, C] 327 | for i in range(fg_num): 328 | proto = fg_feats[fg_labels == i].mean(0) # [C] 329 | fg_protos.append(proto) 330 | 331 | fg_protos = torch.stack(fg_protos, dim=1) # [C, k] 332 | batch_fg_protos.append(fg_protos) 333 | fg_proto = torch.stack(batch_fg_protos, dim=0).transpose(1, 2) # [B, k, C] 334 | 335 | return fg_proto 336 | 337 | def BATE(self, fg_prototypes, qry_prototype_coarse): 338 | 339 | # S&W module 340 | A = torch.mm(fg_prototypes, qry_prototype_coarse.t()) 341 | kc = ((A.min() + A.mean()) / 2).floor() 342 | 343 | if A is not None: 344 | S = torch.zeros(A.size(), dtype=torch.float).cuda() 345 | S[A < kc] = -10000.0 346 | 347 | A = torch.softmax((A + S), dim=0) 348 | # fg_prototypes = A * fg_prototypes 349 | A = torch.mm(A, qry_prototype_coarse) 350 | A = self.layer_norm(A + fg_prototypes) 351 | 352 | # rest Transformer operation 353 | T = self.MHA(A.unsqueeze(0), A.unsqueeze(0), A.unsqueeze(0)) 354 | T = self.MLP(T) 355 | 356 | 357 | return T 358 | 359 | def alignLoss(self, supp_fts, qry_fts, pred, fore_mask): 360 | n_ways, n_shots = len(fore_mask), len(fore_mask[0]) 361 | 362 | # Get query mask 363 | pred_mask = pred.argmax(dim=1, keepdim=True).squeeze(1) # N x H' x W' 364 | binary_masks = [pred_mask == i for i in range(1 + n_ways)] 365 | skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0] 366 | pred_mask = torch.stack(binary_masks, dim=0).float() # (1 + Wa) x N x H' x W' 367 | 368 | # Compute the support loss 369 | loss = torch.zeros(1).to(self.device) 370 | for way in range(n_ways): 371 | if way in skip_ways: 372 | continue 373 | # Get the query prototypes 374 | for shot in range(n_shots): 375 | # Get prototypes 376 | qry_fts_ = [self.getFeatures(qry_fts, pred_mask[way + 1])] 377 | fg_prototypes = self.getPrototype([qry_fts_]) 378 | 379 | # Get predictions 380 | supp_pred = self.getPred(supp_fts[way, [shot]], fg_prototypes[way], self.thresh_pred[way]) # N x Wa x H' x W' 381 | supp_pred = F.interpolate(supp_pred[None, ...], size=fore_mask.shape[-2:], mode='bilinear', 382 | align_corners=True) 383 | 384 | 385 | # Combine predictions of different feature maps 386 | preds = supp_pred 387 | pred_ups = torch.cat((1.0 - preds, preds), dim=1) 388 | 389 | # Construct the support Ground-Truth segmentation 390 | supp_label = torch.full_like(fore_mask[way, shot], 255, device=fore_mask.device) 391 | supp_label[fore_mask[way, shot] == 1] = 1 392 | supp_label[fore_mask[way, shot] == 0] = 0 393 | 394 | # Compute Loss 395 | eps = torch.finfo(torch.float32).eps 396 | log_prob = torch.log(torch.clamp(pred_ups, eps, 1 - eps)) 397 | loss += self.criterion(log_prob, supp_label[None, ...].long()) / n_shots / n_ways 398 | 399 | return loss 400 | 401 | def proto_mse(self, qry_fts, pred, fore_mask, supp_prototypes): 402 | n_ways, n_shots = len(fore_mask), len(fore_mask[0]) 403 | 404 | pred_mask = pred.argmax(dim=1, keepdim=True).squeeze(1) 405 | binary_masks = [pred_mask == i for i in range(1 + n_ways)] 406 | skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0] 407 | pred_mask = torch.stack(binary_masks, dim=0).float() # (1 + Wa) x N x H' x W' 408 | 409 | # Compute the support loss 410 | loss_sim = torch.zeros(1).to(self.device) 411 | for way in range(n_ways): 412 | if way in skip_ways: 413 | continue 414 | # Get the query prototypes 415 | for shot in range(n_shots): 416 | # Get prototypes 417 | qry_fts_ = [[self.getFeatures(qry_fts, pred_mask[way + 1])]] 418 | 419 | fg_prototypes = self.getPrototype(qry_fts_) 420 | 421 | fg_prototypes_ = torch.sum(torch.stack(fg_prototypes, dim=0), dim=0) 422 | supp_prototypes_ = torch.sum(torch.stack(supp_prototypes, dim=0), dim=0) 423 | 424 | # Combine prototypes from different scales 425 | # fg_prototypes = self.alpha * fg_prototypes[way] 426 | # fg_prototypes = torch.sum(torch.stack(fg_prototypes, dim=0), dim=0) / torch.sum(self.alpha) 427 | # supp_prototypes_ = [self.alpha[n] * supp_prototypes[n][way] for n in range(len(supp_fts))] 428 | # supp_prototypes_ = torch.sum(torch.stack(supp_prototypes_, dim=0), dim=0) / torch.sum(self.alpha) 429 | 430 | # Compute the MSE loss 431 | 432 | loss_sim += self.criterion_MSE(fg_prototypes_, supp_prototypes_) 433 | 434 | return loss_sim 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | --------------------------------------------------------------------------------