├── models ├── __init__.py ├── backbone │ ├── __init__.py │ └── torchvision_backbones.py ├── alpmodule.py └── grid_proto_fewshot.py ├── dataloaders ├── __init__.py ├── niftiio.py ├── dataset_utils.py ├── augutils.py ├── common.py ├── dev_customized_med.py ├── image_transforms.py ├── GenericSuperDatasetv2.py └── ManualAnnoDatasetv2.py ├── util ├── __init__.py ├── utils.py └── metric.py ├── intro.png ├── pigeon.jpg ├── requirements.txt ├── data ├── CHAOST2 │ ├── dcm_img_to_nii.sh │ ├── png_gth_to_nii.ipynb │ ├── class_slice_index_gen.ipynb │ └── image_normalize.ipynb ├── SABS │ ├── Synapse_abdominal_classmap.ipynb │ └── intensity_normalization.ipynb └── pseudolabel_gen.ipynb ├── LICENSE ├── examples ├── train_ssl_abdominal_ct.sh ├── train_ssl_abdominal_mri.sh ├── test_ssl_abdominal_ct.sh └── test_ssl_abdominal_mri.sh ├── .gitignore ├── README.md ├── config_ssl_upload.py ├── training.py └── validation.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | from utils import * -------------------------------------------------------------------------------- /intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation/HEAD/intro.png -------------------------------------------------------------------------------- /pigeon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation/HEAD/pigeon.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dcm2nii 2 | json5==0.8.5 3 | jupyter==1.0.0 4 | nibabel==2.5.1 5 | numpy==1.22.0 6 | opencv-python==4.2.0.32 7 | Pillow>=8.1.1 8 | sacred==0.7.5 9 | scikit-image==0.14.0 10 | SimpleITK==1.2.3 11 | torch==1.3.0 12 | torchvision==0.4.1 13 | tqdm==4.32.2 14 | -------------------------------------------------------------------------------- /data/CHAOST2/dcm_img_to_nii.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | # Convert dicom-like images to nii files in 3D 3 | # This is the first step for image pre-processing 4 | 5 | # Feed path to the downloaded data here 6 | DATAPATH=./MR # please put chaos dataset training fold here which contains ground truth 7 | 8 | # Feed path to the output folder here 9 | OUTPATH=./niis 10 | 11 | if [ ! -d $OUTPATH/T2SPIR ] 12 | then 13 | mkdir $OUTPATH/T2SPIR 14 | fi 15 | 16 | for sid in $(ls "$DATAPATH") 17 | do 18 | dcm2nii -o "$DATAPATH/$sid/T2SPIR" "$DATAPATH/$sid/T2SPIR/DICOM_anon"; 19 | find "$DATAPATH/$sid/T2SPIR" -name "*.nii.gz" -exec mv {} "$OUTPATH/T2SPIR/image_$sid.nii.gz" \; 20 | done; 21 | 22 | 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Cheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataloaders/niftiio.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for datasets 3 | """ 4 | import numpy as np 5 | 6 | import numpy as np 7 | import SimpleITK as sitk 8 | 9 | 10 | def read_nii_bysitk(input_fid, peel_info = False): 11 | """ read nii to numpy through simpleitk 12 | peelinfo: taking direction, origin, spacing and metadata out 13 | """ 14 | img_obj = sitk.ReadImage(input_fid) 15 | img_np = sitk.GetArrayFromImage(img_obj) 16 | if peel_info: 17 | info_obj = { 18 | "spacing": img_obj.GetSpacing(), 19 | "origin": img_obj.GetOrigin(), 20 | "direction": img_obj.GetDirection(), 21 | "array_size": img_np.shape 22 | } 23 | return img_np, info_obj 24 | else: 25 | return img_np 26 | 27 | def convert_to_sitk(input_mat, peeled_info): 28 | """ 29 | write a numpy array to sitk image object with essential meta-data 30 | """ 31 | nii_obj = sitk.GetImageFromArray(input_mat) 32 | if peeled_info: 33 | nii_obj.SetSpacing( peeled_info["spacing"] ) 34 | nii_obj.SetOrigin( peeled_info["origin"] ) 35 | nii_obj.SetDirection(peeled_info["direction"] ) 36 | return nii_obj 37 | 38 | def np2itk(img, ref_obj): 39 | """ 40 | img: numpy array 41 | ref_obj: reference sitk object for copying information from 42 | """ 43 | itk_obj = sitk.GetImageFromArray(img) 44 | itk_obj.SetSpacing( ref_obj.GetSpacing() ) 45 | itk_obj.SetOrigin( ref_obj.GetOrigin() ) 46 | itk_obj.SetDirection( ref_obj.GetDirection() ) 47 | return itk_obj 48 | 49 | -------------------------------------------------------------------------------- /models/backbone/torchvision_backbones.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backbones supported by torchvison. 3 | """ 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import torchvision 11 | 12 | class TVDeeplabRes101Encoder(nn.Module): 13 | """ 14 | FCN-Resnet101 backbone from torchvision deeplabv3 15 | No ASPP is used as we found emperically it hurts performance 16 | """ 17 | def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False): 18 | super().__init__() 19 | _model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None) 20 | if use_coco_init: 21 | print("###### NETWORK: Using ms-coco initialization ######") 22 | else: 23 | print("###### NETWORK: Training from scratch ######") 24 | 25 | _model_list = list(_model.children()) 26 | self.aux_dim_keep = aux_dim_keep 27 | self.backbone = _model_list[0] 28 | self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension 29 | self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False) 30 | 31 | _aspp = _model_list[1][0] 32 | _conv256 = _model_list[1][1] 33 | self.aspp_out = nn.Sequential(*[_aspp, _conv256] ) 34 | self.use_aspp = use_aspp 35 | 36 | def forward(self, x_in, low_level): 37 | """ 38 | Args: 39 | low_level: whether returning aggregated low-level features in FCN 40 | """ 41 | fts = self.backbone(x_in) 42 | if self.use_aspp: 43 | fts256 = self.aspp_out(fts['out']) 44 | high_level_fts = fts256 45 | else: 46 | fts2048 = fts['out'] 47 | high_level_fts = self.localconv(fts2048) 48 | 49 | if low_level: 50 | low_level_fts = fts['aux'][:, : self.aux_dim_keep] 51 | return high_level_fts, low_level_fts 52 | else: 53 | return high_level_fts 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /examples/train_ssl_abdominal_ct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal CT 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ####### Shared configs ###### 7 | PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training 8 | CPT="myexp" 9 | DATASET='SABS_Superpix' 10 | NWORKER=4 11 | 12 | ALL_EV=( 0) # 5-fold cross validation (0, 1, 2, 3, 4) 13 | ALL_SCALE=( "MIDDLE") # config of pseudolabels 14 | 15 | ### Use L/R kidney as testing classes 16 | LABEL_SETS=0 17 | EXCLU='[2,3]' # setting 2: excluding kidneies in training set to test generalization capability even though they are unlabeled. Use [] for setting 1 by Roy et al. 18 | 19 | ### Use Liver and spleen as testing classes 20 | # LABEL_SETS=1 21 | # EXCLU='[1,6]' 22 | 23 | ###### Training configs ###### 24 | NSTEP=100100 25 | DECAY=0.95 26 | 27 | MAX_ITER=1000 # defines the size of an epoch 28 | SNAPSHOT_INTERVAL=25000 # interval for saving snapshot 29 | SEED='1234' 30 | 31 | ###### Validation configs ###### 32 | SUPP_ID='[6]' # using the additionally loaded scan as support 33 | 34 | echo =================================== 35 | 36 | for EVAL_FOLD in "${ALL_EV[@]}" 37 | do 38 | for SUPERPIX_SCALE in "${ALL_SCALE[@]}" 39 | do 40 | PREFIX="train_${DATASET}_lbgroup${LABEL_SETS}_scale_${SUPERPIX_SCALE}_vfold${EVAL_FOLD}" 41 | echo $PREFIX 42 | LOGDIR="./exps/${CPT}_${SUPERPIX_SCALE}_${LABEL_SETS}" 43 | 44 | if [ ! -d $LOGDIR ] 45 | then 46 | mkdir $LOGDIR 47 | fi 48 | 49 | python3 training.py with \ 50 | 'modelname=dlfcn_res101' \ 51 | 'usealign=True' \ 52 | 'optim_type=sgd' \ 53 | num_workers=$NWORKER \ 54 | scan_per_load=-1 \ 55 | label_sets=$LABEL_SETS \ 56 | 'use_wce=True' \ 57 | exp_prefix=$PREFIX \ 58 | 'clsname=grid_proto' \ 59 | n_steps=$NSTEP \ 60 | exclude_cls_list=$EXCLU \ 61 | eval_fold=$EVAL_FOLD \ 62 | dataset=$DATASET \ 63 | proto_grid_size=$PROTO_GRID \ 64 | max_iters_per_load=$MAX_ITER \ 65 | min_fg_data=1 seed=$SEED \ 66 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 67 | superpix_scale=$SUPERPIX_SCALE \ 68 | lr_step_gamma=$DECAY \ 69 | path.log_dir=$LOGDIR \ 70 | support_idx=$SUPP_ID 71 | done 72 | done 73 | -------------------------------------------------------------------------------- /examples/train_ssl_abdominal_mri.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal MRI (T2 fold of CHAOS challenge) 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ####### Shared configs 7 | PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training 8 | CPT="myexperiments" 9 | DATASET='CHAOST2_Superpix' 10 | NWORKER=4 11 | 12 | ALL_EV=( 0) # 5-fold cross validation (0, 1, 2, 3, 4) 13 | ALL_SCALE=( "MIDDLE") # config of pseudolabels 14 | 15 | ### Use L/R kidney as testing classes 16 | LABEL_SETS=0 17 | EXCLU='[2,3]' # setting 2: excluding kidneies in training set to test generalization capability even though they are unlabeled. Use [] for setting 1 by Roy et al. 18 | 19 | ### Use Liver and spleen as testing classes 20 | # LABEL_SETS=1 21 | # EXCLU='[1,4]' 22 | 23 | ###### Training configs ###### 24 | NSTEP=100100 25 | DECAY=0.95 26 | 27 | MAX_ITER=1000 # defines the size of an epoch 28 | SNAPSHOT_INTERVAL=25000 # interval for saving snapshot 29 | SEED='1234' 30 | 31 | ###### Validation configs ###### 32 | SUPP_ID='[4]' # # using the additionally loaded scan as support 33 | 34 | echo =================================== 35 | 36 | for EVAL_FOLD in "${ALL_EV[@]}" 37 | do 38 | for SUPERPIX_SCALE in "${ALL_SCALE[@]}" 39 | do 40 | PREFIX="train_${DATASET}_lbgroup${LABEL_SETS}_scale_${SUPERPIX_SCALE}_vfold${EVAL_FOLD}" 41 | echo $PREFIX 42 | LOGDIR="./exps/${CPT}_${SUPERPIX_SCALE}_${LABEL_SETS}" 43 | 44 | if [ ! -d $LOGDIR ] 45 | then 46 | mkdir $LOGDIR 47 | fi 48 | 49 | python3 training.py with \ 50 | 'modelname=dlfcn_res101' \ 51 | 'usealign=True' \ 52 | 'optim_type=sgd' \ 53 | num_workers=$NWORKER \ 54 | scan_per_load=-1 \ 55 | label_sets=$LABEL_SETS \ 56 | 'use_wce=True' \ 57 | exp_prefix=$PREFIX \ 58 | 'clsname=grid_proto' \ 59 | n_steps=$NSTEP \ 60 | exclude_cls_list=$EXCLU \ 61 | eval_fold=$EVAL_FOLD \ 62 | dataset=$DATASET \ 63 | proto_grid_size=$PROTO_GRID \ 64 | max_iters_per_load=$MAX_ITER \ 65 | min_fg_data=1 seed=$SEED \ 66 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 67 | superpix_scale=$SUPERPIX_SCALE \ 68 | lr_step_gamma=$DECAY \ 69 | path.log_dir=$LOGDIR \ 70 | support_idx=$SUPP_ID 71 | done 72 | done 73 | -------------------------------------------------------------------------------- /examples/test_ssl_abdominal_ct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal CT 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ####### Shared configs ###### 7 | PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training 8 | CPT="myexp" 9 | DATASET='SABS_Superpix' 10 | NWORKER=4 11 | 12 | ALL_EV=( 0) # 5-fold cross validation (0, 1, 2, 3, 4) 13 | ALL_SCALE=( "MIDDLE") # config of pseudolabels 14 | 15 | ### Use L/R kidney as testing classes 16 | LABEL_SETS=0 17 | EXCLU='[2,3]' # setting 2: excluding kidneies in training set to test generalization capability even though they are unlabeled. Use [] for setting 1 by Roy et al. 18 | 19 | ### Use Liver and spleen as testing classes 20 | # LABEL_SETS=1 21 | # EXCLU='[1,6]' 22 | 23 | ###### Training configs (irrelavent in testing) ###### 24 | NSTEP=100100 25 | DECAY=0.95 26 | 27 | MAX_ITER=1000 # defines the size of an epoch 28 | SNAPSHOT_INTERVAL=25000 # interval for saving snapshot 29 | SEED='1234' 30 | 31 | ###### Validation configs ###### 32 | SUPP_ID='[6]' # using the additionally loaded scan as support 33 | 34 | echo =================================== 35 | 36 | for EVAL_FOLD in "${ALL_EV[@]}" 37 | do 38 | for SUPERPIX_SCALE in "${ALL_SCALE[@]}" 39 | do 40 | PREFIX="test_vfold${EVAL_FOLD}" 41 | echo $PREFIX 42 | LOGDIR="./exps/${CPT}" 43 | 44 | if [ ! -d $LOGDIR ] 45 | then 46 | mkdir $LOGDIR 47 | fi 48 | 49 | RELOAD_PATH='please feed the path to the trained weights here' # path to the reloaded model 50 | 51 | python3 validation.py with \ 52 | 'modelname=dlfcn_res101' \ 53 | 'usealign=True' \ 54 | 'optim_type=sgd' \ 55 | reload_model_path=$RELOAD_PATH \ 56 | num_workers=$NWORKER \ 57 | scan_per_load=-1 \ 58 | label_sets=$LABEL_SETS \ 59 | 'use_wce=True' \ 60 | exp_prefix=$PREFIX \ 61 | 'clsname=grid_proto' \ 62 | n_steps=$NSTEP \ 63 | exclude_cls_list=$EXCLU \ 64 | eval_fold=$EVAL_FOLD \ 65 | dataset=$DATASET \ 66 | proto_grid_size=$PROTO_GRID \ 67 | max_iters_per_load=$MAX_ITER \ 68 | min_fg_data=1 seed=$SEED \ 69 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 70 | superpix_scale=$SUPERPIX_SCALE \ 71 | lr_step_gamma=$DECAY \ 72 | path.log_dir=$LOGDIR \ 73 | support_idx=$SUPP_ID 74 | done 75 | done 76 | -------------------------------------------------------------------------------- /examples/test_ssl_abdominal_mri.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal MRI 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ####### Shared configs ###### 7 | PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training 8 | CPT="myexp" 9 | DATASET='CHAOST2_Superpix' 10 | NWORKER=4 11 | 12 | ALL_EV=( 0) # 5-fold cross validation (0, 1, 2, 3, 4) 13 | ALL_SCALE=( "MIDDLE") # config of pseudolabels 14 | 15 | ### Use L/R kidney as testing classes 16 | LABEL_SETS=0 17 | EXCLU='[2,3]' # setting 2: excluding kidneies in training set to test generalization capability even though they are unlabeled. Use [] for setting 1 by Roy et al. 18 | 19 | ### Use Liver and spleen as testing classes 20 | # LABEL_SETS=1 21 | # EXCLU='[1,4]' 22 | 23 | ###### Training configs (irrelavent in testing) ###### 24 | NSTEP=100100 25 | DECAY=0.95 26 | 27 | MAX_ITER=1000 # defines the size of an epoch 28 | SNAPSHOT_INTERVAL=25000 # interval for saving snapshot 29 | SEED='1234' 30 | 31 | ###### Validation configs ###### 32 | SUPP_ID='[4]' # using the additionally loaded scan as support 33 | 34 | echo =================================== 35 | 36 | for EVAL_FOLD in "${ALL_EV[@]}" 37 | do 38 | for SUPERPIX_SCALE in "${ALL_SCALE[@]}" 39 | do 40 | PREFIX="test_vfold${EVAL_FOLD}" 41 | echo $PREFIX 42 | LOGDIR="./exps/${CPT}" 43 | 44 | if [ ! -d $LOGDIR ] 45 | then 46 | mkdir $LOGDIR 47 | fi 48 | 49 | RELOAD_PATH='please feed the path to the trained weights here' # path to the reloaded model 50 | 51 | python3 validation.py with \ 52 | 'modelname=dlfcn_res101' \ 53 | 'usealign=True' \ 54 | 'optim_type=sgd' \ 55 | reload_model_path=$RELOAD_PATH \ 56 | num_workers=$NWORKER \ 57 | scan_per_load=-1 \ 58 | label_sets=$LABEL_SETS \ 59 | 'use_wce=True' \ 60 | exp_prefix=$PREFIX \ 61 | 'clsname=grid_proto' \ 62 | n_steps=$NSTEP \ 63 | exclude_cls_list=$EXCLU \ 64 | eval_fold=$EVAL_FOLD \ 65 | dataset=$DATASET \ 66 | proto_grid_size=$PROTO_GRID \ 67 | max_iters_per_load=$MAX_ITER \ 68 | min_fg_data=1 seed=$SEED \ 69 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 70 | superpix_scale=$SUPERPIX_SCALE \ 71 | lr_step_gamma=$DECAY \ 72 | path.log_dir=$LOGDIR \ 73 | support_idx=$SUPP_ID 74 | done 75 | done 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | """Util functions 2 | Extended from original PANet code 3 | TODO: move part of dataset configurations to data_utils 4 | """ 5 | import random 6 | import torch 7 | import numpy as np 8 | import operator 9 | 10 | def set_seed(seed): 11 | """ 12 | Set the random seed 13 | """ 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | CLASS_LABELS = { 19 | 'SABS': { 20 | 'pa_all': set( [1,2,3,6] ), 21 | 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing 22 | 1: set( [2,3] ), # lower_abdomen 23 | }, 24 | 'C0': { 25 | 'pa_all': set(range(1, 4)), 26 | 0: set([2,3]), 27 | 1: set([1,3]), 28 | 2: set([1,2]), 29 | }, 30 | 'CHAOST2': { 31 | 'pa_all': set(range(1, 5)), 32 | 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes 33 | 1: set([2, 3]), # lower_abdomen 34 | }, 35 | } 36 | 37 | def get_bbox(fg_mask, inst_mask): 38 | """ 39 | Get the ground truth bounding boxes 40 | """ 41 | 42 | fg_bbox = torch.zeros_like(fg_mask, device=fg_mask.device) 43 | bg_bbox = torch.ones_like(fg_mask, device=fg_mask.device) 44 | 45 | inst_mask[fg_mask == 0] = 0 46 | area = torch.bincount(inst_mask.view(-1)) 47 | cls_id = area[1:].argmax() + 1 48 | cls_ids = np.unique(inst_mask)[1:] 49 | 50 | mask_idx = np.where(inst_mask[0] == cls_id) 51 | y_min = mask_idx[0].min() 52 | y_max = mask_idx[0].max() 53 | x_min = mask_idx[1].min() 54 | x_max = mask_idx[1].max() 55 | fg_bbox[0, y_min:y_max+1, x_min:x_max+1] = 1 56 | 57 | for i in cls_ids: 58 | mask_idx = np.where(inst_mask[0] == i) 59 | y_min = max(mask_idx[0].min(), 0) 60 | y_max = min(mask_idx[0].max(), fg_mask.shape[1] - 1) 61 | x_min = max(mask_idx[1].min(), 0) 62 | x_max = min(mask_idx[1].max(), fg_mask.shape[2] - 1) 63 | bg_bbox[0, y_min:y_max+1, x_min:x_max+1] = 0 64 | return fg_bbox, bg_bbox 65 | 66 | def t2n(img_t): 67 | """ 68 | torch to numpy regardless of whether tensor is on gpu or memory 69 | """ 70 | if img_t.is_cuda: 71 | return img_t.data.cpu().numpy() 72 | else: 73 | return img_t.data.numpy() 74 | 75 | def to01(x_np): 76 | """ 77 | normalize a numpy to 0-1 for visualize 78 | """ 79 | return (x_np - x_np.min()) / (x_np.max() - x_np.min() + 1e-5) 80 | 81 | def compose_wt_simple(is_wce, data_name): 82 | """ 83 | Weights for cross-entropy loss 84 | """ 85 | if is_wce: 86 | if data_name in ['SABS', 'SABS_Superpix', 'C0', 'C0_Superpix', 'CHAOST2', 'CHAOST2_Superpix']: 87 | return torch.FloatTensor([0.05, 1.0]).cuda() 88 | else: 89 | raise NotImplementedError 90 | else: 91 | return torch.FloatTensor([1.0, 1.0]).cuda() 92 | 93 | 94 | class CircularList(list): 95 | """ 96 | Helper for spliting training and validation scans 97 | Originally: https://stackoverflow.com/questions/8951020/pythonic-circular-list/8951224 98 | """ 99 | def __getitem__(self, x): 100 | if isinstance(x, slice): 101 | return [self[x] for x in self._rangeify(x)] 102 | 103 | index = operator.index(x) 104 | try: 105 | return super().__getitem__(index % len(self)) 106 | except ZeroDivisionError: 107 | raise IndexError('list index out of range') 108 | 109 | def _rangeify(self, slice): 110 | start, stop, step = slice.start, slice.stop, slice.step 111 | if start is None: 112 | start = 0 113 | if stop is None: 114 | stop = len(self) 115 | if step is None: 116 | step = 1 117 | return range(start, stop, step) 118 | 119 | -------------------------------------------------------------------------------- /dataloaders/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for datasets 3 | """ 4 | import numpy as np 5 | 6 | import os 7 | import sys 8 | import nibabel as nib 9 | import numpy as np 10 | import pdb 11 | import SimpleITK as sitk 12 | 13 | DATASET_INFO = { 14 | "CHAOST2": { 15 | 'PSEU_LABEL_NAME': ["BGD", "SUPFG"], 16 | 'REAL_LABEL_NAME': ["BG", "LIVER", "RK", "LK", "SPLEEN"], 17 | '_SEP': [0, 4, 8, 12, 16, 20], 18 | 'MODALITY': 'MR', 19 | 'LABEL_GROUP': { 20 | 'pa_all': set(range(1, 5)), 21 | 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes 22 | 1: set([2, 3]), # lower_abdomen 23 | }, 24 | }, 25 | 26 | "SABS": { 27 | 'PSEU_LABEL_NAME': ["BGD", "SUPFG"], 28 | 29 | 'REAL_LABEL_NAME': ["BGD", "SPLEEN", "KID_R", "KID_l", "GALLBLADDER", "ESOPHAGUS", "LIVER", "STOMACH", "AORTA", "IVC",\ 30 | "PS_VEIN", "PANCREAS", "AG_R", "AG_L"], 31 | '_SEP': [0, 6, 12, 18, 24, 30], 32 | 'MODALITY': 'CT', 33 | 'LABEL_GROUP':{ 34 | 'pa_all': set( [1,2,3,6] ), 35 | 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing 36 | 1: set( [2,3] ), # lower_abdomen 37 | } 38 | } 39 | 40 | } 41 | 42 | def read_nii_bysitk(input_fid, peel_info = False): 43 | """ read nii to numpy through simpleitk 44 | 45 | peelinfo: taking direction, origin, spacing and metadata out 46 | """ 47 | img_obj = sitk.ReadImage(input_fid) 48 | img_np = sitk.GetArrayFromImage(img_obj) 49 | if peel_info: 50 | info_obj = { 51 | "spacing": img_obj.GetSpacing(), 52 | "origin": img_obj.GetOrigin(), 53 | "direction": img_obj.GetDirection(), 54 | "array_size": img_np.shape 55 | } 56 | return img_np, info_obj 57 | else: 58 | return img_np 59 | 60 | def get_normalize_op(modality, fids): 61 | """ 62 | As title 63 | Args: 64 | modality: CT or MR 65 | fids: fids for the fold 66 | """ 67 | 68 | def get_CT_statistics(scan_fids): 69 | """ 70 | As CT are quantitative, get mean and std for CT images for image normalizing 71 | As in reality we might not be able to load all images at a time, we would better detach statistics calculation with actual data loading 72 | """ 73 | total_val = 0 74 | n_pix = 0 75 | for fid in scan_fids: 76 | in_img = read_nii_bysitk(fid) 77 | total_val += in_img.sum() 78 | n_pix += np.prod(in_img.shape) 79 | del in_img 80 | meanval = total_val / n_pix 81 | 82 | total_var = 0 83 | for fid in scan_fids: 84 | in_img = read_nii_bysitk(fid) 85 | total_var += np.sum((in_img - meanval) ** 2 ) 86 | del in_img 87 | var_all = total_var / n_pix 88 | 89 | global_std = var_all ** 0.5 90 | 91 | return meanval, global_std 92 | 93 | if modality == 'MR': 94 | 95 | def MR_normalize(x_in): 96 | return (x_in - x_in.mean()) / x_in.std() 97 | 98 | return MR_normalize #, {'mean': None, 'std': None} # we do not really need the global statistics for MR 99 | 100 | elif modality == 'CT': 101 | ct_mean, ct_std = get_CT_statistics(fids) 102 | # debug 103 | print(f'###### DEBUG_DATASET CT_STATS NORMALIZED MEAN {ct_mean / 255} STD {ct_std / 255} ######') 104 | 105 | def CT_normalize(x_in): 106 | """ 107 | Normalizing CT images, based on global statistics 108 | """ 109 | return (x_in - ct_mean) / ct_std 110 | 111 | return CT_normalize #, {'mean': ct_mean, 'std': ct_std} 112 | 113 | 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSL_ALPNet: self-supervised few-shot / in-context medical image segmentation 2 | 3 | [ECCV'20] [Self-supervision with Superpixels: Training Few-shot Medical Image Segmentation without Annotation](https://arxiv.org/abs/2007.09886v2) 4 | 5 | ![](./intro.png) 6 | 7 | **Abstract**: 8 | 9 | Few-shot semantic segmentation (FSS, also known as in-context segmentation since the MLLM era) has great potential for medical imaging applications. Most of the existing FSS techniques require abundant annotated semantic classes for training. However, these methods may not be applicable for medical images due to the lack of annotations. To address this problem we make several contributions: (1) A novel self-supervised FSS framework for medical images in order to eliminate the requirement for annotations during training. Additionally, superpixel-based pseudo-labels are generated to provide supervision; (2) An adaptive local prototype pooling module plugged into prototypical networks, to solve the common challenging foreground-background imbalance problem in medical image segmentation; (3) We demonstrate the general applicability of the proposed approach for medical images using three different tasks: abdominal organ segmentation for CT and MRI, as well as cardiac segmentation for MRI. Our results show that, for medical image segmentation, the proposed method outperforms conventional FSS methods which require manual annotations for training. 10 | 11 | **NOTE: We are actively updating this repository** 12 | 13 | If you find this code base useful, please cite our paper. Thanks! 14 | 15 | ``` 16 | @article{ouyang2020self, 17 | title={Self-Supervision with Superpixels: Training Few-shot Medical Image Segmentation without Annotation}, 18 | author={Ouyang, Cheng and Biffi, Carlo and Chen, Chen and Kart, Turkay and Qiu, Huaqi and Rueckert, Daniel}, 19 | journal={arXiv preprint arXiv:2007.09886}, 20 | year={2020} 21 | } 22 | ``` 23 | 24 | ### 1. Dependencies 25 | 26 | Please install essential dependencies (see `requirements.txt`) 27 | 28 | ``` 29 | dcm2nii 30 | json5==0.8.5 31 | jupyter==1.0.0 32 | nibabel==2.5.1 33 | numpy==1.15.1 34 | opencv-python==4.1.1.26 35 | Pillow==7.1.0 36 | sacred==0.7.5 37 | scikit-image==0.14.0 38 | SimpleITK==1.2.3 39 | torch==1.3.0 40 | torchvision==0.4.1 41 | ``` 42 | 43 | ### 2. Data pre-processing 44 | 45 | **Abdominal MRI** 46 | 47 | 0. Download [Combined Healthy Abdominal Organ Segmentation dataset](https://chaos.grand-challenge.org/) and put the `/MR` folder under `./data/CHAOST2/` directory 48 | 49 | 1. Converting downloaded data (T2 fold) to `nii` files in 3D for the ease of reading 50 | 51 | run `./data/CHAOST2/dcm_img_to_nii.sh` to convert dicom images to nifti files. 52 | 53 | run `./data/CHAOST2/png_gth_to_nii.ipynp` to convert ground truth with `png` format to nifti. 54 | 55 | 2. Pre-processing downloaded images 56 | 57 | run `./data/CHAOST2/image_normalize.ipynb` 58 | 59 | **Abdominal CT** 60 | 61 | 0. Download [Synapse Multi-atlas Abdominal Segmentation dataset](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789) and put the `/img` and `/label` folders under `./data/SABS/` directory 62 | 63 | 1. Intensity windowing 64 | 65 | run `./data/SABS/intensity_normalization.ipynb` to apply abdominal window. 66 | 67 | 2. Crop irrelavent emptry background and resample images 68 | 69 | run `./data/SABS/resampling_and_roi.ipynb` 70 | 71 | **Shared steps** 72 | 73 | 3. Build class-slice indexing for setting up experiments 74 | 75 | run `./data/class_slice_index_gen.ipynb` 76 | 77 | ` 78 | You are highly welcomed to use this pre-processing pipeline in your own work for evaluating few-shot medical image segmentation in future. Please consider citing our paper (as well as the original sources of data) if you find this pipeline useful. Thanks! 79 | ` 80 | 81 | ### 3. Pseudolabel generation 82 | 83 | run `./data_preprocessing/pseudolabel_gen.ipynb`. You might need to specify which dataset to use within the notebook. 84 | 85 | ### 4. Running training and evaluation 86 | 87 | run `./examples/train_ssl_abdominal_.sh` and `./examples/test_ssl_abdominal_.sh` 88 | 89 | The results should be easy to re-produce by following the steps described above. Still, if pretrained models are needed, you can find them [here](https://drive.google.com/file/d/1n_y5IDzMAQU8MIAYXLABSZ6003eJdbAe/view?usp=sharing) for 1-shot abdominal CT models on setting 2 (objects of interests excluded even though unlabeled during training), without the boundary prior loss. 90 | 91 | ### Acknowledgement 92 | 93 | This code is based on vanilla [PANet](https://github.com/kaixin96/PANet) (ICCV'19) by [Kaixin Wang](https://github.com/kaixin96) et al. The data augmentation tools are from Dr. [Jo Schlemper](https://github.com/js3611). Should you have any further questions, please let us know. Thanks again for your interest. 94 | 95 | -------------------------------------------------------------------------------- /config_ssl_upload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment configuration file 3 | Extended from config file from original PANet Repository 4 | """ 5 | import os 6 | import re 7 | import glob 8 | import itertools 9 | 10 | import sacred 11 | from sacred import Experiment 12 | from sacred.observers import FileStorageObserver 13 | from sacred.utils import apply_backspaces_and_linefeeds 14 | 15 | from platform import node 16 | from datetime import datetime 17 | 18 | sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False 19 | sacred.SETTINGS.CAPTURE_MODE = 'no' 20 | 21 | ex = Experiment('mySSL') 22 | ex.captured_out_filter = apply_backspaces_and_linefeeds 23 | 24 | source_folders = ['.', './dataloaders', './models', './util'] 25 | sources_to_save = list(itertools.chain.from_iterable( 26 | [glob.glob(f'{folder}/*.py') for folder in source_folders])) 27 | for source_file in sources_to_save: 28 | ex.add_source_file(source_file) 29 | 30 | @ex.config 31 | def cfg(): 32 | """Default configurations""" 33 | seed = 1234 34 | gpu_id = 0 35 | mode = 'train' # for now only allows 'train' 36 | num_workers = 4 # 0 for debugging. 37 | 38 | dataset = 'CHAOST2_Superpix' # i.e. abdominal MRI 39 | use_coco_init = True # initialize backbone with MS_COCO initialization. Anyway coco does not contain medical images 40 | 41 | ### Training 42 | n_steps = 100100 43 | batch_size = 1 44 | lr_milestones = [ (ii + 1) * 1000 for ii in range(n_steps // 1000 - 1)] 45 | lr_step_gamma = 0.95 46 | ignore_label = 255 47 | print_interval = 100 48 | save_snapshot_every = 25000 49 | max_iters_per_load = 1000 # epoch size, interval for reloading the dataset 50 | scan_per_load = -1 # numbers of 3d scans per load for saving memory. If -1, load the entire dataset to the memory 51 | which_aug = 'sabs_aug' # standard data augmentation with intensity and geometric transforms 52 | input_size = (256, 256) 53 | min_fg_data='100' # when training with manual annotations, indicating number of foreground pixels in a single class single slice. This empirically stablizes the training process 54 | label_sets = 0 # which group of labels taking as training (the rest are for testing) 55 | exclude_cls_list = [2, 3] # testing classes to be excluded in training. Set to [] if testing under setting 1 56 | usealign = True # see vanilla PANet 57 | use_wce = True 58 | 59 | ### Validation 60 | z_margin = 0 61 | eval_fold = 0 # which fold for 5 fold cross validation 62 | support_idx=[-1] # indicating which scan is used as support in testing. 63 | val_wsize=2 # L_H, L_W in testing 64 | n_sup_part = 3 # number of chuncks in testing 65 | 66 | # Network 67 | modelname = 'dlfcn_res101' # resnet 101 backbone from torchvision fcn-deeplab 68 | clsname = None # 69 | reload_model_path = None # path for reloading a trained model (overrides ms-coco initialization) 70 | proto_grid_size = 8 # L_H, L_W = (32, 32) / 8 = (4, 4) in training 71 | feature_hw = [32, 32] # feature map size, should couple this with backbone in future 72 | 73 | # SSL 74 | superpix_scale = 'MIDDLE' #MIDDLE/ LARGE 75 | 76 | model = { 77 | 'align': usealign, 78 | 'use_coco_init': use_coco_init, 79 | 'which_model': modelname, 80 | 'cls_name': clsname, 81 | 'proto_grid_size' : proto_grid_size, 82 | 'feature_hw': feature_hw, 83 | 'reload_model_path': reload_model_path 84 | } 85 | 86 | task = { 87 | 'n_ways': 1, 88 | 'n_shots': 1, 89 | 'n_queries': 1, 90 | 'npart': n_sup_part 91 | } 92 | 93 | optim_type = 'sgd' 94 | optim = { 95 | 'lr': 1e-3, 96 | 'momentum': 0.9, 97 | 'weight_decay': 0.0005, 98 | } 99 | 100 | exp_prefix = '' 101 | 102 | exp_str = '_'.join( 103 | [exp_prefix] 104 | + [dataset,] 105 | + [f'sets_{label_sets}_{task["n_shots"]}shot']) 106 | 107 | path = { 108 | 'log_dir': './runs', 109 | 'SABS':{'data_dir': "./data/SABS/sabs_CT_normalized" 110 | }, 111 | 'C0':{'data_dir': "feed your dataset path here" 112 | }, 113 | 'CHAOST2':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized/" 114 | }, 115 | 'SABS_Superpix':{'data_dir': "./data/SABS/sabs_CT_normalized"}, 116 | 'C0_Superpix':{'data_dir': "feed your dataset path here"}, 117 | 'CHAOST2_Superpix':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized/"}, 118 | } 119 | 120 | 121 | @ex.config_hook 122 | def add_observer(config, command_name, logger): 123 | """A hook fucntion to add observer""" 124 | exp_name = f'{ex.path}_{config["exp_str"]}' 125 | observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name)) 126 | ex.observers.append(observer) 127 | return config 128 | -------------------------------------------------------------------------------- /dataloaders/augutils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Utilities for augmentation. Partly credit to Dr. Jo Schlemper 3 | ''' 4 | from os.path import join 5 | 6 | import torch 7 | import numpy as np 8 | import torchvision.transforms as deftfx 9 | import dataloaders.image_transforms as myit 10 | import copy 11 | 12 | sabs_aug = { 13 | # turn flipping off as medical data has fixed orientations 14 | 'flip' : { 'v':False, 'h':False, 't': False, 'p':0.25 }, 15 | 'affine' : { 16 | 'rotate':5, 17 | 'shift':(5,5), 18 | 'shear':5, 19 | 'scale':(0.9, 1.2), 20 | }, 21 | 'elastic' : {'alpha':10,'sigma':5}, 22 | 'patch': 256, 23 | 'reduce_2d': True, 24 | 'gamma_range': (0.5, 1.5) 25 | } 26 | 27 | sabs_augv3 = { 28 | 'flip' : { 'v':False, 'h':False, 't': False, 'p':0.25 }, 29 | 'affine' : { 30 | 'rotate':30, 31 | 'shift':(30,30), 32 | 'shear':30, 33 | 'scale':(0.8, 1.3), 34 | }, 35 | 'elastic' : {'alpha':20,'sigma':5}, 36 | 'patch': 256, 37 | 'reduce_2d': True, 38 | 'gamma_range': (0.2, 1.8) 39 | } 40 | 41 | augs = { 42 | 'sabs_aug': sabs_aug, 43 | 'aug_v3': sabs_augv3, # more aggresive 44 | } 45 | 46 | 47 | def get_geometric_transformer(aug, order=3): 48 | """order: interpolation degree. Select order=0 for augmenting segmentation """ 49 | affine = aug['aug'].get('affine', 0) 50 | alpha = aug['aug'].get('elastic',{'alpha': 0})['alpha'] 51 | sigma = aug['aug'].get('elastic',{'sigma': 0})['sigma'] 52 | flip = aug['aug'].get('flip', {'v': True, 'h': True, 't': True, 'p':0.125}) 53 | 54 | tfx = [] 55 | if 'flip' in aug['aug']: 56 | tfx.append(myit.RandomFlip3D(**flip)) 57 | 58 | if 'affine' in aug['aug']: 59 | tfx.append(myit.RandomAffine(affine.get('rotate'), 60 | affine.get('shift'), 61 | affine.get('shear'), 62 | affine.get('scale'), 63 | affine.get('scale_iso',True), 64 | order=order)) 65 | 66 | if 'elastic' in aug['aug']: 67 | tfx.append(myit.ElasticTransform(alpha, sigma)) 68 | input_transform = deftfx.Compose(tfx) 69 | return input_transform 70 | 71 | def get_intensity_transformer(aug): 72 | """some basic intensity transforms""" 73 | 74 | def gamma_tansform(img): 75 | gamma_range = aug['aug']['gamma_range'] 76 | if isinstance(gamma_range, tuple): 77 | gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0] 78 | cmin = img.min() 79 | irange = (img.max() - cmin + 1e-5) 80 | 81 | img = img - cmin + 1e-5 82 | img = irange * np.power(img * 1.0 / irange, gamma) 83 | img = img + cmin 84 | 85 | elif gamma_range == False: 86 | pass 87 | else: 88 | raise ValueError("Cannot identify gamma transform range {}".format(gamma_range)) 89 | return img 90 | 91 | return gamma_tansform 92 | 93 | def transform_with_label(aug): 94 | """ 95 | Doing image geometric transform 96 | Proposed image to have the following configurations 97 | [H x W x C + CL] 98 | Where CL is the number of channels for the label. It is NOT in one-hot form 99 | """ 100 | 101 | geometric_tfx = get_geometric_transformer(aug) 102 | intensity_tfx = get_intensity_transformer(aug) 103 | 104 | def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs): 105 | """ 106 | Args 107 | comp: a numpy array with shape [H x W x C + c_label] 108 | c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1) 109 | nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label 110 | 111 | """ 112 | comp = copy.deepcopy(comp) 113 | if (use_onehot is True) and (c_label != 1): 114 | raise NotImplementedError("Only allow compact label, also the label can only be 2d") 115 | assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label" 116 | 117 | # geometric transform 118 | _label = comp[..., c_img ] 119 | _h_label = np.float32(np.arange( nclass ) == (_label[..., None]) ) 120 | comp = np.concatenate( [comp[..., :c_img ], _h_label], -1 ) 121 | comp = geometric_tfx(comp) 122 | # round one_hot labels to 0 or 1 123 | t_label_h = comp[..., c_img : ] 124 | t_label_h = np.rint(t_label_h) 125 | assert t_label_h.max() <= 1 126 | t_img = comp[..., 0 : c_img ] 127 | 128 | # intensity transform 129 | t_img = intensity_tfx(t_img) 130 | 131 | if use_onehot is True: 132 | t_label = t_label_h 133 | else: 134 | t_label = np.expand_dims(np.argmax(t_label_h, axis = -1), -1) 135 | return t_img, t_label 136 | 137 | return transform 138 | 139 | -------------------------------------------------------------------------------- /data/CHAOST2/png_gth_to_nii.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Converting labels from png to nii file\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is the first step for data preparation\n", 13 | "\n", 14 | "Input: ground truth labels in `.png` format\n", 15 | "\n", 16 | "Output: labels in `.nii` format, indexed by patient id" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 13, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import os\n", 39 | "import glob\n", 40 | "\n", 41 | "import numpy as np\n", 42 | "import PIL\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import SimpleITK as sitk\n", 45 | "import sys\n", 46 | "sys.path.insert(0, '../../dataloaders/')\n", 47 | "import niftiio as nio" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 14, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "example = \"./MR/1/T2SPIR/Ground/IMG-0002-00001.png\" # example of ground-truth file name. " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 15, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "### search for scan ids\n", 73 | "ids = os.listdir(\"./MR/\")\n", 74 | "OUT_DIR = './niis/T2SPIR/'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 16, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['37',\n", 86 | " '3',\n", 87 | " '15',\n", 88 | " '34',\n", 89 | " '33',\n", 90 | " '39',\n", 91 | " '20',\n", 92 | " '10',\n", 93 | " '22',\n", 94 | " '8',\n", 95 | " '31',\n", 96 | " '2',\n", 97 | " '36',\n", 98 | " '5',\n", 99 | " '13',\n", 100 | " '19',\n", 101 | " '21',\n", 102 | " '1',\n", 103 | " '38',\n", 104 | " '32']" 105 | ] 106 | }, 107 | "execution_count": 16, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "ids" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 17, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "image with id 37 has been saved!\n", 126 | "image with id 3 has been saved!\n", 127 | "image with id 15 has been saved!\n", 128 | "image with id 34 has been saved!\n", 129 | "image with id 33 has been saved!\n", 130 | "image with id 39 has been saved!\n", 131 | "image with id 20 has been saved!\n", 132 | "image with id 10 has been saved!\n", 133 | "image with id 22 has been saved!\n", 134 | "image with id 8 has been saved!\n", 135 | "image with id 31 has been saved!\n", 136 | "image with id 2 has been saved!\n", 137 | "image with id 36 has been saved!\n", 138 | "image with id 5 has been saved!\n", 139 | "image with id 13 has been saved!\n", 140 | "image with id 19 has been saved!\n", 141 | "image with id 21 has been saved!\n", 142 | "image with id 1 has been saved!\n", 143 | "image with id 38 has been saved!\n", 144 | "image with id 32 has been saved!\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "#### Write them to nii files for the ease of loading in future\n", 150 | "for curr_id in ids:\n", 151 | " pngs = glob.glob(f'./MR/{curr_id}/T2SPIR/Ground/*.png')\n", 152 | " pngs = sorted(pngs, key = lambda x: int(os.path.basename(x).split(\"-\")[-1].split(\".png\")[0]))\n", 153 | " buffer = []\n", 154 | "\n", 155 | " for fid in pngs:\n", 156 | " buffer.append(PIL.Image.open(fid))\n", 157 | "\n", 158 | " vol = np.stack(buffer, axis = 0)\n", 159 | " # flip correction\n", 160 | " vol = np.flip(vol, axis = 1).copy()\n", 161 | " # remap values\n", 162 | " for new_val, old_val in enumerate(sorted(np.unique(vol))):\n", 163 | " vol[vol == old_val] = new_val\n", 164 | "\n", 165 | " # get reference \n", 166 | " ref_img = f'./niis/T2SPIR/image_{curr_id}.nii.gz'\n", 167 | " img_o = sitk.ReadImage(ref_img)\n", 168 | " vol_o = nio.np2itk(img=vol, ref_obj=img_o)\n", 169 | " sitk.WriteImage(vol_o, f'{OUT_DIR}/label_{curr_id}.nii.gz')\n", 170 | " print(f'image with id {curr_id} has been saved!')\n", 171 | "\n", 172 | " " 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.0" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /models/alpmodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | ALPModule 3 | """ 4 | import torch 5 | import math 6 | from torch import nn 7 | from torch.nn import functional as F 8 | import numpy as np 9 | from pdb import set_trace 10 | import matplotlib.pyplot as plt 11 | # for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm 12 | 13 | class MultiProtoAsConv(nn.Module): 14 | def __init__(self, proto_grid, feature_hw, upsample_mode = 'bilinear'): 15 | """ 16 | ALPModule 17 | Args: 18 | proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2 19 | feature_hw: Spatial size of input feature map 20 | 21 | """ 22 | super(MultiProtoAsConv, self).__init__() 23 | self.proto_grid = proto_grid 24 | self.upsample_mode = upsample_mode 25 | kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ] 26 | self.avg_pool_op = nn.AvgPool2d( kernel_size ) 27 | 28 | def forward(self, qry, sup_x, sup_y, mode, thresh, isval = False, val_wsize = None, vis_sim = False, **kwargs): 29 | """ 30 | Now supports 31 | Args: 32 | mode: 'mask'/ 'grid'. if mask, works as original prototyping 33 | qry: [way(1), nc, h, w] 34 | sup_x: [nb, nc, h, w] 35 | sup_y: [nb, 1, h, w] 36 | vis_sim: visualize raw similarities or not 37 | New 38 | mode: 'mask'/ 'grid'. if mask, works as original prototyping 39 | qry: [way(1), nb(1), nc, h, w] 40 | sup_x: [way(1), shot, nb(1), nc, h, w] 41 | sup_y: [way(1), shot, nb(1), h, w] 42 | vis_sim: visualize raw similarities or not 43 | """ 44 | 45 | qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w] 46 | sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w] 47 | sup_y = sup_y.squeeze(0) # [nshot, 1, h, w] 48 | 49 | def safe_norm(x, p = 2, dim = 1, eps = 1e-4): 50 | x_norm = torch.norm(x, p = p, dim = dim) # .detach() 51 | x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps) 52 | x = x.div(x_norm.unsqueeze(1).expand_as(x)) 53 | return x 54 | 55 | if mode == 'mask': # class-level prototype only 56 | proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \ 57 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C 58 | 59 | proto = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything 60 | pred_mask = F.cosine_similarity(qry, proto[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w] 61 | 62 | vis_dict = {'proto_assign': None} # things to visualize 63 | if vis_sim: 64 | vis_dict['raw_local_sims'] = pred_mask 65 | return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w] 66 | 67 | # no need to merge with gridconv+ 68 | elif mode == 'gridconv': # using local prototypes only 69 | 70 | input_size = qry.shape 71 | nch = input_size[1] 72 | 73 | sup_nshot = sup_x.shape[0] 74 | 75 | n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) 76 | 77 | n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc 78 | n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) 79 | 80 | sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) 81 | sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) 82 | 83 | protos = n_sup_x[sup_y_g > thresh, :] # npro, nc 84 | pro_n = safe_norm(protos) 85 | qry_n = safe_norm(qry) 86 | 87 | dists = F.conv2d(qry_n, pro_n[..., None, None]) * 20 88 | 89 | pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) 90 | debug_assign = dists.argmax(dim = 1).float().detach() 91 | 92 | vis_dict = {'proto_assign': debug_assign} # things to visualize 93 | 94 | if vis_sim: # return the similarity for visualization 95 | vis_dict['raw_local_sims'] = dists.clone().detach() 96 | 97 | return pred_grid, [debug_assign], vis_dict 98 | 99 | 100 | elif mode == 'gridconv+': # local and global prototypes 101 | 102 | input_size = qry.shape 103 | nch = input_size[1] 104 | nb_q = input_size[0] 105 | 106 | sup_size = sup_x.shape[0] 107 | 108 | n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) 109 | 110 | sup_nshot = sup_x.shape[0] 111 | 112 | n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) 113 | n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) 114 | 115 | sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) 116 | 117 | sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) 118 | 119 | protos = n_sup_x[sup_y_g > thresh, :] 120 | 121 | 122 | glb_proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \ 123 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) 124 | 125 | pro_n = safe_norm( torch.cat( [protos, glb_proto], dim = 0 ) ) 126 | 127 | qry_n = safe_norm(qry) 128 | 129 | dists = F.conv2d(qry_n, pro_n[..., None, None]) * 20 130 | 131 | pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) 132 | raw_local_sims = dists.detach() 133 | 134 | 135 | debug_assign = dists.argmax(dim = 1).float() 136 | 137 | vis_dict = {'proto_assign': debug_assign} 138 | if vis_sim: 139 | vis_dict['raw_local_sims'] = dists.clone().detach() 140 | 141 | return pred_grid, [debug_assign], vis_dict 142 | 143 | else: 144 | raise NotImplementedError 145 | 146 | -------------------------------------------------------------------------------- /data/SABS/Synapse_abdominal_classmap.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Get class-pid-index map for synapse abdominal CT dataset " 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 18, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 20 | "The autoreload extension is already loaded. To reload it, use:\n", 21 | " %reload_ext autoreload\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "%reset\n", 27 | "%load_ext autoreload\n", 28 | "%autoreload 2\n", 29 | "import numpy as np\n", 30 | "import os\n", 31 | "import glob\n", 32 | "import SimpleITK as sitk\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import sys\n", 35 | "sys.path.insert(0, '../../dataloaders/')\n", 36 | "import niftiio as nio" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 19, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# lets save them in a same way as mmwhs for the ease of modifying dataloader\n", 46 | "\n", 47 | "# normalization: cut top 2% of histogram, then doing volume-wise normalization\n", 48 | "\n", 49 | "IMG_BNAME=\"./sabs_CT_normalized/image_*.nii.gz\"\n", 50 | "SEG_BNAME=\"./sabs_CT_normalized/label_*.nii.gz\"\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 20, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "imgs = glob.glob(IMG_BNAME)\n", 60 | "segs = glob.glob(SEG_BNAME)\n", 61 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 62 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 16, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "['./sabs_CT_normalized/image_0.nii.gz',\n", 74 | " './sabs_CT_normalized/image_1.nii.gz',\n", 75 | " './sabs_CT_normalized/image_2.nii.gz',\n", 76 | " './sabs_CT_normalized/image_3.nii.gz',\n", 77 | " './sabs_CT_normalized/image_4.nii.gz',\n", 78 | " './sabs_CT_normalized/image_5.nii.gz',\n", 79 | " './sabs_CT_normalized/image_6.nii.gz',\n", 80 | " './sabs_CT_normalized/image_7.nii.gz',\n", 81 | " './sabs_CT_normalized/image_8.nii.gz',\n", 82 | " './sabs_CT_normalized/image_9.nii.gz',\n", 83 | " './sabs_CT_normalized/image_10.nii.gz',\n", 84 | " './sabs_CT_normalized/image_11.nii.gz',\n", 85 | " './sabs_CT_normalized/image_12.nii.gz',\n", 86 | " './sabs_CT_normalized/image_13.nii.gz',\n", 87 | " './sabs_CT_normalized/image_14.nii.gz',\n", 88 | " './sabs_CT_normalized/image_15.nii.gz',\n", 89 | " './sabs_CT_normalized/image_16.nii.gz',\n", 90 | " './sabs_CT_normalized/image_17.nii.gz',\n", 91 | " './sabs_CT_normalized/image_18.nii.gz',\n", 92 | " './sabs_CT_normalized/image_19.nii.gz',\n", 93 | " './sabs_CT_normalized/image_20.nii.gz',\n", 94 | " './sabs_CT_normalized/image_21.nii.gz',\n", 95 | " './sabs_CT_normalized/image_22.nii.gz',\n", 96 | " './sabs_CT_normalized/image_23.nii.gz',\n", 97 | " './sabs_CT_normalized/image_24.nii.gz',\n", 98 | " './sabs_CT_normalized/image_25.nii.gz',\n", 99 | " './sabs_CT_normalized/image_26.nii.gz',\n", 100 | " './sabs_CT_normalized/image_27.nii.gz',\n", 101 | " './sabs_CT_normalized/image_28.nii.gz',\n", 102 | " './sabs_CT_normalized/image_29.nii.gz']" 103 | ] 104 | }, 105 | "execution_count": 16, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "imgs" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 23, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "(85, 257, 257)" 123 | ] 124 | }, 125 | "execution_count": 23, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "lb = nio.read_nii_bysitk(segs[0])\n", 132 | "lb.shape\n", 133 | "# please check the organizations of dimensions. We will iterate through the z dimension.\n", 134 | "# it should keep consistent with those for CHAOS dataset" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 25, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "pid 0 finished!\n", 147 | "pid 1 finished!\n", 148 | "pid 2 finished!\n", 149 | "pid 3 finished!\n", 150 | "pid 4 finished!\n", 151 | "pid 5 finished!\n", 152 | "pid 6 finished!\n", 153 | "pid 7 finished!\n", 154 | "pid 8 finished!\n", 155 | "pid 9 finished!\n", 156 | "pid 10 finished!\n", 157 | "pid 11 finished!\n", 158 | "pid 12 finished!\n", 159 | "pid 13 finished!\n", 160 | "pid 14 finished!\n", 161 | "pid 15 finished!\n", 162 | "pid 16 finished!\n", 163 | "pid 17 finished!\n", 164 | "pid 18 finished!\n", 165 | "pid 19 finished!\n", 166 | "pid 20 finished!\n", 167 | "pid 21 finished!\n", 168 | "pid 22 finished!\n", 169 | "pid 23 finished!\n", 170 | "pid 24 finished!\n", 171 | "pid 25 finished!\n", 172 | "pid 26 finished!\n", 173 | "pid 27 finished!\n", 174 | "pid 28 finished!\n", 175 | "pid 29 finished!\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "import json\n", 181 | "classmap = {}\n", 182 | "LABEL_NAME = [\"BGD\", \"SPLEEN\", \"KID_R\", \"KID_l\", \"GALLBLADDER\", \"ESOPHAGUS\", \"LIVER\", \"STOMACH\", \"AORTA\", \"IVC\", \"PS_VEIN\", \"PANCREAS\", \"AG_R\", \"AG_L\"] \n", 183 | "\n", 184 | "MIN_TP=1 # minimum number of true positive pixels in a slice\n", 185 | "\n", 186 | "fid = f'./sabs_CT_normalized/classmap_{MIN_TP}.json'\n", 187 | "for _lb in LABEL_NAME:\n", 188 | " classmap[_lb] = {}\n", 189 | " for pid in range(len(segs)):\n", 190 | " classmap[_lb][str(pid)] = []\n", 191 | "\n", 192 | "for pid, seg in enumerate(segs):\n", 193 | " lb_vol = nio.read_nii_bysitk(seg)\n", 194 | " n_slice = lb_vol.shape[0]\n", 195 | " for slc in range(n_slice):\n", 196 | " for cls in range(len(LABEL_NAME)):\n", 197 | " if cls in lb_vol[slc, ...]:\n", 198 | " if np.sum( lb_vol[slc, ...] == cls) >= MIN_TP:\n", 199 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 200 | " print(f'pid {str(pid)} finished!')\n", 201 | " \n", 202 | "with open(fid, 'w') as fopen:\n", 203 | " json.dump(classmap, fopen)\n", 204 | " fopen.close() \n", 205 | " " 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 26, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": "Python 3", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 3 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython3", 232 | "version": "3.6.9" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 2 237 | } 238 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training the model 3 | Extended from original implementation of PANet by Wang et al. 4 | """ 5 | import os 6 | import shutil 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim 10 | from torch.utils.data import DataLoader 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | import torch.backends.cudnn as cudnn 13 | import numpy as np 14 | 15 | from models.grid_proto_fewshot import FewShotSeg 16 | from dataloaders.dev_customized_med import med_fewshot 17 | from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset 18 | from dataloaders.dataset_utils import DATASET_INFO 19 | import dataloaders.augutils as myaug 20 | 21 | from util.utils import set_seed, t2n, to01, compose_wt_simple 22 | from util.metric import Metric 23 | 24 | from config_ssl_upload import ex 25 | import tqdm 26 | 27 | # config pre-trained model caching path 28 | os.environ['TORCH_HOME'] = "./pretrained_model" 29 | 30 | @ex.automain 31 | def main(_run, _config, _log): 32 | if _run.observers: 33 | os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) 34 | for source_file, _ in _run.experiment_info['sources']: 35 | os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), 36 | exist_ok=True) 37 | _run.observers[0].save_file(source_file, f'source/{source_file}') 38 | shutil.rmtree(f'{_run.observers[0].basedir}/_sources') 39 | 40 | set_seed(_config['seed']) 41 | cudnn.enabled = True 42 | cudnn.benchmark = True 43 | torch.cuda.set_device(device=_config['gpu_id']) 44 | torch.set_num_threads(1) 45 | 46 | _log.info('###### Create model ######') 47 | model = FewShotSeg(pretrained_path=None, cfg=_config['model']) 48 | 49 | model = model.cuda() 50 | model.train() 51 | 52 | _log.info('###### Load data ######') 53 | ### Training set 54 | data_name = _config['dataset'] 55 | if data_name == 'SABS_Superpix': 56 | baseset_name = 'SABS' 57 | elif data_name == 'C0_Superpix': 58 | raise NotImplementedError 59 | baseset_name = 'C0' 60 | elif data_name == 'CHAOST2_Superpix': 61 | baseset_name = 'CHAOST2' 62 | else: 63 | raise ValueError(f'Dataset: {data_name} not found') 64 | 65 | ### Transforms for data augmentation 66 | tr_transforms = myaug.transform_with_label({'aug': myaug.augs[_config['which_aug']]}) 67 | assert _config['scan_per_load'] < 0 # by default we load the entire dataset directly 68 | 69 | test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP']['pa_all'] - DATASET_INFO[baseset_name]['LABEL_GROUP'][_config["label_sets"]] 70 | _log.info(f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######') 71 | _log.info(f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######') 72 | 73 | tr_parent = SuperpixelDataset( # base dataset 74 | which_dataset = baseset_name, 75 | base_dir=_config['path'][data_name]['data_dir'], 76 | idx_split = _config['eval_fold'], 77 | mode='train', 78 | min_fg=str(_config["min_fg_data"]), # dummy entry for superpixel dataset 79 | transforms=tr_transforms, 80 | nsup = _config['task']['n_shots'], 81 | scan_per_load = _config['scan_per_load'], 82 | exclude_list = _config["exclude_cls_list"], 83 | superpix_scale = _config["superpix_scale"], 84 | fix_length = _config["max_iters_per_load"] if (data_name == 'C0_Superpix') or (data_name == 'CHAOST2_Superpix') else None 85 | ) 86 | 87 | ### dataloaders 88 | trainloader = DataLoader( 89 | tr_parent, 90 | batch_size=_config['batch_size'], 91 | shuffle=True, 92 | num_workers=_config['num_workers'], 93 | pin_memory=True, 94 | drop_last=True 95 | ) 96 | 97 | _log.info('###### Set optimizer ######') 98 | if _config['optim_type'] == 'sgd': 99 | optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) 100 | else: 101 | raise NotImplementedError 102 | 103 | scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma = _config['lr_step_gamma']) 104 | 105 | my_weight = compose_wt_simple(_config["use_wce"], data_name) 106 | criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'], weight = my_weight) 107 | 108 | i_iter = 0 # total number of iteration 109 | n_sub_epoches = _config['n_steps'] // _config['max_iters_per_load'] # number of times for reloading 110 | 111 | log_loss = {'loss': 0, 'align_loss': 0} 112 | 113 | _log.info('###### Training ######') 114 | for sub_epoch in range(n_sub_epoches): 115 | _log.info(f'###### This is epoch {sub_epoch} of {n_sub_epoches} epoches ######') 116 | for _, sample_batched in enumerate(trainloader): 117 | # Prepare input 118 | i_iter += 1 119 | # add writers 120 | support_images = [[shot.cuda() for shot in way] 121 | for way in sample_batched['support_images']] 122 | support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] 123 | for way in sample_batched['support_mask']] 124 | support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] 125 | for way in sample_batched['support_mask']] 126 | 127 | query_images = [query_image.cuda() 128 | for query_image in sample_batched['query_images']] 129 | query_labels = torch.cat( 130 | [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0) 131 | 132 | optimizer.zero_grad() 133 | # FIXME: in the model definition, filter out the failure case where pseudolabel falls outside of image or too small to calculate a prototype 134 | try: 135 | query_pred, align_loss, debug_vis, assign_mats = model(support_images, support_fg_mask, support_bg_mask, query_images, isval = False, val_wsize = None) 136 | except: 137 | print('Faulty batch detected, skip') 138 | continue 139 | 140 | query_loss = criterion(query_pred, query_labels) 141 | loss = query_loss + align_loss 142 | loss.backward() 143 | optimizer.step() 144 | scheduler.step() 145 | 146 | # Log loss 147 | query_loss = query_loss.detach().data.cpu().numpy() 148 | align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0 149 | 150 | _run.log_scalar('loss', query_loss) 151 | _run.log_scalar('align_loss', align_loss) 152 | log_loss['loss'] += query_loss 153 | log_loss['align_loss'] += align_loss 154 | 155 | # print loss and take snapshots 156 | if (i_iter + 1) % _config['print_interval'] == 0: 157 | 158 | loss = log_loss['loss'] / _config['print_interval'] 159 | align_loss = log_loss['align_loss'] / _config['print_interval'] 160 | 161 | log_loss['loss'] = 0 162 | log_loss['align_loss'] = 0 163 | 164 | print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss},') 165 | 166 | if (i_iter + 1) % _config['save_snapshot_every'] == 0: 167 | _log.info('###### Taking snapshot ######') 168 | torch.save(model.state_dict(), 169 | os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) 170 | 171 | if data_name == 'C0_Superpix' or data_name == 'CHAOST2_Superpix': 172 | if (i_iter + 1) % _config['max_iters_per_load'] == 0: 173 | _log.info('###### Reloading dataset ######') 174 | trainloader.dataset.reload_buffer() 175 | print(f'###### New dataset with {len(trainloader.dataset)} slices has been loaded ######') 176 | 177 | if (i_iter - 2) > _config['n_steps']: 178 | return 1 # finish up 179 | 180 | -------------------------------------------------------------------------------- /data/SABS/intensity_normalization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Overview\n", 8 | "\n", 9 | "This is the first step for data preparation.\n", 10 | "\n", 11 | "Window images, as well as reindex them \n", 12 | "\n", 13 | "Input: original CT images\n", 14 | "\n", 15 | "Output: Images with abdominal windowing\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 35, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 28 | "The autoreload extension is already loaded. To reload it, use:\n", 29 | " %reload_ext autoreload\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "%reset\n", 35 | "%load_ext autoreload\n", 36 | "%autoreload 2\n", 37 | "import numpy as np\n", 38 | "import os\n", 39 | "import glob\n", 40 | "import SimpleITK as sitk\n", 41 | "\n", 42 | "import sys\n", 43 | "\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 40, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# set up directories for images\n", 55 | "IMG_FOLDER=\"./CT/img/\"\n", 56 | "SEG_FOLDER=\"./CT/label/\"\n", 57 | "OUT_FOLDER=\"./tmp_normalized/\"" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 42, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "imgs = glob.glob(IMG_FOLDER + \"/*.nii.gz\")\n", 67 | "imgs = [ fid for fid in sorted(imgs) ]\n", 68 | "segs = [ fid for fid in sorted(glob.glob(SEG_FOLDER + \"/*.nii.gz\")) ]\n", 69 | "\n", 70 | "pids = [ pid.split(\"img0\")[-1].split(\".\")[0] for pid in imgs]" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 43, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# helper function\n", 80 | "def copy_spacing_ori(src, dst):\n", 81 | " dst.SetSpacing(src.GetSpacing())\n", 82 | " dst.SetOrigin(src.GetOrigin())\n", 83 | " dst.SetDirection(src.GetDirection())\n", 84 | " return dst" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 44, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "./tmp_normalized/image_0.nii.gz has been save\n", 97 | "./tmp_normalized/label_0.nii.gz has been save\n", 98 | "./tmp_normalized/image_1.nii.gz has been save\n", 99 | "./tmp_normalized/label_1.nii.gz has been save\n", 100 | "./tmp_normalized/image_2.nii.gz has been save\n", 101 | "./tmp_normalized/label_2.nii.gz has been save\n", 102 | "./tmp_normalized/image_3.nii.gz has been save\n", 103 | "./tmp_normalized/label_3.nii.gz has been save\n", 104 | "./tmp_normalized/image_4.nii.gz has been save\n", 105 | "./tmp_normalized/label_4.nii.gz has been save\n", 106 | "./tmp_normalized/image_5.nii.gz has been save\n", 107 | "./tmp_normalized/label_5.nii.gz has been save\n", 108 | "./tmp_normalized/image_6.nii.gz has been save\n", 109 | "./tmp_normalized/label_6.nii.gz has been save\n", 110 | "./tmp_normalized/image_7.nii.gz has been save\n", 111 | "./tmp_normalized/label_7.nii.gz has been save\n", 112 | "./tmp_normalized/image_8.nii.gz has been save\n", 113 | "./tmp_normalized/label_8.nii.gz has been save\n", 114 | "./tmp_normalized/image_9.nii.gz has been save\n", 115 | "./tmp_normalized/label_9.nii.gz has been save\n", 116 | "./tmp_normalized/image_10.nii.gz has been save\n", 117 | "./tmp_normalized/label_10.nii.gz has been save\n", 118 | "./tmp_normalized/image_11.nii.gz has been save\n", 119 | "./tmp_normalized/label_11.nii.gz has been save\n", 120 | "./tmp_normalized/image_12.nii.gz has been save\n", 121 | "./tmp_normalized/label_12.nii.gz has been save\n", 122 | "./tmp_normalized/image_13.nii.gz has been save\n", 123 | "./tmp_normalized/label_13.nii.gz has been save\n", 124 | "./tmp_normalized/image_14.nii.gz has been save\n", 125 | "./tmp_normalized/label_14.nii.gz has been save\n", 126 | "./tmp_normalized/image_15.nii.gz has been save\n", 127 | "./tmp_normalized/label_15.nii.gz has been save\n", 128 | "./tmp_normalized/image_16.nii.gz has been save\n", 129 | "./tmp_normalized/label_16.nii.gz has been save\n", 130 | "./tmp_normalized/image_17.nii.gz has been save\n", 131 | "./tmp_normalized/label_17.nii.gz has been save\n", 132 | "./tmp_normalized/image_18.nii.gz has been save\n", 133 | "./tmp_normalized/label_18.nii.gz has been save\n", 134 | "./tmp_normalized/image_19.nii.gz has been save\n", 135 | "./tmp_normalized/label_19.nii.gz has been save\n", 136 | "./tmp_normalized/image_20.nii.gz has been save\n", 137 | "./tmp_normalized/label_20.nii.gz has been save\n", 138 | "./tmp_normalized/image_21.nii.gz has been save\n", 139 | "./tmp_normalized/label_21.nii.gz has been save\n", 140 | "./tmp_normalized/image_22.nii.gz has been save\n", 141 | "./tmp_normalized/label_22.nii.gz has been save\n", 142 | "./tmp_normalized/image_23.nii.gz has been save\n", 143 | "./tmp_normalized/label_23.nii.gz has been save\n", 144 | "./tmp_normalized/image_24.nii.gz has been save\n", 145 | "./tmp_normalized/label_24.nii.gz has been save\n", 146 | "./tmp_normalized/image_25.nii.gz has been save\n", 147 | "./tmp_normalized/label_25.nii.gz has been save\n", 148 | "./tmp_normalized/image_26.nii.gz has been save\n", 149 | "./tmp_normalized/label_26.nii.gz has been save\n", 150 | "./tmp_normalized/image_27.nii.gz has been save\n", 151 | "./tmp_normalized/label_27.nii.gz has been save\n", 152 | "./tmp_normalized/image_28.nii.gz has been save\n", 153 | "./tmp_normalized/label_28.nii.gz has been save\n", 154 | "./tmp_normalized/image_29.nii.gz has been save\n", 155 | "./tmp_normalized/label_29.nii.gz has been save\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "import copy\n", 161 | "scan_dir = OUT_FOLDER\n", 162 | "LIR = -125\n", 163 | "HIR = 275\n", 164 | "os.makedirs(scan_dir, exist_ok = True)\n", 165 | "\n", 166 | "reindex = 0\n", 167 | "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n", 168 | "\n", 169 | " img_obj = sitk.ReadImage( img_fid )\n", 170 | " seg_obj = sitk.ReadImage( seg_fid )\n", 171 | "\n", 172 | " array = sitk.GetArrayFromImage(img_obj)\n", 173 | "\n", 174 | " array[array > HIR] = HIR\n", 175 | " array[array < LIR] = LIR\n", 176 | " \n", 177 | " array = (array - array.min()) / (array.max() - array.min()) * 255.0\n", 178 | " \n", 179 | " # then normalize this\n", 180 | " \n", 181 | " wined_img = sitk.GetImageFromArray(array)\n", 182 | " wined_img = copy_spacing_ori(img_obj, wined_img)\n", 183 | " \n", 184 | " out_img_fid = os.path.join( scan_dir, f'image_{str(reindex)}.nii.gz' )\n", 185 | " out_lb_fid = os.path.join( scan_dir, f'label_{str(reindex)}.nii.gz' ) \n", 186 | " \n", 187 | " # then save\n", 188 | " sitk.WriteImage(wined_img, out_img_fid, True) \n", 189 | " sitk.WriteImage(seg_obj, out_lb_fid, True) \n", 190 | " print(\"{} has been save\".format(out_img_fid))\n", 191 | " print(\"{} has been save\".format(out_lb_fid))\n", 192 | " reindex += 1\n" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [] 201 | } 202 | ], 203 | "metadata": { 204 | "kernelspec": { 205 | "display_name": "Python 3", 206 | "language": "python", 207 | "name": "python3" 208 | }, 209 | "language_info": { 210 | "codemirror_mode": { 211 | "name": "ipython", 212 | "version": 3 213 | }, 214 | "file_extension": ".py", 215 | "mimetype": "text/x-python", 216 | "name": "python", 217 | "nbconvert_exporter": "python", 218 | "pygments_lexer": "ipython3", 219 | "version": "3.6.0" 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 2 224 | } 225 | -------------------------------------------------------------------------------- /data/CHAOST2/class_slice_index_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate class-slice indexing table for experiments\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is for experiment setting up for simulating few-shot image segmentation scenarios\n", 13 | "\n", 14 | "Input: pre-processed images and their ground-truth labels\n", 15 | "\n", 16 | "Output: a `json` file for class-slice indexing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n", 29 | "The autoreload extension is already loaded. To reload it, use:\n", 30 | " %reload_ext autoreload\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%reset\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2\n", 38 | "import numpy as np\n", 39 | "import os\n", 40 | "import glob\n", 41 | "import SimpleITK as sitk\n", 42 | "import sys\n", 43 | "import json\n", 44 | "sys.path.insert(0, '../../dataloaders/')\n", 45 | "import niftiio as nio" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "IMG_BNAME=\"./chaos_MR_T2_normalized/image_*.nii.gz\"\n", 55 | "SEG_BNAME=\"./chaos_MR_T2_normalized/label_*.nii.gz\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 5, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "imgs = glob.glob(IMG_BNAME)\n", 65 | "segs = glob.glob(SEG_BNAME)\n", 66 | "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n", 67 | "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0]) ) ]\n" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 6, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "['./chaos_MR_T2_normalized/image_1.nii.gz',\n", 79 | " './chaos_MR_T2_normalized/image_2.nii.gz',\n", 80 | " './chaos_MR_T2_normalized/image_3.nii.gz',\n", 81 | " './chaos_MR_T2_normalized/image_5.nii.gz',\n", 82 | " './chaos_MR_T2_normalized/image_8.nii.gz',\n", 83 | " './chaos_MR_T2_normalized/image_10.nii.gz',\n", 84 | " './chaos_MR_T2_normalized/image_13.nii.gz',\n", 85 | " './chaos_MR_T2_normalized/image_15.nii.gz',\n", 86 | " './chaos_MR_T2_normalized/image_19.nii.gz',\n", 87 | " './chaos_MR_T2_normalized/image_20.nii.gz',\n", 88 | " './chaos_MR_T2_normalized/image_21.nii.gz',\n", 89 | " './chaos_MR_T2_normalized/image_22.nii.gz',\n", 90 | " './chaos_MR_T2_normalized/image_31.nii.gz',\n", 91 | " './chaos_MR_T2_normalized/image_32.nii.gz',\n", 92 | " './chaos_MR_T2_normalized/image_33.nii.gz',\n", 93 | " './chaos_MR_T2_normalized/image_34.nii.gz',\n", 94 | " './chaos_MR_T2_normalized/image_36.nii.gz',\n", 95 | " './chaos_MR_T2_normalized/image_37.nii.gz',\n", 96 | " './chaos_MR_T2_normalized/image_38.nii.gz',\n", 97 | " './chaos_MR_T2_normalized/image_39.nii.gz']" 98 | ] 99 | }, 100 | "execution_count": 6, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "imgs" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 7, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "['./chaos_MR_T2_normalized/label_1.nii.gz',\n", 118 | " './chaos_MR_T2_normalized/label_2.nii.gz',\n", 119 | " './chaos_MR_T2_normalized/label_3.nii.gz',\n", 120 | " './chaos_MR_T2_normalized/label_5.nii.gz',\n", 121 | " './chaos_MR_T2_normalized/label_8.nii.gz',\n", 122 | " './chaos_MR_T2_normalized/label_10.nii.gz',\n", 123 | " './chaos_MR_T2_normalized/label_13.nii.gz',\n", 124 | " './chaos_MR_T2_normalized/label_15.nii.gz',\n", 125 | " './chaos_MR_T2_normalized/label_19.nii.gz',\n", 126 | " './chaos_MR_T2_normalized/label_20.nii.gz',\n", 127 | " './chaos_MR_T2_normalized/label_21.nii.gz',\n", 128 | " './chaos_MR_T2_normalized/label_22.nii.gz',\n", 129 | " './chaos_MR_T2_normalized/label_31.nii.gz',\n", 130 | " './chaos_MR_T2_normalized/label_32.nii.gz',\n", 131 | " './chaos_MR_T2_normalized/label_33.nii.gz',\n", 132 | " './chaos_MR_T2_normalized/label_34.nii.gz',\n", 133 | " './chaos_MR_T2_normalized/label_36.nii.gz',\n", 134 | " './chaos_MR_T2_normalized/label_37.nii.gz',\n", 135 | " './chaos_MR_T2_normalized/label_38.nii.gz',\n", 136 | " './chaos_MR_T2_normalized/label_39.nii.gz']" 137 | ] 138 | }, 139 | "execution_count": 7, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "segs" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 13, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "pid 1 finished!\n", 158 | "pid 2 finished!\n", 159 | "pid 3 finished!\n", 160 | "pid 5 finished!\n", 161 | "pid 8 finished!\n", 162 | "pid 10 finished!\n", 163 | "pid 13 finished!\n", 164 | "pid 15 finished!\n", 165 | "pid 19 finished!\n", 166 | "pid 20 finished!\n", 167 | "pid 21 finished!\n", 168 | "pid 22 finished!\n", 169 | "pid 31 finished!\n", 170 | "pid 32 finished!\n", 171 | "pid 33 finished!\n", 172 | "pid 34 finished!\n", 173 | "pid 36 finished!\n", 174 | "pid 37 finished!\n", 175 | "pid 38 finished!\n", 176 | "pid 39 finished!\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "classmap = {}\n", 182 | "LABEL_NAME = [\"BG\", \"LIVER\", \"RK\", \"LK\", \"SPLEEN\"] \n", 183 | "\n", 184 | "\n", 185 | "MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n", 186 | "\n", 187 | "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n", 188 | "for _lb in LABEL_NAME:\n", 189 | " classmap[_lb] = {}\n", 190 | " for _sid in segs:\n", 191 | " pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 192 | " classmap[_lb][pid] = []\n", 193 | "\n", 194 | "for seg in segs:\n", 195 | " pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n", 196 | " lb_vol = nio.read_nii_bysitk(seg)\n", 197 | " n_slice = lb_vol.shape[0]\n", 198 | " for slc in range(n_slice):\n", 199 | " for cls in range(len(LABEL_NAME)):\n", 200 | " if cls in lb_vol[slc, ...]:\n", 201 | " if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n", 202 | " classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n", 203 | " print(f'pid {str(pid)} finished!')\n", 204 | " \n", 205 | "with open(fid, 'w') as fopen:\n", 206 | " json.dump(classmap, fopen)\n", 207 | " fopen.close() \n", 208 | " " 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 12, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "with open(fid, 'w') as fopen:\n", 218 | " json.dump(classmap, fopen)\n", 219 | " fopen.close()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.6.0" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /dataloaders/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset classes for common uses 3 | Extended from vanilla PANet code by Wang et al. 4 | """ 5 | import random 6 | import torch 7 | 8 | from torch.utils.data import Dataset 9 | 10 | class BaseDataset(Dataset): 11 | """ 12 | Base Dataset 13 | Args: 14 | base_dir: 15 | dataset directory 16 | """ 17 | def __init__(self, base_dir): 18 | self._base_dir = base_dir 19 | self.aux_attrib = {} 20 | self.aux_attrib_args = {} 21 | self.ids = [] # must be overloaded in subclass 22 | 23 | def add_attrib(self, key, func, func_args): 24 | """ 25 | Add attribute to the data sample dict 26 | 27 | Args: 28 | key: 29 | key in the data sample dict for the new attribute 30 | e.g. sample['click_map'], sample['depth_map'] 31 | func: 32 | function to process a data sample and create an attribute (e.g. user clicks) 33 | func_args: 34 | extra arguments to pass, expected a dict 35 | """ 36 | if key in self.aux_attrib: 37 | raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key)) 38 | else: 39 | self.set_attrib(key, func, func_args) 40 | 41 | def set_attrib(self, key, func, func_args): 42 | """ 43 | Set attribute in the data sample dict 44 | 45 | Args: 46 | key: 47 | key in the data sample dict for the new attribute 48 | e.g. sample['click_map'], sample['depth_map'] 49 | func: 50 | function to process a data sample and create an attribute (e.g. user clicks) 51 | func_args: 52 | extra arguments to pass, expected a dict 53 | """ 54 | self.aux_attrib[key] = func 55 | self.aux_attrib_args[key] = func_args 56 | 57 | def del_attrib(self, key): 58 | """ 59 | Remove attribute in the data sample dict 60 | 61 | Args: 62 | key: 63 | key in the data sample dict 64 | """ 65 | self.aux_attrib.pop(key) 66 | self.aux_attrib_args.pop(key) 67 | 68 | def subsets(self, sub_ids, sub_args_lst=None): 69 | """ 70 | Create subsets by ids 71 | 72 | Args: 73 | sub_ids: 74 | a sequence of sequences, each sequence contains data ids for one subset 75 | sub_args_lst: 76 | a list of args for some subset-specific auxiliary attribute function 77 | """ 78 | 79 | indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids] 80 | if sub_args_lst is not None: 81 | subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args) 82 | for index, args in zip(indices, sub_args_lst)] 83 | else: 84 | subsets = [Subset(dataset=self, indices=index) for index in indices] 85 | return subsets 86 | 87 | def __len__(self): 88 | pass 89 | 90 | def __getitem__(self, idx): 91 | pass 92 | 93 | 94 | class ReloadPairedDataset(Dataset): 95 | """ 96 | Make pairs of data from dataset 97 | Eable only loading part of the entire data in each epoach and then reload to the next part 98 | Args: 99 | datasets: 100 | source datasets, expect a list of Dataset. 101 | Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan 102 | n_elements: 103 | number of elements in a pair 104 | curr_max_iters: 105 | number of pairs in an epoch 106 | pair_based_transforms: 107 | some transformation performed on a pair basis, expect a list of functions, 108 | each function takes a pair sample and return a transformed one. 109 | """ 110 | def __init__(self, datasets, n_elements, curr_max_iters, 111 | pair_based_transforms=None): 112 | super().__init__() 113 | self.datasets = datasets 114 | self.n_datasets = len(self.datasets) 115 | self.n_data = [len(dataset) for dataset in self.datasets] 116 | self.n_elements = n_elements 117 | self.curr_max_iters = curr_max_iters 118 | self.pair_based_transforms = pair_based_transforms 119 | self.update_index() 120 | 121 | def update_index(self): 122 | """ 123 | update the order of batches for the next episode 124 | """ 125 | 126 | # update number of elements for each subset 127 | if hasattr(self, 'indices'): 128 | n_data_old = self.n_data # DEBUG 129 | self.n_data = [len(dataset) for dataset in self.datasets] 130 | 131 | if isinstance(self.n_elements, list): 132 | self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use 133 | for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use 134 | for i_iter in range(self.curr_max_iters)] # sample iterations 135 | 136 | elif self.n_elements > self.n_datasets: 137 | raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets") 138 | else: 139 | self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx])) 140 | for dataset_idx in random.sample(range(self.n_datasets), 141 | k=n_elements)] 142 | for i in range(curr_max_iters)] 143 | 144 | def __len__(self): 145 | return self.curr_max_iters 146 | 147 | def __getitem__(self, idx): 148 | sample = [self.datasets[dataset_idx][data_idx] 149 | for dataset_idx, data_idx in self.indices[idx]] 150 | if self.pair_based_transforms is not None: 151 | for transform, args in self.pair_based_transforms: 152 | sample = transform(sample, **args) 153 | return sample 154 | 155 | class Subset(Dataset): 156 | """ 157 | Subset of a dataset at specified indices. Used for seperating a dataset by class in our context 158 | 159 | Args: 160 | dataset: 161 | The whole Dataset 162 | indices: 163 | Indices of samples of the current class in the entire dataset 164 | sub_attrib_args: 165 | Subset-specific arguments for attribute functions, expected a dict 166 | """ 167 | def __init__(self, dataset, indices, sub_attrib_args=None): 168 | self.dataset = dataset 169 | self.indices = indices 170 | self.sub_attrib_args = sub_attrib_args 171 | 172 | def __getitem__(self, idx): 173 | if self.sub_attrib_args is not None: 174 | for key in self.sub_attrib_args: 175 | # Make sure the dataset already has the corresponding attributes 176 | # Here we only make the arguments subset dependent 177 | # (i.e. pass different arguments for each subset) 178 | self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key]) 179 | return self.dataset[self.indices[idx]] 180 | 181 | def __len__(self): 182 | return len(self.indices) 183 | 184 | class ValidationDataset(Dataset): 185 | """ 186 | Dataset for validation 187 | 188 | Args: 189 | dataset: 190 | source dataset with a __getitem__ method 191 | test_classes: 192 | test classes 193 | npart: int. number of parts, used for evaluation when assigning support images 194 | 195 | """ 196 | def __init__(self, dataset, test_classes: list, npart: int): 197 | super().__init__() 198 | self.dataset = dataset 199 | self.__curr_cls = None 200 | self.test_classes = test_classes 201 | self.dataset.aux_attrib = None 202 | self.npart = npart 203 | 204 | def set_curr_cls(self, curr_cls): 205 | assert curr_cls in self.test_classes 206 | self.__curr_cls = curr_cls 207 | 208 | def get_curr_cls(self): 209 | return self.__curr_cls 210 | 211 | def read_dataset(self): 212 | """ 213 | override original read_dataset to allow reading with z_margin 214 | """ 215 | raise NotImplementedError 216 | 217 | def __len__(self): 218 | return len(self.dataset) 219 | 220 | def label_strip(self, label): 221 | """ 222 | mask unrelated labels out 223 | """ 224 | out = torch.where(label == self.__curr_cls, 225 | torch.ones_like(label), torch.zeros_like(label)) 226 | return out 227 | 228 | def __getitem__(self, idx): 229 | if self.__curr_cls is None: 230 | raise Exception("Please initialize current class first") 231 | 232 | sample = self.dataset[idx] 233 | sample["label"] = self.label_strip( sample["label"] ) 234 | sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy() 235 | 236 | labelname = self.dataset.all_label_names[self.__curr_cls] 237 | z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) 238 | z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) 239 | sample["z_min"], sample["z_max"] = z_min, z_max 240 | try: 241 | part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart)) 242 | except: 243 | part_assign = 0 244 | print("###### DATASET: support only have one valid slice ######") 245 | if part_assign < 0: 246 | part_assign = 0 247 | elif part_assign >= self.npart: 248 | part_assign = self.npart - 1 249 | sample["part_assign"] = part_assign 250 | 251 | return sample 252 | 253 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation script 3 | """ 4 | import os 5 | import shutil 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim 9 | from torch.utils.data import DataLoader 10 | from torch.optim.lr_scheduler import MultiStepLR 11 | import torch.backends.cudnn as cudnn 12 | import numpy as np 13 | 14 | from models.grid_proto_fewshot import FewShotSeg 15 | 16 | from dataloaders.dev_customized_med import med_fewshot_val 17 | from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset 18 | from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset 19 | from dataloaders.dataset_utils import DATASET_INFO, get_normalize_op 20 | from dataloaders.niftiio import convert_to_sitk 21 | 22 | from util.metric import Metric 23 | 24 | from config_ssl_upload import ex 25 | 26 | import tqdm 27 | import SimpleITK as sitk 28 | from torchvision.utils import make_grid 29 | 30 | # config pre-trained model caching path 31 | os.environ['TORCH_HOME'] = "./pretrained_model" 32 | 33 | @ex.automain 34 | def main(_run, _config, _log): 35 | if _run.observers: 36 | os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) 37 | for source_file, _ in _run.experiment_info['sources']: 38 | os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), 39 | exist_ok=True) 40 | _run.observers[0].save_file(source_file, f'source/{source_file}') 41 | shutil.rmtree(f'{_run.observers[0].basedir}/_sources') 42 | 43 | cudnn.enabled = True 44 | cudnn.benchmark = True 45 | torch.cuda.set_device(device=_config['gpu_id']) 46 | torch.set_num_threads(1) 47 | 48 | _log.info(f'###### Reload model {_config["reload_model_path"]} ######') 49 | model = FewShotSeg(pretrained_path = _config['reload_model_path'], cfg=_config['model']) 50 | model = model.cuda() 51 | model.eval() 52 | 53 | _log.info('###### Load data ######') 54 | ### Training set 55 | data_name = _config['dataset'] 56 | if data_name == 'SABS_Superpix': 57 | baseset_name = 'SABS' 58 | max_label = 13 59 | elif data_name == 'C0_Superpix': 60 | raise NotImplementedError 61 | baseset_name = 'C0' 62 | max_label = 3 63 | elif data_name == 'CHAOST2_Superpix': 64 | baseset_name = 'CHAOST2' 65 | max_label = 4 66 | else: 67 | raise ValueError(f'Dataset: {data_name} not found') 68 | 69 | test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP']['pa_all'] - DATASET_INFO[baseset_name]['LABEL_GROUP'][_config["label_sets"]] 70 | 71 | ### Transforms for data augmentation 72 | te_transforms = None 73 | 74 | assert _config['scan_per_load'] < 0 # by default we load the entire dataset directly 75 | 76 | _log.info(f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######') 77 | _log.info(f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######') 78 | 79 | if baseset_name == 'SABS': # for CT we need to know statistics of 80 | tr_parent = SuperpixelDataset( # base dataset 81 | which_dataset = baseset_name, 82 | base_dir=_config['path'][data_name]['data_dir'], 83 | idx_split = _config['eval_fold'], 84 | mode='train', 85 | min_fg=str(_config["min_fg_data"]), # dummy entry for superpixel dataset 86 | transforms=None, 87 | nsup = _config['task']['n_shots'], 88 | scan_per_load = _config['scan_per_load'], 89 | exclude_list = _config["exclude_cls_list"], 90 | superpix_scale = _config["superpix_scale"], 91 | fix_length = _config["max_iters_per_load"] if (data_name == 'C0_Superpix') or (data_name == 'CHAOST2_Superpix') else None 92 | ) 93 | norm_func = tr_parent.norm_func 94 | else: 95 | norm_func = get_normalize_op(modality = 'MR', fids = None) 96 | 97 | 98 | te_dataset, te_parent = med_fewshot_val( 99 | dataset_name = baseset_name, 100 | base_dir=_config['path'][baseset_name]['data_dir'], 101 | idx_split = _config['eval_fold'], 102 | scan_per_load = _config['scan_per_load'], 103 | act_labels=test_labels, 104 | npart = _config['task']['npart'], 105 | nsup = _config['task']['n_shots'], 106 | extern_normalize_func = norm_func 107 | ) 108 | 109 | ### dataloaders 110 | testloader = DataLoader( 111 | te_dataset, 112 | batch_size = 1, 113 | shuffle=False, 114 | num_workers=1, 115 | pin_memory=False, 116 | drop_last=False 117 | ) 118 | 119 | _log.info('###### Set validation nodes ######') 120 | mar_val_metric_node = Metric(max_label=max_label, n_scans= len(te_dataset.dataset.pid_curr_load) - _config['task']['n_shots']) 121 | 122 | _log.info('###### Starting validation ######') 123 | model.eval() 124 | mar_val_metric_node.reset() 125 | 126 | with torch.no_grad(): 127 | save_pred_buffer = {} # indexed by class 128 | 129 | for curr_lb in test_labels: 130 | te_dataset.set_curr_cls(curr_lb) 131 | support_batched = te_parent.get_support(curr_class = curr_lb, class_idx = [curr_lb], scan_idx = _config["support_idx"], npart=_config['task']['npart']) 132 | 133 | # way(1 for now) x part x shot x 3 x H x W] # 134 | support_images = [[shot.cuda() for shot in way] 135 | for way in support_batched['support_images']] # way x part x [shot x C x H x W] 136 | suffix = 'mask' 137 | support_fg_mask = [[shot[f'fg_{suffix}'].float().cuda() for shot in way] 138 | for way in support_batched['support_mask']] 139 | support_bg_mask = [[shot[f'bg_{suffix}'].float().cuda() for shot in way] 140 | for way in support_batched['support_mask']] 141 | 142 | curr_scan_count = -1 # counting for current scan 143 | _lb_buffer = {} # indexed by scan 144 | 145 | last_qpart = 0 # used as indicator for adding result to buffer 146 | 147 | for sample_batched in testloader: 148 | 149 | _scan_id = sample_batched["scan_id"][0] # we assume batch size for query is 1 150 | if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query 151 | continue 152 | if sample_batched["is_start"]: 153 | ii = 0 154 | curr_scan_count += 1 155 | _scan_id = sample_batched["scan_id"][0] 156 | outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"] 157 | outsize = (256, 256, outsize[0]) # original image read by itk: Z, H, W, in prediction we use H, W, Z 158 | _pred = np.zeros( outsize ) 159 | _pred.fill(np.nan) 160 | 161 | q_part = sample_batched["part_assign"] # the chunck of query, for assignment with support 162 | query_images = [sample_batched['image'].cuda()] 163 | query_labels = torch.cat([ sample_batched['label'].cuda()], dim=0) 164 | 165 | # [way, [part, [shot x C x H x W]]] -> 166 | sup_img_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_images[0][q_part]]] # way(1) x shot x [B(1) x C x H x W] 167 | sup_fgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_fg_mask[0][q_part]]] 168 | sup_bgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_bg_mask[0][q_part]]] 169 | 170 | query_pred, _, _, assign_mats = model( sup_img_part , sup_fgm_part, sup_bgm_part, query_images, isval = True, val_wsize = _config["val_wsize"] ) 171 | 172 | query_pred = np.array(query_pred.argmax(dim=1)[0].cpu()) 173 | _pred[..., ii] = query_pred.copy() 174 | 175 | if (sample_batched["z_id"] - sample_batched["z_max"] <= _config['z_margin']) and (sample_batched["z_id"] - sample_batched["z_min"] >= -1 * _config['z_margin']): 176 | mar_val_metric_node.record(query_pred, np.array(query_labels[0].cpu()), labels=[curr_lb], n_scan=curr_scan_count) 177 | else: 178 | pass 179 | 180 | ii += 1 181 | # now check data format 182 | if sample_batched["is_end"]: 183 | if _config['dataset'] != 'C0': 184 | _lb_buffer[_scan_id] = _pred.transpose(2,0,1) # H, W, Z -> to Z H W 185 | else: 186 | lb_buffer[_scan_id] = _pred 187 | 188 | save_pred_buffer[str(curr_lb)] = _lb_buffer 189 | 190 | ### save results 191 | for curr_lb, _preds in save_pred_buffer.items(): 192 | for _scan_id, _pred in _preds.items(): 193 | _pred *= float(curr_lb) 194 | itk_pred = convert_to_sitk(_pred, te_dataset.dataset.info_by_scan[_scan_id]) 195 | fid = os.path.join(f'{_run.observers[0].dir}/interm_preds', f'scan_{_scan_id}_label_{curr_lb}.nii.gz') 196 | sitk.WriteImage(itk_pred, fid, True) 197 | _log.info(f'###### {fid} has been saved ######') 198 | 199 | del save_pred_buffer 200 | 201 | del sample_batched, support_images, support_bg_mask, query_images, query_labels, query_pred 202 | 203 | # compute dice scores by scan 204 | m_classDice,_, m_meanDice,_, m_rawDice = mar_val_metric_node.get_mDice(labels=sorted(test_labels), n_scan=None, give_raw = True) 205 | 206 | m_classPrec,_, m_meanPrec,_, m_classRec,_, m_meanRec,_, m_rawPrec, m_rawRec = mar_val_metric_node.get_mPrecRecall(labels=sorted(test_labels), n_scan=None, give_raw = True) 207 | 208 | mar_val_metric_node.reset() # reset this calculation node 209 | 210 | # write validation result to log file 211 | _run.log_scalar('mar_val_batches_classDice', m_classDice.tolist()) 212 | _run.log_scalar('mar_val_batches_meanDice', m_meanDice.tolist()) 213 | _run.log_scalar('mar_val_batches_rawDice', m_rawDice.tolist()) 214 | 215 | _run.log_scalar('mar_val_batches_classPrec', m_classPrec.tolist()) 216 | _run.log_scalar('mar_val_batches_meanPrec', m_meanPrec.tolist()) 217 | _run.log_scalar('mar_val_batches_rawPrec', m_rawPrec.tolist()) 218 | 219 | _run.log_scalar('mar_val_batches_classRec', m_classRec.tolist()) 220 | _run.log_scalar('mar_val_al_batches_meanRec', m_meanRec.tolist()) 221 | _run.log_scalar('mar_val_al_batches_rawRec', m_rawRec.tolist()) 222 | 223 | _log.info(f'mar_val batches classDice: {m_classDice}') 224 | _log.info(f'mar_val batches meanDice: {m_meanDice}') 225 | 226 | _log.info(f'mar_val batches classPrec: {m_classPrec}') 227 | _log.info(f'mar_val batches meanPrec: {m_meanPrec}') 228 | 229 | _log.info(f'mar_val batches classRec: {m_classRec}') 230 | _log.info(f'mar_val batches meanRec: {m_meanRec}') 231 | 232 | print("============ ============") 233 | 234 | _log.info(f'End of validation') 235 | return 1 236 | 237 | 238 | -------------------------------------------------------------------------------- /dataloaders/dev_customized_med.py: -------------------------------------------------------------------------------- 1 | """ 2 | Customized dataset. Extended from vanilla PANet script by Wang et al. 3 | """ 4 | 5 | import os 6 | import random 7 | import torch 8 | import numpy as np 9 | 10 | from dataloaders.common import ReloadPairedDataset, ValidationDataset 11 | from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset 12 | 13 | def attrib_basic(_sample, class_id): 14 | """ 15 | Add basic attribute 16 | Args: 17 | _sample: data sample 18 | class_id: class label asscociated with the data 19 | (sometimes indicting from which subset the data are drawn) 20 | """ 21 | return {'class_id': class_id} 22 | 23 | def getMaskOnly(label, class_id, class_ids): 24 | """ 25 | Generate FG/BG mask from the segmentation mask 26 | 27 | Args: 28 | label: 29 | semantic mask 30 | scribble: 31 | scribble mask 32 | class_id: 33 | semantic class of interest 34 | class_ids: 35 | all class id in this episode 36 | """ 37 | # Dense Mask 38 | fg_mask = torch.where(label == class_id, 39 | torch.ones_like(label), torch.zeros_like(label)) 40 | bg_mask = torch.where(label != class_id, 41 | torch.ones_like(label), torch.zeros_like(label)) 42 | for class_id in class_ids: 43 | bg_mask[label == class_id] = 0 44 | 45 | return {'fg_mask': fg_mask, 46 | 'bg_mask': bg_mask} 47 | 48 | def getMasks(*args, **kwargs): 49 | raise NotImplementedError 50 | 51 | def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True): 52 | """ 53 | Postprocess paired sample for fewshot settings 54 | For now only 1-way is tested but we leave multi-way possible (inherited from original PANet) 55 | 56 | Args: 57 | paired_sample: 58 | data sample from a PairedDataset 59 | n_ways: 60 | n-way few-shot learning 61 | n_shots: 62 | n-shot few-shot learning 63 | cnt_query: 64 | number of query images for each class in the support set 65 | coco: 66 | MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension 67 | mask_only: 68 | only give masks and no scribbles/ instances. Suitable for medical images (for now) 69 | """ 70 | if not mask_only: 71 | raise NotImplementedError 72 | ###### Compose the support and query image list ###### 73 | cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries 74 | 75 | # support class ids 76 | class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query) 77 | 78 | # support images 79 | support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)] 80 | for i in range(n_ways)] # fetch support images for each class 81 | 82 | # support image labels 83 | if coco: 84 | support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]] 85 | for j in range(n_shots)] for i in range(n_ways)] 86 | else: 87 | support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)] 88 | for i in range(n_ways)] 89 | 90 | if not mask_only: 91 | support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)] 92 | for i in range(n_ways)] 93 | support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)] 94 | for i in range(n_ways)] 95 | else: 96 | support_insts = [] 97 | 98 | # query images, masks and class indices 99 | query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways) 100 | for j in range(cnt_query[i])] 101 | if coco: 102 | query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]] 103 | for i in range(n_ways) for j in range(cnt_query[i])] 104 | else: 105 | query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways) 106 | for j in range(cnt_query[i])] 107 | query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1 108 | for x in set(np.unique(query_label)) & set(class_ids)]) 109 | for query_label in query_labels] 110 | 111 | ###### Generate support image masks ###### 112 | if not mask_only: 113 | support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot], 114 | class_ids[way], class_ids) 115 | for shot in range(n_shots)] for way in range(n_ways)] 116 | else: 117 | support_mask = [[getMaskOnly(support_labels[way][shot], 118 | class_ids[way], class_ids) 119 | for shot in range(n_shots)] for way in range(n_ways)] 120 | 121 | ###### Generate query label (class indices in one episode, i.e. the ground truth)###### 122 | query_labels_tmp = [torch.zeros_like(x) for x in query_labels] 123 | for i, query_label_tmp in enumerate(query_labels_tmp): 124 | query_label_tmp[query_labels[i] == 255] = 255 125 | for j in range(n_ways): 126 | query_label_tmp[query_labels[i] == class_ids[j]] = j + 1 127 | 128 | ###### Generate query mask for each semantic class (including BG) ###### 129 | # BG class 130 | query_masks = [[torch.where(query_label == 0, 131 | torch.ones_like(query_label), 132 | torch.zeros_like(query_label))[None, ...],] 133 | for query_label in query_labels] 134 | # Other classes in query image 135 | for i, query_label in enumerate(query_labels): 136 | for idx in query_cls_idx[i][1:]: 137 | mask = torch.where(query_label == class_ids[idx - 1], 138 | torch.ones_like(query_label), 139 | torch.zeros_like(query_label))[None, ...] 140 | query_masks[i].append(mask) 141 | 142 | 143 | return {'class_ids': class_ids, 144 | 'support_images': support_images, 145 | 'support_mask': support_mask, 146 | 'support_inst': support_insts, # leave these interfaces 147 | 'support_scribbles': support_scribbles, 148 | 149 | 'query_images': query_images, 150 | 'query_labels': query_labels_tmp, 151 | 'query_masks': query_masks, 152 | 'query_cls_idx': query_cls_idx, 153 | } 154 | 155 | 156 | def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load, 157 | transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs): 158 | """ 159 | Dataset wrapper 160 | Args: 161 | dataset_name: 162 | indicates what dataset to use 163 | base_dir: 164 | dataset directory 165 | mode: 166 | which mode to use 167 | choose from ('train', 'val', 'trainval', 'trainaug') 168 | idx_split: 169 | index of split 170 | scan_per_load: 171 | number of scans to load into memory as the dataset is large 172 | use that together with reload_buffer 173 | transforms: 174 | transformations to be performed on images/masks 175 | act_labels: 176 | active labels involved in training process. Should be a subset of all labels 177 | n_ways: 178 | n-way few-shot learning, should be no more than # of object class labels 179 | n_shots: 180 | n-shot few-shot learning 181 | max_iters_per_load: 182 | number of pairs per load (epoch size) 183 | n_queries: 184 | number of query images 185 | fix_parent_len: 186 | fixed length of the parent dataset 187 | """ 188 | med_set = ManualAnnoDataset 189 | 190 | 191 | mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\ 192 | scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\ 193 | exclude_list = exclude_list, **kwargs) 194 | 195 | mydataset.add_attrib('basic', attrib_basic, {}) 196 | 197 | # Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside 198 | subsets = mydataset.subsets([{'basic': {'class_id': ii}} 199 | for ii, _ in enumerate(mydataset.label_name)]) 200 | 201 | # Choose the classes of queries 202 | cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways) 203 | # Number of queries for each way 204 | # Set the number of images for each class 205 | n_elements = [n_shots + x for x in cnt_query] # supports + [i] queries 206 | # Create paired dataset. We do not include background. 207 | paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load, 208 | pair_based_transforms=[ 209 | (fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots, 210 | 'cnt_query': cnt_query, 'mask_only': True})]) 211 | return paired_data, mydataset 212 | 213 | def update_loader_dset(loader, parent_set): 214 | """ 215 | Update data loader and the parent dataset behind 216 | Args: 217 | loader: actual dataloader 218 | parent_set: parent dataset which actually stores the data 219 | """ 220 | parent_set.reload_buffer() 221 | loader.dataset.update_index() 222 | print(f'###### Loader and dataset have been updated ######' ) 223 | 224 | def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, **kwargs): 225 | """ 226 | validation set for med images 227 | Args: 228 | dataset_name: 229 | indicates what dataset to use 230 | base_dir: 231 | SABS dataset directory 232 | mode: (original split) 233 | which split to use 234 | choose from ('train', 'val', 'trainval', 'trainaug') 235 | idx_split: 236 | index of split 237 | scan_per_batch: 238 | number of scans to load into memory as the dataset is large 239 | use that together with reload_buffer 240 | act_labels: 241 | actual labels involved in training process. Should be a subset of all labels 242 | npart: number of chunks for splitting a 3d volume 243 | nsup: number of support scans, equivalent to nshot 244 | """ 245 | mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = 'val', scan_per_load = scan_per_load, transforms=None, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs) 246 | mydataset.add_attrib('basic', attrib_basic, {}) 247 | 248 | valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart) 249 | 250 | return valset, mydataset 251 | 252 | -------------------------------------------------------------------------------- /models/grid_proto_fewshot.py: -------------------------------------------------------------------------------- 1 | """ 2 | ALPNet 3 | """ 4 | from collections import OrderedDict 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .alpmodule import MultiProtoAsConv 10 | from .backbone.torchvision_backbones import TVDeeplabRes101Encoder 11 | # DEBUG 12 | from pdb import set_trace 13 | 14 | import pickle 15 | import torchvision 16 | 17 | # options for type of prototypes 18 | FG_PROT_MODE = 'gridconv+' # using both local and global prototype 19 | BG_PROT_MODE = 'gridconv' # using local prototype only. Also 'mask' refers to using global prototype only (as done in vanilla PANet) 20 | 21 | # thresholds for deciding class of prototypes 22 | FG_THRESH = 0.95 23 | BG_THRESH = 0.95 24 | 25 | class FewShotSeg(nn.Module): 26 | """ 27 | ALPNet 28 | Args: 29 | in_channels: Number of input channels 30 | cfg: Model configurations 31 | """ 32 | def __init__(self, in_channels=3, pretrained_path=None, cfg=None): 33 | super(FewShotSeg, self).__init__() 34 | self.pretrained_path = pretrained_path 35 | self.config = cfg or {'align': False} 36 | self.get_encoder(in_channels) 37 | self.get_cls() 38 | 39 | def get_encoder(self, in_channels): 40 | # if self.config['which_model'] == 'deeplab_res101': 41 | if self.config['which_model'] == 'dlfcn_res101': 42 | use_coco_init = self.config['use_coco_init'] 43 | self.encoder = TVDeeplabRes101Encoder(use_coco_init) 44 | 45 | else: 46 | raise NotImplementedError(f'Backbone network {self.config["which_model"]} not implemented') 47 | 48 | if self.pretrained_path: 49 | self.load_state_dict(torch.load(self.pretrained_path)) 50 | print(f'###### Pre-trained model f{self.pretrained_path} has been loaded ######') 51 | 52 | def get_cls(self): 53 | """ 54 | Obtain the similarity-based classifier 55 | """ 56 | proto_hw = self.config["proto_grid_size"] 57 | feature_hw = self.config["feature_hw"] 58 | assert self.config['cls_name'] == 'grid_proto' 59 | if self.config['cls_name'] == 'grid_proto': 60 | self.cls_unit = MultiProtoAsConv(proto_grid = [proto_hw, proto_hw], feature_hw = self.config["feature_hw"]) # when treating it as ordinary prototype 61 | else: 62 | raise NotImplementedError(f'Classifier {self.config["cls_name"]} not implemented') 63 | 64 | def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz = False): 65 | """ 66 | Args: 67 | supp_imgs: support images 68 | way x shot x [B x 3 x H x W], list of lists of tensors 69 | fore_mask: foreground masks for support images 70 | way x shot x [B x H x W], list of lists of tensors 71 | back_mask: background masks for support images 72 | way x shot x [B x H x W], list of lists of tensors 73 | qry_imgs: query images 74 | N x [B x 3 x H x W], list of tensors 75 | show_viz: return the visualization dictionary 76 | """ 77 | # ('Please go through this piece of code carefully') 78 | n_ways = len(supp_imgs) 79 | n_shots = len(supp_imgs[0]) 80 | n_queries = len(qry_imgs) 81 | 82 | assert n_ways == 1, "Multi-shot has not been implemented yet" # NOTE: actual shot in support goes in batch dimension 83 | assert n_queries == 1 84 | 85 | sup_bsize = supp_imgs[0][0].shape[0] 86 | img_size = supp_imgs[0][0].shape[-2:] 87 | qry_bsize = qry_imgs[0].shape[0] 88 | 89 | assert sup_bsize == qry_bsize == 1 90 | 91 | imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs] 92 | + [torch.cat(qry_imgs, dim=0),], dim=0) 93 | 94 | img_fts = self.encoder(imgs_concat, low_level = False) 95 | fts_size = img_fts.shape[-2:] 96 | 97 | supp_fts = img_fts[:n_ways * n_shots * sup_bsize].view( 98 | n_ways, n_shots, sup_bsize, -1, *fts_size) # Wa x Sh x B x C x H' x W' 99 | qry_fts = img_fts[n_ways * n_shots * sup_bsize:].view( 100 | n_queries, qry_bsize, -1, *fts_size) # N x B x C x H' x W' 101 | fore_mask = torch.stack([torch.stack(way, dim=0) 102 | for way in fore_mask], dim=0) # Wa x Sh x B x H' x W' 103 | fore_mask = torch.autograd.Variable(fore_mask, requires_grad = True) 104 | back_mask = torch.stack([torch.stack(way, dim=0) 105 | for way in back_mask], dim=0) # Wa x Sh x B x H' x W' 106 | 107 | ###### Compute loss ###### 108 | align_loss = 0 109 | outputs = [] 110 | visualizes = [] # the buffer for visualization 111 | 112 | for epi in range(1): # batch dimension, fixed to 1 113 | fg_masks = [] # keep the way part 114 | 115 | ''' 116 | for way in range(n_ways): 117 | # note: index of n_ways starts from 0 118 | mean_sup_ft = supp_fts[way].mean(dim = 0) # [ nb, C, H, W]. Just assume batch size is 1 as pytorch only allows this 119 | mean_sup_msk = F.interpolate(fore_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') 120 | fg_masks.append( mean_sup_msk ) 121 | 122 | mean_bg_msk = F.interpolate(back_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') # [nb, C, H, W] 123 | ''' 124 | # re-interpolate support mask to the same size as support feature 125 | res_fg_msk = torch.stack([F.interpolate(fore_mask_w, size = fts_size, mode = 'bilinear') for fore_mask_w in fore_mask], dim = 0) # [nway, ns, nb, nh', nw'] 126 | res_bg_msk = torch.stack([F.interpolate(back_mask_w, size = fts_size, mode = 'bilinear') for back_mask_w in back_mask], dim = 0) # [nway, ns, nb, nh', nw'] 127 | 128 | 129 | scores = [] 130 | assign_maps = [] 131 | bg_sim_maps = [] 132 | fg_sim_maps = [] 133 | 134 | _raw_score, _, aux_attr = self.cls_unit(qry_fts, supp_fts, res_bg_msk, mode = BG_PROT_MODE, thresh = BG_THRESH, isval = isval, val_wsize = val_wsize, vis_sim = show_viz ) 135 | 136 | scores.append(_raw_score) 137 | assign_maps.append(aux_attr['proto_assign']) 138 | if show_viz: 139 | bg_sim_maps.append(aux_attr['raw_local_sims']) 140 | 141 | for way, _msk in enumerate(res_fg_msk): 142 | _raw_score, _, aux_attr = self.cls_unit(qry_fts, supp_fts, _msk.unsqueeze(0), mode = FG_PROT_MODE if F.avg_pool2d(_msk, 4).max() >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask', thresh = FG_THRESH, isval = isval, val_wsize = val_wsize, vis_sim = show_viz ) 143 | 144 | scores.append(_raw_score) 145 | if show_viz: 146 | fg_sim_maps.append(aux_attr['raw_local_sims']) 147 | 148 | pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W' 149 | outputs.append(F.interpolate(pred, size=img_size, mode='bilinear')) 150 | 151 | ###### Prototype alignment loss ###### 152 | if self.config['align'] and self.training: 153 | align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi], 154 | fore_mask[:, :, epi], back_mask[:, :, epi]) 155 | align_loss += align_loss_epi 156 | output = torch.stack(outputs, dim=1) # N x B x (1 + Wa) x H x W 157 | output = output.view(-1, *output.shape[2:]) 158 | assign_maps = torch.stack(assign_maps, dim = 1) 159 | bg_sim_maps = torch.stack(bg_sim_maps, dim = 1) if show_viz else None 160 | fg_sim_maps = torch.stack(fg_sim_maps, dim = 1) if show_viz else None 161 | 162 | return output, align_loss / sup_bsize, [bg_sim_maps, fg_sim_maps], assign_maps 163 | 164 | 165 | # Batch was at the outer loop 166 | def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask): 167 | """ 168 | Compute the loss for the prototype alignment branch 169 | 170 | Args: 171 | qry_fts: embedding features for query images 172 | expect shape: N x C x H' x W' 173 | pred: predicted segmentation score 174 | expect shape: N x (1 + Wa) x H x W 175 | supp_fts: embedding fatures for support images 176 | expect shape: Wa x Sh x C x H' x W' 177 | fore_mask: foreground masks for support images 178 | expect shape: way x shot x H x W 179 | back_mask: background masks for support images 180 | expect shape: way x shot x H x W 181 | """ 182 | n_ways, n_shots = len(fore_mask), len(fore_mask[0]) 183 | 184 | # Masks for getting query prototype 185 | pred_mask = pred.argmax(dim=1).unsqueeze(0) #1 x N x H' x W' 186 | binary_masks = [pred_mask == i for i in range(1 + n_ways)] 187 | 188 | # skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0] 189 | # FIXME: fix this in future we here make a stronger assumption that a positive class must be there to avoid undersegmentation/ lazyness 190 | skip_ways = [] 191 | 192 | ### added for matching dimensions to the new data format 193 | qry_fts = qry_fts.unsqueeze(0).unsqueeze(2) # added to nway(1) and nb(1) 194 | 195 | ### end of added part 196 | 197 | loss = [] 198 | for way in range(n_ways): 199 | if way in skip_ways: 200 | continue 201 | # Get the query prototypes 202 | for shot in range(n_shots): 203 | img_fts = supp_fts[way: way + 1, shot: shot + 1] # actual local query [way(1), nb(1, nb is now nshot), nc, h, w] 204 | 205 | qry_pred_fg_msk = F.interpolate(binary_masks[way + 1].float(), size = img_fts.shape[-2:], mode = 'bilinear') # [1 (way), n (shot), h, w] 206 | 207 | # background 208 | qry_pred_bg_msk = F.interpolate(binary_masks[0].float(), size = img_fts.shape[-2:], mode = 'bilinear') # 1, n, h ,w 209 | scores = [] 210 | 211 | _raw_score_bg, _, _ = self.cls_unit(qry = img_fts, sup_x = qry_fts, sup_y = qry_pred_bg_msk.unsqueeze(-3), mode = BG_PROT_MODE, thresh = BG_THRESH ) 212 | 213 | scores.append(_raw_score_bg) 214 | 215 | _raw_score_fg, _, _ = self.cls_unit(qry = img_fts, sup_x = qry_fts, sup_y = qry_pred_fg_msk.unsqueeze(-3), mode = FG_PROT_MODE if F.avg_pool2d(qry_pred_fg_msk, 4).max() >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask', thresh = FG_THRESH ) 216 | scores.append(_raw_score_fg) 217 | 218 | supp_pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W' 219 | supp_pred = F.interpolate(supp_pred, size=fore_mask.shape[-2:], mode='bilinear') 220 | 221 | # Construct the support Ground-Truth segmentation 222 | supp_label = torch.full_like(fore_mask[way, shot], 255, 223 | device=img_fts.device).long() 224 | supp_label[fore_mask[way, shot] == 1] = 1 225 | supp_label[back_mask[way, shot] == 1] = 0 226 | # Compute Loss 227 | loss.append( F.cross_entropy( 228 | supp_pred, supp_label[None, ...], ignore_index=255) / n_shots / n_ways) 229 | 230 | return torch.sum( torch.stack(loss)) 231 | -------------------------------------------------------------------------------- /util/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics for computing evalutation results 3 | Modified from vanilla PANet code by Wang et al. 4 | """ 5 | 6 | import numpy as np 7 | 8 | class Metric(object): 9 | """ 10 | Compute evaluation result 11 | 12 | Args: 13 | max_label: 14 | max label index in the data (0 denoting background) 15 | n_scans: 16 | number of test scans 17 | """ 18 | def __init__(self, max_label=20, n_scans=None): 19 | self.labels = list(range(max_label + 1)) # all class labels 20 | self.n_scans = 1 if n_scans is None else n_scans 21 | 22 | # list of list of array, each array save the TP/FP/FN statistic of a testing sample 23 | self.tp_lst = [[] for _ in range(self.n_scans)] 24 | self.fp_lst = [[] for _ in range(self.n_scans)] 25 | self.fn_lst = [[] for _ in range(self.n_scans)] 26 | 27 | def reset(self): 28 | """ 29 | Reset accumulated evaluation. 30 | """ 31 | # assert self.n_scans == 1, 'Should not reset accumulated result when we are not doing one-time batch-wise validation' 32 | del self.tp_lst, self.fp_lst, self.fn_lst 33 | self.tp_lst = [[] for _ in range(self.n_scans)] 34 | self.fp_lst = [[] for _ in range(self.n_scans)] 35 | self.fn_lst = [[] for _ in range(self.n_scans)] 36 | 37 | def record(self, pred, target, labels=None, n_scan=None): 38 | """ 39 | Record the evaluation result for each sample and each class label, including: 40 | True Positive, False Positive, False Negative 41 | 42 | Args: 43 | pred: 44 | predicted mask array, expected shape is H x W 45 | target: 46 | target mask array, expected shape is H x W 47 | labels: 48 | only count specific label, used when knowing all possible labels in advance 49 | """ 50 | assert pred.shape == target.shape 51 | 52 | if self.n_scans == 1: 53 | n_scan = 0 54 | 55 | # array to save the TP/FP/FN statistic for each class (plus BG) 56 | tp_arr = np.full(len(self.labels), np.nan) 57 | fp_arr = np.full(len(self.labels), np.nan) 58 | fn_arr = np.full(len(self.labels), np.nan) 59 | 60 | if labels is None: 61 | labels = self.labels 62 | else: 63 | labels = [0,] + labels 64 | 65 | for j, label in enumerate(labels): 66 | # Get the location of the pixels that are predicted as class j 67 | idx = np.where(np.logical_and(pred == j, target != 255)) 68 | pred_idx_j = set(zip(idx[0].tolist(), idx[1].tolist())) 69 | # Get the location of the pixels that are class j in ground truth 70 | idx = np.where(target == j) 71 | target_idx_j = set(zip(idx[0].tolist(), idx[1].tolist())) 72 | 73 | # this should not work: if target_idx_j: # if ground-truth contains this class 74 | # the author is adding posion to the code 75 | tp_arr[label] = len(set.intersection(pred_idx_j, target_idx_j)) 76 | fp_arr[label] = len(pred_idx_j - target_idx_j) 77 | fn_arr[label] = len(target_idx_j - pred_idx_j) 78 | 79 | self.tp_lst[n_scan].append(tp_arr) 80 | self.fp_lst[n_scan].append(fp_arr) 81 | self.fn_lst[n_scan].append(fn_arr) 82 | 83 | def get_mIoU(self, labels=None, n_scan=None): 84 | """ 85 | Compute mean IoU 86 | 87 | Args: 88 | labels: 89 | specify a subset of labels to compute mean IoU, default is using all classes 90 | """ 91 | if labels is None: 92 | labels = self.labels 93 | # Sum TP, FP, FN statistic of all samples 94 | if n_scan is None: 95 | tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) 96 | for _scan in range(self.n_scans)] 97 | fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) 98 | for _scan in range(self.n_scans)] 99 | fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) 100 | for _scan in range(self.n_scans)] 101 | 102 | # Compute mean IoU classwisely 103 | # Average across n_scans, then average over classes 104 | mIoU_class = np.vstack([tp_sum[_scan] / (tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) 105 | for _scan in range(self.n_scans)]) 106 | mIoU = mIoU_class.mean(axis=1) 107 | 108 | return (mIoU_class.mean(axis=0), mIoU_class.std(axis=0), 109 | mIoU.mean(axis=0), mIoU.std(axis=0)) 110 | else: 111 | tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) 112 | fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) 113 | fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) 114 | 115 | # Compute mean IoU classwisely and average over classes 116 | mIoU_class = tp_sum / (tp_sum + fp_sum + fn_sum) 117 | mIoU = mIoU_class.mean() 118 | 119 | return mIoU_class, mIoU 120 | 121 | def get_mDice(self, labels=None, n_scan=None, give_raw = False): 122 | """ 123 | Compute mean Dice score (in 3D scan level) 124 | 125 | Args: 126 | labels: 127 | specify a subset of labels to compute mean IoU, default is using all classes 128 | """ 129 | # NOTE: unverified 130 | if labels is None: 131 | labels = self.labels 132 | # Sum TP, FP, FN statistic of all samples 133 | if n_scan is None: 134 | tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) 135 | for _scan in range(self.n_scans)] 136 | fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) 137 | for _scan in range(self.n_scans)] 138 | fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) 139 | for _scan in range(self.n_scans)] 140 | 141 | # Average across n_scans, then average over classes 142 | mDice_class = np.vstack([ 2 * tp_sum[_scan] / ( 2 * tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) 143 | for _scan in range(self.n_scans)]) 144 | mDice = mDice_class.mean(axis=1) 145 | print(mDice_class) 146 | if not give_raw: 147 | return (mDice_class.mean(axis=0), mDice_class.std(axis=0), 148 | mDice.mean(axis=0), mDice.std(axis=0)) 149 | else: 150 | return (mDice_class.mean(axis=0), mDice_class.std(axis=0), 151 | mDice.mean(axis=0), mDice.std(axis=0), mDice_class) 152 | 153 | else: 154 | tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) 155 | fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) 156 | fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) 157 | 158 | # Compute mean IoU classwisely and average over classes 159 | mDice_class = 2 * tp_sum / ( 2 * tp_sum + fp_sum + fn_sum) 160 | mDice = mIoU_class.mean() 161 | 162 | return mDice_class, mDice 163 | 164 | def get_mPrecRecall(self, labels=None, n_scan=None, give_raw = False): 165 | """ 166 | Compute precision and recall 167 | 168 | Args: 169 | labels: 170 | specify a subset of labels to compute mean IoU, default is using all classes 171 | """ 172 | # NOTE: unverified 173 | if labels is None: 174 | labels = self.labels 175 | # Sum TP, FP, FN statistic of all samples 176 | if n_scan is None: 177 | tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) 178 | for _scan in range(self.n_scans)] 179 | fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) 180 | for _scan in range(self.n_scans)] 181 | fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) 182 | for _scan in range(self.n_scans)] 183 | 184 | # Compute mean IoU classwisely 185 | # Average across n_scans, then average over classes 186 | mPrec_class = np.vstack([ tp_sum[_scan] / ( tp_sum[_scan] + fp_sum[_scan] ) 187 | for _scan in range(self.n_scans)]) 188 | 189 | mRec_class = np.vstack([ tp_sum[_scan] / ( tp_sum[_scan] + fn_sum[_scan] ) 190 | for _scan in range(self.n_scans)]) 191 | 192 | mPrec = mPrec_class.mean(axis=1) 193 | mRec = mRec_class.mean(axis=1) 194 | if not give_raw: 195 | return (mPrec_class.mean(axis=0), mPrec_class.std(axis=0), mPrec.mean(axis=0), mPrec.std(axis=0), mRec_class.mean(axis=0), mRec_class.std(axis=0), mRec.mean(axis=0), mRec.std(axis=0)) 196 | else: 197 | return (mPrec_class.mean(axis=0), mPrec_class.std(axis=0), mPrec.mean(axis=0), mPrec.std(axis=0), mRec_class.mean(axis=0), mRec_class.std(axis=0), mRec.mean(axis=0), mRec.std(axis=0), mPrec_class, mRec_class) 198 | 199 | 200 | else: 201 | tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) 202 | fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) 203 | fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) 204 | 205 | # Compute mean IoU classwisely and average over classes 206 | mPrec_class = tp_sum / (tp_sum + fp_sum) 207 | mPrec = mPrec_class.mean() 208 | 209 | mRec_class = tp_sum / (tp_sum + fn_sum) 210 | mRec = mRec_class.mean() 211 | 212 | return mPrec_class, mPrec, mRec_class, mRec 213 | 214 | def get_mIoU_binary(self, n_scan=None): 215 | """ 216 | Compute mean IoU for binary scenario 217 | (sum all foreground classes as one class) 218 | """ 219 | # Sum TP, FP, FN statistic of all samples 220 | if n_scan is None: 221 | tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0) 222 | for _scan in range(self.n_scans)] 223 | fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0) 224 | for _scan in range(self.n_scans)] 225 | fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0) 226 | for _scan in range(self.n_scans)] 227 | 228 | # Sum over all foreground classes 229 | tp_sum = [np.c_[tp_sum[_scan][0], np.nansum(tp_sum[_scan][1:])] 230 | for _scan in range(self.n_scans)] 231 | fp_sum = [np.c_[fp_sum[_scan][0], np.nansum(fp_sum[_scan][1:])] 232 | for _scan in range(self.n_scans)] 233 | fn_sum = [np.c_[fn_sum[_scan][0], np.nansum(fn_sum[_scan][1:])] 234 | for _scan in range(self.n_scans)] 235 | 236 | # Compute mean IoU classwisely and average across classes 237 | mIoU_class = np.vstack([tp_sum[_scan] / (tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) 238 | for _scan in range(self.n_scans)]) 239 | mIoU = mIoU_class.mean(axis=1) 240 | 241 | return (mIoU_class.mean(axis=0), mIoU_class.std(axis=0), 242 | mIoU.mean(axis=0), mIoU.std(axis=0)) 243 | else: 244 | tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0) 245 | fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0) 246 | fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0) 247 | 248 | # Sum over all foreground classes 249 | tp_sum = np.c_[tp_sum[0], np.nansum(tp_sum[1:])] 250 | fp_sum = np.c_[fp_sum[0], np.nansum(fp_sum[1:])] 251 | fn_sum = np.c_[fn_sum[0], np.nansum(fn_sum[1:])] 252 | 253 | mIoU_class = tp_sum / (tp_sum + fp_sum + fn_sum) 254 | mIoU = mIoU_class.mean() 255 | 256 | return mIoU_class, mIoU 257 | -------------------------------------------------------------------------------- /dataloaders/image_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image transforms functions for data augmentation 3 | Credit to Dr. Jo Schlemper 4 | """ 5 | 6 | from collections import Sequence 7 | import cv2 8 | import numpy as np 9 | import scipy 10 | from scipy.ndimage.filters import gaussian_filter 11 | from scipy.ndimage.interpolation import map_coordinates 12 | from numpy.lib.stride_tricks import as_strided 13 | 14 | ###### UTILITIES ###### 15 | def random_num_generator(config, random_state=np.random): 16 | if config[0] == 'uniform': 17 | ret = random_state.uniform(config[1], config[2], 1)[0] 18 | elif config[0] == 'lognormal': 19 | ret = random_state.lognormal(config[1], config[2], 1)[0] 20 | else: 21 | #print(config) 22 | raise Exception('unsupported format') 23 | return ret 24 | 25 | def get_translation_matrix(translation): 26 | """ translation: [tx, ty] """ 27 | tx, ty = translation 28 | translation_matrix = np.array([[1, 0, tx], 29 | [0, 1, ty], 30 | [0, 0, 1]]) 31 | return translation_matrix 32 | 33 | 34 | 35 | def get_rotation_matrix(rotation, input_shape, centred=True): 36 | theta = np.pi / 180 * np.array(rotation) 37 | if centred: 38 | rotation_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), rotation, 1) 39 | rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]]) 40 | else: 41 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 42 | [np.sin(theta), np.cos(theta), 0], 43 | [0, 0, 1]]) 44 | return rotation_matrix 45 | 46 | def get_zoom_matrix(zoom, input_shape, centred=True): 47 | zx, zy = zoom 48 | if centred: 49 | zoom_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), 0, zoom[0]) 50 | zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]]) 51 | else: 52 | zoom_matrix = np.array([[zx, 0, 0], 53 | [0, zy, 0], 54 | [0, 0, 1]]) 55 | return zoom_matrix 56 | 57 | def get_shear_matrix(shear_angle): 58 | theta = (np.pi * shear_angle) / 180 59 | shear_matrix = np.array([[1, -np.sin(theta), 0], 60 | [0, np.cos(theta), 0], 61 | [0, 0, 1]]) 62 | return shear_matrix 63 | 64 | ###### AFFINE TRANSFORM ###### 65 | class RandomAffine(object): 66 | """Apply random affine transformation on a numpy.ndarray (H x W x C) 67 | Comment by co1818: this is still doing affine on 2d (H x W plane). 68 | A same transform is applied to all C channels 69 | 70 | Parameter: 71 | ---------- 72 | 73 | alpha: Range [0, 4] seems good for small images 74 | 75 | order: interpolation method (c.f. opencv) 76 | """ 77 | 78 | def __init__(self, 79 | rotation_range=None, 80 | translation_range=None, 81 | shear_range=None, 82 | zoom_range=None, 83 | zoom_keep_aspect=False, 84 | interp='bilinear', 85 | order=3): 86 | """ 87 | Perform an affine transforms. 88 | 89 | Arguments 90 | --------- 91 | rotation_range : one integer or float 92 | image will be rotated randomly between (-degrees, degrees) 93 | 94 | translation_range : (x_shift, y_shift) 95 | shifts in pixels 96 | 97 | *NOT TESTED* shear_range : float 98 | image will be sheared randomly between (-degrees, degrees) 99 | 100 | zoom_range : (zoom_min, zoom_max) 101 | list/tuple with two floats between [0, infinity). 102 | first float should be less than the second 103 | lower and upper bounds on percent zoom. 104 | Anything less than 1.0 will zoom in on the image, 105 | anything greater than 1.0 will zoom out on the image. 106 | e.g. (0.7, 1.0) will only zoom in, 107 | (1.0, 1.4) will only zoom out, 108 | (0.7, 1.4) will randomly zoom in or out 109 | """ 110 | 111 | self.rotation_range = rotation_range 112 | self.translation_range = translation_range 113 | self.shear_range = shear_range 114 | self.zoom_range = zoom_range 115 | self.zoom_keep_aspect = zoom_keep_aspect 116 | self.interp = interp 117 | self.order = order 118 | 119 | def build_M(self, input_shape): 120 | tfx = [] 121 | final_tfx = np.eye(3) 122 | if self.rotation_range: 123 | rot = np.random.uniform(-self.rotation_range, self.rotation_range) 124 | tfx.append(get_rotation_matrix(rot, input_shape)) 125 | if self.translation_range: 126 | tx = np.random.uniform(-self.translation_range[0], self.translation_range[0]) 127 | ty = np.random.uniform(-self.translation_range[1], self.translation_range[1]) 128 | tfx.append(get_translation_matrix((tx,ty))) 129 | if self.shear_range: 130 | rot = np.random.uniform(-self.shear_range, self.shear_range) 131 | tfx.append(get_shear_matrix(rot)) 132 | if self.zoom_range: 133 | sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) 134 | if self.zoom_keep_aspect: 135 | sy = sx 136 | else: 137 | sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) 138 | 139 | tfx.append(get_zoom_matrix((sx, sy), input_shape)) 140 | 141 | for tfx_mat in tfx: 142 | final_tfx = np.dot(tfx_mat, final_tfx) 143 | 144 | return final_tfx.astype(np.float32) 145 | 146 | def __call__(self, image): 147 | # build matrix 148 | input_shape = image.shape[:2] 149 | M = self.build_M(input_shape) 150 | 151 | res = np.zeros_like(image) 152 | #if isinstance(self.interp, Sequence): 153 | if type(self.order) is list or type(self.order) is tuple: 154 | for i, intp in enumerate(self.order): 155 | res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp) 156 | else: 157 | # squeeze if needed 158 | orig_shape = image.shape 159 | image_s = np.squeeze(image) 160 | res = affine_transform_via_M(image_s, M[:2], interp=self.order) 161 | res = res.reshape(orig_shape) 162 | 163 | #res = affine_transform_via_M(image, M[:2], interp=self.order) 164 | 165 | return res 166 | 167 | def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST): 168 | imshape = image.shape 169 | shape_size = imshape[:2] 170 | 171 | # Random affine 172 | warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1], 173 | flags=interp, borderMode=borderMode) 174 | 175 | #print(imshape, warped.shape) 176 | 177 | warped = warped[..., np.newaxis].reshape(imshape) 178 | 179 | return warped 180 | 181 | ###### ELASTIC TRANSFORM ###### 182 | def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): 183 | """Elastic deformation of image as described in [Simard2003]_. 184 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 185 | Convolutional Neural Networks applied to Visual Document Analysis", in 186 | Proc. of the International Conference on Document Analysis and 187 | Recognition, 2003. 188 | """ 189 | assert image.ndim == 3 190 | shape = image.shape[:2] 191 | 192 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), 193 | sigma, mode="constant", cval=0) * alpha 194 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), 195 | sigma, mode="constant", cval=0) * alpha 196 | 197 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 198 | indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] 199 | result = np.empty_like(image) 200 | for i in range(image.shape[2]): 201 | result[:, :, i] = map_coordinates( 202 | image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) 203 | return result 204 | 205 | 206 | def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False): 207 | """Expects data to be (nx, ny, n1 ,..., nm) 208 | params: 209 | ------ 210 | 211 | alpha: 212 | the scaling parameter. 213 | E.g.: alpha=2 => distorts images up to 2x scaling 214 | 215 | sigma: 216 | standard deviation of gaussian filter. 217 | E.g. 218 | low (sig~=1e-3) => no smoothing, pixelated. 219 | high (1/5 * imsize) => smooth, more like affine. 220 | very high (1/2*im_size) => translation 221 | """ 222 | 223 | if random_state is None: 224 | random_state = np.random.RandomState(None) 225 | 226 | shape = image.shape 227 | imsize = shape[:2] 228 | dim = shape[2:] 229 | 230 | # Random affine 231 | blur_size = int(4*sigma) | 1 232 | dx = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, 233 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 234 | dy = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, 235 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 236 | 237 | # use as_strided to copy things over across n1...nn channels 238 | dx = as_strided(dx.astype(np.float32), 239 | strides=(0,) * len(dim) + (4*shape[1], 4), 240 | shape=dim+(shape[0], shape[1])) 241 | dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim)))) 242 | 243 | dy = as_strided(dy.astype(np.float32), 244 | strides=(0,) * len(dim) + (4*shape[1], 4), 245 | shape=dim+(shape[0], shape[1])) 246 | dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim)))) 247 | 248 | coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim]) 249 | indices = [np.reshape(e+de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:], 250 | [dy, dx] + [0] * len(dim))] 251 | 252 | if lazy: 253 | return indices 254 | 255 | return map_coordinates(image, indices, order=order, mode='reflect').reshape(shape) 256 | 257 | class ElasticTransform(object): 258 | """Apply elastic transformation on a numpy.ndarray (H x W x C) 259 | """ 260 | 261 | def __init__(self, alpha, sigma, order=1): 262 | self.alpha = alpha 263 | self.sigma = sigma 264 | self.order = order 265 | 266 | def __call__(self, image): 267 | if isinstance(self.alpha, Sequence): 268 | alpha = random_num_generator(self.alpha) 269 | else: 270 | alpha = self.alpha 271 | if isinstance(self.sigma, Sequence): 272 | sigma = random_num_generator(self.sigma) 273 | else: 274 | sigma = self.sigma 275 | return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order) 276 | 277 | class RandomFlip3D(object): 278 | 279 | def __init__(self, h=True, v=True, t=True, p=0.5): 280 | """ 281 | Randomly flip an image horizontally and/or vertically with 282 | some probability. 283 | 284 | Arguments 285 | --------- 286 | h : boolean 287 | whether to horizontally flip w/ probability p 288 | 289 | v : boolean 290 | whether to vertically flip w/ probability p 291 | 292 | p : float between [0,1] 293 | probability with which to apply allowed flipping operations 294 | """ 295 | self.horizontal = h 296 | self.vertical = v 297 | self.depth = t 298 | self.p = p 299 | 300 | def __call__(self, x, y=None): 301 | # horizontal flip with p = self.p 302 | if self.horizontal: 303 | if np.random.random() < self.p: 304 | x = x[::-1, ...] 305 | 306 | # vertical flip with p = self.p 307 | if self.vertical: 308 | if np.random.random() < self.p: 309 | x = x[:, ::-1, ...] 310 | 311 | if self.depth: 312 | if np.random.random() < self.p: 313 | x = x[..., ::-1] 314 | 315 | return x 316 | 317 | 318 | -------------------------------------------------------------------------------- /data/pseudolabel_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generate superpixel-based pseudolabels\n", 8 | "\n", 9 | "\n", 10 | "### Overview\n", 11 | "\n", 12 | "This is the third step for data preparation\n", 13 | "\n", 14 | "Input: normalized images\n", 15 | "\n", 16 | "Output: pseulabel label candidates for all the images" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "%reset\n", 34 | "%load_ext autoreload\n", 35 | "%autoreload 2\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "import copy\n", 38 | "import skimage\n", 39 | "\n", 40 | "from skimage.segmentation import slic\n", 41 | "from skimage.segmentation import mark_boundaries\n", 42 | "from skimage.util import img_as_float\n", 43 | "from skimage.measure import label \n", 44 | "import scipy.ndimage.morphology as snm\n", 45 | "from skimage import io\n", 46 | "import argparse\n", 47 | "import numpy as np\n", 48 | "import glob\n", 49 | "\n", 50 | "import SimpleITK as sitk\n", 51 | "import os\n", 52 | "\n", 53 | "to01 = lambda x: (x - x.min()) / (x.max() - x.min())\n", 54 | "\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "**Summary**\n", 62 | "\n", 63 | "a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background\n", 64 | "\n", 65 | "b. Generate superpixels as pseudolabels\n", 66 | "\n", 67 | "**Configurations of pseudlabels**\n", 68 | "\n", 69 | "```python\n", 70 | "# default setting of minimum superpixel sizes\n", 71 | "segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)\n", 72 | "# you can also try other configs\n", 73 | "segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)\n", 74 | "```\n" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 2, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "DATASET_CONFIG = {'SABS':{\n", 84 | " 'img_bname': f'./SABS/sabs_CT_normalized/image_*.nii.gz',\n", 85 | " 'out_dir': './SABS/sabs_CT_normalized',\n", 86 | " 'fg_thresh': 1e-4\n", 87 | " },\n", 88 | " 'CHAOST2':{\n", 89 | " 'img_bname': f'../CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz',\n", 90 | " 'out_dir': './CHAOST2/chaos_MR_T2_normalized',\n", 91 | " 'fg_thresh': 1e-4 + 50\n", 92 | " }\n", 93 | " }\n", 94 | " \n", 95 | "\n", 96 | "DOMAIN = 'SABS'\n", 97 | "img_bname = DATASET_CONFIG[DOMAIN]['img_bname']\n", 98 | "imgs = glob.glob(img_bname)\n", 99 | "out_dir = DATASET_CONFIG[DOMAIN]['out_dir']\n" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "text/plain": [ 110 | "['./SABS/sabs_CT_normalized/image_25.nii.gz',\n", 111 | " './SABS/sabs_CT_normalized/image_2.nii.gz',\n", 112 | " './SABS/sabs_CT_normalized/image_28.nii.gz',\n", 113 | " './SABS/sabs_CT_normalized/image_12.nii.gz',\n", 114 | " './SABS/sabs_CT_normalized/image_0.nii.gz',\n", 115 | " './SABS/sabs_CT_normalized/image_27.nii.gz',\n", 116 | " './SABS/sabs_CT_normalized/image_10.nii.gz',\n", 117 | " './SABS/sabs_CT_normalized/image_6.nii.gz',\n", 118 | " './SABS/sabs_CT_normalized/image_21.nii.gz',\n", 119 | " './SABS/sabs_CT_normalized/image_16.nii.gz',\n", 120 | " './SABS/sabs_CT_normalized/image_9.nii.gz',\n", 121 | " './SABS/sabs_CT_normalized/image_23.nii.gz',\n", 122 | " './SABS/sabs_CT_normalized/image_4.nii.gz',\n", 123 | " './SABS/sabs_CT_normalized/image_14.nii.gz',\n", 124 | " './SABS/sabs_CT_normalized/image_19.nii.gz',\n", 125 | " './SABS/sabs_CT_normalized/image_17.nii.gz',\n", 126 | " './SABS/sabs_CT_normalized/image_20.nii.gz',\n", 127 | " './SABS/sabs_CT_normalized/image_7.nii.gz',\n", 128 | " './SABS/sabs_CT_normalized/image_18.nii.gz',\n", 129 | " './SABS/sabs_CT_normalized/image_15.nii.gz',\n", 130 | " './SABS/sabs_CT_normalized/image_5.nii.gz',\n", 131 | " './SABS/sabs_CT_normalized/image_22.nii.gz',\n", 132 | " './SABS/sabs_CT_normalized/image_8.nii.gz',\n", 133 | " './SABS/sabs_CT_normalized/image_13.nii.gz',\n", 134 | " './SABS/sabs_CT_normalized/image_29.nii.gz',\n", 135 | " './SABS/sabs_CT_normalized/image_3.nii.gz',\n", 136 | " './SABS/sabs_CT_normalized/image_24.nii.gz',\n", 137 | " './SABS/sabs_CT_normalized/image_11.nii.gz',\n", 138 | " './SABS/sabs_CT_normalized/image_26.nii.gz',\n", 139 | " './SABS/sabs_CT_normalized/image_1.nii.gz']" 140 | ] 141 | }, 142 | "execution_count": 3, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "imgs" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 4, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "imgs = sorted(imgs, key = lambda x: int(x.split('_')[-1].split('.nii.gz')[0]) )" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "['./SABS/sabs_CT_normalized/image_0.nii.gz',\n", 169 | " './SABS/sabs_CT_normalized/image_1.nii.gz',\n", 170 | " './SABS/sabs_CT_normalized/image_2.nii.gz',\n", 171 | " './SABS/sabs_CT_normalized/image_3.nii.gz',\n", 172 | " './SABS/sabs_CT_normalized/image_4.nii.gz',\n", 173 | " './SABS/sabs_CT_normalized/image_5.nii.gz',\n", 174 | " './SABS/sabs_CT_normalized/image_6.nii.gz',\n", 175 | " './SABS/sabs_CT_normalized/image_7.nii.gz',\n", 176 | " './SABS/sabs_CT_normalized/image_8.nii.gz',\n", 177 | " './SABS/sabs_CT_normalized/image_9.nii.gz',\n", 178 | " './SABS/sabs_CT_normalized/image_10.nii.gz',\n", 179 | " './SABS/sabs_CT_normalized/image_11.nii.gz',\n", 180 | " './SABS/sabs_CT_normalized/image_12.nii.gz',\n", 181 | " './SABS/sabs_CT_normalized/image_13.nii.gz',\n", 182 | " './SABS/sabs_CT_normalized/image_14.nii.gz',\n", 183 | " './SABS/sabs_CT_normalized/image_15.nii.gz',\n", 184 | " './SABS/sabs_CT_normalized/image_16.nii.gz',\n", 185 | " './SABS/sabs_CT_normalized/image_17.nii.gz',\n", 186 | " './SABS/sabs_CT_normalized/image_18.nii.gz',\n", 187 | " './SABS/sabs_CT_normalized/image_19.nii.gz',\n", 188 | " './SABS/sabs_CT_normalized/image_20.nii.gz',\n", 189 | " './SABS/sabs_CT_normalized/image_21.nii.gz',\n", 190 | " './SABS/sabs_CT_normalized/image_22.nii.gz',\n", 191 | " './SABS/sabs_CT_normalized/image_23.nii.gz',\n", 192 | " './SABS/sabs_CT_normalized/image_24.nii.gz',\n", 193 | " './SABS/sabs_CT_normalized/image_25.nii.gz',\n", 194 | " './SABS/sabs_CT_normalized/image_26.nii.gz',\n", 195 | " './SABS/sabs_CT_normalized/image_27.nii.gz',\n", 196 | " './SABS/sabs_CT_normalized/image_28.nii.gz',\n", 197 | " './SABS/sabs_CT_normalized/image_29.nii.gz']" 198 | ] 199 | }, 200 | "execution_count": 5, 201 | "metadata": {}, 202 | "output_type": "execute_result" 203 | } 204 | ], 205 | "source": [ 206 | "imgs" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 23, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting\n", 216 | "\n", 217 | "# wrapper for process 3d image in 2d\n", 218 | "def superpix_vol(img, method = 'fezlen', **kwargs):\n", 219 | " \"\"\"\n", 220 | " loop through the entire volume\n", 221 | " assuming image with axis z, x, y\n", 222 | " \"\"\"\n", 223 | " if method =='fezlen':\n", 224 | " seg_func = skimage.segmentation.felzenszwalb\n", 225 | " else:\n", 226 | " raise NotImplementedError\n", 227 | " \n", 228 | " out_vol = np.zeros(img.shape)\n", 229 | " for ii in range(img.shape[0]):\n", 230 | " if MODE == 'MIDDLE':\n", 231 | " segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)\n", 232 | " else:\n", 233 | " raise NotImplementedError\n", 234 | " out_vol[ii, ...] = segs\n", 235 | " \n", 236 | " return out_vol\n", 237 | "\n", 238 | "# thresholding the intensity values to get a binary mask of the patient\n", 239 | "def fg_mask2d(img_2d, thresh): # change this by your need\n", 240 | " mask_map = np.float32(img_2d > thresh)\n", 241 | " \n", 242 | " def getLargestCC(segmentation): # largest connected components\n", 243 | " labels = label(segmentation)\n", 244 | " assert( labels.max() != 0 ) # assume at least 1 CC\n", 245 | " largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1\n", 246 | " return largestCC\n", 247 | " if mask_map.max() < 0.999:\n", 248 | " return mask_map\n", 249 | " else:\n", 250 | " post_mask = getLargestCC(mask_map)\n", 251 | " fill_mask = snm.binary_fill_holes(post_mask)\n", 252 | " return fill_mask\n", 253 | "\n", 254 | "# remove superpixels within the empty regions\n", 255 | "def superpix_masking(raw_seg2d, mask2d):\n", 256 | " raw_seg2d = np.int32(raw_seg2d)\n", 257 | " lbvs = np.unique(raw_seg2d)\n", 258 | " max_lb = lbvs.max()\n", 259 | " raw_seg2d[raw_seg2d == 0] = max_lb + 1\n", 260 | " lbvs = list(lbvs)\n", 261 | " lbvs.append( max_lb )\n", 262 | " raw_seg2d = raw_seg2d * mask2d\n", 263 | " lb_new = 1\n", 264 | " out_seg2d = np.zeros(raw_seg2d.shape)\n", 265 | " for lbv in lbvs:\n", 266 | " if lbv == 0:\n", 267 | " continue\n", 268 | " else:\n", 269 | " out_seg2d[raw_seg2d == lbv] = lb_new\n", 270 | " lb_new += 1\n", 271 | " \n", 272 | " return out_seg2d\n", 273 | " \n", 274 | "def superpix_wrapper(img, verbose = False, fg_thresh = 1e-4):\n", 275 | " raw_seg = superpix_vol(img)\n", 276 | " fg_mask_vol = np.zeros(raw_seg.shape)\n", 277 | " processed_seg_vol = np.zeros(raw_seg.shape)\n", 278 | " for ii in range(raw_seg.shape[0]):\n", 279 | " if verbose:\n", 280 | " print(\"doing {} slice\".format(ii))\n", 281 | " _fgm = fg_mask2d(img[ii, ...], fg_thresh )\n", 282 | " _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)\n", 283 | " fg_mask_vol[ii] = _fgm\n", 284 | " processed_seg_vol[ii] = _out_seg\n", 285 | " return fg_mask_vol, processed_seg_vol\n", 286 | " \n", 287 | "# copy spacing and orientation info between sitk objects\n", 288 | "def copy_info(src, dst):\n", 289 | " dst.SetSpacing(src.GetSpacing())\n", 290 | " dst.SetOrigin(src.GetOrigin())\n", 291 | " dst.SetDirection(src.GetDirection())\n", 292 | " # dst.CopyInfomation(src)\n", 293 | " return dst\n", 294 | "\n", 295 | "\n", 296 | "def strip_(img, lb):\n", 297 | " img = np.int32(img)\n", 298 | " if isinstance(lb, float):\n", 299 | " lb = int(lb)\n", 300 | " return np.float32(img == lb) * float(lb)\n", 301 | " elif isinstance(lb, list):\n", 302 | " out = np.zeros(img.shape)\n", 303 | " for _lb in lb:\n", 304 | " out += np.float32(img == int(_lb)) * float(_lb)\n", 305 | " \n", 306 | " return out\n", 307 | " else:\n", 308 | " raise Exception" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 24, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "name": "stdout", 318 | "output_type": "stream", 319 | "text": [ 320 | "image with id 0 has finished\n", 321 | "image with id 1 has finished\n", 322 | "image with id 2 has finished\n", 323 | "image with id 3 has finished\n", 324 | "image with id 4 has finished\n", 325 | "image with id 5 has finished\n", 326 | "image with id 6 has finished\n", 327 | "image with id 7 has finished\n", 328 | "image with id 8 has finished\n", 329 | "image with id 9 has finished\n", 330 | "image with id 10 has finished\n", 331 | "image with id 11 has finished\n", 332 | "image with id 12 has finished\n", 333 | "image with id 13 has finished\n", 334 | "image with id 14 has finished\n", 335 | "image with id 15 has finished\n", 336 | "image with id 16 has finished\n", 337 | "image with id 17 has finished\n", 338 | "image with id 18 has finished\n", 339 | "image with id 19 has finished\n", 340 | "image with id 20 has finished\n", 341 | "image with id 21 has finished\n", 342 | "image with id 22 has finished\n", 343 | "image with id 23 has finished\n", 344 | "image with id 24 has finished\n", 345 | "image with id 25 has finished\n", 346 | "image with id 26 has finished\n", 347 | "image with id 27 has finished\n", 348 | "image with id 28 has finished\n", 349 | "image with id 29 has finished\n" 350 | ] 351 | } 352 | ], 353 | "source": [ 354 | "# Generate pseudolabels for every image and save them\n", 355 | "for img_fid in imgs:\n", 356 | "# img_fid = imgs[0]\n", 357 | "\n", 358 | " idx = os.path.basename(img_fid).split(\"_\")[-1].split(\".nii.gz\")[0]\n", 359 | " im_obj = sitk.ReadImage(img_fid)\n", 360 | "\n", 361 | " out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj), fg_thresh = DATASET_CONFIG[DOMAIN]['fg_thresh'] )\n", 362 | " out_fg_o = sitk.GetImageFromArray(out_fg ) \n", 363 | " out_seg_o = sitk.GetImageFromArray(out_seg )\n", 364 | "\n", 365 | " out_fg_o = copy_info(im_obj, out_fg_o)\n", 366 | " out_seg_o = copy_info(im_obj, out_seg_o)\n", 367 | " seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')\n", 368 | " msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')\n", 369 | " sitk.WriteImage(out_fg_o, msk_fid)\n", 370 | " sitk.WriteImage(out_seg_o, seg_fid)\n", 371 | " print(f'image with id {idx} has finished')\n" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [] 380 | } 381 | ], 382 | "metadata": { 383 | "kernelspec": { 384 | "display_name": "Python 3", 385 | "language": "python", 386 | "name": "python3" 387 | }, 388 | "language_info": { 389 | "codemirror_mode": { 390 | "name": "ipython", 391 | "version": 3 392 | }, 393 | "file_extension": ".py", 394 | "mimetype": "text/x-python", 395 | "name": "python", 396 | "nbconvert_exporter": "python", 397 | "pygments_lexer": "ipython3", 398 | "version": "3.6.0" 399 | } 400 | }, 401 | "nbformat": 4, 402 | "nbformat_minor": 2 403 | } 404 | -------------------------------------------------------------------------------- /dataloaders/GenericSuperDatasetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for training with pseudolabels 3 | TODO: 4 | 1. Merge with manual annotated dataset 5 | 2. superpixel_scale -> superpix_config, feed like a dict 6 | """ 7 | import glob 8 | import numpy as np 9 | import dataloaders.augutils as myaug 10 | import torch 11 | import random 12 | import os 13 | import copy 14 | import platform 15 | import json 16 | import re 17 | from dataloaders.common import BaseDataset, Subset 18 | from dataloaders.dataset_utils import* 19 | from pdb import set_trace 20 | from util.utils import CircularList 21 | 22 | class SuperpixelDataset(BaseDataset): 23 | def __init__(self, which_dataset, base_dir, idx_split, mode, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], superpix_scale = 'SMALL', **kwargs): 24 | """ 25 | Pseudolabel dataset 26 | Args: 27 | which_dataset: name of the dataset to use 28 | base_dir: directory of dataset 29 | idx_split: index of data split as we will do cross validation 30 | mode: 'train', 'val'. 31 | nsup: number of scans used as support. currently idle for superpixel dataset 32 | transforms: data transform (augmentation) function 33 | scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time 34 | num_rep: Number of augmentation applied for a same pseudolabel 35 | tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images 36 | fix_length: fix the length of dataset 37 | exclude_list: Labels to be excluded 38 | superpix_scale: config of superpixels 39 | """ 40 | super(SuperpixelDataset, self).__init__(base_dir) 41 | 42 | self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] 43 | self.sep = DATASET_INFO[which_dataset]['_SEP'] 44 | self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME'] 45 | self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] 46 | 47 | self.transforms = transforms 48 | self.is_train = True if mode == 'train' else False 49 | assert mode == 'train' 50 | self.fix_length = fix_length 51 | self.nclass = len(self.pseu_label_name) 52 | self.num_rep = num_rep 53 | self.tile_z_dim = tile_z_dim 54 | 55 | # find scans in the data folder 56 | self.nsup = nsup 57 | self.base_dir = base_dir 58 | self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii.gz") ] 59 | self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) 60 | 61 | # experiment configs 62 | self.exclude_lbs = exclude_list 63 | self.superpix_scale = superpix_scale 64 | if len(exclude_list) > 0: 65 | print(f'###### Dataset: the following classes has been excluded {exclude_list}######') 66 | self.idx_split = idx_split 67 | self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold 68 | self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) 69 | self.scan_per_load = scan_per_load 70 | 71 | self.info_by_scan = None 72 | self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold 73 | self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()]) 74 | 75 | if self.is_train: 76 | if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch 77 | self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) 78 | else: # load the entire set without a buffer 79 | self.pid_curr_load = self.scan_ids 80 | elif mode == 'val': 81 | self.pid_curr_load = self.scan_ids 82 | else: 83 | raise Exception 84 | self.actual_dataset = self.read_dataset() 85 | self.size = len(self.actual_dataset) 86 | self.overall_slice_by_cls = self.read_classfiles() 87 | 88 | print("###### Initial scans loaded: ######") 89 | print(self.pid_curr_load) 90 | 91 | def get_scanids(self, mode, idx_split): 92 | """ 93 | Load scans by train-test split 94 | leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one 95 | Args: 96 | idx_split: index for spliting cross-validation folds 97 | """ 98 | val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) 99 | if mode == 'train': 100 | return [ ii for ii in self.img_pids if ii not in val_ids ] 101 | elif mode == 'val': 102 | return val_ids 103 | 104 | def reload_buffer(self): 105 | """ 106 | Reload a only portion of the entire dataset, if the dataset is too large 107 | 1. delete original buffer 108 | 2. update self.ids_this_batch 109 | 3. update other internel variables like __len__ 110 | """ 111 | if self.scan_per_load <= 0: 112 | print("We are not using the reload buffer, doing notiong") 113 | return -1 114 | 115 | del self.actual_dataset 116 | del self.info_by_scan 117 | 118 | self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) 119 | self.actual_dataset = self.read_dataset() 120 | self.size = len(self.actual_dataset) 121 | self.update_subclass_lookup() 122 | print(f'Loader buffer reloaded with a new size of {self.size} slices') 123 | 124 | def organize_sample_fids(self): 125 | out_list = {} 126 | for curr_id in self.scan_ids: 127 | curr_dict = {} 128 | 129 | _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') 130 | _lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz') 131 | 132 | curr_dict["img_fid"] = _img_fid 133 | curr_dict["lbs_fid"] = _lb_fid 134 | out_list[str(curr_id)] = curr_dict 135 | return out_list 136 | 137 | def read_dataset(self): 138 | """ 139 | Read images into memory and store them in 2D 140 | Build tables for the position of an individual 2D slice in the entire dataset 141 | """ 142 | out_list = [] 143 | self.scan_z_idx = {} 144 | self.info_by_scan = {} # meta data of each scan 145 | glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset 146 | 147 | for scan_id, itm in self.img_lb_fids.items(): 148 | if scan_id not in self.pid_curr_load: 149 | continue 150 | 151 | img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out 152 | img = img.transpose(1,2,0) 153 | self.info_by_scan[scan_id] = _info 154 | 155 | img = np.float32(img) 156 | img = self.norm_func(img) 157 | 158 | self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] 159 | 160 | lb = read_nii_bysitk(itm["lbs_fid"]) 161 | lb = lb.transpose(1,2,0) 162 | lb = np.int32(lb) 163 | 164 | img = img[:256, :256, :] 165 | lb = lb[:256, :256, :] 166 | 167 | # format of slices: [axial_H x axial_W x Z] 168 | 169 | assert img.shape[-1] == lb.shape[-1] 170 | base_idx = img.shape[-1] // 2 # index of the middle slice 171 | 172 | # re-organize 3D images into 2D slices and record essential information for each slice 173 | out_list.append( {"img": img[..., 0: 1], 174 | "lb":lb[..., 0: 0 + 1], 175 | "sup_max_cls": lb[..., 0: 0 + 1].max(), 176 | "is_start": True, 177 | "is_end": False, 178 | "nframe": img.shape[-1], 179 | "scan_id": scan_id, 180 | "z_id":0}) 181 | 182 | self.scan_z_idx[scan_id][0] = glb_idx 183 | glb_idx += 1 184 | 185 | for ii in range(1, img.shape[-1] - 1): 186 | out_list.append( {"img": img[..., ii: ii + 1], 187 | "lb":lb[..., ii: ii + 1], 188 | "is_start": False, 189 | "is_end": False, 190 | "sup_max_cls": lb[..., ii: ii + 1].max(), 191 | "nframe": -1, 192 | "scan_id": scan_id, 193 | "z_id": ii 194 | }) 195 | self.scan_z_idx[scan_id][ii] = glb_idx 196 | glb_idx += 1 197 | 198 | ii += 1 # last slice of a 3D volume 199 | out_list.append( {"img": img[..., ii: ii + 1], 200 | "lb":lb[..., ii: ii+ 1], 201 | "is_start": False, 202 | "is_end": True, 203 | "sup_max_cls": lb[..., ii: ii + 1].max(), 204 | "nframe": -1, 205 | "scan_id": scan_id, 206 | "z_id": ii 207 | }) 208 | 209 | self.scan_z_idx[scan_id][ii] = glb_idx 210 | glb_idx += 1 211 | 212 | return out_list 213 | 214 | def read_classfiles(self): 215 | """ 216 | Load the scan-slice-class indexing file 217 | """ 218 | with open( os.path.join(self.base_dir, f'classmap_{self.min_fg}.json') , 'r' ) as fopen: 219 | cls_map = json.load( fopen) 220 | fopen.close() 221 | 222 | with open( os.path.join(self.base_dir, 'classmap_1.json') , 'r' ) as fopen: 223 | self.tp1_cls_map = json.load( fopen) 224 | fopen.close() 225 | 226 | return cls_map 227 | 228 | def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val = None): 229 | """ 230 | pick up a certain super-pixel class or multiple classes, and binarize it into segmentation target 231 | Args: 232 | super_map: super-pixel map 233 | bi_val: if given, pick up a certain superpixel. Otherwise, draw a random one 234 | sup_max_cls: max index of superpixel for avoiding overshooting when selecting superpixel 235 | 236 | """ 237 | if bi_val == None: 238 | bi_val = int(torch.randint(low = 1, high = int(sup_max_cls), size = (1,))) 239 | 240 | return np.float32(super_map == bi_val) 241 | 242 | 243 | def __getitem__(self, index): 244 | index = index % len(self.actual_dataset) 245 | curr_dict = self.actual_dataset[index] 246 | sup_max_cls = curr_dict['sup_max_cls'] 247 | if sup_max_cls < 1: 248 | return self.__getitem__(index + 1) 249 | 250 | image_t = curr_dict["img"] 251 | label_raw = curr_dict["lb"] 252 | 253 | for _ex_cls in self.exclude_lbs: 254 | if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen 255 | return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) 256 | 257 | label_t = self.supcls_pick_binarize(label_raw, sup_max_cls) 258 | 259 | pair_buffer = [] 260 | 261 | comp = np.concatenate( [curr_dict["img"], label_t], axis = -1 ) 262 | 263 | for ii in range(self.num_rep): 264 | img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False) 265 | 266 | img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ) 267 | lb = torch.from_numpy( lb.squeeze(-1)) 268 | 269 | if self.tile_z_dim: 270 | img = img.repeat( [ self.tile_z_dim, 1, 1] ) 271 | assert img.ndimension() == 3, f'actual dim {img.ndimension()}' 272 | 273 | is_start = curr_dict["is_start"] 274 | is_end = curr_dict["is_end"] 275 | nframe = np.int32(curr_dict["nframe"]) 276 | scan_id = curr_dict["scan_id"] 277 | z_id = curr_dict["z_id"] 278 | 279 | sample = {"image": img, 280 | "label":lb, 281 | "is_start": is_start, 282 | "is_end": is_end, 283 | "nframe": nframe, 284 | "scan_id": scan_id, 285 | "z_id": z_id 286 | } 287 | 288 | # Add auxiliary attributes 289 | if self.aux_attrib is not None: 290 | for key_prefix in self.aux_attrib: 291 | # Process the data sample, create new attributes and save them in a dictionary 292 | aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) 293 | for key_suffix in aux_attrib_val: 294 | # one function may create multiple attributes, so we need suffix to distinguish them 295 | sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] 296 | pair_buffer.append(sample) 297 | 298 | support_images = [] 299 | support_mask = [] 300 | support_class = [] 301 | 302 | query_images = [] 303 | query_labels = [] 304 | query_class = [] 305 | 306 | for idx, itm in enumerate(pair_buffer): 307 | if idx % 2 == 0: 308 | support_images.append(itm["image"]) 309 | support_class.append(1) # pseudolabel class 310 | support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] )) 311 | else: 312 | query_images.append(itm["image"]) 313 | query_class.append(1) 314 | query_labels.append( itm["label"]) 315 | 316 | return {'class_ids': [support_class], 317 | 'support_images': [support_images], # 318 | 'support_mask': [support_mask], 319 | 'query_images': query_images, # 320 | 'query_labels': query_labels, 321 | } 322 | 323 | 324 | def __len__(self): 325 | """ 326 | copy-paste from basic naive dataset configuration 327 | """ 328 | if self.fix_length != None: 329 | assert self.fix_length >= len(self.actual_dataset) 330 | return self.fix_length 331 | else: 332 | return len(self.actual_dataset) 333 | 334 | def getMaskMedImg(self, label, class_id, class_ids): 335 | """ 336 | Generate FG/BG mask from the segmentation mask 337 | 338 | Args: 339 | label: semantic mask 340 | class_id: semantic class of interest 341 | class_ids: all class id in this episode 342 | """ 343 | fg_mask = torch.where(label == class_id, 344 | torch.ones_like(label), torch.zeros_like(label)) 345 | bg_mask = torch.where(label != class_id, 346 | torch.ones_like(label), torch.zeros_like(label)) 347 | for class_id in class_ids: 348 | bg_mask[label == class_id] = 0 349 | 350 | return {'fg_mask': fg_mask, 351 | 'bg_mask': bg_mask} 352 | -------------------------------------------------------------------------------- /data/CHAOST2/image_normalize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Image Pre-processing\n", 8 | "\n", 9 | "### Overview\n", 10 | "\n", 11 | "This is the second step for data preparation\n", 12 | "\n", 13 | "Input: `.nii`-like images and labels converted from `dicom`s/ `png` files\n", 14 | "\n", 15 | "Output: image-labels with unified size (axial), voxel-spacing, and alleviated off-resonance effects" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "%reset\n", 33 | "%load_ext autoreload\n", 34 | "%autoreload 2\n", 35 | "import numpy as np\n", 36 | "import os\n", 37 | "import glob\n", 38 | "import SimpleITK as sitk\n", 39 | "\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import copy\n", 42 | "import sys\n", 43 | "sys.path.insert(0, '../../dataloaders/')\n", 44 | "import niftiio as nio" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "IMG_FOLDER = \"./niis/T2SPIR\" #, path of nii-like images from step 1\n", 54 | "OUT_FOLDER=\"./chaos_MR_T2_normalized/\" # output directory" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "**0. Find images and their ground-truth segmentations**" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "imgs = glob.glob(IMG_FOLDER + f'/image_*.nii.gz')\n", 71 | "imgs = [ fid for fid in sorted(imgs) ]\n", 72 | "segs = [ fid for fid in sorted(glob.glob(IMG_FOLDER + f'/label_*.nii.gz')) ]\n", 73 | "\n", 74 | "pids = [pid.split(\"_\")[-1].split(\".\")[0] for pid in imgs]" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['./niis/T2SPIR/image_1.nii.gz',\n", 86 | " './niis/T2SPIR/image_10.nii.gz',\n", 87 | " './niis/T2SPIR/image_13.nii.gz',\n", 88 | " './niis/T2SPIR/image_15.nii.gz',\n", 89 | " './niis/T2SPIR/image_19.nii.gz',\n", 90 | " './niis/T2SPIR/image_2.nii.gz',\n", 91 | " './niis/T2SPIR/image_20.nii.gz',\n", 92 | " './niis/T2SPIR/image_21.nii.gz',\n", 93 | " './niis/T2SPIR/image_22.nii.gz',\n", 94 | " './niis/T2SPIR/image_3.nii.gz',\n", 95 | " './niis/T2SPIR/image_31.nii.gz',\n", 96 | " './niis/T2SPIR/image_32.nii.gz',\n", 97 | " './niis/T2SPIR/image_33.nii.gz',\n", 98 | " './niis/T2SPIR/image_34.nii.gz',\n", 99 | " './niis/T2SPIR/image_36.nii.gz',\n", 100 | " './niis/T2SPIR/image_37.nii.gz',\n", 101 | " './niis/T2SPIR/image_38.nii.gz',\n", 102 | " './niis/T2SPIR/image_39.nii.gz',\n", 103 | " './niis/T2SPIR/image_5.nii.gz',\n", 104 | " './niis/T2SPIR/image_8.nii.gz']" 105 | ] 106 | }, 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "imgs" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "['./niis/T2SPIR/label_1.nii.gz',\n", 125 | " './niis/T2SPIR/label_10.nii.gz',\n", 126 | " './niis/T2SPIR/label_13.nii.gz',\n", 127 | " './niis/T2SPIR/label_15.nii.gz',\n", 128 | " './niis/T2SPIR/label_19.nii.gz',\n", 129 | " './niis/T2SPIR/label_2.nii.gz',\n", 130 | " './niis/T2SPIR/label_20.nii.gz',\n", 131 | " './niis/T2SPIR/label_21.nii.gz',\n", 132 | " './niis/T2SPIR/label_22.nii.gz',\n", 133 | " './niis/T2SPIR/label_3.nii.gz',\n", 134 | " './niis/T2SPIR/label_31.nii.gz',\n", 135 | " './niis/T2SPIR/label_32.nii.gz',\n", 136 | " './niis/T2SPIR/label_33.nii.gz',\n", 137 | " './niis/T2SPIR/label_34.nii.gz',\n", 138 | " './niis/T2SPIR/label_36.nii.gz',\n", 139 | " './niis/T2SPIR/label_37.nii.gz',\n", 140 | " './niis/T2SPIR/label_38.nii.gz',\n", 141 | " './niis/T2SPIR/label_39.nii.gz',\n", 142 | " './niis/T2SPIR/label_5.nii.gz',\n", 143 | " './niis/T2SPIR/label_8.nii.gz']" 144 | ] 145 | }, 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "segs" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "**1. Unify image sizes and roi**\n", 160 | "\n", 161 | "a. Cut bright end of histogram to alleviate off-resonance issue\n", 162 | "\n", 163 | "b. Resample images to unified spacing\n", 164 | "\n", 165 | "c. Crop ROIs out to unify image sizes" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 6, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# some helper functions\n", 175 | "def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True):\n", 176 | " resample = sitk.ResampleImageFilter()\n", 177 | " resample.SetInterpolator(interpolator)\n", 178 | " resample.SetOutputDirection(mov_img_obj.GetDirection())\n", 179 | " resample.SetOutputOrigin(mov_img_obj.GetOrigin())\n", 180 | " mov_spacing = mov_img_obj.GetSpacing()\n", 181 | "\n", 182 | " resample.SetOutputSpacing(new_spacing)\n", 183 | " RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing)\n", 184 | " new_size = np.array(mov_img_obj.GetSize()) * RES_COE \n", 185 | "\n", 186 | " resample.SetSize( [int(sz+1) for sz in new_size] )\n", 187 | " if logging:\n", 188 | " print(\"Spacing: {} -> {}\".format(mov_spacing, new_spacing))\n", 189 | " print(\"Size {} -> {}\".format( mov_img_obj.GetSize(), new_size ))\n", 190 | "\n", 191 | " return resample.Execute(mov_img_obj)\n", 192 | "\n", 193 | "def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True):\n", 194 | " src_mat = sitk.GetArrayFromImage(mov_lb_obj)\n", 195 | " lbvs = np.unique(src_mat)\n", 196 | " if logging:\n", 197 | " print(\"Label values: {}\".format(lbvs))\n", 198 | " for idx, lbv in enumerate(lbvs):\n", 199 | " _src_curr_mat = np.float32(src_mat == lbv) \n", 200 | " _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat)\n", 201 | " _src_curr_obj.CopyInformation(mov_lb_obj)\n", 202 | " _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging )\n", 203 | " _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv\n", 204 | " if idx == 0:\n", 205 | " out_vol = _tar_curr_mat\n", 206 | " else:\n", 207 | " out_vol[_tar_curr_mat == lbv] = lbv\n", 208 | " out_obj = sitk.GetImageFromArray(out_vol)\n", 209 | " out_obj.SetSpacing( _tar_curr_obj.GetSpacing() )\n", 210 | " if ref_img != None:\n", 211 | " out_obj.CopyInformation(ref_img)\n", 212 | " return out_obj\n", 213 | " \n", 214 | "def get_label_center(label):\n", 215 | " nnz = np.sum(label > 1e-5)\n", 216 | " return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz))\n", 217 | "\n", 218 | "def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True):\n", 219 | " \"\"\" crop a 3d matrix given the index of the new volume on the original volume\n", 220 | " Args:\n", 221 | " refernce_ctr_idx: the center of the new volume on the original volume (in indices)\n", 222 | " only_2d: only do cropping on first two dimensions\n", 223 | " \"\"\"\n", 224 | " _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case\n", 225 | " if only_2d:\n", 226 | " assert len(crop_size) == 2, \"Actual len {}\".format(len(crop_size))\n", 227 | " assert len(referece_ctr_idx) == 2, \"Actual len {}\".format(len(referece_ctr_idx))\n", 228 | " _expand_cropsize.append(ori_vol.shape[-1])\n", 229 | " \n", 230 | " image_patch = np.ones(tuple(_expand_cropsize)) * padval\n", 231 | "\n", 232 | " half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] )\n", 233 | " _min_idx = [0,0,0]\n", 234 | " _max_idx = list(ori_vol.shape)\n", 235 | "\n", 236 | " # bias of actual cropped size to the beginning and the end of this volume\n", 237 | " _bias_start = [0,0,0]\n", 238 | " _bias_end = [0,0,0]\n", 239 | "\n", 240 | " for dim,hsize in enumerate(half_size):\n", 241 | " if dim == 2 and only_2d:\n", 242 | " break\n", 243 | "\n", 244 | " _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]])\n", 245 | " _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]])\n", 246 | "\n", 247 | " _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim]\n", 248 | " _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim]\n", 249 | " \n", 250 | " if only_2d:\n", 251 | " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", 252 | " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \\\n", 253 | " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", 254 | " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ]\n", 255 | "\n", 256 | " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ]\n", 257 | " # then goes back to original volume\n", 258 | " else:\n", 259 | " image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \\\n", 260 | " half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \\\n", 261 | " half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \\\n", 262 | " ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \\\n", 263 | " referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \\\n", 264 | " referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ]\n", 265 | "\n", 266 | " image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ]\n", 267 | " return image_patch\n", 268 | "\n", 269 | "def copy_spacing_ori(src, dst):\n", 270 | " dst.SetSpacing(src.GetSpacing())\n", 271 | " dst.SetOrigin(src.GetOrigin())\n", 272 | " dst.SetDirection(src.GetDirection())\n", 273 | " return dst\n", 274 | "\n", 275 | "s2n = sitk.GetArrayFromImage" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 7, 281 | "metadata": { 282 | "scrolled": false 283 | }, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Failed to create the output folder.\n" 290 | ] 291 | }, 292 | { 293 | "ename": "NameError", 294 | "evalue": "name 'copy_spacing_ori' is not defined", 295 | "output_type": "error", 296 | "traceback": [ 297 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 298 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 299 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mhis_img_o\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msitk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGetImageFromArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mhis_img_o\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopy_spacing_ori\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhis_img_o\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;31m# resampling\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 300 | "\u001b[0;31mNameError\u001b[0m: name 'copy_spacing_ori' is not defined" 301 | ] 302 | } 303 | ], 304 | "source": [ 305 | "import copy\n", 306 | "try:\n", 307 | " os.mkdir(OUT_FOLDER)\n", 308 | "except:\n", 309 | " print(\"Failed to create the output folder.\")\n", 310 | " \n", 311 | "HIST_CUT_TOP = 0.5 # cut top 0.5% of intensity historgam to alleviate off-resonance effect\n", 312 | "\n", 313 | "NEW_SPA = [1.25, 1.25, 7.70] # unified voxel spacing\n", 314 | "\n", 315 | "for img_fid, seg_fid, pid in zip(imgs, segs, pids):\n", 316 | "\n", 317 | " lb_n = nio.read_nii_bysitk(seg_fid)\n", 318 | " resample_flg = True\n", 319 | "\n", 320 | " img_obj = sitk.ReadImage( img_fid )\n", 321 | " seg_obj = sitk.ReadImage( seg_fid )\n", 322 | "\n", 323 | " array = sitk.GetArrayFromImage(img_obj)\n", 324 | "\n", 325 | " # cut histogram\n", 326 | " hir = float(np.percentile(array, 100.0 - HIST_CUT_TOP))\n", 327 | " array[array > hir] = hir\n", 328 | "\n", 329 | " his_img_o = sitk.GetImageFromArray(array)\n", 330 | " his_img_o = copy_spacing_ori(img_obj, his_img_o)\n", 331 | "\n", 332 | " # resampling\n", 333 | " img_spa_ori = img_obj.GetSpacing()\n", 334 | " res_img_o = resample_by_res(his_img_o, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2]],\n", 335 | " interpolator = sitk.sitkLinear, logging = True)\n", 336 | "\n", 337 | "\n", 338 | "\n", 339 | " ## label\n", 340 | " lb_arr = sitk.GetArrayFromImage(seg_obj)\n", 341 | "\n", 342 | " # resampling\n", 343 | " res_lb_o = resample_lb_by_res(seg_obj, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2] ], interpolator = sitk.sitkLinear,\n", 344 | " ref_img = None, logging = True)\n", 345 | "\n", 346 | "\n", 347 | " # crop out rois\n", 348 | " res_img_a = s2n(res_img_o)\n", 349 | "\n", 350 | " crop_img_a = image_crop(res_img_a.transpose(1,2,0), [256, 256],\n", 351 | " referece_ctr_idx = [res_img_a.shape[1] // 2, res_img_a.shape[2] //2],\n", 352 | " padval = res_img_a.min(), only_2d = True).transpose(2,0,1)\n", 353 | "\n", 354 | " out_img_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_img_a))\n", 355 | "\n", 356 | " res_lb_a = s2n(res_lb_o)\n", 357 | "\n", 358 | " crop_lb_a = image_crop(res_lb_a.transpose(1,2,0), [256, 256],\n", 359 | " referece_ctr_idx = [res_lb_a.shape[1] // 2, res_lb_a.shape[2] //2],\n", 360 | " padval = 0, only_2d = True).transpose(2,0,1)\n", 361 | "\n", 362 | " out_lb_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_lb_a))\n", 363 | "\n", 364 | "\n", 365 | " out_img_fid = os.path.join( OUT_FOLDER, f'image_{pid}.nii.gz' )\n", 366 | " out_lb_fid = os.path.join( OUT_FOLDER, f'label_{pid}.nii.gz' ) \n", 367 | "\n", 368 | " # then save pre-processed images\n", 369 | " sitk.WriteImage(out_img_obj, out_img_fid, True) \n", 370 | " sitk.WriteImage(out_lb_obj, out_lb_fid, True) \n", 371 | " print(\"{} has been saved\".format(out_img_fid))" 372 | ] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "Python 3", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.6.0" 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 2 396 | } 397 | -------------------------------------------------------------------------------- /dataloaders/ManualAnnoDatasetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually labeled dataset 3 | TODO: 4 | 1. Merge with superpixel dataset 5 | """ 6 | import glob 7 | import numpy as np 8 | import dataloaders.augutils as myaug 9 | import torch 10 | import random 11 | import os 12 | import copy 13 | import platform 14 | import json 15 | import re 16 | from dataloaders.common import BaseDataset, Subset 17 | # from common import BaseDataset, Subset 18 | from dataloaders.dataset_utils import* 19 | from pdb import set_trace 20 | from util.utils import CircularList 21 | 22 | class ManualAnnoDataset(BaseDataset): 23 | def __init__(self, which_dataset, base_dir, idx_split, mode, transforms, scan_per_load, min_fg = '', fix_length = None, tile_z_dim = 3, nsup = 1, exclude_list = [], extern_normalize_func = None,**kwargs): 24 | """ 25 | Manually labeled dataset 26 | Args: 27 | which_dataset: name of the dataset to use 28 | base_dir: directory of dataset 29 | idx_split: index of data split as we will do cross validation 30 | mode: 'train', 'val'. 31 | transforms: data transform (augmentation) function 32 | min_fg: minimum number of positive pixels in a 2D slice, mainly for stablize training when trained on manually labeled dataset 33 | scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time 34 | tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images 35 | nsup: number of support scans 36 | fix_length: fix the length of dataset 37 | exclude_list: Labels to be excluded 38 | extern_normalize_function: normalization function used for data pre-processing 39 | """ 40 | super(ManualAnnoDataset, self).__init__(base_dir) 41 | self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] 42 | self.sep = DATASET_INFO[which_dataset]['_SEP'] 43 | self.label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] 44 | self.transforms = transforms 45 | self.is_train = True if mode == 'train' else False 46 | self.phase = mode 47 | self.fix_length = fix_length 48 | self.all_label_names = self.label_name 49 | self.nclass = len(self.label_name) 50 | self.tile_z_dim = tile_z_dim 51 | self.base_dir = base_dir 52 | self.nsup = nsup 53 | self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii.gz") ] 54 | self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # make it circular for the ease of spliting folds 55 | 56 | self.exclude_lbs = exclude_list 57 | if len(exclude_list) > 0: 58 | print(f'###### Dataset: the following classes has been excluded {exclude_list}######') 59 | 60 | self.idx_split = idx_split 61 | self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold 62 | self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) 63 | 64 | self.scan_per_load = scan_per_load 65 | 66 | self.info_by_scan = None 67 | self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold 68 | 69 | if extern_normalize_func is not None: # helps to keep consistent between training and testing dataset. 70 | self.norm_func = extern_normalize_func 71 | print(f'###### Dataset: using external normalization statistics ######') 72 | else: 73 | self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()]) 74 | print(f'###### Dataset: using normalization statistics calculated from loaded data ######') 75 | 76 | if self.is_train: 77 | if scan_per_load > 0: # buffer needed 78 | self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) 79 | else: # load the entire set without a buffer 80 | self.pid_curr_load = self.scan_ids 81 | elif mode == 'val': 82 | self.pid_curr_load = self.scan_ids 83 | self.potential_support_sid = [] 84 | else: 85 | raise Exception 86 | self.actual_dataset = self.read_dataset() 87 | self.size = len(self.actual_dataset) 88 | self.overall_slice_by_cls = self.read_classfiles() 89 | self.update_subclass_lookup() 90 | 91 | def get_scanids(self, mode, idx_split): 92 | val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) 93 | self.potential_support_sid = val_ids[-self.nsup:] # this is actual file scan id, not index 94 | if mode == 'train': 95 | return [ ii for ii in self.img_pids if ii not in val_ids ] 96 | elif mode == 'val': 97 | return val_ids 98 | 99 | def reload_buffer(self): 100 | """ 101 | Reload a portion of the entire dataset, if the dataset is too large 102 | 1. delete original buffer 103 | 2. update self.ids_this_batch 104 | 3. update other internel variables like __len__ 105 | """ 106 | if self.scan_per_load <= 0: 107 | print("We are not using the reload buffer, doing notiong") 108 | return -1 109 | 110 | del self.actual_dataset 111 | del self.info_by_scan 112 | self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) 113 | self.actual_dataset = self.read_dataset() 114 | self.size = len(self.actual_dataset) 115 | self.update_subclass_lookup() 116 | print(f'Loader buffer reloaded with a new size of {self.size} slices') 117 | 118 | def organize_sample_fids(self): 119 | out_list = {} 120 | for curr_id in self.scan_ids: 121 | curr_dict = {} 122 | 123 | _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') 124 | _lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz') 125 | 126 | curr_dict["img_fid"] = _img_fid 127 | curr_dict["lbs_fid"] = _lb_fid 128 | out_list[str(curr_id)] = curr_dict 129 | return out_list 130 | 131 | def read_dataset(self): 132 | """ 133 | Build index pointers to individual slices 134 | Also keep a look-up table from scan_id, slice to index 135 | """ 136 | out_list = [] 137 | self.scan_z_idx = {} 138 | self.info_by_scan = {} # meta data of each scan 139 | glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset 140 | 141 | for scan_id, itm in self.img_lb_fids.items(): 142 | if scan_id not in self.pid_curr_load: 143 | continue 144 | 145 | img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out 146 | 147 | img = img.transpose(1,2,0) 148 | 149 | self.info_by_scan[scan_id] = _info 150 | 151 | img = np.float32(img) 152 | img = self.norm_func(img) 153 | 154 | self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] 155 | 156 | lb = read_nii_bysitk(itm["lbs_fid"]) 157 | lb = lb.transpose(1,2,0) 158 | 159 | lb = np.float32(lb) 160 | 161 | img = img[:256, :256, :] # FIXME a bug in shape from the pre-processing code 162 | lb = lb[:256, :256, :] 163 | 164 | assert img.shape[-1] == lb.shape[-1] 165 | base_idx = img.shape[-1] // 2 # index of the middle slice 166 | 167 | # write the beginning frame 168 | out_list.append( {"img": img[..., 0: 1], 169 | "lb":lb[..., 0: 0 + 1], 170 | "is_start": True, 171 | "is_end": False, 172 | "nframe": img.shape[-1], 173 | "scan_id": scan_id, 174 | "z_id":0}) 175 | 176 | self.scan_z_idx[scan_id][0] = glb_idx 177 | glb_idx += 1 178 | 179 | for ii in range(1, img.shape[-1] - 1): 180 | out_list.append( {"img": img[..., ii: ii + 1], 181 | "lb":lb[..., ii: ii + 1], 182 | "is_start": False, 183 | "is_end": False, 184 | "nframe": -1, 185 | "scan_id": scan_id, 186 | "z_id": ii 187 | }) 188 | self.scan_z_idx[scan_id][ii] = glb_idx 189 | glb_idx += 1 190 | 191 | ii += 1 # last frame, note the is_end flag 192 | out_list.append( {"img": img[..., ii: ii + 1], 193 | "lb":lb[..., ii: ii+ 1], 194 | "is_start": False, 195 | "is_end": True, 196 | "nframe": -1, 197 | "scan_id": scan_id, 198 | "z_id": ii 199 | }) 200 | 201 | self.scan_z_idx[scan_id][ii] = glb_idx 202 | glb_idx += 1 203 | 204 | return out_list 205 | 206 | def read_classfiles(self): 207 | with open( os.path.join(self.base_dir, f'classmap_{self.min_fg}.json') , 'r' ) as fopen: 208 | cls_map = json.load( fopen) 209 | fopen.close() 210 | 211 | with open( os.path.join(self.base_dir, 'classmap_1.json') , 'r' ) as fopen: 212 | self.tp1_cls_map = json.load( fopen) 213 | fopen.close() 214 | 215 | return cls_map 216 | 217 | def __getitem__(self, index): 218 | index = index % len(self.actual_dataset) 219 | curr_dict = self.actual_dataset[index] 220 | if self.is_train: 221 | if len(self.exclude_lbs) > 0: 222 | for _ex_cls in self.exclude_lbs: 223 | if curr_dict["z_id"] in self.tp1_cls_map[self.label_name[_ex_cls]][curr_dict["scan_id"]]: # this slice need to be excluded since it contains label which is supposed to be unseen 224 | return self.__getitem__(index + torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) 225 | 226 | comp = np.concatenate( [curr_dict["img"], curr_dict["lb"]], axis = -1 ) 227 | img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, use_onehot = False) 228 | 229 | else: 230 | img = curr_dict['img'] 231 | lb = curr_dict['lb'] 232 | 233 | img = np.float32(img) 234 | lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure 235 | 236 | img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) 237 | lb = torch.from_numpy( lb) 238 | 239 | if self.tile_z_dim: 240 | img = img.repeat( [ self.tile_z_dim, 1, 1] ) 241 | assert img.ndimension() == 3, f'actual dim {img.ndimension()}' 242 | 243 | is_start = curr_dict["is_start"] 244 | is_end = curr_dict["is_end"] 245 | nframe = np.int32(curr_dict["nframe"]) 246 | scan_id = curr_dict["scan_id"] 247 | z_id = curr_dict["z_id"] 248 | 249 | sample = {"image": img, 250 | "label":lb, 251 | "is_start": is_start, 252 | "is_end": is_end, 253 | "nframe": nframe, 254 | "scan_id": scan_id, 255 | "z_id": z_id 256 | } 257 | # Add auxiliary attributes 258 | if self.aux_attrib is not None: 259 | for key_prefix in self.aux_attrib: 260 | # Process the data sample, create new attributes and save them in a dictionary 261 | aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) 262 | for key_suffix in aux_attrib_val: 263 | # one function may create multiple attributes, so we need suffix to distinguish them 264 | sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] 265 | 266 | return sample 267 | 268 | def __len__(self): 269 | """ 270 | copy-paste from basic naive dataset configuration 271 | """ 272 | if self.fix_length != None: 273 | assert self.fix_length >= len(self.actual_dataset) 274 | return self.fix_length 275 | else: 276 | return len(self.actual_dataset) 277 | 278 | def update_subclass_lookup(self): 279 | """ 280 | Updating the class-slice indexing list 281 | Args: 282 | [internal] overall_slice_by_cls: 283 | { 284 | class1: {pid1: [slice1, slice2, ....], 285 | pid2: [slice1, slice2]}, 286 | ...} 287 | class2: 288 | ... 289 | } 290 | out[internal]: 291 | { 292 | class1: [ idx1, idx2, ... ], 293 | class2: [ idx1, idx2, ... ], 294 | ... 295 | } 296 | 297 | """ 298 | # delete previous ones if any 299 | assert self.overall_slice_by_cls is not None 300 | 301 | if not hasattr(self, 'idx_by_class'): 302 | self.idx_by_class = {} 303 | # filter the new one given the actual list 304 | for cls in self.label_name: 305 | if cls not in self.idx_by_class.keys(): 306 | self.idx_by_class[cls] = [] 307 | else: 308 | del self.idx_by_class[cls][:] 309 | for cls, dict_by_pid in self.overall_slice_by_cls.items(): 310 | for pid, slice_list in dict_by_pid.items(): 311 | if pid not in self.pid_curr_load: 312 | continue 313 | self.idx_by_class[cls] += [ self.scan_z_idx[pid][_sli] for _sli in slice_list ] 314 | print("###### index-by-class table has been reloaded ######") 315 | 316 | def getMaskMedImg(self, label, class_id, class_ids): 317 | """ 318 | Generate FG/BG mask from the segmentation mask. Used when getting the support 319 | """ 320 | # Dense Mask 321 | fg_mask = torch.where(label == class_id, 322 | torch.ones_like(label), torch.zeros_like(label)) 323 | bg_mask = torch.where(label != class_id, 324 | torch.ones_like(label), torch.zeros_like(label)) 325 | for class_id in class_ids: 326 | bg_mask[label == class_id] = 0 327 | 328 | return {'fg_mask': fg_mask, 329 | 'bg_mask': bg_mask} 330 | 331 | def subsets(self, sub_args_lst=None): 332 | """ 333 | Override base-class subset method 334 | Create subsets by scan_ids 335 | 336 | output: list [[] , ] 337 | """ 338 | 339 | if sub_args_lst is not None: 340 | subsets = [] 341 | ii = 0 342 | for cls_name, index_list in self.idx_by_class.items(): 343 | subsets.append( Subset(dataset = self, indices = index_list, sub_attrib_args = sub_args_lst[ii]) ) 344 | ii += 1 345 | else: 346 | subsets = [Subset(dataset=self, indices=index_list) for _, index_list in self.idx_by_class.items()] 347 | return subsets 348 | 349 | def get_support(self, curr_class: int, class_idx: list, scan_idx: list, npart: int): 350 | """ 351 | getting (probably multi-shot) support set for evaluation 352 | sample from 50% (1shot) or 20 35 50 65 80 (5shot) 353 | Args: 354 | curr_cls: current class to segment, starts from 1 355 | class_idx: a list of all foreground class in nways, starts from 1 356 | npart: how may chunks used to split the support 357 | scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan 358 | being served as support, in self.pid_curr_load 359 | """ 360 | assert npart % 2 == 1 361 | assert curr_class != 0; assert 0 not in class_idx 362 | assert not self.is_train 363 | 364 | self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] 365 | print(f'###### Using {len(scan_idx)} shot evaluation!') 366 | 367 | if npart == 1: 368 | pcts = [0.5] 369 | else: 370 | half_part = 1 / (npart * 2) 371 | part_interval = (1.0 - 1.0 / npart) / (npart - 1) 372 | pcts = [ half_part + part_interval * ii for ii in range(npart) ] 373 | 374 | print(f'###### Parts percentage: {pcts} ######') 375 | 376 | out_buffer = [] # [{scanid, img, lb}] 377 | for _part in range(npart): 378 | concat_buffer = [] # for each fold do a concat in image and mask in batch dimension 379 | for scan_order in scan_idx: 380 | _scan_id = self.pid_curr_load[ scan_order ] 381 | print(f'Using scan {_scan_id} as support!') 382 | 383 | # for _pc in pcts: 384 | _zlist = self.tp1_cls_map[self.label_name[curr_class]][_scan_id] # list of indices 385 | _zid = _zlist[int(pcts[_part] * len(_zlist))] 386 | _glb_idx = self.scan_z_idx[_scan_id][_zid] 387 | 388 | # almost copy-paste __getitem__ but no augmentation 389 | curr_dict = self.actual_dataset[_glb_idx] 390 | img = curr_dict['img'] 391 | lb = curr_dict['lb'] 392 | 393 | img = np.float32(img) 394 | lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure 395 | 396 | img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) 397 | lb = torch.from_numpy( lb ) 398 | 399 | if self.tile_z_dim: 400 | img = img.repeat( [ self.tile_z_dim, 1, 1] ) 401 | assert img.ndimension() == 3, f'actual dim {img.ndimension()}' 402 | 403 | is_start = curr_dict["is_start"] 404 | is_end = curr_dict["is_end"] 405 | nframe = np.int32(curr_dict["nframe"]) 406 | scan_id = curr_dict["scan_id"] 407 | z_id = curr_dict["z_id"] 408 | 409 | sample = {"image": img, 410 | "label":lb, 411 | "is_start": is_start, 412 | "inst": None, 413 | "scribble": None, 414 | "is_end": is_end, 415 | "nframe": nframe, 416 | "scan_id": scan_id, 417 | "z_id": z_id 418 | } 419 | 420 | concat_buffer.append(sample) 421 | out_buffer.append({ 422 | "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0), 423 | "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0), 424 | 425 | }) 426 | 427 | # do the concat, and add to output_buffer 428 | 429 | # post-processing, including keeping the foreground and suppressing background. 430 | support_images = [] 431 | support_mask = [] 432 | support_class = [] 433 | for itm in out_buffer: 434 | support_images.append(itm["image"]) 435 | support_class.append(curr_class) 436 | support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx )) 437 | 438 | return {'class_ids': [support_class], 439 | 'support_images': [support_images], # 440 | 'support_mask': [support_mask], 441 | } 442 | 443 | --------------------------------------------------------------------------------