├── DSPNet.png ├── LICENSE ├── README.md ├── config_ssl_upload.py ├── dataloaders ├── GenericSuperDatasetv2.py ├── ManualAnnoDatasetv2.py ├── __init__.py ├── augutils.py ├── common.py ├── dataset_utils.py ├── dev_customized_med.py ├── image_transforms.py └── niftiio.py ├── models ├── __init__.py ├── alpmodule.py ├── backbone │ ├── __init__.py │ └── torchvision_backbones.py └── grid_proto_fewshot.py ├── pigeon.jpg ├── requirements.txt ├── train_ssl_abdominal_ct.sh ├── training.py ├── util ├── __init__.py ├── metric.py ├── seed_init.py └── utils.py └── validation.py /DSPNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tntek/DSPNet/a4b71c2ab1229221584ccb2a5100bf66a7d1ef03/DSPNet.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Cheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSPNet 2 | 3 | **Abstract**: 4 | 5 | As dominant Few-shot Semantic Segmentation (FSS) methods, the prototypical scheme suffers from a fundamental limitation: The pooling-based prototypes are prone to losing local details. The complicated and diverse details in medical 6 | images amplify this problem considerably. Unlike conventional incremental solution that constructs new prototypes to capture more details, this paper introduces a novel Detail Self-refined Prototype Network (DSPNet). Our core idea is 7 | enhancing theprototypes’ ability to model details via detail self-refining. To this end, we propose two new attention-like designs. In foreground semantic prototype attention module, to construct global semantics while maintaining the 8 | captured detail semantics, we fuse cluster-based detail prototypes as a single class prototype in a channel-wise weighting fashion. In background channel-structural multi-head attention module, considering that the complicated background 9 | often has no apparent semantic relation in the spatial dimensions,we integrate each background detail prototype’s channel structural information for its self-enhancement. Specifically, we introduce a neighbour channel-aware regulation 10 | into the multi-head channel attention, exploiting a local-global adjustment mechanism.Elements of each detail prototype are individually refreshed by different heads in BCMA. Extensive experiments on two challenging medical benchmarks 11 | demonstrate the superiority of DSPNet over previous state-of-the-art FSS methods. 12 | **NOTE: We are actively updating this repository** 13 | 14 | If you find this code base useful, please cite our paper. Thanks! 15 | 16 | ``` 17 | @article{Song Tang2024DSPNet, 18 | title={Few-Shot Medical Image Segmentation with Detail Self-Refined Prototypes}, 19 | author={Song Tang, Shaxu Yan, Xiaozhi Qi, Jianxin Gao, Mao Ye, Jianwei Zhang and Xiatian Zhu}, 20 | journal={}, 21 | year={2024} 22 | } 23 | ``` 24 | 25 | ### 1. Dependencies 26 | 27 | Please install essential dependencies (see `requirements.txt`) 28 | 29 | ``` 30 | dcm2nii 31 | nibabel==2.5.1 32 | numpy==1.21.6 33 | opencv-python==4.1.1 34 | Pillow==9.5.0 35 | sacred==0.7.5 36 | scikit-image==0.14.0 37 | SimpleITK==1.2.3 38 | torch==1.8.1 39 | torchvision==0.9.1 40 | ``` 41 | 42 | ### 2. Data pre-processing 43 | 44 | ### Datasets and pre-processing 45 | **NOTE:** The ipynb and sh files below, used for pre-processing, can be found at the link of SSL-ALPNet: https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation 46 | 47 | Download: 48 | **Abdominal MRI** 49 | 50 | 0. Download [Combined Healthy Abdominal Organ Segmentation dataset](https://chaos.grand-challenge.org/) and put the `/MR` folder under `./data/CHAOST2/` directory 51 | 52 | 1. Converting downloaded data (T2 fold) to `nii` files in 3D for the ease of reading 53 | 54 | run `./data/CHAOST2/dcm_img_to_nii.sh` to convert dicom images to nifti files. 55 | 56 | run `./data/CHAOST2/png_gth_to_nii.ipynp` to convert ground truth with `png` format to nifti. 57 | 58 | 2. Pre-processing downloaded images 59 | 60 | run `./data/CHAOST2/image_normalize.ipynb` 61 | 62 | **Abdominal CT** 63 | 64 | 0. Download [Synapse Multi-atlas Abdominal Segmentation dataset](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789) and put the `/img` and `/label` folders under `./data/SABS/` directory 65 | 66 | 1. Intensity windowing 67 | 68 | run `./data/SABS/intensity_normalization.ipynb` to apply abdominal window. 69 | 70 | 2. Crop irrelavent emptry background and resample images 71 | 72 | run `./data/SABS/resampling_and_roi.ipynb` 73 | 74 | **Shared steps** 75 | 76 | 3. Build class-slice indexing for setting up experiments 77 | 78 | run `./data/class_slice_index_gen.ipynb` 79 | 80 | ### Training 81 | 1. Download pre-trained ResNet-101 weights [vanilla version](https://download.pytorch.org/models/resnet101-63fe2227.pth) or [deeplabv3 version](https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth) and put your checkpoints `./pretrained_model/hub/checkpoints` folder, 82 | 2. Run `training.py` 83 | 84 | #### Note: 85 | The α and (w1, w2, w3) coefficient for in the code `./models/alpmodule.py` should be manually modified. 86 | For setting 1, ABD: α = 0.3, (w1, w2, w3) = (0.2, 0.8, 0.2) CMR:α = 0.3, (w1, w2, w3) = (0.1, 0.9, 0.1) 87 | For setting 2, ABD: α = 0.2, (w1, w2, w3) = (0.3, 0.6, 0.3) 88 | 89 | ### Testing 90 | Run `validation.py` 91 | 92 | ### Acknowledgement 93 | This code is based on [SSL-ALPNet](https://arxiv.org/abs/2007.09886v2) (ECCV'20) by [Ouyang et al.](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation.git) 94 | -------------------------------------------------------------------------------- /config_ssl_upload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment configuration file 3 | Extended from config file from original PANet Repository 4 | """ 5 | import os 6 | import re 7 | import glob 8 | import itertools 9 | 10 | import sacred 11 | from sacred import Experiment 12 | from sacred.observers import FileStorageObserver 13 | from sacred.utils import apply_backspaces_and_linefeeds 14 | 15 | from platform import node 16 | from datetime import datetime 17 | 18 | sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False 19 | sacred.SETTINGS.CAPTURE_MODE = 'no' 20 | 21 | ex = Experiment('mySSL') 22 | ex.captured_out_filter = apply_backspaces_and_linefeeds 23 | 24 | source_folders = ['.', './dataloaders', './models', './util'] 25 | sources_to_save = list(itertools.chain.from_iterable( 26 | [glob.glob(f'{folder}/*.py') for folder in source_folders])) 27 | for source_file in sources_to_save: 28 | ex.add_source_file(source_file) 29 | 30 | @ex.config 31 | def cfg(): 32 | """Default configurations""" 33 | seed = 1234 34 | gpu_id = 0 35 | mode = 'train' # for now only allows 'train' 36 | num_workers = 4 # 0 for debugging. 37 | 38 | dataset = 'CHAOST2_Superpix' # i.e. abdominal MRI SABS_Superpix CHAOST2_Superpix 39 | use_coco_init = True # initialize backbone with MS_COCO initialization. Anyway coco does not contain medical images 40 | 41 | ### Training 42 | n_steps = 50000 # 100100 43 | batch_size = 1 44 | lr_milestones = [ (ii + 1) * 1000 for ii in range(n_steps // 1000 - 1)] 45 | lr_step_gamma = 0.95 46 | ignore_label = 255 47 | print_interval = 100 48 | save_snapshot_every =25000 # 25000 49 | max_iters_per_load = 1000 # epoch size, interval for reloading the dataset 50 | scan_per_load = -1 # numbers of 3d scans per load for saving memory. If -1, load the entire dataset to the memory 51 | which_aug = 'sabs_aug' # standard data augmentation with intensity and geometric transforms 52 | input_size = (256, 256) 53 | min_fg_data='1' # when training with manual annotations, indicating number of foreground pixels in a single class single slice. This empirically stablizes the training process 54 | label_sets = 1 # which group of labels taking as training (the rest are for testing) 55 | exclude_cls_list = [2,3] # testing classes to be excluded in training. Set to [] if testing under setting 1 56 | usealign = True # see vanilla PANet 57 | use_wce = True 58 | 59 | ### Validation 60 | z_margin = 0 61 | eval_fold = 0 # which fold for 5 fold cross validation 62 | support_idx=[-1] # indicating which scan is used as support in testing. 63 | val_wsize=2 # L_H, L_W in testing 64 | n_sup_part = 3 # number of chuncks in testing 65 | 66 | # Network 67 | modelname = 'dlfcn_res101' # resnet 101 backbone from torchvision fcn-deeplab 68 | clsname = "grid_proto" # 69 | reload_model_path = '/home/kouguozhao/projects/YSX/few shot segmentation to/ssl-image medical -change/Self-supervised-Fewshot-Medical-Image-Segmentation-master(ts) (copy)/Self-supervised-Fewshot-Medical-Image-Segmentation-master/runs/mySSL__SABS_Superpix_sets_1_1shot/5/snapshots/50000.pth' # path for reloading a trained model (overrides ms-coco initialization) 70 | proto_grid_size = 8 # L_H, L_W = (32, 32) / 8 = (4, 4) in training 71 | feature_hw = [32, 32] # feature map size, should couple this with backbone in future 72 | 73 | # SSL 74 | superpix_scale = 'MIDDLE' #MIDDLE/ LARGE 75 | 76 | model = { 77 | 'align': usealign, 78 | 'use_coco_init': use_coco_init, 79 | 'which_model': modelname, 80 | 'cls_name': clsname, 81 | 'proto_grid_size' : proto_grid_size, 82 | 'feature_hw': feature_hw, 83 | 'reload_model_path': reload_model_path 84 | } 85 | 86 | task = { 87 | 'n_ways': 1, 88 | 'n_shots': 1, 89 | 'n_queries': 1, 90 | 'npart': n_sup_part 91 | } 92 | 93 | optim_type = 'sgd' 94 | optim = { 95 | 'lr': 1e-3, 96 | 'momentum': 0.9, 97 | 'weight_decay': 0.0005, 98 | } 99 | 100 | exp_prefix = '' 101 | 102 | exp_str = '_'.join( 103 | [exp_prefix] 104 | + [dataset,] 105 | + [f'sets_{label_sets}_{task["n_shots"]}shot']) 106 | 107 | path = { 108 | 'log_dir': './runs', 109 | 'SABS':{'data_dir': "./data/SABS/sabs_CT_normalized" 110 | }, 111 | 'C0':{'data_dir': "feed your dataset path here" 112 | }, 113 | 'CHAOST2':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized/" 114 | }, 115 | 'SABS_Superpix':{'data_dir': "./data/SABS/sabs_CT_normalized"}, 116 | 'C0_Superpix':{'data_dir': "feed your dataset path here"}, 117 | 'CHAOST2_Superpix':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized/"}, 118 | } 119 | 120 | 121 | @ex.config_hook 122 | def add_observer(config, command_name, logger): 123 | """A hook fucntion to add observer""" 124 | exp_name = f'{ex.path}_{config["exp_str"]}' 125 | observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name)) 126 | ex.observers.append(observer) 127 | return config 128 | -------------------------------------------------------------------------------- /dataloaders/GenericSuperDatasetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for training with pseudolabels 3 | TODO: 4 | 1. Merge with manual annotated dataset 5 | 2. superpixel_scale -> superpix_config, feed like a dict 6 | """ 7 | import glob 8 | import numpy as np 9 | import dataloaders.augutils as myaug 10 | import torch 11 | import random 12 | import os 13 | import copy 14 | import platform 15 | import json 16 | import re 17 | from dataloaders.common import BaseDataset, Subset 18 | from dataloaders.dataset_utils import* 19 | from pdb import set_trace 20 | from util.utils import CircularList 21 | 22 | class SuperpixelDataset(BaseDataset): 23 | def __init__(self, which_dataset, base_dir, idx_split, mode, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], superpix_scale = 'SMALL', **kwargs): 24 | """ 25 | Pseudolabel dataset 26 | Args: 27 | which_dataset: name of the dataset to use 28 | base_dir: directory of dataset 29 | idx_split: index of data split as we will do cross validation 30 | mode: 'train', 'val'. 31 | nsup: number of scans used as support. currently idle for superpixel dataset 32 | transforms: data transform (augmentation) function 33 | scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time 34 | num_rep: Number of augmentation applied for a same pseudolabel 35 | tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images 36 | fix_length: fix the length of dataset 37 | exclude_list: Labels to be excluded 38 | superpix_scale: config of superpixels 39 | """ 40 | 41 | super(SuperpixelDataset, self).__init__(base_dir) 42 | 43 | self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] 44 | self.sep = DATASET_INFO[which_dataset]['_SEP'] 45 | self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME'] 46 | self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] 47 | 48 | self.transforms = transforms 49 | self.is_train = True if mode == 'train' else False 50 | assert mode == 'train' 51 | self.fix_length = fix_length 52 | self.nclass = len(self.pseu_label_name) 53 | self.num_rep = num_rep 54 | self.tile_z_dim = tile_z_dim 55 | 56 | # find scans in the data folder 57 | self.nsup = nsup 58 | self.base_dir = base_dir 59 | self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii.gz") ] 60 | self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) 61 | 62 | # experiment configs 63 | self.exclude_lbs = exclude_list 64 | self.superpix_scale = superpix_scale 65 | if len(exclude_list) > 0: 66 | print(f'###### Dataset: the following classes has been excluded {exclude_list}######') 67 | self.idx_split = idx_split 68 | self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold 69 | self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) 70 | self.scan_per_load = scan_per_load 71 | 72 | self.info_by_scan = None 73 | self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold 74 | self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()]) 75 | 76 | if self.is_train: 77 | if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch 78 | self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) 79 | else: # load the entire set without a buffer 80 | self.pid_curr_load = self.scan_ids 81 | elif mode == 'val': 82 | self.pid_curr_load = self.scan_ids 83 | else: 84 | raise Exception 85 | 86 | self.actual_dataset = self.read_dataset() 87 | self.size = len(self.actual_dataset) 88 | self.overall_slice_by_cls = self.read_classfiles() 89 | 90 | print("###### Initial scans loaded: ######") 91 | print(self.pid_curr_load) 92 | 93 | def get_scanids(self, mode, idx_split): 94 | """ 95 | Load scans by train-test split 96 | leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one 97 | Args: 98 | idx_split: index for spliting cross-validation folds 99 | """ 100 | val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) 101 | if mode == 'train': 102 | return [ ii for ii in self.img_pids if ii not in val_ids ] 103 | elif mode == 'val': 104 | return val_ids 105 | 106 | def reload_buffer(self): 107 | """ 108 | Reload a only portion of the entire dataset, if the dataset is too large 109 | 1. delete original buffer 110 | 2. update self.ids_this_batch 111 | 3. update other internel variables like __len__ 112 | """ 113 | if self.scan_per_load <= 0: 114 | print("We are not using the reload buffer, doing notiong") 115 | return -1 116 | 117 | del self.actual_dataset 118 | del self.info_by_scan 119 | 120 | self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) 121 | self.actual_dataset = self.read_dataset() 122 | self.size = len(self.actual_dataset) 123 | self.update_subclass_lookup() 124 | print(f'Loader buffer reloaded with a new size of {self.size} slices') 125 | 126 | def organize_sample_fids(self): 127 | out_list = {} 128 | for curr_id in self.scan_ids: 129 | curr_dict = {} 130 | 131 | _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') 132 | _lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz') 133 | 134 | curr_dict["img_fid"] = _img_fid 135 | curr_dict["lbs_fid"] = _lb_fid 136 | out_list[str(curr_id)] = curr_dict 137 | return out_list 138 | 139 | def read_dataset(self): 140 | """ 141 | Read images into memory and store them in 2D 142 | Build tables for the position of an individual 2D slice in the entire dataset 143 | """ 144 | out_list = [] 145 | self.scan_z_idx = {} 146 | self.info_by_scan = {} # meta data of each scan 147 | glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset 148 | 149 | for scan_id, itm in self.img_lb_fids.items(): 150 | if scan_id not in self.pid_curr_load: 151 | continue 152 | 153 | img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out 154 | img = img.transpose(1,2,0) 155 | self.info_by_scan[scan_id] = _info 156 | 157 | img = np.float32(img) 158 | img = self.norm_func(img) 159 | 160 | self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] 161 | 162 | lb = read_nii_bysitk(itm["lbs_fid"]) 163 | lb = lb.transpose(1,2,0) 164 | lb = np.int32(lb) 165 | 166 | img = img[:256, :256, :] 167 | lb = lb[:256, :256, :] 168 | 169 | # format of slices: [axial_H x axial_W x Z] 170 | 171 | assert img.shape[-1] == lb.shape[-1] 172 | base_idx = img.shape[-1] // 2 # index of the middle slice 173 | 174 | # re-organize 3D images into 2D slices and record essential information for each slice 175 | out_list.append( {"img": img[..., 0: 1], 176 | "lb":lb[..., 0: 0 + 1], 177 | "sup_max_cls": lb[..., 0: 0 + 1].max(), 178 | "is_start": True, 179 | "is_end": False, 180 | "nframe": img.shape[-1], 181 | "scan_id": scan_id, 182 | "z_id":0}) 183 | 184 | self.scan_z_idx[scan_id][0] = glb_idx 185 | glb_idx += 1 186 | 187 | for ii in range(1, img.shape[-1] - 1): 188 | out_list.append( {"img": img[..., ii: ii + 1], 189 | "lb":lb[..., ii: ii + 1], 190 | "is_start": False, 191 | "is_end": False, 192 | "sup_max_cls": lb[..., ii: ii + 1].max(), 193 | "nframe": -1, 194 | "scan_id": scan_id, 195 | "z_id": ii 196 | }) 197 | self.scan_z_idx[scan_id][ii] = glb_idx 198 | glb_idx += 1 199 | 200 | ii += 1 # last slice of a 3D volume 201 | out_list.append( {"img": img[..., ii: ii + 1], 202 | "lb":lb[..., ii: ii+ 1], 203 | "is_start": False, 204 | "is_end": True, 205 | "sup_max_cls": lb[..., ii: ii + 1].max(), 206 | "nframe": -1, 207 | "scan_id": scan_id, 208 | "z_id": ii 209 | }) 210 | 211 | self.scan_z_idx[scan_id][ii] = glb_idx 212 | glb_idx += 1 213 | 214 | return out_list 215 | 216 | def read_classfiles(self): 217 | """ 218 | Load the scan-slice-class indexing file 219 | """ 220 | with open( os.path.join(self.base_dir, f'classmap_{self.min_fg}.json') , 'r' ) as fopen: 221 | cls_map = json.load( fopen) 222 | fopen.close() 223 | 224 | with open( os.path.join(self.base_dir, 'classmap_1.json') , 'r' ) as fopen: 225 | self.tp1_cls_map = json.load( fopen) 226 | fopen.close() 227 | 228 | return cls_map 229 | 230 | def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val = None): 231 | """ 232 | pick up a certain super-pixel class or multiple classes, and binarize it into segmentation target 233 | Args: 234 | super_map: super-pixel map 235 | bi_val: if given, pick up a certain superpixel. Otherwise, draw a random one 236 | sup_max_cls: max index of superpixel for avoiding overshooting when selecting superpixel 237 | 238 | """ 239 | if bi_val == None: 240 | bi_val = int(torch.randint(low = 1, high = int(sup_max_cls), size = (1,))) 241 | 242 | return np.float32(super_map == bi_val) 243 | 244 | 245 | def __getitem__(self, index): 246 | index = index % len(self.actual_dataset) 247 | curr_dict = self.actual_dataset[index] 248 | sup_max_cls = curr_dict['sup_max_cls'] 249 | if sup_max_cls < 1: 250 | return self.__getitem__(index + 1) 251 | 252 | image_t = curr_dict["img"] 253 | label_raw = curr_dict["lb"] 254 | 255 | for _ex_cls in self.exclude_lbs: 256 | if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen 257 | return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) 258 | 259 | label_t = self.supcls_pick_binarize(label_raw, sup_max_cls) 260 | 261 | pair_buffer = [] 262 | 263 | comp = np.concatenate( [curr_dict["img"], label_t], axis = -1 ) 264 | 265 | for ii in range(self.num_rep): 266 | img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False) 267 | 268 | img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ) 269 | lb = torch.from_numpy( lb.squeeze(-1)) 270 | 271 | if self.tile_z_dim: 272 | img = img.repeat( [ self.tile_z_dim, 1, 1] ) 273 | assert img.ndimension() == 3, f'actual dim {img.ndimension()}' 274 | 275 | is_start = curr_dict["is_start"] 276 | is_end = curr_dict["is_end"] 277 | nframe = np.int32(curr_dict["nframe"]) 278 | scan_id = curr_dict["scan_id"] 279 | z_id = curr_dict["z_id"] 280 | 281 | sample = {"image": img, 282 | "label":lb, 283 | "is_start": is_start, 284 | "is_end": is_end, 285 | "nframe": nframe, 286 | "scan_id": scan_id, 287 | "z_id": z_id 288 | } 289 | 290 | # Add auxiliary attributes 291 | if self.aux_attrib is not None: 292 | for key_prefix in self.aux_attrib: 293 | # Process the data sample, create new attributes and save them in a dictionary 294 | aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) 295 | for key_suffix in aux_attrib_val: 296 | # one function may create multiple attributes, so we need suffix to distinguish them 297 | sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] 298 | pair_buffer.append(sample) 299 | 300 | support_images = [] 301 | support_mask = [] 302 | support_class = [] 303 | 304 | query_images = [] 305 | query_labels = [] 306 | query_class = [] 307 | 308 | for idx, itm in enumerate(pair_buffer): 309 | if idx % 2 == 0: 310 | support_images.append(itm["image"]) 311 | support_class.append(1) # pseudolabel class 312 | support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] )) 313 | else: 314 | query_images.append(itm["image"]) 315 | query_class.append(1) 316 | query_labels.append( itm["label"]) 317 | 318 | return {'class_ids': [support_class], 319 | 'support_images': [support_images], # 320 | 'support_mask': [support_mask], 321 | 'query_images': query_images, # 322 | 'query_labels': query_labels, 323 | } 324 | 325 | 326 | def __len__(self): 327 | """ 328 | copy-paste from basic naive dataset configuration 329 | """ 330 | if self.fix_length != None: 331 | assert self.fix_length >= len(self.actual_dataset) 332 | return self.fix_length 333 | else: 334 | return len(self.actual_dataset) 335 | 336 | def getMaskMedImg(self, label, class_id, class_ids): 337 | """ 338 | Generate FG/BG mask from the segmentation mask 339 | 340 | Args: 341 | label: semantic mask 342 | class_id: semantic class of interest 343 | class_ids: all class id in this episode 344 | """ 345 | fg_mask = torch.where(label == class_id, 346 | torch.ones_like(label), torch.zeros_like(label)) 347 | bg_mask = torch.where(label != class_id, 348 | torch.ones_like(label), torch.zeros_like(label)) 349 | for class_id in class_ids: 350 | bg_mask[label == class_id] = 0 351 | 352 | return {'fg_mask': fg_mask, 353 | 'bg_mask': bg_mask} 354 | -------------------------------------------------------------------------------- /dataloaders/ManualAnnoDatasetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually labeled dataset 3 | TODO: 4 | 1. Merge with superpixel dataset 5 | """ 6 | import glob 7 | import numpy as np 8 | import dataloaders.augutils as myaug 9 | import torch 10 | import random 11 | import os 12 | import copy 13 | import platform 14 | import json 15 | import re 16 | from dataloaders.common import BaseDataset, Subset 17 | # from common import BaseDataset, Subset 18 | from dataloaders.dataset_utils import* 19 | from pdb import set_trace 20 | from util.utils import CircularList 21 | 22 | class ManualAnnoDataset(BaseDataset): 23 | def __init__(self, which_dataset, base_dir, idx_split, mode, transforms, scan_per_load, min_fg = '', fix_length = None, tile_z_dim = 3, nsup = 1, exclude_list = [], extern_normalize_func = None,**kwargs): 24 | """ 25 | Manually labeled dataset 26 | Args: 27 | which_dataset: name of the dataset to use 28 | base_dir: directory of dataset 29 | idx_split: index of data split as we will do cross validation 30 | mode: 'train', 'val'. 31 | transforms: data transform (augmentation) function 32 | min_fg: minimum number of positive pixels in a 2D slice, mainly for stablize training when trained on manually labeled dataset 33 | scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time 34 | tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images 35 | nsup: number of support scans 36 | fix_length: fix the length of dataset 37 | exclude_list: Labels to be excluded 38 | extern_normalize_function: normalization function used for data pre-processing 39 | """ 40 | super(ManualAnnoDataset, self).__init__(base_dir) 41 | self.img_modality = DATASET_INFO[which_dataset]['MODALITY'] 42 | self.sep = DATASET_INFO[which_dataset]['_SEP'] 43 | self.label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME'] 44 | self.transforms = transforms 45 | self.is_train = True if mode == 'train' else False 46 | self.phase = mode 47 | self.fix_length = fix_length 48 | self.all_label_names = self.label_name 49 | self.nclass = len(self.label_name) 50 | self.tile_z_dim = tile_z_dim 51 | self.base_dir = base_dir 52 | self.nsup = nsup 53 | self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii.gz") ] 54 | self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # make it circular for the ease of spliting folds 55 | 56 | self.exclude_lbs = exclude_list 57 | if len(exclude_list) > 0: 58 | print(f'###### Dataset: the following classes has been excluded {exclude_list}######') 59 | 60 | self.idx_split = idx_split 61 | self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold 62 | self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg) 63 | 64 | self.scan_per_load = scan_per_load 65 | 66 | self.info_by_scan = None 67 | self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold 68 | 69 | if extern_normalize_func is not None: # helps to keep consistent between training and testing dataset. 70 | self.norm_func = extern_normalize_func 71 | print(f'###### Dataset: using external normalization statistics ######') 72 | else: 73 | self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()]) 74 | print(f'###### Dataset: using normalization statistics calculated from loaded data ######') 75 | 76 | if self.is_train: 77 | if scan_per_load > 0: # buffer needed 78 | self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load) 79 | else: # load the entire set without a buffer 80 | self.pid_curr_load = self.scan_ids 81 | elif mode == 'val': 82 | self.pid_curr_load = self.scan_ids 83 | self.potential_support_sid = [] 84 | else: 85 | raise Exception 86 | self.actual_dataset = self.read_dataset() 87 | self.size = len(self.actual_dataset) 88 | self.overall_slice_by_cls = self.read_classfiles() 89 | self.update_subclass_lookup() 90 | 91 | def get_scanids(self, mode, idx_split): 92 | val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup]) 93 | self.potential_support_sid = val_ids[-self.nsup:] # this is actual file scan id, not index 94 | if mode == 'train': 95 | return [ ii for ii in self.img_pids if ii not in val_ids ] 96 | elif mode == 'val': 97 | return val_ids 98 | 99 | def reload_buffer(self): 100 | """ 101 | Reload a portion of the entire dataset, if the dataset is too large 102 | 1. delete original buffer 103 | 2. update self.ids_this_batch 104 | 3. update other internel variables like __len__ 105 | """ 106 | if self.scan_per_load <= 0: 107 | print("We are not using the reload buffer, doing notiong") 108 | return -1 109 | 110 | del self.actual_dataset 111 | del self.info_by_scan 112 | self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False ) 113 | self.actual_dataset = self.read_dataset() 114 | self.size = len(self.actual_dataset) 115 | self.update_subclass_lookup() 116 | print(f'Loader buffer reloaded with a new size of {self.size} slices') 117 | 118 | def organize_sample_fids(self): 119 | out_list = {} 120 | for curr_id in self.scan_ids: 121 | curr_dict = {} 122 | 123 | _img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz') 124 | _lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz') 125 | 126 | curr_dict["img_fid"] = _img_fid 127 | curr_dict["lbs_fid"] = _lb_fid 128 | out_list[str(curr_id)] = curr_dict 129 | return out_list 130 | 131 | def read_dataset(self): 132 | """ 133 | Build index pointers to individual slices 134 | Also keep a look-up table from scan_id, slice to index 135 | """ 136 | out_list = [] 137 | self.scan_z_idx = {} 138 | self.info_by_scan = {} # meta data of each scan 139 | glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset 140 | 141 | for scan_id, itm in self.img_lb_fids.items(): 142 | if scan_id not in self.pid_curr_load: 143 | continue 144 | 145 | img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out 146 | 147 | img = img.transpose(1,2,0) 148 | 149 | self.info_by_scan[scan_id] = _info 150 | 151 | img = np.float32(img) 152 | img = self.norm_func(img) 153 | 154 | self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])] 155 | 156 | lb = read_nii_bysitk(itm["lbs_fid"]) 157 | lb = lb.transpose(1,2,0) 158 | 159 | lb = np.float32(lb) 160 | 161 | img = img[:256, :256, :] # FIXME a bug in shape from the pre-processing code 162 | lb = lb[:256, :256, :] 163 | 164 | assert img.shape[-1] == lb.shape[-1] 165 | base_idx = img.shape[-1] // 2 # index of the middle slice 166 | 167 | # write the beginning frame 168 | out_list.append( {"img": img[..., 0: 1], 169 | "lb":lb[..., 0: 0 + 1], 170 | "is_start": True, 171 | "is_end": False, 172 | "nframe": img.shape[-1], 173 | "scan_id": scan_id, 174 | "z_id":0}) 175 | 176 | self.scan_z_idx[scan_id][0] = glb_idx 177 | glb_idx += 1 178 | 179 | for ii in range(1, img.shape[-1] - 1): 180 | out_list.append( {"img": img[..., ii: ii + 1], 181 | "lb":lb[..., ii: ii + 1], 182 | "is_start": False, 183 | "is_end": False, 184 | "nframe": -1, 185 | "scan_id": scan_id, 186 | "z_id": ii 187 | }) 188 | self.scan_z_idx[scan_id][ii] = glb_idx 189 | glb_idx += 1 190 | 191 | ii += 1 # last frame, note the is_end flag 192 | out_list.append( {"img": img[..., ii: ii + 1], 193 | "lb":lb[..., ii: ii+ 1], 194 | "is_start": False, 195 | "is_end": True, 196 | "nframe": -1, 197 | "scan_id": scan_id, 198 | "z_id": ii 199 | }) 200 | 201 | self.scan_z_idx[scan_id][ii] = glb_idx 202 | glb_idx += 1 203 | 204 | return out_list 205 | 206 | def read_classfiles(self): 207 | with open( os.path.join(self.base_dir, f'classmap_{self.min_fg}.json') , 'r' ) as fopen: 208 | cls_map = json.load( fopen) 209 | fopen.close() 210 | 211 | with open( os.path.join(self.base_dir, 'classmap_1.json') , 'r' ) as fopen: 212 | self.tp1_cls_map = json.load( fopen) 213 | fopen.close() 214 | 215 | return cls_map 216 | 217 | def __getitem__(self, index): 218 | index = index % len(self.actual_dataset) 219 | curr_dict = self.actual_dataset[index] 220 | if self.is_train: 221 | if len(self.exclude_lbs) > 0: 222 | for _ex_cls in self.exclude_lbs: 223 | if curr_dict["z_id"] in self.tp1_cls_map[self.label_name[_ex_cls]][curr_dict["scan_id"]]: # this slice need to be excluded since it contains label which is supposed to be unseen 224 | return self.__getitem__(index + torch.randint(low = 0, high = self.__len__() - 1, size = (1,))) 225 | 226 | comp = np.concatenate( [curr_dict["img"], curr_dict["lb"]], axis = -1 ) 227 | img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, use_onehot = False) 228 | 229 | else: 230 | img = curr_dict['img'] 231 | lb = curr_dict['lb'] 232 | 233 | img = np.float32(img) 234 | lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure 235 | 236 | img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) 237 | lb = torch.from_numpy( lb) 238 | 239 | if self.tile_z_dim: 240 | img = img.repeat( [ self.tile_z_dim, 1, 1] ) 241 | assert img.ndimension() == 3, f'actual dim {img.ndimension()}' 242 | 243 | is_start = curr_dict["is_start"] 244 | is_end = curr_dict["is_end"] 245 | nframe = np.int32(curr_dict["nframe"]) 246 | scan_id = curr_dict["scan_id"] 247 | z_id = curr_dict["z_id"] 248 | 249 | sample = {"image": img, 250 | "label":lb, 251 | "is_start": is_start, 252 | "is_end": is_end, 253 | "nframe": nframe, 254 | "scan_id": scan_id, 255 | "z_id": z_id 256 | } 257 | # Add auxiliary attributes 258 | if self.aux_attrib is not None: 259 | for key_prefix in self.aux_attrib: 260 | # Process the data sample, create new attributes and save them in a dictionary 261 | aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix]) 262 | for key_suffix in aux_attrib_val: 263 | # one function may create multiple attributes, so we need suffix to distinguish them 264 | sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix] 265 | 266 | return sample 267 | 268 | def __len__(self): 269 | """ 270 | copy-paste from basic naive dataset configuration 271 | """ 272 | if self.fix_length != None: 273 | assert self.fix_length >= len(self.actual_dataset) 274 | return self.fix_length 275 | else: 276 | return len(self.actual_dataset) 277 | 278 | def update_subclass_lookup(self): 279 | """ 280 | Updating the class-slice indexing list 281 | Args: 282 | [internal] overall_slice_by_cls: 283 | { 284 | class1: {pid1: [slice1, slice2, ....], 285 | pid2: [slice1, slice2]}, 286 | ...} 287 | class2: 288 | ... 289 | } 290 | out[internal]: 291 | { 292 | class1: [ idx1, idx2, ... ], 293 | class2: [ idx1, idx2, ... ], 294 | ... 295 | } 296 | 297 | """ 298 | # delete previous ones if any 299 | assert self.overall_slice_by_cls is not None 300 | 301 | if not hasattr(self, 'idx_by_class'): 302 | self.idx_by_class = {} 303 | # filter the new one given the actual list 304 | for cls in self.label_name: 305 | if cls not in self.idx_by_class.keys(): 306 | self.idx_by_class[cls] = [] 307 | else: 308 | del self.idx_by_class[cls][:] 309 | for cls, dict_by_pid in self.overall_slice_by_cls.items(): 310 | for pid, slice_list in dict_by_pid.items(): 311 | if pid not in self.pid_curr_load: 312 | continue 313 | self.idx_by_class[cls] += [ self.scan_z_idx[pid][_sli] for _sli in slice_list ] 314 | print("###### index-by-class table has been reloaded ######") 315 | 316 | def getMaskMedImg(self, label, class_id, class_ids): 317 | """ 318 | Generate FG/BG mask from the segmentation mask. Used when getting the support 319 | """ 320 | # Dense Mask 321 | fg_mask = torch.where(label == class_id, 322 | torch.ones_like(label), torch.zeros_like(label)) 323 | bg_mask = torch.where(label != class_id, 324 | torch.ones_like(label), torch.zeros_like(label)) 325 | for class_id in class_ids: 326 | bg_mask[label == class_id] = 0 327 | 328 | return {'fg_mask': fg_mask, 329 | 'bg_mask': bg_mask} 330 | 331 | def subsets(self, sub_args_lst=None): 332 | """ 333 | Override base-class subset method 334 | Create subsets by scan_ids 335 | 336 | output: list [[] , ] 337 | """ 338 | 339 | if sub_args_lst is not None: 340 | subsets = [] 341 | ii = 0 342 | for cls_name, index_list in self.idx_by_class.items(): 343 | subsets.append( Subset(dataset = self, indices = index_list, sub_attrib_args = sub_args_lst[ii]) ) 344 | ii += 1 345 | else: 346 | subsets = [Subset(dataset=self, indices=index_list) for _, index_list in self.idx_by_class.items()] 347 | return subsets 348 | 349 | def get_support(self, curr_class: int, class_idx: list, scan_idx: list, npart: int): 350 | """ 351 | getting (probably multi-shot) support set for evaluation 352 | sample from 50% (1shot) or 20 35 50 65 80 (5shot) 353 | Args: 354 | curr_cls: current class to segment, starts from 1 355 | class_idx: a list of all foreground class in nways, starts from 1 356 | npart: how may chunks used to split the support 357 | scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan 358 | being served as support, in self.pid_curr_load 359 | """ 360 | assert npart % 2 == 1 361 | assert curr_class != 0; assert 0 not in class_idx 362 | assert not self.is_train 363 | 364 | self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ] 365 | print(f'###### Using {len(scan_idx)} shot evaluation!') 366 | 367 | if npart == 1: 368 | pcts = [0.5] 369 | else: 370 | half_part = 1 / (npart * 2) 371 | part_interval = (1.0 - 1.0 / npart) / (npart - 1) 372 | pcts = [ half_part + part_interval * ii for ii in range(npart) ] 373 | 374 | print(f'###### Parts percentage: {pcts} ######') 375 | 376 | out_buffer = [] # [{scanid, img, lb}] 377 | for _part in range(npart): 378 | concat_buffer = [] # for each fold do a concat in image and mask in batch dimension 379 | for scan_order in scan_idx: 380 | _scan_id = self.pid_curr_load[ scan_order ] 381 | print(f'Using scan {_scan_id} as support!') 382 | 383 | # for _pc in pcts: 384 | _zlist = self.tp1_cls_map[self.label_name[curr_class]][_scan_id] # list of indices 385 | _zid = _zlist[int(pcts[_part] * len(_zlist))] 386 | _glb_idx = self.scan_z_idx[_scan_id][_zid] 387 | 388 | # almost copy-paste __getitem__ but no augmentation 389 | curr_dict = self.actual_dataset[_glb_idx] 390 | img = curr_dict['img'] 391 | lb = curr_dict['lb'] 392 | 393 | img = np.float32(img) 394 | lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure 395 | 396 | img = torch.from_numpy( np.transpose(img, (2, 0, 1)) ) 397 | lb = torch.from_numpy( lb ) 398 | 399 | if self.tile_z_dim: 400 | img = img.repeat( [ self.tile_z_dim, 1, 1] ) 401 | assert img.ndimension() == 3, f'actual dim {img.ndimension()}' 402 | 403 | is_start = curr_dict["is_start"] 404 | is_end = curr_dict["is_end"] 405 | nframe = np.int32(curr_dict["nframe"]) 406 | scan_id = curr_dict["scan_id"] 407 | z_id = curr_dict["z_id"] 408 | 409 | sample = {"image": img, 410 | "label":lb, 411 | "is_start": is_start, 412 | "inst": None, 413 | "scribble": None, 414 | "is_end": is_end, 415 | "nframe": nframe, 416 | "scan_id": scan_id, 417 | "z_id": z_id 418 | } 419 | 420 | concat_buffer.append(sample) 421 | out_buffer.append({ 422 | "image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0), 423 | "label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0), 424 | 425 | }) 426 | 427 | # do the concat, and add to output_buffer 428 | 429 | # post-processing, including keeping the foreground and suppressing background. 430 | support_images = [] 431 | support_mask = [] 432 | support_class = [] 433 | for itm in out_buffer: 434 | support_images.append(itm["image"]) 435 | support_class.append(curr_class) 436 | support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx )) 437 | 438 | return {'class_ids': [support_class], 439 | 'support_images': [support_images], # 440 | 'support_mask': [support_mask], 441 | } 442 | 443 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tntek/DSPNet/a4b71c2ab1229221584ccb2a5100bf66a7d1ef03/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/augutils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Utilities for augmentation. Partly credit to Dr. Jo Schlemper 3 | ''' 4 | from os.path import join 5 | 6 | import torch 7 | import numpy as np 8 | import torchvision.transforms as deftfx 9 | import dataloaders.image_transforms as myit 10 | import copy 11 | 12 | sabs_aug = { 13 | # turn flipping off as medical data has fixed orientations 14 | 'flip' : { 'v':False, 'h':False, 't': False, 'p':0.25 }, 15 | 'affine' : { 16 | 'rotate':5, 17 | 'shift':(5,5), 18 | 'shear':5, 19 | 'scale':(0.9, 1.2), 20 | }, 21 | 'elastic' : {'alpha':10,'sigma':5}, 22 | 'patch': 256, 23 | 'reduce_2d': True, 24 | 'gamma_range': (0.5, 1.5) 25 | } 26 | 27 | sabs_augv3 = { 28 | 'flip' : { 'v':False, 'h':False, 't': False, 'p':0.25 }, 29 | 'affine' : { 30 | 'rotate':30, 31 | 'shift':(30,30), 32 | 'shear':30, 33 | 'scale':(0.8, 1.3), 34 | }, 35 | 'elastic' : {'alpha':20,'sigma':5}, 36 | 'patch': 256, 37 | 'reduce_2d': True, 38 | 'gamma_range': (0.2, 1.8) 39 | } 40 | 41 | augs = { 42 | 'sabs_aug': sabs_aug, 43 | 'aug_v3': sabs_augv3, # more aggresive 44 | } 45 | 46 | 47 | def get_geometric_transformer(aug, order=3): 48 | """order: interpolation degree. Select order=0 for augmenting segmentation """ 49 | affine = aug['aug'].get('affine', 0) 50 | alpha = aug['aug'].get('elastic',{'alpha': 0})['alpha'] 51 | sigma = aug['aug'].get('elastic',{'sigma': 0})['sigma'] 52 | flip = aug['aug'].get('flip', {'v': True, 'h': True, 't': True, 'p':0.125}) 53 | 54 | tfx = [] 55 | if 'flip' in aug['aug']: 56 | tfx.append(myit.RandomFlip3D(**flip)) 57 | 58 | if 'affine' in aug['aug']: 59 | tfx.append(myit.RandomAffine(affine.get('rotate'), 60 | affine.get('shift'), 61 | affine.get('shear'), 62 | affine.get('scale'), 63 | affine.get('scale_iso',True), 64 | order=order)) 65 | 66 | if 'elastic' in aug['aug']: 67 | tfx.append(myit.ElasticTransform(alpha, sigma)) 68 | input_transform = deftfx.Compose(tfx) 69 | return input_transform 70 | 71 | def get_intensity_transformer(aug): 72 | """some basic intensity transforms""" 73 | 74 | def gamma_tansform(img): 75 | gamma_range = aug['aug']['gamma_range'] 76 | if isinstance(gamma_range, tuple): 77 | gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0] 78 | cmin = img.min() 79 | irange = (img.max() - cmin + 1e-5) 80 | 81 | img = img - cmin + 1e-5 82 | img = irange * np.power(img * 1.0 / irange, gamma) 83 | img = img + cmin 84 | 85 | elif gamma_range == False: 86 | pass 87 | else: 88 | raise ValueError("Cannot identify gamma transform range {}".format(gamma_range)) 89 | return img 90 | 91 | return gamma_tansform 92 | 93 | def transform_with_label(aug): 94 | """ 95 | Doing image geometric transform 96 | Proposed image to have the following configurations 97 | [H x W x C + CL] 98 | Where CL is the number of channels for the label. It is NOT in one-hot form 99 | """ 100 | 101 | geometric_tfx = get_geometric_transformer(aug) 102 | intensity_tfx = get_intensity_transformer(aug) 103 | 104 | def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs): 105 | """ 106 | Args 107 | comp: a numpy array with shape [H x W x C + c_label] 108 | c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1) 109 | nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label 110 | 111 | """ 112 | comp = copy.deepcopy(comp) 113 | if (use_onehot is True) and (c_label != 1): 114 | raise NotImplementedError("Only allow compact label, also the label can only be 2d") 115 | assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label" 116 | 117 | # geometric transform 118 | _label = comp[..., c_img ] 119 | _h_label = np.float32(np.arange( nclass ) == (_label[..., None]) ) 120 | comp = np.concatenate( [comp[..., :c_img ], _h_label], -1 ) 121 | comp = geometric_tfx(comp) 122 | # round one_hot labels to 0 or 1 123 | t_label_h = comp[..., c_img : ] 124 | t_label_h = np.rint(t_label_h) 125 | assert t_label_h.max() <= 1 126 | t_img = comp[..., 0 : c_img ] 127 | 128 | # intensity transform 129 | t_img = intensity_tfx(t_img) 130 | 131 | if use_onehot is True: 132 | t_label = t_label_h 133 | else: 134 | t_label = np.expand_dims(np.argmax(t_label_h, axis = -1), -1) 135 | return t_img, t_label 136 | 137 | return transform 138 | 139 | -------------------------------------------------------------------------------- /dataloaders/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset classes for common uses 3 | Extended from vanilla PANet code by Wang et al. 4 | """ 5 | import random 6 | import torch 7 | 8 | from torch.utils.data import Dataset 9 | 10 | class BaseDataset(Dataset): 11 | """ 12 | Base Dataset 13 | Args: 14 | base_dir: 15 | dataset directory 16 | """ 17 | def __init__(self, base_dir): 18 | self._base_dir = base_dir 19 | self.aux_attrib = {} 20 | self.aux_attrib_args = {} 21 | self.ids = [] # must be overloaded in subclass 22 | 23 | def add_attrib(self, key, func, func_args): 24 | """ 25 | Add attribute to the data sample dict 26 | 27 | Args: 28 | key: 29 | key in the data sample dict for the new attribute 30 | e.g. sample['click_map'], sample['depth_map'] 31 | func: 32 | function to process a data sample and create an attribute (e.g. user clicks) 33 | func_args: 34 | extra arguments to pass, expected a dict 35 | """ 36 | if key in self.aux_attrib: 37 | raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key)) 38 | else: 39 | self.set_attrib(key, func, func_args) 40 | 41 | def set_attrib(self, key, func, func_args): 42 | """ 43 | Set attribute in the data sample dict 44 | 45 | Args: 46 | key: 47 | key in the data sample dict for the new attribute 48 | e.g. sample['click_map'], sample['depth_map'] 49 | func: 50 | function to process a data sample and create an attribute (e.g. user clicks) 51 | func_args: 52 | extra arguments to pass, expected a dict 53 | """ 54 | self.aux_attrib[key] = func 55 | self.aux_attrib_args[key] = func_args 56 | 57 | def del_attrib(self, key): 58 | """ 59 | Remove attribute in the data sample dict 60 | 61 | Args: 62 | key: 63 | key in the data sample dict 64 | """ 65 | self.aux_attrib.pop(key) 66 | self.aux_attrib_args.pop(key) 67 | 68 | def subsets(self, sub_ids, sub_args_lst=None): 69 | """ 70 | Create subsets by ids 71 | 72 | Args: 73 | sub_ids: 74 | a sequence of sequences, each sequence contains data ids for one subset 75 | sub_args_lst: 76 | a list of args for some subset-specific auxiliary attribute function 77 | """ 78 | 79 | indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids] 80 | if sub_args_lst is not None: 81 | subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args) 82 | for index, args in zip(indices, sub_args_lst)] 83 | else: 84 | subsets = [Subset(dataset=self, indices=index) for index in indices] 85 | return subsets 86 | 87 | def __len__(self): 88 | pass 89 | 90 | def __getitem__(self, idx): 91 | pass 92 | 93 | 94 | class ReloadPairedDataset(Dataset): 95 | """ 96 | Make pairs of data from dataset 97 | Eable only loading part of the entire data in each epoach and then reload to the next part 98 | Args: 99 | datasets: 100 | source datasets, expect a list of Dataset. 101 | Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan 102 | n_elements: 103 | number of elements in a pair 104 | curr_max_iters: 105 | number of pairs in an epoch 106 | pair_based_transforms: 107 | some transformation performed on a pair basis, expect a list of functions, 108 | each function takes a pair sample and return a transformed one. 109 | """ 110 | def __init__(self, datasets, n_elements, curr_max_iters, 111 | pair_based_transforms=None): 112 | super().__init__() 113 | self.datasets = datasets 114 | self.n_datasets = len(self.datasets) 115 | self.n_data = [len(dataset) for dataset in self.datasets] 116 | self.n_elements = n_elements 117 | self.curr_max_iters = curr_max_iters 118 | self.pair_based_transforms = pair_based_transforms 119 | self.update_index() 120 | 121 | def update_index(self): 122 | """ 123 | update the order of batches for the next episode 124 | """ 125 | 126 | # update number of elements for each subset 127 | if hasattr(self, 'indices'): 128 | n_data_old = self.n_data # DEBUG 129 | self.n_data = [len(dataset) for dataset in self.datasets] 130 | 131 | if isinstance(self.n_elements, list): 132 | self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use 133 | for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use 134 | for i_iter in range(self.curr_max_iters)] # sample iterations 135 | 136 | elif self.n_elements > self.n_datasets: 137 | raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets") 138 | else: 139 | self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx])) 140 | for dataset_idx in random.sample(range(self.n_datasets), 141 | k=n_elements)] 142 | for i in range(curr_max_iters)] 143 | 144 | def __len__(self): 145 | return self.curr_max_iters 146 | 147 | def __getitem__(self, idx): 148 | sample = [self.datasets[dataset_idx][data_idx] 149 | for dataset_idx, data_idx in self.indices[idx]] 150 | if self.pair_based_transforms is not None: 151 | for transform, args in self.pair_based_transforms: 152 | sample = transform(sample, **args) 153 | return sample 154 | 155 | class Subset(Dataset): 156 | """ 157 | Subset of a dataset at specified indices. Used for seperating a dataset by class in our context 158 | 159 | Args: 160 | dataset: 161 | The whole Dataset 162 | indices: 163 | Indices of samples of the current class in the entire dataset 164 | sub_attrib_args: 165 | Subset-specific arguments for attribute functions, expected a dict 166 | """ 167 | def __init__(self, dataset, indices, sub_attrib_args=None): 168 | self.dataset = dataset 169 | self.indices = indices 170 | self.sub_attrib_args = sub_attrib_args 171 | 172 | def __getitem__(self, idx): 173 | if self.sub_attrib_args is not None: 174 | for key in self.sub_attrib_args: 175 | # Make sure the dataset already has the corresponding attributes 176 | # Here we only make the arguments subset dependent 177 | # (i.e. pass different arguments for each subset) 178 | self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key]) 179 | return self.dataset[self.indices[idx]] 180 | 181 | def __len__(self): 182 | return len(self.indices) 183 | 184 | class ValidationDataset(Dataset): 185 | """ 186 | Dataset for validation 187 | 188 | Args: 189 | dataset: 190 | source dataset with a __getitem__ method 191 | test_classes: 192 | test classes 193 | npart: int. number of parts, used for evaluation when assigning support images 194 | 195 | """ 196 | def __init__(self, dataset, test_classes: list, npart: int): 197 | super().__init__() 198 | self.dataset = dataset 199 | self.__curr_cls = None 200 | self.test_classes = test_classes 201 | self.dataset.aux_attrib = None 202 | self.npart = npart 203 | 204 | def set_curr_cls(self, curr_cls): 205 | assert curr_cls in self.test_classes 206 | self.__curr_cls = curr_cls 207 | 208 | def get_curr_cls(self): 209 | return self.__curr_cls 210 | 211 | def read_dataset(self): 212 | """ 213 | override original read_dataset to allow reading with z_margin 214 | """ 215 | raise NotImplementedError 216 | 217 | def __len__(self): 218 | return len(self.dataset) 219 | 220 | def label_strip(self, label): 221 | """ 222 | mask unrelated labels out 223 | """ 224 | out = torch.where(label == self.__curr_cls, 225 | torch.ones_like(label), torch.zeros_like(label)) 226 | return out 227 | 228 | def __getitem__(self, idx): 229 | if self.__curr_cls is None: 230 | raise Exception("Please initialize current class first") 231 | 232 | sample = self.dataset[idx] 233 | sample["label"] = self.label_strip( sample["label"] ) 234 | sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy() 235 | 236 | labelname = self.dataset.all_label_names[self.__curr_cls] 237 | z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) 238 | z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']]) 239 | sample["z_min"], sample["z_max"] = z_min, z_max 240 | try: 241 | part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart)) 242 | except: 243 | part_assign = 0 244 | print("###### DATASET: support only have one valid slice ######") 245 | if part_assign < 0: 246 | part_assign = 0 247 | elif part_assign >= self.npart: 248 | part_assign = self.npart - 1 249 | sample["part_assign"] = part_assign 250 | sample["idx"] = idx 251 | 252 | return sample 253 | 254 | -------------------------------------------------------------------------------- /dataloaders/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for datasets 3 | """ 4 | import numpy as np 5 | 6 | import os 7 | import sys 8 | import nibabel as nib 9 | import numpy as np 10 | import pdb 11 | import SimpleITK as sitk 12 | 13 | DATASET_INFO = { 14 | "CHAOST2": { 15 | 'PSEU_LABEL_NAME': ["BGD", "SUPFG"], 16 | 'REAL_LABEL_NAME': ["BG", "LIVER", "RK", "LK", "SPLEEN"], 17 | '_SEP': [0, 4, 8, 12, 16, 20], 18 | 'MODALITY': 'MR', 19 | 'LABEL_GROUP': { 20 | 'pa_all': set(range(1, 5)), 21 | 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes 22 | 1: set([2, 3]), # lower_abdomen 23 | }, 24 | }, 25 | 26 | "SABS": { 27 | 'PSEU_LABEL_NAME': ["BGD", "SUPFG"], 28 | 29 | 'REAL_LABEL_NAME': ["BGD", "SPLEEN", "KID_R", "KID_l", "GALLBLADDER", "ESOPHAGUS", "LIVER", "STOMACH", "AORTA", "IVC",\ 30 | "PS_VEIN", "PANCREAS", "AG_R", "AG_L"], 31 | '_SEP': [0, 6, 12, 18, 24, 30], 32 | 'MODALITY': 'CT', 33 | 'LABEL_GROUP':{ 34 | 'pa_all': set( [1,2,3,6] ), 35 | 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing 36 | 1: set( [2,3] ), # lower_abdomen 37 | } 38 | } 39 | 40 | } 41 | 42 | def read_nii_bysitk(input_fid, peel_info = False): 43 | """ read nii to numpy through simpleitk 44 | 45 | peelinfo: taking direction, origin, spacing and metadata out 46 | """ 47 | img_obj = sitk.ReadImage(input_fid) 48 | img_np = sitk.GetArrayFromImage(img_obj) 49 | if peel_info: 50 | info_obj = { 51 | "spacing": img_obj.GetSpacing(), 52 | "origin": img_obj.GetOrigin(), 53 | "direction": img_obj.GetDirection(), 54 | "array_size": img_np.shape 55 | } 56 | return img_np, info_obj 57 | else: 58 | return img_np 59 | 60 | def get_normalize_op(modality, fids): 61 | """ 62 | As title 63 | Args: 64 | modality: CT or MR 65 | fids: fids for the fold 66 | """ 67 | 68 | def get_CT_statistics(scan_fids): 69 | """ 70 | As CT are quantitative, get mean and std for CT images for image normalizing 71 | As in reality we might not be able to load all images at a time, we would better detach statistics calculation with actual data loading 72 | """ 73 | total_val = 0 74 | n_pix = 0 75 | for fid in scan_fids: 76 | in_img = read_nii_bysitk(fid) 77 | total_val += in_img.sum() 78 | n_pix += np.prod(in_img.shape) 79 | del in_img 80 | meanval = total_val / n_pix 81 | 82 | total_var = 0 83 | for fid in scan_fids: 84 | in_img = read_nii_bysitk(fid) 85 | total_var += np.sum((in_img - meanval) ** 2 ) 86 | del in_img 87 | var_all = total_var / n_pix 88 | 89 | global_std = var_all ** 0.5 90 | 91 | return meanval, global_std 92 | 93 | if modality == 'MR': 94 | 95 | def MR_normalize(x_in): 96 | return (x_in - x_in.mean()) / x_in.std() 97 | 98 | return MR_normalize #, {'mean': None, 'std': None} # we do not really need the global statistics for MR 99 | 100 | elif modality == 'CT': 101 | ct_mean, ct_std = get_CT_statistics(fids) 102 | # debug 103 | print(f'###### DEBUG_DATASET CT_STATS NORMALIZED MEAN {ct_mean / 255} STD {ct_std / 255} ######') 104 | 105 | def CT_normalize(x_in): 106 | """ 107 | Normalizing CT images, based on global statistics 108 | """ 109 | return (x_in - ct_mean) / ct_std 110 | 111 | return CT_normalize #, {'mean': ct_mean, 'std': ct_std} 112 | 113 | 114 | -------------------------------------------------------------------------------- /dataloaders/dev_customized_med.py: -------------------------------------------------------------------------------- 1 | """ 2 | Customized dataset. Extended from vanilla PANet script by Wang et al. 3 | """ 4 | 5 | import os 6 | import random 7 | import torch 8 | import numpy as np 9 | 10 | from dataloaders.common import ReloadPairedDataset, ValidationDataset 11 | from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset 12 | 13 | def attrib_basic(_sample, class_id): 14 | """ 15 | Add basic attribute 16 | Args: 17 | _sample: data sample 18 | class_id: class label asscociated with the data 19 | (sometimes indicting from which subset the data are drawn) 20 | """ 21 | return {'class_id': class_id} 22 | 23 | def getMaskOnly(label, class_id, class_ids): 24 | """ 25 | Generate FG/BG mask from the segmentation mask 26 | 27 | Args: 28 | label: 29 | semantic mask 30 | scribble: 31 | scribble mask 32 | class_id: 33 | semantic class of interest 34 | class_ids: 35 | all class id in this episode 36 | """ 37 | # Dense Mask 38 | fg_mask = torch.where(label == class_id, 39 | torch.ones_like(label), torch.zeros_like(label)) 40 | bg_mask = torch.where(label != class_id, 41 | torch.ones_like(label), torch.zeros_like(label)) 42 | for class_id in class_ids: 43 | bg_mask[label == class_id] = 0 44 | 45 | return {'fg_mask': fg_mask, 46 | 'bg_mask': bg_mask} 47 | 48 | def getMasks(*args, **kwargs): 49 | raise NotImplementedError 50 | 51 | def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True): 52 | """ 53 | Postprocess paired sample for fewshot settings 54 | For now only 1-way is tested but we leave multi-way possible (inherited from original PANet) 55 | 56 | Args: 57 | paired_sample: 58 | data sample from a PairedDataset 59 | n_ways: 60 | n-way few-shot learning 61 | n_shots: 62 | n-shot few-shot learning 63 | cnt_query: 64 | number of query images for each class in the support set 65 | coco: 66 | MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension 67 | mask_only: 68 | only give masks and no scribbles/ instances. Suitable for medical images (for now) 69 | """ 70 | if not mask_only: 71 | raise NotImplementedError 72 | ###### Compose the support and query image list ###### 73 | cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries 74 | 75 | # support class ids 76 | class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query) 77 | 78 | # support images 79 | support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)] 80 | for i in range(n_ways)] # fetch support images for each class 81 | 82 | # support image labels 83 | if coco: 84 | support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]] 85 | for j in range(n_shots)] for i in range(n_ways)] 86 | else: 87 | support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)] 88 | for i in range(n_ways)] 89 | 90 | if not mask_only: 91 | support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)] 92 | for i in range(n_ways)] 93 | support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)] 94 | for i in range(n_ways)] 95 | else: 96 | support_insts = [] 97 | 98 | # query images, masks and class indices 99 | query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways) 100 | for j in range(cnt_query[i])] 101 | if coco: 102 | query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]] 103 | for i in range(n_ways) for j in range(cnt_query[i])] 104 | else: 105 | query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways) 106 | for j in range(cnt_query[i])] 107 | query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1 108 | for x in set(np.unique(query_label)) & set(class_ids)]) 109 | for query_label in query_labels] 110 | 111 | ###### Generate support image masks ###### 112 | if not mask_only: 113 | support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot], 114 | class_ids[way], class_ids) 115 | for shot in range(n_shots)] for way in range(n_ways)] 116 | else: 117 | support_mask = [[getMaskOnly(support_labels[way][shot], 118 | class_ids[way], class_ids) 119 | for shot in range(n_shots)] for way in range(n_ways)] 120 | 121 | ###### Generate query label (class indices in one episode, i.e. the ground truth)###### 122 | query_labels_tmp = [torch.zeros_like(x) for x in query_labels] 123 | for i, query_label_tmp in enumerate(query_labels_tmp): 124 | query_label_tmp[query_labels[i] == 255] = 255 125 | for j in range(n_ways): 126 | query_label_tmp[query_labels[i] == class_ids[j]] = j + 1 127 | 128 | ###### Generate query mask for each semantic class (including BG) ###### 129 | # BG class 130 | query_masks = [[torch.where(query_label == 0, 131 | torch.ones_like(query_label), 132 | torch.zeros_like(query_label))[None, ...],] 133 | for query_label in query_labels] 134 | # Other classes in query image 135 | for i, query_label in enumerate(query_labels): 136 | for idx in query_cls_idx[i][1:]: 137 | mask = torch.where(query_label == class_ids[idx - 1], 138 | torch.ones_like(query_label), 139 | torch.zeros_like(query_label))[None, ...] 140 | query_masks[i].append(mask) 141 | 142 | 143 | return {'class_ids': class_ids, 144 | 'support_images': support_images, 145 | 'support_mask': support_mask, 146 | 'support_inst': support_insts, # leave these interfaces 147 | 'support_scribbles': support_scribbles, 148 | 149 | 'query_images': query_images, 150 | 'query_labels': query_labels_tmp, 151 | 'query_masks': query_masks, 152 | 'query_cls_idx': query_cls_idx, 153 | } 154 | 155 | 156 | def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load, 157 | transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs): 158 | """ 159 | Dataset wrapper 160 | Args: 161 | dataset_name: 162 | indicates what dataset to use 163 | base_dir: 164 | dataset directory 165 | mode: 166 | which mode to use 167 | choose from ('train', 'val', 'trainval', 'trainaug') 168 | idx_split: 169 | index of split 170 | scan_per_load: 171 | number of scans to load into memory as the dataset is large 172 | use that together with reload_buffer 173 | transforms: 174 | transformations to be performed on images/masks 175 | act_labels: 176 | active labels involved in training process. Should be a subset of all labels 177 | n_ways: 178 | n-way few-shot learning, should be no more than # of object class labels 179 | n_shots: 180 | n-shot few-shot learning 181 | max_iters_per_load: 182 | number of pairs per load (epoch size) 183 | n_queries: 184 | number of query images 185 | fix_parent_len: 186 | fixed length of the parent dataset 187 | """ 188 | med_set = ManualAnnoDataset 189 | 190 | 191 | mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\ 192 | scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\ 193 | exclude_list = exclude_list, **kwargs) 194 | 195 | mydataset.add_attrib('basic', attrib_basic, {}) 196 | 197 | # Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside 198 | subsets = mydataset.subsets([{'basic': {'class_id': ii}} 199 | for ii, _ in enumerate(mydataset.label_name)]) 200 | 201 | # Choose the classes of queries 202 | cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways) 203 | # Number of queries for each way 204 | # Set the number of images for each class 205 | n_elements = [n_shots + x for x in cnt_query] # supports + [i] queries 206 | # Create paired dataset. We do not include background. 207 | paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load, 208 | pair_based_transforms=[ 209 | (fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots, 210 | 'cnt_query': cnt_query, 'mask_only': True})]) 211 | return paired_data, mydataset 212 | 213 | def update_loader_dset(loader, parent_set): 214 | """ 215 | Update data loader and the parent dataset behind 216 | Args: 217 | loader: actual dataloader 218 | parent_set: parent dataset which actually stores the data 219 | """ 220 | parent_set.reload_buffer() 221 | loader.dataset.update_index() 222 | print(f'###### Loader and dataset have been updated ######' ) 223 | 224 | def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, **kwargs): 225 | """ 226 | validation set for med images 227 | Args: 228 | dataset_name: 229 | indicates what dataset to use 230 | base_dir: 231 | SABS dataset directory 232 | mode: (original split) 233 | which split to use 234 | choose from ('train', 'val', 'trainval', 'trainaug') 235 | idx_split: 236 | index of split 237 | scan_per_batch: 238 | number of scans to load into memory as the dataset is large 239 | use that together with reload_buffer 240 | act_labels: 241 | actual labels involved in training process. Should be a subset of all labels 242 | npart: number of chunks for splitting a 3d volume 243 | nsup: number of support scans, equivalent to nshot 244 | """ 245 | mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = 'val', scan_per_load = scan_per_load, transforms=None, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs) 246 | mydataset.add_attrib('basic', attrib_basic, {}) 247 | 248 | valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart) 249 | 250 | return valset, mydataset 251 | 252 | -------------------------------------------------------------------------------- /dataloaders/image_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image transforms functions for data augmentation 3 | Credit to Dr. Jo Schlemper 4 | """ 5 | 6 | from collections import Sequence 7 | import cv2 8 | import numpy as np 9 | import scipy 10 | from scipy.ndimage.filters import gaussian_filter 11 | from scipy.ndimage.interpolation import map_coordinates 12 | from numpy.lib.stride_tricks import as_strided 13 | 14 | ###### UTILITIES ###### 15 | def random_num_generator(config, random_state=np.random): 16 | if config[0] == 'uniform': 17 | ret = random_state.uniform(config[1], config[2], 1)[0] 18 | elif config[0] == 'lognormal': 19 | ret = random_state.lognormal(config[1], config[2], 1)[0] 20 | else: 21 | #print(config) 22 | raise Exception('unsupported format') 23 | return ret 24 | 25 | def get_translation_matrix(translation): 26 | """ translation: [tx, ty] """ 27 | tx, ty = translation 28 | translation_matrix = np.array([[1, 0, tx], 29 | [0, 1, ty], 30 | [0, 0, 1]]) 31 | return translation_matrix 32 | 33 | 34 | 35 | def get_rotation_matrix(rotation, input_shape, centred=True): 36 | theta = np.pi / 180 * np.array(rotation) 37 | if centred: 38 | rotation_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), rotation, 1) 39 | rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]]) 40 | else: 41 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 42 | [np.sin(theta), np.cos(theta), 0], 43 | [0, 0, 1]]) 44 | return rotation_matrix 45 | 46 | def get_zoom_matrix(zoom, input_shape, centred=True): 47 | zx, zy = zoom 48 | if centred: 49 | zoom_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), 0, zoom[0]) 50 | zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]]) 51 | else: 52 | zoom_matrix = np.array([[zx, 0, 0], 53 | [0, zy, 0], 54 | [0, 0, 1]]) 55 | return zoom_matrix 56 | 57 | def get_shear_matrix(shear_angle): 58 | theta = (np.pi * shear_angle) / 180 59 | shear_matrix = np.array([[1, -np.sin(theta), 0], 60 | [0, np.cos(theta), 0], 61 | [0, 0, 1]]) 62 | return shear_matrix 63 | 64 | ###### AFFINE TRANSFORM ###### 65 | class RandomAffine(object): 66 | """Apply random affine transformation on a numpy.ndarray (H x W x C) 67 | Comment by co1818: this is still doing affine on 2d (H x W plane). 68 | A same transform is applied to all C channels 69 | 70 | Parameter: 71 | ---------- 72 | 73 | alpha: Range [0, 4] seems good for small images 74 | 75 | order: interpolation method (c.f. opencv) 76 | """ 77 | 78 | def __init__(self, 79 | rotation_range=None, 80 | translation_range=None, 81 | shear_range=None, 82 | zoom_range=None, 83 | zoom_keep_aspect=False, 84 | interp='bilinear', 85 | order=3): 86 | """ 87 | Perform an affine transforms. 88 | 89 | Arguments 90 | --------- 91 | rotation_range : one integer or float 92 | image will be rotated randomly between (-degrees, degrees) 93 | 94 | translation_range : (x_shift, y_shift) 95 | shifts in pixels 96 | 97 | *NOT TESTED* shear_range : float 98 | image will be sheared randomly between (-degrees, degrees) 99 | 100 | zoom_range : (zoom_min, zoom_max) 101 | list/tuple with two floats between [0, infinity). 102 | first float should be less than the second 103 | lower and upper bounds on percent zoom. 104 | Anything less than 1.0 will zoom in on the image, 105 | anything greater than 1.0 will zoom out on the image. 106 | e.g. (0.7, 1.0) will only zoom in, 107 | (1.0, 1.4) will only zoom out, 108 | (0.7, 1.4) will randomly zoom in or out 109 | """ 110 | 111 | self.rotation_range = rotation_range 112 | self.translation_range = translation_range 113 | self.shear_range = shear_range 114 | self.zoom_range = zoom_range 115 | self.zoom_keep_aspect = zoom_keep_aspect 116 | self.interp = interp 117 | self.order = order 118 | 119 | def build_M(self, input_shape): 120 | tfx = [] 121 | final_tfx = np.eye(3) 122 | if self.rotation_range: 123 | rot = np.random.uniform(-self.rotation_range, self.rotation_range) 124 | tfx.append(get_rotation_matrix(rot, input_shape)) 125 | if self.translation_range: 126 | tx = np.random.uniform(-self.translation_range[0], self.translation_range[0]) 127 | ty = np.random.uniform(-self.translation_range[1], self.translation_range[1]) 128 | tfx.append(get_translation_matrix((tx,ty))) 129 | if self.shear_range: 130 | rot = np.random.uniform(-self.shear_range, self.shear_range) 131 | tfx.append(get_shear_matrix(rot)) 132 | if self.zoom_range: 133 | sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) 134 | if self.zoom_keep_aspect: 135 | sy = sx 136 | else: 137 | sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) 138 | 139 | tfx.append(get_zoom_matrix((sx, sy), input_shape)) 140 | 141 | for tfx_mat in tfx: 142 | final_tfx = np.dot(tfx_mat, final_tfx) 143 | 144 | return final_tfx.astype(np.float32) 145 | 146 | def __call__(self, image): 147 | # build matrix 148 | input_shape = image.shape[:2] 149 | M = self.build_M(input_shape) 150 | 151 | res = np.zeros_like(image) 152 | #if isinstance(self.interp, Sequence): 153 | if type(self.order) is list or type(self.order) is tuple: 154 | for i, intp in enumerate(self.order): 155 | res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp) 156 | else: 157 | # squeeze if needed 158 | orig_shape = image.shape 159 | image_s = np.squeeze(image) 160 | res = affine_transform_via_M(image_s, M[:2], interp=self.order) 161 | res = res.reshape(orig_shape) 162 | 163 | #res = affine_transform_via_M(image, M[:2], interp=self.order) 164 | 165 | return res 166 | 167 | def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST): 168 | imshape = image.shape 169 | shape_size = imshape[:2] 170 | 171 | # Random affine 172 | warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1], 173 | flags=interp, borderMode=borderMode) 174 | 175 | #print(imshape, warped.shape) 176 | 177 | warped = warped[..., np.newaxis].reshape(imshape) 178 | 179 | return warped 180 | 181 | ###### ELASTIC TRANSFORM ###### 182 | def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): 183 | """Elastic deformation of image as described in [Simard2003]_. 184 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 185 | Convolutional Neural Networks applied to Visual Document Analysis", in 186 | Proc. of the International Conference on Document Analysis and 187 | Recognition, 2003. 188 | """ 189 | assert image.ndim == 3 190 | shape = image.shape[:2] 191 | 192 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), 193 | sigma, mode="constant", cval=0) * alpha 194 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), 195 | sigma, mode="constant", cval=0) * alpha 196 | 197 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 198 | indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] 199 | result = np.empty_like(image) 200 | for i in range(image.shape[2]): 201 | result[:, :, i] = map_coordinates( 202 | image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) 203 | return result 204 | 205 | 206 | def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False): 207 | """Expects data to be (nx, ny, n1 ,..., nm) 208 | params: 209 | ------ 210 | 211 | alpha: 212 | the scaling parameter. 213 | E.g.: alpha=2 => distorts images up to 2x scaling 214 | 215 | sigma: 216 | standard deviation of gaussian filter. 217 | E.g. 218 | low (sig~=1e-3) => no smoothing, pixelated. 219 | high (1/5 * imsize) => smooth, more like affine. 220 | very high (1/2*im_size) => translation 221 | """ 222 | 223 | if random_state is None: 224 | random_state = np.random.RandomState(None) 225 | 226 | shape = image.shape 227 | imsize = shape[:2] 228 | dim = shape[2:] 229 | 230 | # Random affine 231 | blur_size = int(4*sigma) | 1 232 | dx = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, 233 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 234 | dy = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, 235 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 236 | 237 | # use as_strided to copy things over across n1...nn channels 238 | dx = as_strided(dx.astype(np.float32), 239 | strides=(0,) * len(dim) + (4*shape[1], 4), 240 | shape=dim+(shape[0], shape[1])) 241 | dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim)))) 242 | 243 | dy = as_strided(dy.astype(np.float32), 244 | strides=(0,) * len(dim) + (4*shape[1], 4), 245 | shape=dim+(shape[0], shape[1])) 246 | dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim)))) 247 | 248 | coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim]) 249 | indices = [np.reshape(e+de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:], 250 | [dy, dx] + [0] * len(dim))] 251 | 252 | if lazy: 253 | return indices 254 | 255 | return map_coordinates(image, indices, order=order, mode='reflect').reshape(shape) 256 | 257 | class ElasticTransform(object): 258 | """Apply elastic transformation on a numpy.ndarray (H x W x C) 259 | """ 260 | 261 | def __init__(self, alpha, sigma, order=1): 262 | self.alpha = alpha 263 | self.sigma = sigma 264 | self.order = order 265 | 266 | def __call__(self, image): 267 | if isinstance(self.alpha, Sequence): 268 | alpha = random_num_generator(self.alpha) 269 | else: 270 | alpha = self.alpha 271 | if isinstance(self.sigma, Sequence): 272 | sigma = random_num_generator(self.sigma) 273 | else: 274 | sigma = self.sigma 275 | return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order) 276 | 277 | class RandomFlip3D(object): 278 | 279 | def __init__(self, h=True, v=True, t=True, p=0.5): 280 | """ 281 | Randomly flip an image horizontally and/or vertically with 282 | some probability. 283 | 284 | Arguments 285 | --------- 286 | h : boolean 287 | whether to horizontally flip w/ probability p 288 | 289 | v : boolean 290 | whether to vertically flip w/ probability p 291 | 292 | p : float between [0,1] 293 | probability with which to apply allowed flipping operations 294 | """ 295 | self.horizontal = h 296 | self.vertical = v 297 | self.depth = t 298 | self.p = p 299 | 300 | def __call__(self, x, y=None): 301 | # horizontal flip with p = self.p 302 | if self.horizontal: 303 | if np.random.random() < self.p: 304 | x = x[::-1, ...] 305 | 306 | # vertical flip with p = self.p 307 | if self.vertical: 308 | if np.random.random() < self.p: 309 | x = x[:, ::-1, ...] 310 | 311 | if self.depth: 312 | if np.random.random() < self.p: 313 | x = x[..., ::-1] 314 | 315 | return x 316 | 317 | 318 | -------------------------------------------------------------------------------- /dataloaders/niftiio.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for datasets 3 | """ 4 | import numpy as np 5 | 6 | import numpy as np 7 | import SimpleITK as sitk 8 | 9 | 10 | def read_nii_bysitk(input_fid, peel_info = False): 11 | """ read nii to numpy through simpleitk 12 | peelinfo: taking direction, origin, spacing and metadata out 13 | """ 14 | img_obj = sitk.ReadImage(input_fid) 15 | img_np = sitk.GetArrayFromImage(img_obj) 16 | if peel_info: 17 | info_obj = {'' 18 | "spacing": img_obj.GetSpacing(), 19 | "origin": img_obj.GetOrigin(), 20 | "direction": img_obj.GetDirection(), 21 | "array_size": img_np.shape 22 | } 23 | return img_np, info_obj 24 | else: 25 | return img_np 26 | 27 | def convert_to_sitk(input_mat, peeled_info): 28 | """ 29 | write a numpy array to sitk image object with essential meta-data 30 | """ 31 | nii_obj = sitk.GetImageFromArray(input_mat) 32 | if peeled_info: 33 | nii_obj.SetSpacing( peeled_info["spacing"] ) 34 | nii_obj.SetOrigin( peeled_info["origin"] ) 35 | nii_obj.SetDirection(peeled_info["direction"] ) 36 | return nii_obj 37 | 38 | def np2itk(img, ref_obj): 39 | """ 40 | img: numpy array 41 | ref_obj: reference sitk object for copying information from 42 | """ 43 | itk_obj = sitk.GetImageFromArray(img) 44 | itk_obj.SetSpacing( ref_obj.GetSpacing() ) 45 | itk_obj.SetOrigin( ref_obj.GetOrigin() ) 46 | itk_obj.SetDirection( ref_obj.GetDirection() ) 47 | return itk_obj 48 | 49 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tntek/DSPNet/a4b71c2ab1229221584ccb2a5100bf66a7d1ef03/models/__init__.py -------------------------------------------------------------------------------- /models/alpmodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | ALPModule 3 | """ 4 | import torch 5 | import math 6 | from torch import nn 7 | from torch.nn import functional as F 8 | import numpy as np 9 | from pdb import set_trace 10 | import matplotlib.pyplot as plt 11 | # for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm 12 | 13 | class MultiProtoAsConv(nn.Module): 14 | def __init__(self, proto_grid, feature_hw, upsample_mode = 'bilinear'): 15 | """ 16 | ALPModule 17 | Args: 18 | proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2 19 | feature_hw: Spatial size of input feature map 20 | 21 | """ 22 | super(MultiProtoAsConv, self).__init__() 23 | self.proto_grid = proto_grid 24 | self.upsample_mode = upsample_mode 25 | self.get_wight() 26 | kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ] 27 | self.avg_pool_op = nn.AvgPool2d( kernel_size ) # kernel_size 28 | self.a = 0.2 # α $---set1:[ABD: α = 0.3] // [CMR: α = 0.2] ---$ $---set2:[ABD: α = 0.2] ---$ 29 | 30 | def get_wight(self): 31 | # """ 32 | # ------------(w1, w2, w3)--------------- 33 | # """ 34 | self.wight1 = 0.3 # w1 35 | self.wight2 = 0.8 # w2 36 | self.wight3 = 0.3 # w3 37 | N=256 38 | cal = torch.eye(N).float().cuda() 39 | for i in range(N-2): 40 | cal[i+1][i] = self.wight1 41 | cal[i+1][i+1] = self.wight2 42 | cal[i+1][i+2] = self.wight3 43 | self.cal = nn.Parameter(cal) 44 | 45 | 46 | def forward(self, mol,qry, sup_x, sup_y,s_init_seed, mode, thresh, isval = False, val_wsize = None, vis_sim = False, **kwargs): 47 | """ 48 | Now supports 49 | Args: 50 | mode: 'mask'/ 'grid'. if mask, works as original prototyping 51 | qry: [way(1), nc, h, w] 52 | sup_x: [nb, nc, h, w] 53 | sup_y: [nb, 1, h, w] 54 | vis_sim: visualize raw similarities or not 55 | New 56 | mode: 'mask'/ 'grid'. if mask, works as original prototyping 57 | qry: [way(1), nb(1), nc, h, w] 58 | sup_x: [way(1), shot, nb(1), nc, h, w] 59 | sup_y: [way(1), shot, nb(1), h, w] 60 | vis_sim: visualize raw similarities or not 61 | """ 62 | 63 | qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w] 64 | sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w] 65 | sup_y = sup_y.squeeze(0) # [nshot, 1, h, w] 66 | 67 | def safe_norm(x, p = 2, dim = 1, eps = 1e-4): 68 | x_norm = torch.norm(x, p = p, dim = dim) # .detach() 69 | x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps) 70 | x = x.div(x_norm.unsqueeze(1).expand_as(x)) 71 | return x 72 | 73 | if mode == 'mask': # class-level prototype only 74 | sup_nshot = sup_x.shape[0] 75 | 76 | out_su = self.attention(sup_x,qry) 77 | s_seed_ = s_init_seed[0, :, :] 78 | num_sp = max(len(torch.nonzero(s_seed_[:, 0])), len(torch.nonzero(s_seed_[:, 1]))) 79 | if (num_sp == 0): 80 | proto = torch.sum(out_su * sup_y, dim=(-1, -2)) \ 81 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C 82 | cos_sim_map_sup = F.conv2d(out_su, 83 | proto[..., None, None].repeat(1, 1, 1, 1)) 84 | cos_sim_map_sup_t = cos_sim_map_sup.view(out_su.size()[0], 1, -1) 85 | attention = cos_sim_map_sup_t.softmax(dim=-1) 86 | sp_center_t = proto.t().unsqueeze(0) 87 | out = torch.bmm(sp_center_t, attention).view(1, sup_x.size()[1], sup_x.size()[-2], sup_x.size()[-1]) 88 | out1 = out + sup_x 89 | 90 | proto = torch.sum(out1 * sup_y, dim=(-1, -2)) \ 91 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) 92 | else: 93 | if mol == 'alignLoss': 94 | proto = torch.sum(out_su * sup_y, dim=(-1, -2)) \ 95 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) 96 | cos_sim_map_sup = F.conv2d(out_su, 97 | proto[..., None, None].repeat(1, 1, 1, 1)) 98 | cos_sim_map_sup_t = cos_sim_map_sup.view(out_su.size()[0], 1, -1) 99 | attention = cos_sim_map_sup_t.softmax(dim=-1) 100 | sp_center_t = proto.t().unsqueeze(0) 101 | out = torch.bmm(sp_center_t, attention).view(1, sup_x.size()[1], sup_x.size()[-2], sup_x.size()[-1]) 102 | out1 = out + sup_x 103 | 104 | proto = torch.sum(out1 * sup_y, dim=(-1, -2)) \ 105 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) 106 | else: 107 | sp_center_list = [] 108 | sup_nshot = sup_x.shape[0] 109 | for sup_nshot in range(sup_nshot): 110 | with torch.no_grad(): 111 | s_seed_ = s_seed_[:num_sp, :] # num_sp x 2 112 | sp_init_center = sup_x[sup_nshot][:, s_seed_[:, 0], s_seed_[:, 1]] 113 | sp_init_center = torch.cat([sp_init_center, s_seed_.transpose(1, 0).float()], dim=0) 114 | sp_center = self.sp_center_iter(sup_x[sup_nshot], sup_y[sup_nshot], sp_init_center, n_iter=10) 115 | sp_center_list.append(sp_center) 116 | y1 = sp_center_list[0].shape[1] 117 | sp_center = torch.cat(sp_center_list) 118 | cos_sim_map_sup = F.conv2d(out_su, 119 | sp_center[..., None, None].repeat(1, 1, 1, 1).permute(1, 0, 2, 3)) 120 | cos_sim_map_sup_t = cos_sim_map_sup.view(sup_x.size()[0], y1, -1) 121 | attention = cos_sim_map_sup_t.softmax(dim=-1) 122 | sp_center_t = sp_center.unsqueeze(0) 123 | out = torch.bmm(sp_center_t, attention).view(1, sup_x.size()[1], sup_x.size()[-2], sup_x.size()[-1]) 124 | out1 = out + sup_x 125 | proto = torch.sum(out1 * sup_y, dim=(-1, -2)) \ 126 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) 127 | 128 | # proto = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything 129 | pred_mask = F.cosine_similarity(qry, proto[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w] 130 | 131 | vis_dict = {'proto_assign': None} # things to visualize 132 | if vis_sim: 133 | vis_dict['raw_local_sims'] = pred_mask 134 | return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w] 135 | 136 | # no need to merge with gridconv+ 137 | elif mode == 'gridconv': # using local prototypes only 138 | 139 | input_size = qry.shape # torch.Size([1, 256, 32, 32]) 140 | nch = input_size[1] # 256 141 | out_su = self.attention(sup_x,qry) 142 | sup_nshot = sup_x.shape[0] # torch.Size([1, 256, 32, 32]) 143 | n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) 144 | n_sup_x = n_sup_x.view(sup_nshot, nch, -1) 145 | n_sup_x = n_sup_x.permute(0, 2, 1).unsqueeze(0) 146 | 147 | n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) 148 | sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) 149 | sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) 150 | 151 | n_sup_x = n_sup_x.permute(0, 1, 3, 2).squeeze(0).squeeze(0) 152 | 153 | w1 = torch.mm(n_sup_x.float(),n_sup_x.permute(1, 0).float()) # 256 256 154 | softmax_matrix = F.softmax(w1,dim=1) 155 | mask_w = self.wts_near(n_sup_x, 1, 1, 1) 156 | add_res_w2 = mask_w*softmax_matrix.float() 157 | #softmax_matrix = F.softmax(add_res_w2,dim=1) 158 | A = 1+0.2*add_res_w2 159 | w_3=A*self.cal 160 | add_res_new = torch.mm(w_3,n_sup_x.float()) 161 | 162 | n_sup_x = add_res_new 163 | 164 | n_sup_x = n_sup_x.permute(1, 0).unsqueeze(0).unsqueeze(0) 165 | 166 | protos = n_sup_x[sup_y_g > thresh, :] # npro, nc 167 | 168 | pro_n = safe_norm(protos) # 56 256 169 | qry_n = safe_norm(qry) # 1 256 32 32 170 | 171 | dists = F.conv2d(qry_n, pro_n[..., None, None]) * 20 172 | 173 | pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) 174 | debug_assign = dists.argmax(dim = 1).float().detach() 175 | 176 | 177 | vis_dict = {'proto_assign': debug_assign} # things to visualize 178 | 179 | if vis_sim: # return the similarity for visualization 180 | vis_dict['raw_local_sims'] = dists.clone().detach() 181 | 182 | return pred_grid, [debug_assign], vis_dict 183 | 184 | 185 | elif mode == 'gridconv+': # local and global prototypes 186 | 187 | input_size = qry.shape 188 | nch = input_size[1] 189 | nb_q = input_size[0] 190 | 191 | sup_size = sup_x.shape[0] 192 | out_su = self.attention(sup_x,qry) 193 | 194 | n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) # 1 256 16 16 195 | 196 | sup_nshot = sup_x.shape[0] 197 | 198 | n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # 1 1 64 256 199 | n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) # 1 1 64 256 200 | 201 | n_sup_x = n_sup_x.permute(0, 1, 3, 2).squeeze(0).squeeze(0) 202 | 203 | w1 = torch.mm(n_sup_x.float(),n_sup_x.permute(1, 0).float()) # 256 256 204 | softmax_matrix = F.softmax(w1,dim=1) 205 | mask_w = self.wts_near(n_sup_x, 1, 1, 1) 206 | add_res_w2 = mask_w*softmax_matrix.float() 207 | A = 1+self.a*add_res_w2 208 | w_3=A*self.cal 209 | add_res_new = torch.mm(w_3,n_sup_x.float()) 210 | n_sup_x = add_res_new 211 | n_sup_x = n_sup_x.permute(1, 0).unsqueeze(0).unsqueeze(0) 212 | sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) 213 | 214 | sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) 215 | 216 | protos = n_sup_x[sup_y_g > thresh, :] 217 | 218 | s_seed_ = s_init_seed[0, :, :] 219 | num_sp = max(len(torch.nonzero(s_seed_[:, 0])), len(torch.nonzero(s_seed_[:, 1]))) 220 | if (num_sp == 0): 221 | proto = torch.sum(out_su * sup_y, dim=(-1, -2)) \ 222 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C 223 | cos_sim_map_sup = F.conv2d(out_su, 224 | proto[..., None, None].repeat(1, 1, 1, 1)) 225 | cos_sim_map_sup_t = cos_sim_map_sup.view(out_su.size()[0], 1, -1) 226 | attention = cos_sim_map_sup_t.softmax(dim=-1) 227 | sp_center_t = proto.t().unsqueeze(0) 228 | out = torch.bmm(sp_center_t, attention).view(1, sup_x.size()[1], sup_x.size()[-2], sup_x.size()[-1]) 229 | out1 = out + sup_x 230 | 231 | proto = torch.sum(out1 * sup_y, dim=(-1, -2)) \ 232 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C 233 | else: 234 | if mol == 'alignLoss': 235 | proto = torch.sum(out_su * sup_y, dim=(-1, -2)) \ 236 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C 237 | cos_sim_map_sup = F.conv2d(out_su, 238 | proto[..., None, None].repeat(1, 1, 1, 1)) 239 | cos_sim_map_sup_t = cos_sim_map_sup.view(out_su.size()[0], 1, -1) 240 | attention = cos_sim_map_sup_t.softmax(dim=-1) 241 | sp_center_t = proto.t().unsqueeze(0) 242 | out = torch.bmm(sp_center_t, attention).view(1, sup_x.size()[1], sup_x.size()[-2], sup_x.size()[-1]) 243 | out1 = out + sup_x 244 | 245 | proto = torch.sum(out1 * sup_y, dim=(-1, -2)) \ 246 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C 247 | else: 248 | sp_center_list = [] 249 | sup_nshot = sup_x.shape[0] 250 | for sup_nshot in range(sup_nshot): 251 | with torch.no_grad(): 252 | s_seed_ = s_seed_[:num_sp, :] # num_sp x 2 253 | sp_init_center = sup_x[sup_nshot][:, s_seed_[:, 0], s_seed_[:, 1]] 254 | sp_init_center = torch.cat([sp_init_center, s_seed_.transpose(1, 0).float()], dim=0) 255 | sp_center = self.sp_center_iter(sup_x[sup_nshot], sup_y[sup_nshot], sp_init_center, n_iter=10) 256 | sp_center_list.append(sp_center) 257 | y1 = sp_center_list[0].shape[1] 258 | sp_center = torch.cat(sp_center_list) 259 | cos_sim_map_sup = F.conv2d(out_su, 260 | sp_center[..., None, None].repeat(1, 1, 1, 1).permute(1, 0, 2, 3)) 261 | cos_sim_map_sup_t = cos_sim_map_sup.view(sup_x.size()[0], y1, -1) 262 | attention = cos_sim_map_sup_t.softmax(dim=-1) 263 | sp_center_t = sp_center.unsqueeze(0) 264 | out = torch.bmm(sp_center_t, attention).view(1, sup_x.size()[1], sup_x.size()[-2], sup_x.size()[-1]) 265 | out1 = out + sup_x 266 | proto = torch.sum(out1 * sup_y, dim=(-1, -2)) \ 267 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) 268 | 269 | pro_n = safe_norm( torch.cat( [protos, proto], dim = 0 ) ) 270 | 271 | qry_n = safe_norm(qry) 272 | 273 | dists = F.conv2d(qry_n, pro_n[..., None, None]) * 20 274 | 275 | pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) 276 | raw_local_sims = dists.detach() 277 | 278 | 279 | debug_assign = dists.argmax(dim = 1).float() 280 | 281 | vis_dict = {'proto_assign': debug_assign} 282 | if vis_sim: 283 | vis_dict['raw_local_sims'] = dists.clone().detach() 284 | 285 | return pred_grid, [debug_assign], vis_dict 286 | 287 | elif mode == 'mask++': # class-level prototype only 288 | proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \ 289 | / (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C 290 | 291 | proto = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything 292 | pred_mask = F.cosine_similarity(qry, proto[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w] 293 | 294 | vis_dict = {'proto_assign': None} # things to visualize 295 | if vis_sim: 296 | vis_dict['raw_local_sims'] = pred_mask 297 | return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w] 298 | 299 | elif mode == 'bg': # using local prototypes only 300 | 301 | input_size = qry.shape 302 | nch = input_size[1] 303 | 304 | sup_nshot = sup_x.shape[0] 305 | 306 | n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) 307 | 308 | n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc 309 | n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) 310 | 311 | sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) 312 | sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) 313 | 314 | protos = n_sup_x[sup_y_g > thresh, :] # npro, nc 315 | pro_n = safe_norm(protos) 316 | qry_n = safe_norm(qry) 317 | dists = F.conv2d(qry_n, pro_n[..., None, None]) * 20 318 | 319 | pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) 320 | debug_assign = dists.argmax(dim = 1).float().detach() 321 | 322 | vis_dict = {'proto_assign': debug_assign} # things to visualize 323 | 324 | if vis_sim: # return the similarity for visualization 325 | vis_dict['raw_local_sims'] = dists.clone().detach() 326 | 327 | return pred_grid, [debug_assign], vis_dict 328 | 329 | else: 330 | raise NotImplementedError 331 | 332 | 333 | def avg_near(self, tmp): 334 | N = tmp.shape[0] 335 | cal = torch.eye(N).cuda() 336 | for i in range(N-2): 337 | cal[i+1][i] = 1 338 | cal[i+1][i+1] = 1 339 | cal[i+1][i+2] = 1 340 | 341 | add_res = torch.mm(cal.float(),tmp.float()) 342 | 343 | diag = torch.full([N],1.0/3.0) 344 | cal = torch.diag_embed(diag).cuda() 345 | cal[0][0] = 1 346 | cal[N-1][N-1] = 1 347 | res = torch.mm(cal.float(),add_res) 348 | return res 349 | 350 | def wts_near(self, tmp, weights_1, weight_2, weight_3): 351 | 352 | N = tmp.shape[0] 353 | 354 | cal = torch.eye(N).float().cuda() 355 | for i in range(N-2): 356 | cal[i+1][i] = weights_1 357 | cal[i+1][i+1] = weight_2 358 | cal[i+1][i+2] = weight_3 359 | 360 | return cal 361 | 362 | def attention(self,sup_x,qry): 363 | reduce_dim = 256 364 | #key_conv = nn.Conv2d(in_channels=reduce_dim, out_channels=reduce_dim, kernel_size=1).cuda() 365 | #qu_conv = nn.Conv2d(in_channels=reduce_dim, out_channels=reduce_dim, kernel_size=1).cuda() 366 | #v_conv = nn.Conv2d(in_channels=reduce_dim, out_channels=reduce_dim, kernel_size=1).cuda() 367 | x_sup = sup_x.view(sup_x.size()[0],sup_x.size()[1], -1) 368 | x_que =qry.view(qry.size()[0], qry.size()[1], -1) 369 | x_sup_g = sup_x.view(sup_x.size()[0], sup_x.size()[1], -1) 370 | 371 | x_que_norm = torch.norm(x_que, p=2, dim=1, keepdim=True) 372 | x_sup_norm = torch.norm(x_sup, p=2, dim=1, keepdim=True) 373 | 374 | x_sup_norm = x_sup_norm.permute(0, 2, 1) 375 | x_qs_norm = torch.matmul( x_sup_norm, x_que_norm) 376 | x_sup = x_sup.permute(0, 2, 1) 377 | x_qs = torch.matmul(x_sup, x_que) 378 | x_qs = x_qs / (x_qs_norm + 1e-5) 379 | R_qs = x_qs 380 | attention = R_qs.softmax(dim=-1) 381 | 382 | proj_value = x_sup_g 383 | 384 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 385 | out = out.view(1, sup_x.size()[1], sup_x.size()[2], sup_x.size()[3]) 386 | out = out + sup_x 387 | out_su = out # sup_x[[sup_nshot]] 388 | return out_su 389 | 390 | def sp_center_iter(self, supp_feat, supp_mask, sp_init_center, n_iter): 391 | c_xy, num_sp = sp_init_center.size() 392 | _, h, w = supp_feat.size() 393 | h_coords = torch.arange(h).view(h, 1).contiguous().repeat(1, w).unsqueeze(0).float().cuda() 394 | w_coords = torch.arange(w).repeat(h, 1).unsqueeze(0).float().cuda() 395 | supp_feat = torch.cat([supp_feat, h_coords, w_coords], 0) 396 | supp_feat_roi = supp_feat[:, (supp_mask == 1).squeeze()] 397 | 398 | num_roi = supp_feat_roi.size(1) 399 | supp_feat_roi_rep = supp_feat_roi.unsqueeze(-1).repeat(1, 1, num_sp) 400 | sp_center = torch.zeros_like(sp_init_center).cuda() # (C + xy) x num_sp 401 | 402 | for i in range(n_iter): 403 | # Compute association between each pixel in RoI and superpixel 404 | if i == 0: 405 | sp_center_rep = sp_init_center.unsqueeze(1).repeat(1, num_roi, 1) 406 | else: 407 | sp_center_rep = sp_center.unsqueeze(1).repeat(1, num_roi, 1) 408 | assert supp_feat_roi_rep.shape == sp_center_rep.shape # (C + xy) x num_roi x num_sp 409 | dist = torch.pow(supp_feat_roi_rep - sp_center_rep, 2.0) 410 | feat_dist = dist[:-2, :, :].sum(0) 411 | spat_dist = dist[-2:, :, :].sum(0) 412 | total_dist = torch.pow(feat_dist/100 + spat_dist / 100, 0.5) 413 | p2sp_assoc = torch.neg(total_dist).exp() 414 | p2sp_assoc = p2sp_assoc / (p2sp_assoc.sum(0, keepdim=True)) # num_roi x num_sp 415 | 416 | sp_center = supp_feat_roi_rep * p2sp_assoc.unsqueeze(0) # (C + xy) x num_roi x num_sp 417 | sp_center = sp_center.sum(1) 418 | 419 | return sp_center[:-2, :] 420 | 421 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tntek/DSPNet/a4b71c2ab1229221584ccb2a5100bf66a7d1ef03/models/backbone/__init__.py -------------------------------------------------------------------------------- /models/backbone/torchvision_backbones.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backbones supported by torchvison. 3 | """ 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import torchvision 11 | 12 | class TVDeeplabRes101Encoder(nn.Module): 13 | """ 14 | FCN-Resnet101 backbone from torchvision deeplabv3 15 | No ASPP is used as we found emperically it hurts performance 16 | """ 17 | def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False): 18 | super().__init__() 19 | _model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None) 20 | if use_coco_init: 21 | print("###### NETWORK: Using ms-coco initialization ######") 22 | else: 23 | print("###### NETWORK: Training from scratch ######") 24 | 25 | _model_list = list(_model.children()) 26 | self.aux_dim_keep = aux_dim_keep 27 | self.backbone = _model_list[0] 28 | self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension 29 | self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False) 30 | 31 | _aspp = _model_list[1][0] 32 | _conv256 = _model_list[1][1] 33 | self.aspp_out = nn.Sequential(*[_aspp, _conv256] ) 34 | self.use_aspp = use_aspp 35 | 36 | def forward(self, x_in, low_level): 37 | """ 38 | Args: 39 | low_level: whether returning aggregated low-level features in FCN 40 | """ 41 | fts = self.backbone(x_in) 42 | if self.use_aspp: 43 | fts256 = self.aspp_out(fts['out']) 44 | high_level_fts = fts256 45 | else: 46 | fts2048 = fts['out'] 47 | high_level_fts = self.localconv(fts2048) 48 | 49 | if low_level: 50 | low_level_fts = fts['aux'][:, : self.aux_dim_keep] 51 | return high_level_fts, low_level_fts 52 | else: 53 | return high_level_fts 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /models/grid_proto_fewshot.py: -------------------------------------------------------------------------------- 1 | """ 2 | ALPNet 3 | """ 4 | from collections import OrderedDict 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .alpmodule import MultiProtoAsConv 10 | from .backbone.torchvision_backbones import TVDeeplabRes101Encoder 11 | from util.seed_init import place_seed_points_d 12 | # DEBUG 13 | from pdb import set_trace 14 | 15 | import pickle 16 | import torchvision 17 | 18 | # options for type of prototypes 19 | FG_PROT_MODE = 'gridconv+' 20 | BG_PROT_MODE = 'gridconv' 21 | # thresholds for deciding class of prototypes 22 | FG_THRESH = 0.95 23 | BG_THRESH = 0.95 24 | 25 | 26 | 27 | class FewShotSeg(nn.Module): 28 | """ 29 | ALPNet 30 | Args: 31 | in_channels: Number of input channels 32 | cfg: Model configurations 33 | """ 34 | def __init__(self, in_channels=3, pretrained_path=None, cfg=None): 35 | super(FewShotSeg, self).__init__() 36 | self.pretrained_path = pretrained_path 37 | self.config = cfg or {'align': False} 38 | self.get_encoder(in_channels) 39 | self.get_cls() 40 | 41 | def get_encoder(self, in_channels): 42 | # if self.config['which_model'] == 'deeplab_res101': 43 | if self.config['which_model'] == 'dlfcn_res101': 44 | use_coco_init = self.config['use_coco_init'] 45 | self.encoder = TVDeeplabRes101Encoder(use_coco_init) 46 | 47 | else: 48 | raise NotImplementedError(f'Backbone network {self.config["which_model"]} not implemented') 49 | 50 | if self.pretrained_path: 51 | check = torch.load(self.pretrained_path) 52 | self.load_state_dict(torch.load(self.pretrained_path),False) 53 | print(f'###### Pre-trained model f{self.pretrained_path} has been loaded ######') 54 | 55 | 56 | def get_cls(self): 57 | """ 58 | Obtain the similarity-based classifier 59 | """ 60 | proto_hw = self.config["proto_grid_size"] 61 | feature_hw = self.config["feature_hw"] 62 | assert self.config['cls_name'] == 'grid_proto' 63 | if self.config['cls_name'] == 'grid_proto': 64 | self.cls_unit = MultiProtoAsConv(proto_grid = [proto_hw, proto_hw], feature_hw = self.config["feature_hw"]) # when treating it as ordinary prototype 65 | else: 66 | raise NotImplementedError(f'Classifier {self.config["cls_name"]} not implemented') 67 | 68 | 69 | def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz = False): 70 | """ 71 | Args: 72 | supp_imgs: support images 73 | way x shot x [B x 3 x H x W], list of lists of tensors 74 | fore_mask: foreground masks for support images 75 | way x shot x [B x H x W], list of lists of tensors 76 | back_mask: background masks for support images 77 | way x shot x [B x H x W], list of lists of tensors 78 | qry_imgs: query images 79 | N x [B x 3 x H x W], list of tensors 80 | show_viz: return the visualization dictionary 81 | """ 82 | mol = 'Loss' 83 | # ('Please go through this piece of code carefully') 84 | n_ways = len(supp_imgs) 85 | n_shots = len(supp_imgs[0]) 86 | n_queries = len(qry_imgs) 87 | 88 | # print("aa", n_ways, "bb", n_shots, "cc", n_queries) 89 | 90 | assert n_ways == 1, "Multi-shot has not been implemented yet" # NOTE: actual shot in support goes in batch dimension 91 | assert n_queries == 1 92 | 93 | sup_bsize = supp_imgs[0][0].shape[0] 94 | img_size = supp_imgs[0][0].shape[-2:] 95 | qry_bsize = qry_imgs[0].shape[0] 96 | 97 | assert sup_bsize == qry_bsize == 1 98 | 99 | imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs] 100 | + [torch.cat(qry_imgs, dim=0),], dim=0) 101 | 102 | img_fts = self.encoder(imgs_concat, low_level = False) 103 | fts_size = img_fts.shape[-2:] 104 | 105 | # print("aa", img_fts.shape) 106 | 107 | supp_fts = img_fts[:n_ways * n_shots * sup_bsize].view( 108 | n_ways, n_shots, sup_bsize, -1, *fts_size) # Wa x Sh x B x C x H' x W' 109 | qry_fts = img_fts[n_ways * n_shots * sup_bsize:].view( 110 | n_queries, qry_bsize, -1, *fts_size) # N x B x C x H' x W' 111 | fore_mask = torch.stack([torch.stack(way, dim=0) 112 | for way in fore_mask], dim=0) # Wa x Sh x B x H' x W' 113 | fore_mask = torch.autograd.Variable(fore_mask, requires_grad = True) 114 | back_mask = torch.stack([torch.stack(way, dim=0) 115 | for way in back_mask], dim=0) # Wa x Sh x B x H' x W' 116 | s_y = fore_mask[0][0][0] 117 | 118 | init_seed_list = [] 119 | mask = (s_y == 1).float() # H x W 120 | 121 | init_seed = place_seed_points_d(mask, down_stride=8, max_num_sp=5, 122 | avg_sp_area=100) 123 | init_seed_list.append(init_seed.unsqueeze(0)) 124 | s_init_seed = torch.cat(init_seed_list).cuda() 125 | 126 | ###### Compute loss ###### 127 | align_loss = 0 128 | outputs = [] 129 | visualizes = [] # the buffer for visualization 130 | 131 | for epi in range(1): 132 | fg_masks = [] # keep the way part 133 | res_fg_msk = torch.stack([F.interpolate(fore_mask_w, size = fts_size, mode = 'bilinear') for fore_mask_w in fore_mask], dim = 0) # [nway, ns, nb, nh', nw'] 134 | res_bg_msk = torch.stack([F.interpolate(back_mask_w, size = fts_size, mode = 'bilinear') for back_mask_w in back_mask], dim = 0) # [nway, ns, nb, nh', nw'] 135 | 136 | 137 | scores = [] 138 | assign_maps = [] 139 | bg_sim_maps = [] 140 | fg_sim_maps = [] 141 | 142 | _raw_score, _, aux_attr = self.cls_unit(mol,qry_fts, supp_fts, res_bg_msk,s_init_seed, mode = BG_PROT_MODE, thresh = BG_THRESH, isval = isval, val_wsize = val_wsize, vis_sim = show_viz ) 143 | scores.append(_raw_score) 144 | assign_maps.append(aux_attr['proto_assign']) 145 | if show_viz: 146 | bg_sim_maps.append(aux_attr['raw_local_sims']) 147 | 148 | for way, _msk in enumerate(res_fg_msk): #'mask++' 149 | _raw_score, _, aux_attr = self.cls_unit(mol,qry_fts, supp_fts, _msk.unsqueeze(0),s_init_seed, mode = FG_PROT_MODE if F.avg_pool2d(_msk, 4).max() >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask', thresh = FG_THRESH, isval = isval, val_wsize = val_wsize, vis_sim = show_viz ) 150 | 151 | scores.append(_raw_score) 152 | if show_viz: 153 | fg_sim_maps.append(aux_attr['raw_local_sims']) 154 | 155 | pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W' 156 | 157 | pred_int = F.interpolate(pred, size=img_size, mode='bilinear') 158 | 159 | outputs.append(F.interpolate(pred, size=img_size, mode='bilinear')) 160 | 161 | if self.config['align'] and self.training: 162 | align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, pred_int, supp_fts[:, :, epi], 163 | fore_mask[:, :, epi], back_mask[:, :, epi]) 164 | align_loss += align_loss_epi 165 | 166 | 167 | output = torch.stack(outputs, dim=1) # N x B x (1 + Wa) x H x W 168 | output = output.view(-1, *output.shape[2:]) 169 | assign_maps = torch.stack(assign_maps, dim = 1) 170 | bg_sim_maps = torch.stack(bg_sim_maps, dim = 1) if show_viz else None 171 | fg_sim_maps = torch.stack(fg_sim_maps, dim = 1) if show_viz else None 172 | 173 | return output, align_loss / sup_bsize, [bg_sim_maps, fg_sim_maps], assign_maps 174 | 175 | 176 | # Batch was at the outer loop 177 | def alignLoss(self, qry_fts, pred, pred_int,supp_fts, fore_mask, back_mask): 178 | """ 179 | Compute the loss for the prototype alignment branch 180 | 181 | Args: 182 | qry_fts: embedding features for query images 183 | expect shape: N x C x H' x W' 184 | pred: predicted segmentation score 185 | expect shape: N x (1 + Wa) x H x W 186 | supp_fts: embedding fatures for support images 187 | expect shape: Wa x Sh x C x H' x W' 188 | fore_mask: foreground masks for support images 189 | expect shape: way x shot x H x W 190 | back_mask: background masks for support images 191 | expect shape: way x shot x H x W 192 | """ 193 | mol = 'alignLoss' 194 | n_ways, n_shots = len(fore_mask), len(fore_mask[0]) 195 | 196 | pred_mask = pred.argmax(dim=1).unsqueeze(0) #1 x N x H' x W' 197 | binary_masks = [pred_mask == i for i in range(1 + n_ways)] 198 | 199 | pred_mask_int = pred_int.argmax(dim=1).unsqueeze(0) # N x 1 x H' x W' 200 | binary_masks_int = [pred_mask_int == i for i in range(1 + n_ways)] 201 | pred_mask_int = binary_masks_int[1].float() 202 | s_y = pred_mask_int 203 | 204 | init_seed_list = [] 205 | mask = (s_y[0,:,:] == 1).float() # H x W 473 473 [0.0, 1.0] 206 | 207 | init_seed = place_seed_points_d(mask, down_stride=8, max_num_sp=5, 208 | avg_sp_area=100) 209 | init_seed_list.append(init_seed.unsqueeze(0)) 210 | s_init_seed = torch.cat(init_seed_list) 211 | 212 | # skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0] 213 | # FIXME: fix this in future we here make a stronger assumption that a positive class must be there to avoid undersegmentation/ lazyness 214 | skip_ways = [] 215 | 216 | qry_fts = qry_fts.unsqueeze(0).unsqueeze(2) 217 | 218 | ### end of added part 219 | 220 | loss = [] 221 | for way in range(n_ways): 222 | if way in skip_ways: 223 | continue 224 | # Get the query prototypes 225 | for shot in range(n_shots): 226 | img_fts = supp_fts[way: way + 1, shot: shot + 1] # actual local query [way(1), nb(1, nb is now nshot), nc, h, w] 227 | 228 | qry_pred_fg_msk = F.interpolate(binary_masks[way + 1].float(), size = img_fts.shape[-2:], mode = 'bilinear') # [1 (way), n (shot), h, w] 229 | 230 | # background 231 | qry_pred_bg_msk = F.interpolate(binary_masks[0].float(), size = img_fts.shape[-2:], mode = 'bilinear') # 1, n, h ,w 232 | scores = [] 233 | 234 | _raw_score_bg, _, _ = self.cls_unit(mol=mol,qry = img_fts, sup_x = qry_fts, sup_y = qry_pred_bg_msk.unsqueeze(-3),s_init_seed=s_init_seed, mode = BG_PROT_MODE, thresh = BG_THRESH ) 235 | 236 | scores.append(_raw_score_bg) 237 | 238 | _raw_score_fg, _, _ = self.cls_unit(mol=mol,qry = img_fts, sup_x = qry_fts, sup_y = qry_pred_fg_msk.unsqueeze(-3),s_init_seed=s_init_seed, mode = FG_PROT_MODE if F.avg_pool2d(qry_pred_fg_msk, 4).max() >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask', thresh = FG_THRESH ) 239 | 240 | scores.append(_raw_score_fg) 241 | 242 | supp_pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W' 243 | supp_pred = F.interpolate(supp_pred, size=fore_mask.shape[-2:], mode='bilinear') 244 | 245 | supp_label = torch.full_like(fore_mask[way, shot], 255, 246 | device=img_fts.device).long() 247 | supp_label[fore_mask[way, shot] == 1] = 1 248 | supp_label[back_mask[way, shot] == 1] = 0 249 | # Compute Loss 250 | loss.append( F.cross_entropy( 251 | supp_pred, supp_label[None, ...], ignore_index=255) / n_shots / n_ways) 252 | 253 | return torch.sum( torch.stack(loss)) 254 | -------------------------------------------------------------------------------- /pigeon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tntek/DSPNet/a4b71c2ab1229221584ccb2a5100bf66a7d1ef03/pigeon.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | json5==0.8.5 2 | jupyter==1.0.0 3 | nibabel==2.5.1 4 | numpy==1.15.1 5 | opencv-python==4.2.0.32 6 | Pillow>=8.1.1 7 | sacred==0.7.5 8 | scikit-image==0.14.0 9 | SimpleITK==1.2.3 10 | torch==1.3.0 11 | torchvision==0.4.1 12 | tqdm==4.32.2 13 | dcm2nii 14 | -------------------------------------------------------------------------------- /train_ssl_abdominal_ct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # train a model to segment abdominal CT 3 | GPUID1=0 4 | export CUDA_VISIBLE_DEVICES=$GPUID1 5 | 6 | ####### Shared configs ###### 7 | PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training 8 | CPT="myexp" 9 | DATASET='SABS_Superpix' 10 | NWORKER=4 11 | 12 | ALL_EV=( 0) # 5-fold cross validation (0, 1, 2, 3, 4) 13 | ALL_SCALE=( "MIDDLE") # config of pseudolabels 14 | 15 | ### Use L/R kidney as testing classes 16 | LABEL_SETS=0 17 | EXCLU='[2,3]' # setting 2: excluding kidneies in training set to test generalization capability even though they are unlabeled. Use [] for setting 1 by Roy et al. 18 | 19 | ### Use Liver and spleen as testing classes 20 | # LABEL_SETS=1 21 | # EXCLU='[1,6]' 22 | 23 | ###### Training configs ###### 24 | NSTEP=100100 25 | DECAY=0.95 26 | 27 | MAX_ITER=1000 # defines the size of an epoch 28 | SNAPSHOT_INTERVAL=25000 # interval for saving snapshot 29 | SEED='1234' 30 | 31 | ###### Validation configs ###### 32 | SUPP_ID='[6]' # using the additionally loaded scan as support 33 | 34 | echo =================================== 35 | 36 | for EVAL_FOLD in "${ALL_EV[@]}" 37 | do 38 | for SUPERPIX_SCALE in "${ALL_SCALE[@]}" 39 | do 40 | PREFIX="train_${DATASET}_lbgroup${LABEL_SETS}_scale_${SUPERPIX_SCALE}_vfold${EVAL_FOLD}" 41 | echo $PREFIX 42 | LOGDIR="./exps/${CPT}_${SUPERPIX_SCALE}_${LABEL_SETS}" 43 | 44 | if [ ! -d $LOGDIR ] 45 | then 46 | mkdir $LOGDIR 47 | fi 48 | 49 | ~/anaconda3/envs/syj_alpnet/bin/python training.py with \ 50 | 'modelname=dlfcn_res101' \ 51 | 'usealign=True' \ 52 | 'optim_type=sgd' \ 53 | num_workers=$NWORKER \ 54 | scan_per_load=-1 \ 55 | label_sets=$LABEL_SETS \ 56 | 'use_wce=True' \ 57 | exp_prefix=$PREFIX \ 58 | 'clsname=grid_proto' \ 59 | n_steps=$NSTEP \ 60 | exclude_cls_list=$EXCLU \ 61 | eval_fold=$EVAL_FOLD \ 62 | dataset=$DATASET \ 63 | proto_grid_size=$PROTO_GRID \ 64 | max_iters_per_load=$MAX_ITER \ 65 | min_fg_data=1 seed=$SEED \ 66 | save_snapshot_every=$SNAPSHOT_INTERVAL \ 67 | superpix_scale=$SUPERPIX_SCALE \ 68 | lr_step_gamma=$DECAY \ 69 | path.log_dir=$LOGDIR \ 70 | support_idx=$SUPP_ID 71 | done 72 | done 73 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training the model 3 | Extended from original implementation of PANet by Wang et al. 4 | """ 5 | import os 6 | import shutil 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim 10 | from torch.utils.data import DataLoader 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | import torch.backends.cudnn as cudnn 13 | import numpy as np 14 | 15 | from models.grid_proto_fewshot import FewShotSeg 16 | from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset 17 | from dataloaders.dataset_utils import DATASET_INFO 18 | import dataloaders.augutils as myaug 19 | 20 | from util.utils import set_seed, t2n, to01, compose_wt_simple 21 | from util.metric import Metric 22 | 23 | from config_ssl_upload import ex 24 | import tqdm 25 | 26 | # config pre-trained model caching path 27 | os.environ['TORCH_HOME'] = "./pretrained_model" 28 | 29 | @ex.automain 30 | def main(_run, _config, _log): 31 | 32 | if _run.observers: 33 | os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) 34 | for source_file, _ in _run.experiment_info['sources']: 35 | os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), 36 | exist_ok=True) 37 | _run.observers[0].save_file(source_file, f'source/{source_file}') 38 | shutil.rmtree(f'{_run.observers[0].basedir}/_sources') 39 | 40 | 41 | set_seed(_config['seed']) 42 | cudnn.enabled = True 43 | cudnn.benchmark = True 44 | torch.cuda.set_device(device=_config['gpu_id']) 45 | torch.set_num_threads(1) 46 | 47 | _log.info('###### Create model ######') 48 | model = FewShotSeg(pretrained_path=None, cfg=_config['model']) 49 | 50 | model = model.cuda() 51 | model.train() 52 | 53 | _log.info('###### Load data ######') 54 | ### Training set 55 | data_name = _config['dataset'] 56 | if data_name == 'SABS_Superpix': 57 | baseset_name = 'SABS' 58 | elif data_name == 'C0_Superpix': 59 | raise NotImplementedError 60 | baseset_name = 'C0' 61 | elif data_name == 'CHAOST2_Superpix': 62 | baseset_name = 'CHAOST2' 63 | else: 64 | raise ValueError(f'Dataset: {data_name} not found') 65 | 66 | ### Transforms for data augmentation 67 | tr_transforms = myaug.transform_with_label({'aug': myaug.augs[_config['which_aug']]}) 68 | assert _config['scan_per_load'] < 0 # by default we load the entire dataset directly 69 | 70 | test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP']['pa_all'] - DATASET_INFO[baseset_name]['LABEL_GROUP'][_config["label_sets"]] 71 | _log.info(f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######') 72 | _log.info(f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######') 73 | 74 | 75 | tr_parent = SuperpixelDataset( # base dataset 76 | which_dataset = baseset_name, 77 | base_dir = _config['path'][data_name]['data_dir'], 78 | idx_split = _config['eval_fold'], 79 | mode = 'train', 80 | min_fg = str(_config["min_fg_data"]), # dummy entry for superpixel dataset 81 | transforms = tr_transforms, 82 | nsup = _config['task']['n_shots'], 83 | scan_per_load = _config['scan_per_load'], 84 | exclude_list = _config["exclude_cls_list"], 85 | superpix_scale = _config["superpix_scale"], 86 | fix_length = _config["max_iters_per_load"] if (data_name == 'C0_Superpix') or (data_name == 'CHAOST2_Superpix') else None 87 | ) 88 | 89 | 90 | ### dataloaders 91 | trainloader = DataLoader( 92 | tr_parent, 93 | batch_size=_config['batch_size'], 94 | shuffle=True, 95 | num_workers=_config['num_workers'], 96 | pin_memory=True, 97 | drop_last=True 98 | ) 99 | param_group = [] 100 | for k,v in model.named_parameters(): 101 | if 'cls_unit' in k: 102 | param_group +=[{'params':v,'lr':_config['optim']['lr']*0.0001,'momentum':_config['optim']['momentum'],'weight_decay':_config['optim']['weight_decay']}] 103 | else : 104 | param_group +=[{'params':v,'lr':_config['optim']['lr'],'momentum':_config['optim']['momentum'],'weight_decay':_config['optim']['weight_decay']}] 105 | _log.info('###### Set optimizer ######') 106 | if _config['optim_type'] == 'sgd': 107 | optimizer = torch.optim.SGD(param_group) 108 | else: 109 | raise NotImplementedError 110 | 111 | 112 | scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma = _config['lr_step_gamma']) 113 | 114 | my_weight = compose_wt_simple(_config["use_wce"], data_name) 115 | 116 | criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'], weight = my_weight) 117 | 118 | 119 | i_iter = 0 # total number of iteration 120 | n_sub_epoches = _config['n_steps'] // _config['max_iters_per_load'] # number of times for reloading 121 | 122 | 123 | log_loss = {'loss': 0, 'align_loss': 0} 124 | 125 | _log.info('###### Training ######') 126 | for sub_epoch in range(n_sub_epoches): 127 | _log.info(f'###### This is epoch {sub_epoch} of {n_sub_epoches} epoches ######') 128 | for _, sample_batched in enumerate(trainloader): 129 | # Prepare input 130 | i_iter += 1 131 | # add writers 132 | support_images = [[shot.cuda() for shot in way] 133 | for way in sample_batched['support_images']] 134 | 135 | support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] 136 | for way in sample_batched['support_mask']] 137 | 138 | support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] 139 | for way in sample_batched['support_mask']] 140 | 141 | query_images = [query_image.cuda() 142 | for query_image in sample_batched['query_images']] 143 | 144 | query_labels = torch.cat( 145 | [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0) 146 | 147 | 148 | optimizer.zero_grad() 149 | # FIXME: in the model definition, filter out the failure case where pseudolabel falls outside of image or too small to calculate a prototype 150 | try: 151 | query_pred, align_loss, debug_vis, assign_mats = model(support_images, support_fg_mask, support_bg_mask, query_images, isval = False, val_wsize = None) 152 | except: 153 | 154 | print('Faulty batch detected, skip') 155 | continue 156 | 157 | query_loss = criterion(query_pred, query_labels) 158 | 159 | loss = query_loss + align_loss 160 | 161 | 162 | loss.backward() 163 | optimizer.step() 164 | scheduler.step() 165 | # Log loss 166 | query_loss = query_loss.detach().data.cpu().numpy() 167 | align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0 168 | 169 | _run.log_scalar('loss', query_loss) 170 | _run.log_scalar('align_loss', align_loss) 171 | log_loss['loss'] += query_loss 172 | log_loss['align_loss'] += align_loss 173 | 174 | # print loss and take snapshots 175 | if (i_iter + 1) % _config['print_interval'] == 0: 176 | 177 | loss = log_loss['loss'] / _config['print_interval'] 178 | align_loss = log_loss['align_loss'] / _config['print_interval'] 179 | 180 | log_loss['loss'] = 0 181 | log_loss['align_loss'] = 0 182 | 183 | print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss},') 184 | 185 | if (i_iter + 1) % _config['save_snapshot_every'] == 0: 186 | _log.info('###### Taking snapshot ######') 187 | torch.save(model.state_dict(), 188 | os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) 189 | if data_name == 'C0_Superpix' or data_name == 'CHAOST2_Superpix': 190 | if (i_iter + 1) % _config['max_iters_per_load'] == 0: 191 | _log.info('###### Reloading dataset ######') 192 | trainloader.dataset.reload_buffer() 193 | print(f'###### New dataset with {len(trainloader.dataset)} slices has been loaded ######') 194 | 195 | if (i_iter - 2) > _config['n_steps']: 196 | return 1 # finish up 197 | 198 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # from utils import * -------------------------------------------------------------------------------- /util/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics for computing evalutation results 3 | Modified from vanilla PANet code by Wang et al. 4 | """ 5 | 6 | import numpy as np 7 | 8 | class Metric(object): 9 | """ 10 | Compute evaluation result 11 | 12 | Args: 13 | max_label: 14 | max label index in the data (0 denoting background) 15 | n_scans: 16 | number of test scans 17 | """ 18 | def __init__(self, max_label=20, n_scans=None): 19 | self.labels = list(range(max_label + 1)) # all class labels [0 1 2 3 4 ] 20 | self.n_scans = 1 if n_scans is None else n_scans # 4 21 | 22 | # list of list of array, each array save the TP/FP/FN statistic of a testing sample 23 | self.tp_lst = [[] for _ in range(self.n_scans)] # [[],[],[],[]] 24 | self.fp_lst = [[] for _ in range(self.n_scans)] # [[],[],[],[]] 25 | self.fn_lst = [[] for _ in range(self.n_scans)] # [[],[],[],[]] 26 | 27 | def reset(self): 28 | """ 29 | Reset accumulated evaluation. 30 | """ 31 | # assert self.n_scans == 1, 'Should not reset accumulated result when we are not doing one-time batch-wise validation' 32 | del self.tp_lst, self.fp_lst, self.fn_lst 33 | self.tp_lst = [[] for _ in range(self.n_scans)] 34 | self.fp_lst = [[] for _ in range(self.n_scans)] 35 | self.fn_lst = [[] for _ in range(self.n_scans)] 36 | 37 | def record(self, pred, target, labels=None, n_scan=None): 38 | """ 39 | Record the evaluation result for each sample and each class label, including: 40 | True Positive, False Positive, False Negative 41 | 42 | Args: 43 | pred: 44 | predicted mask array, expected shape is H x W 256 256 45 | target: 46 | target mask array, expected shape is H x W 256 256 47 | labels: 48 | only count specific label, used when knowing all possible labels in advance 2 49 | """ 50 | assert pred.shape == target.shape 51 | 52 | if self.n_scans == 1: 53 | n_scan = 0 54 | 55 | # array to save the TP/FP/FN statistic for each class (plus BG) 56 | tp_arr = np.full(len(self.labels), np.nan) # array([nan, nan, nan, nan, nan]) 57 | fp_arr = np.full(len(self.labels), np.nan) # array([nan, nan, nan, nan, nan]) 58 | fn_arr = np.full(len(self.labels), np.nan) # array([nan, nan, nan, nan, nan]) 59 | 60 | if labels is None: 61 | labels = self.labels 62 | else: 63 | labels = [0,] + labels # [0,2] 64 | 65 | for j, label in enumerate(labels): 66 | # Get the location of the pixels that are predicted as class j 67 | idx = np.where(np.logical_and(pred == j, target != 255)) 68 | pred_idx_j = set(zip(idx[0].tolist(), idx[1].tolist())) 69 | # Get the location of the pixels that are class j in ground truth 70 | idx = np.where(target == j) 71 | target_idx_j = set(zip(idx[0].tolist(), idx[1].tolist())) 72 | 73 | # this should not work: if target_idx_j: # if ground-truth contains this class 74 | # the author is adding posion to the code 75 | tp_arr[label] = len(set.intersection(pred_idx_j, target_idx_j)) # 交集 76 | fp_arr[label] = len(pred_idx_j - target_idx_j) 77 | fn_arr[label] = len(target_idx_j - pred_idx_j) 78 | 79 | self.tp_lst[n_scan].append(tp_arr) 80 | self.fp_lst[n_scan].append(fp_arr) 81 | self.fn_lst[n_scan].append(fn_arr) 82 | 83 | def get_mIoU(self, labels=None, n_scan=None): 84 | """ 85 | Compute mean IoU 86 | 87 | Args: 88 | labels: 89 | specify a subset of labels to compute mean IoU, default is using all classes 90 | """ 91 | if labels is None: 92 | labels = self.labels 93 | # Sum TP, FP, FN statistic of all samples 94 | if n_scan is None: 95 | tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) 96 | for _scan in range(self.n_scans)] 97 | fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) 98 | for _scan in range(self.n_scans)] 99 | fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) 100 | for _scan in range(self.n_scans)] 101 | 102 | # Compute mean IoU classwisely 103 | # Average across n_scans, then average over classes 104 | mIoU_class = np.vstack([tp_sum[_scan] / (tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) 105 | for _scan in range(self.n_scans)]) 106 | mIoU = mIoU_class.mean(axis=1) 107 | 108 | return (mIoU_class.mean(axis=0), mIoU_class.std(axis=0), 109 | mIoU.mean(axis=0), mIoU.std(axis=0)) 110 | else: 111 | tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) 112 | fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) 113 | fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) 114 | 115 | # Compute mean IoU classwisely and average over classes 116 | mIoU_class = tp_sum / (tp_sum + fp_sum + fn_sum) 117 | mIoU = mIoU_class.mean() 118 | 119 | return mIoU_class, mIoU 120 | 121 | def get_mDice(self, labels=None, n_scan=None, give_raw = False): 122 | """ 123 | Compute mean Dice score (in 3D scan level) 124 | 125 | Args: 126 | labels: 127 | specify a subset of labels to compute mean IoU, default is using all classes 128 | """ 129 | # NOTE: unverified 130 | if labels is None: 131 | labels = self.labels 132 | # Sum TP, FP, FN statistic of all samples 1 4 133 | if n_scan is None: 134 | tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) 135 | for _scan in range(self.n_scans)] 136 | fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) 137 | for _scan in range(self.n_scans)] 138 | fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) 139 | for _scan in range(self.n_scans)] 140 | 141 | # Average across n_scans, then average over classes 142 | mDice_class = np.vstack([ 2 * tp_sum[_scan] / ( 2 * tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) 143 | for _scan in range(self.n_scans)]) 144 | mDice = mDice_class.mean(axis=1) 145 | print(mDice_class) 146 | if not give_raw: 147 | return (mDice_class.mean(axis=0), mDice_class.std(axis=0), 148 | mDice.mean(axis=0), mDice.std(axis=0)) 149 | else: 150 | return (mDice_class.mean(axis=0), mDice_class.std(axis=0), 151 | mDice.mean(axis=0), mDice.std(axis=0), mDice_class) 152 | 153 | else: 154 | tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) 155 | fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) 156 | fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) 157 | 158 | # Compute mean IoU classwisely and average over classes 159 | mDice_class = 2 * tp_sum / ( 2 * tp_sum + fp_sum + fn_sum) 160 | mDice = mIoU_class.mean() 161 | 162 | return mDice_class, mDice 163 | 164 | def get_mPrecRecall(self, labels=None, n_scan=None, give_raw = False): 165 | """ 166 | Compute precision and recall 167 | 168 | Args: 169 | labels: 170 | specify a subset of labels to compute mean IoU, default is using all classes 171 | """ 172 | # NOTE: unverified 173 | if labels is None: 174 | labels = self.labels 175 | # Sum TP, FP, FN statistic of all samples 176 | if n_scan is None: 177 | tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0).take(labels) 178 | for _scan in range(self.n_scans)] 179 | fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0).take(labels) 180 | for _scan in range(self.n_scans)] 181 | fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0).take(labels) 182 | for _scan in range(self.n_scans)] 183 | 184 | # Compute mean IoU classwisely 185 | # Average across n_scans, then average over classes 186 | mPrec_class = np.vstack([ tp_sum[_scan] / ( tp_sum[_scan] + fp_sum[_scan] ) 187 | for _scan in range(self.n_scans)]) 188 | 189 | mRec_class = np.vstack([ tp_sum[_scan] / ( tp_sum[_scan] + fn_sum[_scan] ) 190 | for _scan in range(self.n_scans)]) 191 | 192 | mPrec = mPrec_class.mean(axis=1) 193 | mRec = mRec_class.mean(axis=1) 194 | if not give_raw: 195 | return (mPrec_class.mean(axis=0), mPrec_class.std(axis=0), mPrec.mean(axis=0), mPrec.std(axis=0), mRec_class.mean(axis=0), mRec_class.std(axis=0), mRec.mean(axis=0), mRec.std(axis=0)) 196 | else: 197 | return (mPrec_class.mean(axis=0), mPrec_class.std(axis=0), mPrec.mean(axis=0), mPrec.std(axis=0), mRec_class.mean(axis=0), mRec_class.std(axis=0), mRec.mean(axis=0), mRec.std(axis=0), mPrec_class, mRec_class) 198 | 199 | 200 | else: 201 | tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0).take(labels) 202 | fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0).take(labels) 203 | fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0).take(labels) 204 | 205 | # Compute mean IoU classwisely and average over classes 206 | mPrec_class = tp_sum / (tp_sum + fp_sum) 207 | mPrec = mPrec_class.mean() 208 | 209 | mRec_class = tp_sum / (tp_sum + fn_sum) 210 | mRec = mRec_class.mean() 211 | 212 | return mPrec_class, mPrec, mRec_class, mRec 213 | 214 | def get_mIoU_binary(self, n_scan=None): 215 | """ 216 | Compute mean IoU for binary scenario 217 | (sum all foreground classes as one class) 218 | """ 219 | # Sum TP, FP, FN statistic of all samples 220 | if n_scan is None: 221 | tp_sum = [np.nansum(np.vstack(self.tp_lst[_scan]), axis=0) 222 | for _scan in range(self.n_scans)] 223 | fp_sum = [np.nansum(np.vstack(self.fp_lst[_scan]), axis=0) 224 | for _scan in range(self.n_scans)] 225 | fn_sum = [np.nansum(np.vstack(self.fn_lst[_scan]), axis=0) 226 | for _scan in range(self.n_scans)] 227 | 228 | # Sum over all foreground classes 229 | tp_sum = [np.c_[tp_sum[_scan][0], np.nansum(tp_sum[_scan][1:])] 230 | for _scan in range(self.n_scans)] 231 | fp_sum = [np.c_[fp_sum[_scan][0], np.nansum(fp_sum[_scan][1:])] 232 | for _scan in range(self.n_scans)] 233 | fn_sum = [np.c_[fn_sum[_scan][0], np.nansum(fn_sum[_scan][1:])] 234 | for _scan in range(self.n_scans)] 235 | 236 | # Compute mean IoU classwisely and average across classes 237 | mIoU_class = np.vstack([tp_sum[_scan] / (tp_sum[_scan] + fp_sum[_scan] + fn_sum[_scan]) 238 | for _scan in range(self.n_scans)]) 239 | mIoU = mIoU_class.mean(axis=1) 240 | 241 | return (mIoU_class.mean(axis=0), mIoU_class.std(axis=0), 242 | mIoU.mean(axis=0), mIoU.std(axis=0)) 243 | else: 244 | tp_sum = np.nansum(np.vstack(self.tp_lst[n_scan]), axis=0) 245 | fp_sum = np.nansum(np.vstack(self.fp_lst[n_scan]), axis=0) 246 | fn_sum = np.nansum(np.vstack(self.fn_lst[n_scan]), axis=0) 247 | 248 | # Sum over all foreground classes 249 | tp_sum = np.c_[tp_sum[0], np.nansum(tp_sum[1:])] 250 | fp_sum = np.c_[fp_sum[0], np.nansum(fp_sum[1:])] 251 | fn_sum = np.c_[fn_sum[0], np.nansum(fn_sum[1:])] 252 | 253 | mIoU_class = tp_sum / (tp_sum + fp_sum + fn_sum) 254 | mIoU = mIoU_class.mean() 255 | 256 | return mIoU_class, mIoU 257 | -------------------------------------------------------------------------------- /util/seed_init.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.morphology import distance_transform_edt 2 | from scipy.ndimage.filters import gaussian_filter 3 | 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def place_seed_points_d(mask, down_stride=8, max_num_sp=5, avg_sp_area=100): 10 | ''' 11 | :param mask: the RoI region to do clustering, torch tensor: H x W 12 | :param down_stride: downsampled stride for RoI region 13 | :param max_num_sp: the maximum number of superpixels 14 | :return: segments: the coordinates of the initial seed, max_num_sp x 2 15 | ''' 16 | 17 | segments_x = np.zeros(max_num_sp, dtype=np.int64) 18 | segments_y = np.zeros(max_num_sp, dtype=np.int64) 19 | 20 | m_np = mask.cpu().numpy() 21 | down_h = int((m_np.shape[0] - 1) / down_stride + 1) 22 | down_w = int((m_np.shape[1] - 1) / down_stride + 1) 23 | down_size = (down_h, down_w) 24 | m_np_down = cv2.resize(m_np, dsize=down_size, interpolation=cv2.INTER_NEAREST) 25 | 26 | nz = np.nonzero(m_np_down) 27 | if len(nz[0]) != 0: 28 | 29 | p = [np.min(nz[0]), np.min(nz[1])] #[15, 10] 30 | pend = [np.max(nz[0]), np.max(nz[1])] # [45, 50] 31 | 32 | # cropping to bounding box around ROI 33 | m_np_roi = np.copy(m_np_down)[p[0]:pend[0] + 1, p[1]:pend[1] + 1] 34 | 35 | # num_sp is adaptive, based on the area of support mask 36 | mask_area = (m_np_roi == 1).sum() 37 | num_sp = int(min((np.array(mask_area) / avg_sp_area).round(), max_num_sp)) 38 | 39 | else: 40 | num_sp = 0 41 | 42 | if (num_sp != 0) and (num_sp != 1): 43 | for i in range(num_sp): 44 | 45 | # n seeds are placed as far as possible from every other seed and the edge. 46 | 47 | # STEP 1: conduct Distance Transform and choose the maximum point 48 | dtrans = distance_transform_edt(m_np_roi) 49 | dtrans = gaussian_filter(dtrans, sigma=0.1) 50 | 51 | coords1 = np.nonzero(dtrans == np.max(dtrans)) 52 | segments_x[i] = coords1[0][0] 53 | segments_y[i] = coords1[1][0] 54 | 55 | # STEP 2: set the point to False and repeat Step 1 56 | m_np_roi[segments_x[i], segments_y[i]] = False 57 | segments_x[i] += p[0] 58 | segments_y[i] += p[1] 59 | 60 | segments = np.concatenate([segments_x[..., np.newaxis], segments_y[..., np.newaxis]], axis=1) # max_num_sp x 2 61 | segments = torch.from_numpy(segments) 62 | 63 | return segments -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | """Util functions 2 | Extended from original PANet code 3 | TODO: move part of dataset configurations to data_utils 4 | """ 5 | import random 6 | import torch 7 | import numpy as np 8 | import operator 9 | 10 | def set_seed(seed): 11 | """ 12 | Set the random seed 13 | """ 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | CLASS_LABELS = { 19 | 'SABS': { 20 | 'pa_all': set( [1,2,3,6] ), 21 | 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing 22 | 1: set( [2,3] ), # lower_abdomen 23 | }, 24 | 'C0': { 25 | 'pa_all': set(range(1, 4)), 26 | 0: set([2,3]), 27 | 1: set([1,3]), 28 | 2: set([1,2]), 29 | }, 30 | 'CHAOST2': { 31 | 'pa_all': set(range(1, 5)), 32 | 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes 33 | 1: set([2, 3]), # lower_abdomen 34 | }, 35 | } 36 | 37 | def get_bbox(fg_mask, inst_mask): 38 | """ 39 | Get the ground truth bounding boxes 40 | """ 41 | 42 | fg_bbox = torch.zeros_like(fg_mask, device=fg_mask.device) 43 | bg_bbox = torch.ones_like(fg_mask, device=fg_mask.device) 44 | 45 | inst_mask[fg_mask == 0] = 0 46 | area = torch.bincount(inst_mask.view(-1)) 47 | cls_id = area[1:].argmax() + 1 48 | cls_ids = np.unique(inst_mask)[1:] 49 | 50 | mask_idx = np.where(inst_mask[0] == cls_id) 51 | y_min = mask_idx[0].min() 52 | y_max = mask_idx[0].max() 53 | x_min = mask_idx[1].min() 54 | x_max = mask_idx[1].max() 55 | fg_bbox[0, y_min:y_max+1, x_min:x_max+1] = 1 56 | 57 | for i in cls_ids: 58 | mask_idx = np.where(inst_mask[0] == i) 59 | y_min = max(mask_idx[0].min(), 0) 60 | y_max = min(mask_idx[0].max(), fg_mask.shape[1] - 1) 61 | x_min = max(mask_idx[1].min(), 0) 62 | x_max = min(mask_idx[1].max(), fg_mask.shape[2] - 1) 63 | bg_bbox[0, y_min:y_max+1, x_min:x_max+1] = 0 64 | return fg_bbox, bg_bbox 65 | 66 | def t2n(img_t): 67 | """ 68 | torch to numpy regardless of whether tensor is on gpu or memory 69 | """ 70 | if img_t.is_cuda: 71 | return img_t.data.cpu().numpy() 72 | else: 73 | return img_t.data.numpy() 74 | 75 | def to01(x_np): 76 | """ 77 | normalize a numpy to 0-1 for visualize 78 | """ 79 | return (x_np - x_np.min()) / (x_np.max() - x_np.min() + 1e-5) 80 | 81 | def compose_wt_simple(is_wce, data_name): 82 | """ 83 | Weights for cross-entropy loss 84 | """ 85 | if is_wce: 86 | if data_name in ['SABS', 'SABS_Superpix', 'C0', 'C0_Superpix', 'CHAOST2', 'CHAOST2_Superpix','CMR_Superpix','CMR']: 87 | return torch.FloatTensor([0.05, 1.0]).cuda() 88 | else: 89 | raise NotImplementedError 90 | else: 91 | return torch.FloatTensor([1.0, 1.0]).cuda() 92 | 93 | 94 | class CircularList(list): 95 | """ 96 | Helper for spliting training and validation scans 97 | Originally: https://stackoverflow.com/questions/8951020/pythonic-circular-list/8951224 98 | """ 99 | def __getitem__(self, x): 100 | if isinstance(x, slice): 101 | return [self[x] for x in self._rangeify(x)] 102 | 103 | index = operator.index(x) 104 | try: 105 | return super().__getitem__(index % len(self)) 106 | except ZeroDivisionError: 107 | raise IndexError('list index out of range') 108 | 109 | def _rangeify(self, slice): 110 | start, stop, step = slice.start, slice.stop, slice.step 111 | if start is None: 112 | start = 0 113 | if stop is None: 114 | stop = len(self) 115 | if step is None: 116 | step = 1 117 | return range(start, stop, step) 118 | 119 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation script 3 | """ 4 | import os 5 | import shutil 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim 9 | from torch.utils.data import DataLoader 10 | from torch.optim.lr_scheduler import MultiStepLR 11 | import torch.backends.cudnn as cudnn 12 | import numpy as np 13 | import torchvision 14 | import cv2 15 | from PIL import Image 16 | 17 | from models.grid_proto_fewshot import FewShotSeg 18 | 19 | from dataloaders.dev_customized_med import med_fewshot_val 20 | from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset 21 | from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset 22 | from dataloaders.dataset_utils import DATASET_INFO, get_normalize_op 23 | from dataloaders.niftiio import convert_to_sitk 24 | 25 | from util.metric import Metric 26 | 27 | from config_ssl_upload import ex 28 | 29 | import tqdm 30 | import SimpleITK as sitk 31 | from torchvision.utils import make_grid 32 | 33 | # config pre-trained model caching path 34 | os.environ['TORCH_HOME'] = "./pretrained_model" 35 | 36 | @ex.automain 37 | def main(_run, _config, _log): 38 | if _run.observers: 39 | os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) 40 | for source_file, _ in _run.experiment_info['sources']: 41 | os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), 42 | exist_ok=True) 43 | _run.observers[0].save_file(source_file, f'source/{source_file}') 44 | shutil.rmtree(f'{_run.observers[0].basedir}/_sources') 45 | 46 | cudnn.enabled = True 47 | cudnn.benchmark = True 48 | torch.cuda.set_device(device=_config['gpu_id']) 49 | torch.set_num_threads(1) 50 | 51 | _log.info(f'###### Reload model {_config["reload_model_path"]} ######') 52 | model = FewShotSeg(pretrained_path = _config['reload_model_path'], cfg=_config['model']) 53 | model = model.cuda() 54 | model.eval() 55 | 56 | _log.info('###### Load data ######') 57 | ### Training set 58 | data_name = _config['dataset'] 59 | if data_name == 'SABS_Superpix': 60 | baseset_name = 'SABS' 61 | max_label = 13 62 | elif data_name == 'C0_Superpix': 63 | raise NotImplementedError 64 | baseset_name = 'C0' 65 | max_label = 3 66 | elif data_name == 'CHAOST2_Superpix': 67 | baseset_name = 'CHAOST2' 68 | max_label = 4 69 | else: 70 | raise ValueError(f'Dataset: {data_name} not found') 71 | 72 | test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP']['pa_all'] - DATASET_INFO[baseset_name]['LABEL_GROUP'][_config["label_sets"]] 73 | 74 | ### Transforms for data augmentation 75 | te_transforms = None 76 | 77 | assert _config['scan_per_load'] < 0 # by default we load the entire dataset directly 78 | 79 | _log.info(f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######') 80 | _log.info(f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######') 81 | 82 | if baseset_name == 'SABS': # for CT we need to know statistics of 83 | tr_parent = SuperpixelDataset( # base dataset 84 | which_dataset = baseset_name, 85 | base_dir=_config['path'][data_name]['data_dir'], 86 | idx_split = _config['eval_fold'], 87 | mode='train', 88 | min_fg=str(_config["min_fg_data"]), # dummy entry for superpixel dataset 89 | transforms=None, 90 | nsup = _config['task']['n_shots'], 91 | scan_per_load = _config['scan_per_load'], 92 | exclude_list = _config["exclude_cls_list"], 93 | superpix_scale = _config["superpix_scale"], 94 | fix_length = _config["max_iters_per_load"] if (data_name == 'C0_Superpix') or (data_name == 'CHAOST2_Superpix') else None 95 | ) 96 | norm_func = tr_parent.norm_func 97 | else: 98 | norm_func = get_normalize_op(modality = 'MR', fids = None) 99 | 100 | 101 | te_dataset, te_parent = med_fewshot_val( 102 | dataset_name = baseset_name, 103 | base_dir=_config['path'][baseset_name]['data_dir'], 104 | idx_split = _config['eval_fold'], 105 | scan_per_load = _config['scan_per_load'], 106 | act_labels=test_labels, 107 | npart = _config['task']['npart'], 108 | nsup = _config['task']['n_shots'], 109 | extern_normalize_func = norm_func 110 | ) 111 | 112 | ### dataloaders 113 | testloader = DataLoader( 114 | te_dataset, 115 | batch_size = 1, 116 | shuffle=False, 117 | num_workers=1, 118 | pin_memory=False, 119 | drop_last=False 120 | ) 121 | 122 | _log.info('###### Set validation nodes ######') 123 | mar_val_metric_node = Metric(max_label=max_label, n_scans= len(te_dataset.dataset.pid_curr_load) - _config['task']['n_shots']) 124 | 125 | _log.info('###### Starting validation ######') 126 | model.eval() 127 | mar_val_metric_node.reset() 128 | 129 | with torch.no_grad(): 130 | save_pred_buffer = {} # indexed by class 131 | 132 | for curr_lb in test_labels: 133 | te_dataset.set_curr_cls(curr_lb) 134 | support_batched = te_parent.get_support(curr_class = curr_lb, class_idx = [curr_lb], scan_idx = _config["support_idx"], npart=_config['task']['npart']) 135 | 136 | # way(1 for now) x part x shot x 3 x H x W] # 137 | support_images = [[shot.cuda() for shot in way] 138 | for way in support_batched['support_images']] # way x part x [shot x C x H x W] 139 | suffix = 'mask' 140 | support_fg_mask = [[shot[f'fg_{suffix}'].float().cuda() for shot in way] 141 | for way in support_batched['support_mask']] 142 | support_bg_mask = [[shot[f'bg_{suffix}'].float().cuda() for shot in way] 143 | for way in support_batched['support_mask']] 144 | 145 | curr_scan_count = -1 # counting for current scan 146 | _lb_buffer = {} # indexed by scan 147 | 148 | last_qpart = 0 # used as indicator for adding result to buffer 149 | 150 | for sample_batched in testloader: 151 | 152 | _scan_id = sample_batched["scan_id"][0] # we assume batch size for query is 1 153 | if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query 154 | continue 155 | if sample_batched["is_start"]: 156 | ii = 0 157 | curr_scan_count += 1 158 | _scan_id = sample_batched["scan_id"][0] 159 | outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"] 160 | outsize = (256, 256, outsize[0]) # original image read by itk: Z, H, W, in prediction we use H, W, Z 161 | _pred = np.zeros( outsize ) 162 | _pred.fill(np.nan) 163 | 164 | q_part = sample_batched["part_assign"] # the chunck of query, for assignment with support 165 | query_images = [sample_batched['image'].cuda()] 166 | query_labels = torch.cat([ sample_batched['label'].cuda()], dim=0) 167 | idx = sample_batched['idx'][0] 168 | 169 | # [way, [part, [shot x C x H x W]]] -> 170 | sup_img_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_images[0][q_part]]] # way(1) x shot x [B(1) x C x H x W] 171 | sup_fgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_fg_mask[0][q_part]]] 172 | sup_bgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_bg_mask[0][q_part]]] 173 | 174 | query_pred, _, _, assign_mats = model( sup_img_part , sup_fgm_part, sup_bgm_part, query_images, isval = True, val_wsize = _config["val_wsize"] ) 175 | query_pred_bg = np.array(query_pred.argmin(dim=1)[0].cpu()) 176 | query_pred = np.array(query_pred.argmax(dim=1)[0].cpu()) 177 | _pred[..., ii] = query_pred.copy() 178 | 179 | if (sample_batched["z_id"] - sample_batched["z_max"] <= _config['z_margin']) and (sample_batched["z_id"] - sample_batched["z_min"] >= -1 * _config['z_margin']): 180 | mar_val_metric_node.record(query_pred, np.array(query_labels[0].cpu()), labels=[curr_lb], n_scan=curr_scan_count) 181 | else: 182 | pass 183 | 184 | ii += 1 185 | # now check data format 186 | if sample_batched["is_end"]: 187 | if _config['dataset'] != 'C0': 188 | _lb_buffer[_scan_id] = _pred.transpose(2,0,1) # H, W, Z -> to Z H W 189 | else: 190 | _lb_buffer[_scan_id] = _pred 191 | 192 | save_pred_buffer[str(curr_lb)] = _lb_buffer 193 | 194 | ### save results 195 | for curr_lb, _preds in save_pred_buffer.items(): 196 | for _scan_id, _pred in _preds.items(): 197 | _pred *= float(curr_lb) 198 | itk_pred = convert_to_sitk(_pred, te_dataset.dataset.info_by_scan[_scan_id]) 199 | fid = os.path.join(f'{_run.observers[0].dir}/interm_preds', f'scan_{_scan_id}_label_{curr_lb}.nii.gz') 200 | sitk.WriteImage(itk_pred, fid, True) 201 | _log.info(f'###### {fid} has been saved ######') 202 | 203 | del save_pred_buffer 204 | 205 | del sample_batched, support_images, support_bg_mask, query_images, query_labels, query_pred 206 | 207 | # compute dice scores by scan 208 | m_classDice,_, m_meanDice,_, m_rawDice = mar_val_metric_node.get_mDice(labels=sorted(test_labels), n_scan=None, give_raw = True) 209 | 210 | m_classPrec,_, m_meanPrec,_, m_classRec,_, m_meanRec,_, m_rawPrec, m_rawRec = mar_val_metric_node.get_mPrecRecall(labels=sorted(test_labels), n_scan=None, give_raw = True) 211 | 212 | mar_val_metric_node.reset() # reset this calculation node 213 | 214 | # write validation result to log file 215 | _run.log_scalar('mar_val_batches_classDice', m_classDice.tolist()) 216 | _run.log_scalar('mar_val_batches_meanDice', m_meanDice.tolist()) 217 | _run.log_scalar('mar_val_batches_rawDice', m_rawDice.tolist()) 218 | 219 | _run.log_scalar('mar_val_batches_classPrec', m_classPrec.tolist()) 220 | _run.log_scalar('mar_val_batches_meanPrec', m_meanPrec.tolist()) 221 | _run.log_scalar('mar_val_batches_rawPrec', m_rawPrec.tolist()) 222 | 223 | _run.log_scalar('mar_val_batches_classRec', m_classRec.tolist()) 224 | _run.log_scalar('mar_val_al_batches_meanRec', m_meanRec.tolist()) 225 | _run.log_scalar('mar_val_al_batches_rawRec', m_rawRec.tolist()) 226 | 227 | _log.info(f'mar_val batches classDice: {m_classDice}') 228 | _log.info(f'mar_val batches meanDice: {m_meanDice}') 229 | 230 | _log.info(f'mar_val batches classPrec: {m_classPrec}') 231 | _log.info(f'mar_val batches meanPrec: {m_meanPrec}') 232 | 233 | _log.info(f'mar_val batches classRec: {m_classRec}') 234 | _log.info(f'mar_val batches meanRec: {m_meanRec}') 235 | 236 | print("============ ============") 237 | 238 | _log.info(f'End of validation') 239 | return 1 240 | 241 | 242 | --------------------------------------------------------------------------------