├── utils ├── __init__.py ├── common_utils.py ├── shot_batch_sampler.py ├── preprocessor.py ├── config.py ├── convert_h5.py ├── data_utils.py ├── log_utils.py └── evaluator_slow.py ├── dataset ├── __init__.py ├── transform.py └── dataset_us.py ├── modules ├── __init__.py ├── cre.py ├── se_modules.py ├── voxelmorph_us.py └── conv_blocks.py ├── README.md ├── settings.ini ├── settings.yaml ├── cus_loss.py ├── settings.py ├── train.py ├── test.py ├── losses.py ├── RAP.py └── solver_us_pro.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def create_if_not(path): 5 | if not os.path.exists(path): 6 | os.makedirs(path) 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning What and Where to Segment: A New Perspective on Medical Image Few-Shot Segmentation 2 | --- 3 | This repository provides the official PyTorch implementation of RAP(Regitration and Prototype). 4 | 5 | 6 | # Requirements 7 | * torch == 1.6.0 8 | 9 | --- 10 | # Getting Started 11 | test.py 12 | 13 | --- 14 | # Citation 15 | ``` 16 | @article{FENG2023102834, 17 | title = {Learning what and where to segment: A new perspective on medical image few-shot segmentation}, 18 | journal = {Medical Image Analysis}, 19 | volume = {87}, 20 | pages = {102834}, 21 | year = {2023}, 22 | issn = {1361-8415}, 23 | doi = {https://doi.org/10.1016/j.media.2023.102834}, 24 | url = {https://www.sciencedirect.com/science/article/pii/S1361841523000944}, 25 | author = {Yong Feng and Yonghuai Wang and Honghe Li and Mingjun Qu and Jinzhu Yang}, 26 | keywords = {Few-shot segmentation, Organ segmentation, Echocardiography, Medical prior knowledge} 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /modules/cre.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def coords_grid(batch, ht, wd): 5 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 6 | coords = torch.stack(coords[::-1], dim=0).float() # shape = 2,h,w 7 | return coords[None].repeat(batch, 1, 1, 1) 8 | 9 | 10 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 11 | """ Wrapper for grid_sample, uses pixel coordinates """ 12 | H, W = img.shape[-2:] 13 | xgrid, ygrid = coords.split([1,1], dim=-1) 14 | xgrid = 2*xgrid/(W-1) - 1 15 | ygrid = 2*ygrid/(H-1) - 1 16 | 17 | grid = torch.cat([xgrid, ygrid], dim=-1) 18 | img = F.grid_sample(img, grid, align_corners=True) 19 | 20 | if mask: 21 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 22 | return img, mask.float() 23 | 24 | return img 25 | 26 | def Correlation(fmap1, fmap2, r=3): 27 | batch, dim, ht, wd = fmap1.shape 28 | fmap1 = fmap1.view(batch, dim, ht * wd) 29 | fmap2 = fmap2.view(batch, dim, ht * wd) 30 | 31 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 32 | corr = corr.view(batch, ht, wd, 1, ht, wd) 33 | corr = corr / torch.sqrt(torch.tensor(dim).float()) 34 | corr = corr.view(-1, 1, ht, wd) 35 | # corr = F.adaptive_avg_pool2d(corr, (64, 64)) 36 | # corr = corr.view(batch, ht, wd, -1) 37 | # corr = corr.permute(0, 3, 1, 2).contiguous() 38 | 39 | coords = coords_grid(batch, ht, wd).to(fmap1.device) 40 | coords = coords.permute(0, 2, 3, 1) 41 | batch, h1, w1, _ = coords.shape 42 | dx = torch.linspace(-r, r, 2 * r + 1) 43 | dy = torch.linspace(-r, r, 2 * r + 1) 44 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 45 | 46 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) 47 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 48 | coords_lvl = centroid_lvl + delta_lvl 49 | 50 | corr = bilinear_sampler(corr, coords_lvl) 51 | corr = corr.view(batch, h1, w1, -1) 52 | out = corr.permute(0, 3, 1, 2).contiguous().float() 53 | 54 | return out 55 | 56 | if __name__ == '__main__': 57 | fmap1 = torch.ones(1,1,3,3) 58 | fmap2 = torch.ones(1, 1, 3, 3) 59 | out = Correlation(fmap1, fmap2) 60 | print(out.shape) 61 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [COMMON] 2 | save_model_dir = "saved_models" 3 | model_name = "laa" 4 | log_dir = "logs" 5 | device = "cuda:0" 6 | exp_dir = "experiments" 7 | 8 | [DATA] 9 | data_dir = "datasets/silver_corpus" 10 | train_data_file = "Data_train.h5" 11 | train_label_file = "Label_train.h5" 12 | train_class_weights_file = "Class_Weight_train.h5" 13 | train_weights_file = "Weight_train.h5" 14 | test_data_file = "Data_test.h5" 15 | test_label_file = "Label_test.h5" 16 | test_class_weights_file = "Class_Weight_test.h5" 17 | test_weights_file = "Weight_test.h5" 18 | labels = ["Background", "Left WM", "Left Cortex", "Left Lateral ventricle", "Left Inf LatVentricle", "Left Cerebellum WM", "Left Cerebellum Cortex", "Left Thalamus", "Left Caudate", "Left Putamen", "Left Pallidum", "3rd Ventricle", "4th Ventricle", "Brain Stem", "Left Hippocampus", "Left Amygdala", "CSF (Cranial)", "Left Accumbens", "Left Ventral DC", "Right WM", "Right Cortex", "Right Lateral Ventricle", "Right Inf LatVentricle", "Right Cerebellum WM", "Right Cerebellum Cortex", "Right Thalamus", "Right Caudate", "Right Putamen", "Right Pallidum", "Right Hippocampus", "Right Amygdala", "Right Accumbens", "Right Ventral DC"] 19 | 20 | [NETWORK] 21 | num_class = 2 22 | num_channels = 1 23 | num_filters = 64 24 | kernel_h = 5 25 | kernel_w = 5 26 | kernel_c = 1 27 | stride_conv = 1 28 | pool = 2 29 | stride_pool = 2 30 | se_block = "SSE" #Valid options : NONE, CSE, SSE, CSSE 31 | drop_out = 0 32 | 33 | [TRAINING] 34 | fold = 'fold2' 35 | exp_name = "laa" 36 | final_model_file = "laa.pth.tar" 37 | learning_rate = 1e-3 38 | momentum = 0.95 39 | optim_weight_decay = 0.00001 40 | train_batch_size = 2 41 | val_batch_size = 2 42 | log_nth = 500 43 | num_epochs = 400 44 | optim_betas = (0.9, 0.999) 45 | optim_eps = 1e-8 46 | lr_scheduler_step_size = 10 47 | lr_scheduler_gamma = 0.5 48 | iterations=1000 49 | test_iterations=1000 50 | pre_trained_path = "" 51 | 52 | #Uses the last checkpoint file from the exp_dir_name folder 53 | use_last_checkpoint = True 54 | 55 | [EVAL] 56 | eval_model_path = "few_shot_segmentation" 57 | data_dir = "/home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral" 58 | label_dir = "/home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral" 59 | volumes_txt_file = "datasets/MALC/test_volumes.txt" 60 | query_txt_file = "datasets/eval_query.txt" 61 | support_txt_file = "datasets/eval_support.txt" 62 | remap_config = "WholeBody" #Valid options : Neo, FS, WholeBody 63 | orientation = "AXI" #Valid options : COR, AXI, SAG 64 | save_predictions_dir = "copy_over" 65 | -------------------------------------------------------------------------------- /settings.yaml: -------------------------------------------------------------------------------- 1 | COMMON: 2 | save_model_dir : saved_models 3 | model_name : laa 4 | log_dir : logs 5 | device : cuda:0 6 | exp_dir : experiments 7 | 8 | DATA: 9 | data_dir : datasets/silver_corpus 10 | train_data_file : Data_train.h5 11 | train_label_file : Label_train.h5 12 | train_class_weights_file : Class_Weight_train.h5 13 | train_weights_file : Weight_train.h5 14 | test_data_file : Data_test.h5 15 | test_label_file : Label_test.h5 16 | test_class_weights_file : Class_Weight_test.h5 17 | test_weights_file : Weight_test.h5 18 | labels : [ Background , Left WM , Left Cortex , Left Lateral ventricle , 19 | Left Inf LatVentricle , Left Cerebellum WM , Left Cerebellum Cortex , 20 | Left Thalamus , Left Caudate , Left Putamen , Left Pallidum , 3rd Ventricle , 21 | 4th Ventricle , Brain Stem , Left Hippocampus , Left Amygdala , CSF (Cranial) , 22 | Left Accumbens , Left Ventral DC , Right WM , Right Cortex , Right Lateral Ventricle , 23 | Right Inf LatVentricle , Right Cerebellum WM , Right Cerebellum Cortex , Right Thalamus , 24 | Right Caudate , Right Putamen , Right Pallidum , Right Hippocampus , Right Amygdala , 25 | Right Accumbens , Right Ventral DC ] 26 | 27 | NETWORK: 28 | num_class : 2 29 | num_channels : 1 30 | num_filters : 64 31 | kernel_h : 5 32 | kernel_w : 5 33 | kernel_c : 1 34 | stride_conv : 1 35 | pool : 2 36 | stride_pool : 2 37 | se_block : SSE #Valid options : NONE, CSE, SSE, CSSE 38 | drop_out : 0 39 | 40 | TRAINING: 41 | fold : fold2 42 | exp_name : laa 43 | final_model_file : laa.pth.tar 44 | learning_rate : 0.001 45 | momentum : 0.95 46 | optim_weight_decay : 0.00001 47 | train_batch_size : 2 48 | val_batch_size : 2 49 | log_nth : 10 50 | num_epochs : 200 51 | optim_betas : (0.9, 0.999) 52 | optim_eps : 0.00000001 53 | lr_scheduler_step_size : 10 54 | lr_scheduler_gamma : 0.5 55 | iterations: 1000 56 | test_iterations: 1000 57 | pre_trained_path : 58 | 59 | #Uses the last checkpoint file from the exp_dir_name folder 60 | use_last_checkpoint : True 61 | 62 | EVAL: 63 | eval_model_path : few_shot_segmentation 64 | data_dir : /home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral 65 | label_dir : /home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral 66 | volumes_txt_file : datasets/MALC/test_volumes.txt 67 | query_txt_file : datasets/eval_query.txt 68 | support_txt_file : datasets/eval_support.txt 69 | remap_config : WholeBody #Valid options : Neo, FS, WholeBody 70 | orientation : AXI #Valid options : COR, AXI, SAG 71 | save_predictions_dir : copy_over 72 | -------------------------------------------------------------------------------- /utils/shot_batch_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | lab_list_fold = {"fold1": {"train": [2, 6, 7, 8, 9], "val": [1]}, 4 | "fold2": {"train": [1, 6, 7, 8, 9], "val": [2]}, 5 | "fold3": {"train": [1, 2, 8, 9], "val": [6, 7]}, 6 | "fold4": {"train": [1, 2, 6, 7], "val": [8, 9]}} 7 | 8 | 9 | def get_lab_list(phase, fold): 10 | return lab_list_fold[fold][phase] 11 | 12 | 13 | # 14 | def get_class_slices(labels, i): 15 | num_slices, H, W = labels.shape 16 | thresh = 0.005 17 | total_slices = labels == i 18 | pixel_sum = np.sum(total_slices, axis=(1, 2)).squeeze() 19 | pixel_sum = pixel_sum / (H * W) 20 | threshold_list = [idx for idx, slice in enumerate( 21 | pixel_sum) if slice > thresh] 22 | return threshold_list 23 | 24 | 25 | def get_index_dict(labels, lab_list): 26 | index_list = {i: get_class_slices(labels, i) for i in lab_list} 27 | p = [1 - (len(val) / len(labels)) for val in index_list.values()] 28 | p = p / np.sum(p) 29 | return index_list, p 30 | 31 | 32 | class OneShotBatchSampler: 33 | ''' 34 | 35 | ''' 36 | 37 | def _gen_query_label(self): 38 | """ 39 | Returns a query label uniformly from the label list of current phase. Also returns indexes of the slices which contain that label 40 | 41 | :return: random query label, index list of slices with generated class available 42 | """ 43 | query_label = np.random.choice(self.lab_list, 1, p=self.p)[0] 44 | return query_label 45 | 46 | def __init__(self, labels, phase, fold, batch_size, iteration=500): 47 | ''' 48 | 49 | ''' 50 | super(OneShotBatchSampler, self).__init__() 51 | 52 | self.index_list = None 53 | self.query_label = None 54 | self.batch_size = batch_size 55 | self.iteration = iteration 56 | self.labels = labels 57 | self.phase = phase 58 | self.lab_list = get_lab_list(phase, fold) 59 | self.index_dict, self.p = get_index_dict(labels, self.lab_list) 60 | 61 | def __iter__(self): 62 | ''' 63 | yield a batch of indexes 64 | ''' 65 | self.n = 0 66 | return self 67 | 68 | def __next__(self): 69 | """ 70 | Called on each iteration to return slices a random class label. On each iteration gets a random class label from label list and selects 2 x batch_size slices uniformly from index list 71 | :return: randomly select 2 x batch_size slices of a class label for the given iteration 72 | """ 73 | if self.n > self.iteration: 74 | raise StopIteration 75 | 76 | self.query_label = self._gen_query_label() 77 | self.index_list = self.index_dict[self.query_label] 78 | batch = np.random.choice(self.index_list, size=2 * self.batch_size) 79 | self.n += 1 80 | return batch 81 | 82 | def __len__(self): 83 | """ 84 | returns the number of iterations (episodes) per epoch 85 | :return: number os iterations 86 | """ 87 | 88 | return self.iteration 89 | -------------------------------------------------------------------------------- /cus_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn import functional as F 5 | 6 | def make_one_hot(input, num_classes): 7 | """Convert class index tensor to one hot encoding tensor. 8 | Args: 9 | input: A tensor of shape [N, 1, *] 10 | num_classes: An int of number of class 11 | Returns: 12 | A tensor of shape [N, num_classes, *] 13 | """ 14 | # input = input.clone().unsqueeze(1).contiguous() # clone(). 15 | # input[input==255]=0 16 | # print('unique:', torch.unique(input)) 17 | shape = np.array(input.shape) 18 | shape[1] = num_classes 19 | shape = tuple(shape) 20 | result = torch.zeros(shape).cuda() 21 | result = result.scatter_(1, input, 1) 22 | return result 23 | 24 | 25 | class onecls_DiceLoss(nn.Module): 26 | def __init__(self, smooth=1e-6): 27 | super().__init__() 28 | self.smooth = smooth 29 | 30 | def forward(self, predict, target): 31 | # print(predict.shape, target.shape) 32 | assert predict.shape == target.shape, "predict & target shape don't match" 33 | predict = predict.contiguous().view(predict.shape[0], -1) 34 | target = target.contiguous().view(target.shape[0], -1) 35 | # print('unique:', target.unique()) 36 | # print(predict.max(), target.max()) 37 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth 38 | # den = torch.sum(torch.pow(predict, 2) + torch.pow(target, 2), dim=1) + self.smooth 39 | den = torch.sum(predict,dim=1) + torch.sum(target,dim=1) + self.smooth 40 | # print(num, den) 41 | loss = 1 - 2*num / den 42 | return torch.mean(loss) 43 | 44 | class onecls_TverskyLoss(nn.Module): 45 | def __init__(self, smooth=1e-6, alpha=0.6): 46 | super().__init__() 47 | self.smooth = smooth 48 | self.alpha = alpha 49 | self.beta = 1 - self.alpha 50 | 51 | def forward(self, predict, target): 52 | # print(predict.shape, target.shape) 53 | assert predict.shape == target.shape, "predict & target shape don't match" 54 | predict = predict.contiguous().view(predict.shape[0], -1) 55 | target = target.contiguous().view(target.shape[0], -1) 56 | 57 | tp = torch.sum(torch.mul(predict, target), dim=1) 58 | fp = torch.sum(torch.mul(predict, 1-target), dim=1) 59 | fn = torch.sum(torch.mul(1-predict, target), dim=1) 60 | tversky = (tp + self.smooth) / (tp + self.beta*fp + self.alpha*fn + self.smooth) 61 | loss = 1 - tversky 62 | return loss 63 | 64 | class multicls_DiceLoss(nn.Module): # For images that have ground truth 65 | ''' 66 | input: 67 | predict shape: batch_size * class_num * H * W 68 | target shape: batch_size * H * W 69 | ''' 70 | def __init__(self, n_class = 2): 71 | super().__init__() 72 | # self.loss = onecls_TverskyLoss(alpha=0.7) 73 | self.loss = onecls_DiceLoss() 74 | self.n_class = n_class 75 | 76 | def forward(self, predict, target): 77 | # target [B, 5, 1, H, W] 78 | target = make_one_hot(target, self.n_class) 79 | #print('channel:', chl) 80 | assert predict.shape == target.shape, "predict & target shape don't match" 81 | # print(target.shape) 82 | w = [0.5, 0.5] 83 | total_loss = 0 84 | # predict = predict.permute(0, 2, 1, 3, 4) 85 | # target = target.permute(0, 2, 1, 3, 4) 86 | for i in range(target.shape[1]): 87 | # print(target.shape) 88 | #print('sum_%d :'%i, torch.sum(target[:, i, 1:, :, :]), torch.sum(target[:, i, :, :, :])) 89 | # loss_i = torch.pow(w[i]*self.loss(predict[:, :, i], target[:,:, i]), 0.75) 90 | loss_i = w[i]*self.loss(predict[:, i], target[:, i]) 91 | # print('loss_%d_0 :'%i, loss_i, loss_i.shape) 92 | # loss_i += 0.3*(-torch.log(predict[:, :, i]+1e-6)*target[:,:, i]).mean((1,2,3)) 93 | # print('loss_%d_1 :' % i, loss_i, loss_i.shape) 94 | total_loss += loss_i 95 | #print(target.shape) 96 | return total_loss # / target.shape[1] 97 | 98 | -------------------------------------------------------------------------------- /utils/preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ORIENTATION = { 4 | 'coronal': "COR", 5 | 'axial': "AXI", 6 | 'sagital': "SAG" 7 | } 8 | 9 | 10 | def rotate_orientation(volume_data, volume_label, orientation=ORIENTATION['coronal']): 11 | if orientation == ORIENTATION['coronal']: 12 | return volume_data.transpose((2, 0, 1)), volume_label.transpose((2, 0, 1)) 13 | elif orientation == ORIENTATION['axial']: 14 | return volume_data.transpose((1, 2, 0)), volume_label.transpose((1, 2, 0)) 15 | elif orientation == ORIENTATION['sagital']: 16 | return volume_data, volume_label 17 | else: 18 | raise ValueError("Invalid value for orientation. Pleas see help") 19 | 20 | 21 | def estimate_weights_mfb(labels): 22 | class_weights = np.zeros_like(labels) 23 | unique, counts = np.unique(labels, return_counts=True) 24 | median_freq = np.median(counts) 25 | weights = np.zeros(len(unique)) 26 | for i, label in enumerate(unique): 27 | class_weights += (median_freq // counts[i]) * np.array(labels == label) 28 | try: 29 | weights[int(label)] = median_freq // counts[i] 30 | except IndexError as e: 31 | print("Exception in processing") 32 | continue 33 | 34 | grads = np.gradient(labels) 35 | edge_weights = (grads[0] ** 2 + grads[1] ** 2) > 0 36 | class_weights += 2 * edge_weights 37 | return class_weights, weights 38 | 39 | 40 | def remap_labels(labels, remap_config): 41 | """ 42 | Function to remap the label values into the desired range of algorithm 43 | """ 44 | if remap_config == 'FS': 45 | label_list = [2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 41, 42, 43, 44, 46, 47, 49, 50, 46 | 51, 52, 53, 54, 58, 60] 47 | elif remap_config == 'Neo': 48 | labels[(labels >= 100) & (labels % 2 == 0)] = 210 49 | labels[(labels >= 100) & (labels % 2 == 1)] = 211 50 | label_list = [45, 211, 52, 50, 41, 39, 60, 37, 58, 56, 4, 11, 35, 48, 32, 46, 30, 62, 44, 210, 51, 49, 40, 38, 51 | 59, 36, 57, 55, 47, 31, 23, 61] 52 | 53 | elif remap_config == 'WholeBody': 54 | label_list = [1, 2, 7, 8, 9, 13, 14, 17, 18] 55 | 56 | elif remap_config == 'brain_fewshot': 57 | labels[(labels >= 100) & (labels % 2 == 0)] = 210 58 | labels[(labels >= 100) & (labels % 2 == 1)] = 211 59 | label_list = [[210, 211], [45, 44], [52, 51], [35], [39, 41, 40, 38], [36, 37, 57, 58, 60, 59, 56, 55]] 60 | else: 61 | raise ValueError("Invalid argument value for remap config, only valid options are FS and Neo") 62 | 63 | new_labels = np.zeros_like(labels) 64 | 65 | k = isinstance(label_list[0], list) 66 | 67 | if not k: 68 | for i, label in enumerate(label_list): 69 | label_present = np.zeros_like(labels) 70 | label_present[labels == label] = 1 71 | new_labels = new_labels + (i + 1) * label_present 72 | else: 73 | for i, label in enumerate(label_list): 74 | label_present = np.zeros_like(labels) 75 | for j in label: 76 | label_present[labels == j] = 1 77 | new_labels = new_labels + (i + 1) * label_present 78 | return new_labels 79 | 80 | 81 | def reduce_slices(data, labels, skip_Frame=40): 82 | """ 83 | This function removes the useless black slices from the start and end. And then selects every even numbered frame. 84 | """ 85 | no_slices, H, W = data.shape 86 | mask_vector = np.zeros(no_slices, dtype=int) 87 | mask_vector[::2], mask_vector[1::2] = 1, 0 88 | mask_vector[:skip_Frame], mask_vector[-skip_Frame:-1] = 0, 0 89 | 90 | data_reduced = np.compress(mask_vector, data, axis=0).reshape(-1, H, W) 91 | labels_reduced = np.compress(mask_vector, labels, axis=0).reshape(-1, H, W) 92 | 93 | return data_reduced, labels_reduced 94 | 95 | 96 | def remove_black(data, labels): 97 | clean_data, clean_labels = [], [] 98 | for i, frame in enumerate(labels): 99 | unique, counts = np.unique(frame, return_counts=True) 100 | if counts[0] / sum(counts) < .99: 101 | clean_labels.append(frame) 102 | clean_data.append(data[i]) 103 | return np.array(clean_data), np.array(clean_labels) 104 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import configparser # parse .ini 3 | from collections.abc import Mapping 4 | 5 | import yaml 6 | import os 7 | from ast import literal_eval 8 | 9 | 10 | class Settings(Mapping): 11 | def __init__(self, setting_file='settings.ini'): 12 | config = configparser.ConfigParser() 13 | config.read(setting_file) 14 | self.settings_dict = _parse_values(config) 15 | 16 | def __getitem__(self, key): 17 | return self.settings_dict[key] 18 | 19 | def __len__(self): 20 | return len(self.settings_dict) 21 | 22 | def __iter__(self): 23 | return self.settings_dict.items() 24 | 25 | 26 | def _parse_values(config): 27 | config_parsed = {} 28 | for section in config.sections(): 29 | config_parsed[section] = {} 30 | for key, value in config[section].items(): 31 | config_parsed[section][key] = ast.literal_eval(value) # safer than eval(): string to ori type 32 | return config_parsed 33 | 34 | 35 | # ----------------------------------------------------------------------------- 36 | # Functions for parsing args 37 | # ----------------------------------------------------------------------------- 38 | class CfgNode(dict): 39 | """ 40 | CfgNode represents an internal node in the configuration tree. It's a simple 41 | dict-like container that allows for attribute-based access to keys. 42 | """ 43 | 44 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 45 | # Recursively convert nested dictionaries in init_dict into CfgNodes 46 | init_dict = {} if init_dict is None else init_dict 47 | key_list = [] if key_list is None else key_list 48 | for k, v in init_dict.items(): 49 | if type(v) is dict: 50 | # Convert dict to CfgNode 51 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 52 | super(CfgNode, self).__init__(init_dict) 53 | 54 | def __getattr__(self, name): 55 | if name in self: 56 | return self[name] 57 | else: 58 | raise AttributeError(name) 59 | 60 | def __setattr__(self, name, value): 61 | self[name] = value 62 | 63 | def __str__(self): 64 | def _indent(s_, num_spaces): 65 | s = s_.split("\n") 66 | if len(s) == 1: 67 | return s_ 68 | first = s.pop(0) 69 | s = [(num_spaces * " ") + line for line in s] 70 | s = "\n".join(s) 71 | s = first + "\n" + s 72 | return s 73 | 74 | r = "" 75 | s = [] 76 | for k, v in sorted(self.items()): 77 | seperator = "\n" if isinstance(v, CfgNode) else " " 78 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 79 | attr_str = _indent(attr_str, 2) 80 | s.append(attr_str) 81 | r += "\n".join(s) 82 | return r 83 | 84 | def __repr__(self): 85 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 86 | 87 | 88 | def load_cfg_from_cfg_file(file): 89 | cfg = {} 90 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 91 | '{} is not a yaml file'.format(file) 92 | 93 | with open(file, 'r') as f: 94 | cfg_from_file = yaml.safe_load(f) 95 | 96 | # for key in cfg_from_file: 97 | # for k, v in cfg_from_file[key].items(): 98 | # cfg[k] = v 99 | 100 | cfg = CfgNode(cfg_from_file) 101 | return cfg 102 | 103 | def _decode_cfg_value(v): 104 | """Decodes a raw config value (e.g., from a yaml config files or command 105 | line argument) into a Python object. 106 | """ 107 | # All remaining processing is only applied to strings 108 | if not isinstance(v, str): 109 | return v 110 | # Try to interpret `v` as a: 111 | # string, number, tuple, list, dict, boolean, or None 112 | try: 113 | v = literal_eval(v) 114 | # The following two excepts allow v to pass through when it represents a 115 | # string. 116 | # 117 | # Longer explanation: 118 | # The type of v is always a string (before calling literal_eval), but 119 | # sometimes it *represents* a string and other times a data structure, like 120 | # a list. In the case that v represents a string, what we got back from the 121 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 122 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 123 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 124 | # will raise a SyntaxError. 125 | except ValueError: 126 | pass 127 | except SyntaxError: 128 | pass 129 | return v 130 | 131 | 132 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torch.optim.lr_scheduler import MultiStepLR 7 | # from dataloaders_medical.prostate import * 8 | # import few_shot_segmentor as fs 9 | import RAP as fs 10 | # import v9_only_pro as fs 11 | # import v9_only_stn as fs 12 | # import v9_pro_stn as fs 13 | # import v9_decoder as fs 14 | # import lba_cbam_v9_noskip as fs 15 | from settings import Settings 16 | from solver_us_pro import Solver 17 | # from solver import Solver 18 | # from solver_se import Solver 19 | from dataset import transform, dataset_us 20 | from utils import config 21 | 22 | torch.set_default_tensor_type('torch.FloatTensor') 23 | 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 25 | 26 | def get_configs(): 27 | parser = argparse.ArgumentParser(description='PyTorch Few Shot Semantic Segmentation') 28 | parser.add_argument('--config', type=str, default='./settings.yaml', help='config file') 29 | parser.add_argument('--mode', '-m', default='train', 30 | help='run mode, valid values are train and eval') 31 | parser.add_argument('--device', '-d', default=0, 32 | help='device to run on') 33 | args = parser.parse_args() 34 | assert args.config is not None 35 | cfg = config.load_cfg_from_cfg_file(args.config) 36 | return cfg, args # cfg.xxx 37 | 38 | 39 | def train(train_params, common_params, data_params, net_params): 40 | 41 | train_transform = [ 42 | # transform.RandScale([0.8, 1.25]), 43 | # transform.RandRotate([-10, 10], padding=mean, ignore_label=0), 44 | transform.RandomGaussianBlur(), 45 | # transform.RandomHorizontalFlip(), 46 | # transform.Crop([256, 256], crop_type='rand', padding=mean, ignore_label=0), 47 | transform.Resize_fy([256, 256]), 48 | transform.ToTensor(), 49 | transform.Scaling()] 50 | train_transform = transform.Compose(train_transform) 51 | train_data = dataset_us.SemData_US(shot=1, transform=train_transform, mode='train') 52 | train_sampler = None 53 | trainloader = torch.utils.data.DataLoader(train_data, batch_size=2, shuffle=(train_sampler is None), 54 | num_workers=1, pin_memory=True, sampler=train_sampler, 55 | drop_last=False) 56 | 57 | validationloader = trainloader 58 | 59 | 60 | 61 | final_model_path = os.path.join( 62 | common_params['save_model_dir'], 'last.pth.tar') 63 | 64 | few_shot_model = fs.RAP(net_params) 65 | 66 | 67 | solver = Solver(few_shot_model, 68 | device=common_params['device'], 69 | num_class=net_params['num_class'], 70 | optim_args={"lr": train_params['learning_rate'], 71 | "weight_decay": train_params['optim_weight_decay'], 72 | "momentum": train_params['momentum']}, 73 | model_name=common_params['model_name'], 74 | exp_name=train_params['exp_name'], 75 | labels=data_params['labels'], 76 | log_nth=train_params['log_nth'], 77 | num_epochs=train_params['num_epochs'], 78 | lr_scheduler_step_size=train_params['lr_scheduler_step_size'], 79 | lr_scheduler_gamma=train_params['lr_scheduler_gamma'], 80 | use_last_checkpoint=train_params['use_last_checkpoint'], 81 | log_dir=common_params['log_dir'], 82 | exp_dir=common_params['exp_dir']) 83 | 84 | solver.train(trainloader, validationloader) 85 | # solver.save_best_model(final_model_path) 86 | print("final model saved @ " + str(final_model_path)) 87 | 88 | 89 | 90 | 91 | 92 | if __name__ == '__main__': 93 | 94 | # parser = argparse.ArgumentParser() 95 | # parser.add_argument('--mode', '-m', default='train', 96 | # help='run mode, valid values are train and eval') 97 | # parser.add_argument('--device', '-d', default=0, 98 | # help='device to run on') 99 | # args = parser.parse_args() 100 | # 101 | # settings = Settings() # parse .ini 102 | # common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], settings[ 103 | # 'NETWORK'], settings['TRAINING'], settings['EVAL'] 104 | # 105 | # if args.device is not None: 106 | # common_params['device'] = args.device 107 | # 108 | # if args.mode == 'train': 109 | # train(train_params, common_params, data_params, net_params) 110 | # elif args.mode == 'eval': 111 | # pass 112 | # else: 113 | # raise ValueError( 114 | # 'Invalid value for mode. only support values are train and eval') 115 | 116 | cfgs, args = get_configs() 117 | print(cfgs.DATA) 118 | common_params, data_params, net_params, train_params, eval_params = cfgs.COMMON, cfgs.DATA, cfgs.NETWORK, \ 119 | cfgs.TRAINING, cfgs.EVAL 120 | 121 | if args.device is not None: 122 | common_params['device'] = args.device 123 | 124 | if args.mode == 'train': 125 | train(train_params, common_params, data_params, net_params) 126 | -------------------------------------------------------------------------------- /modules/se_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Squeeze and Excitation Module 3 | ***************************** 4 | 5 | Collection of squeeze and excitation classes where each can be inserted as a block into a neural network architechture 6 | 7 | 1. `Channel Squeeze and Excitation `_ 8 | 2. `Spatial Squeeze and Excitation `_ 9 | 3. `Channel and Spatial Squeeze and Excitation `_ 10 | 11 | """ 12 | 13 | from enum import Enum 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | 20 | class ChannelSELayer(nn.Module): 21 | """ 22 | Re-implementation of Squeeze-and-Excitation (SE) block described in: 23 | *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* 24 | 25 | """ 26 | 27 | def __init__(self, num_channels, reduction_ratio=2): 28 | """ 29 | 30 | :param num_channels: No of input channels 31 | :param reduction_ratio: By how much should the num_channels should be reduced 32 | """ 33 | super(ChannelSELayer, self).__init__() 34 | num_channels_reduced = num_channels // reduction_ratio 35 | self.reduction_ratio = reduction_ratio 36 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 37 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 38 | self.relu = nn.ReLU() 39 | self.sigmoid = nn.Sigmoid() 40 | 41 | def forward(self, input_tensor): 42 | """ 43 | 44 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 45 | :return: output tensor 46 | """ 47 | batch_size, num_channels, H, W = input_tensor.size() 48 | # Average along each channel 49 | squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2) 50 | 51 | # channel excitation 52 | fc_out_1 = self.relu(self.fc1(squeeze_tensor)) 53 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 54 | 55 | a, b = squeeze_tensor.size() 56 | output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1)) 57 | return output_tensor 58 | 59 | 60 | class SpatialSELayer(nn.Module): 61 | """ 62 | Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in: 63 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* 64 | """ 65 | 66 | def __init__(self, num_channels): 67 | """ 68 | 69 | :param num_channels: No of input channels 70 | """ 71 | super(SpatialSELayer, self).__init__() 72 | self.conv = nn.Conv2d(num_channels, 1, 1) 73 | self.sigmoid = nn.Sigmoid() 74 | 75 | def forward(self, input_tensor, weights=None): 76 | """ 77 | 78 | :param weights: weights for few shot learning 79 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 80 | :return: output_tensor 81 | """ 82 | # spatial squeeze 83 | batch_size, channel, a, b = input_tensor.size() 84 | 85 | if weights is not None: 86 | weights = torch.mean(weights, dim=0) 87 | weights = weights.view(1, channel, 1, 1) 88 | out = F.conv2d(input_tensor, weights) 89 | else: 90 | out = self.conv(input_tensor) 91 | squeeze_tensor = self.sigmoid(out) 92 | 93 | # spatial excitation 94 | # print(input_tensor.size(), squeeze_tensor.size()) 95 | squeeze_tensor = squeeze_tensor.view(batch_size, 1, a, b) 96 | output_tensor = torch.mul(input_tensor, squeeze_tensor) 97 | #output_tensor = torch.mul(input_tensor, squeeze_tensor) 98 | return output_tensor 99 | 100 | 101 | class ChannelSpatialSELayer(nn.Module): 102 | """ 103 | Re-implementation of concurrent spatial and channel squeeze & excitation: 104 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018, arXiv:1803.02579* 105 | """ 106 | 107 | def __init__(self, num_channels, reduction_ratio=2): 108 | """ 109 | 110 | :param num_channels: No of input channels 111 | :param reduction_ratio: By how much should the num_channels should be reduced 112 | """ 113 | super(ChannelSpatialSELayer, self).__init__() 114 | self.cSE = ChannelSELayer(num_channels, reduction_ratio) 115 | self.sSE = SpatialSELayer(num_channels) 116 | 117 | def forward(self, input_tensor): 118 | """ 119 | 120 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 121 | :return: output_tensor 122 | """ 123 | output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor)) 124 | return output_tensor 125 | 126 | 127 | class SELayer(Enum): 128 | """ 129 | Enum restricting the type of SE Blockes available. So that type checking can be adding when adding these blockes to 130 | a neural network:: 131 | 132 | if self.se_block_type == se.SELayer.CSE.value: 133 | self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) 134 | 135 | elif self.se_block_type == se.SELayer.SSE.value: 136 | self.SELayer = se.SpatialSELayer(params['num_filters']) 137 | 138 | elif self.se_block_type == se.SELayer.CSSE.value: 139 | self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) 140 | """ 141 | NONE = 'NONE' 142 | CSE = 'CSE' 143 | SSE = 'SSE' 144 | CSSE = 'CSSE' 145 | -------------------------------------------------------------------------------- /modules/voxelmorph_us.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions.normal import Normal 5 | 6 | 7 | class U_Network(nn.Module): 8 | def __init__(self, dim, enc_nf, dec_nf, bn=None, full_size=True): 9 | super(U_Network, self).__init__() 10 | self.bn = bn 11 | self.dim = dim 12 | self.enc_nf = enc_nf 13 | self.full_size = full_size 14 | self.vm2 = len(dec_nf) == 7 15 | # Encoder functions 16 | self.enc = nn.ModuleList() 17 | for i in range(len(enc_nf)): 18 | # prev_nf = 2 if i == 0 else enc_nf[i - 1] 19 | prev_nf = 6 if i == 0 else enc_nf[i - 1] 20 | self.enc.append(self.conv_block(dim, prev_nf, enc_nf[i], 4, 2, batchnorm=bn)) 21 | # Decoder functions 22 | self.dec = nn.ModuleList() 23 | self.dec.append(self.conv_block(dim, enc_nf[-1], dec_nf[0], batchnorm=bn)) # 1 24 | self.dec.append(self.conv_block(dim, dec_nf[0] * 2, dec_nf[1], batchnorm=bn)) # 2 25 | self.dec.append(self.conv_block(dim, dec_nf[1] * 2, dec_nf[2], batchnorm=bn)) # 3 26 | self.dec.append(self.conv_block(dim, dec_nf[2] + enc_nf[0], dec_nf[3], batchnorm=bn)) # 4 27 | self.dec.append(self.conv_block(dim, dec_nf[3], dec_nf[4], batchnorm=bn)) # 5 28 | 29 | if self.full_size: 30 | self.dec.append(self.conv_block(dim, dec_nf[4] + 6, dec_nf[5], batchnorm=bn)) 31 | # self.dec.append(self.conv_block(dim, dec_nf[4] + 2, dec_nf[5], batchnorm=bn)) 32 | if self.vm2: 33 | self.vm2_conv = self.conv_block(dim, dec_nf[5], dec_nf[6], batchnorm=bn) 34 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 35 | 36 | # One conv to get the flow field 37 | conv_fn = getattr(nn, 'Conv%dd' % dim) 38 | self.flow = conv_fn(dec_nf[-1], dim, kernel_size=3, padding=1) 39 | # Make flow weights + bias small. Not sure this is necessary. 40 | nd = Normal(0, 1e-5) 41 | self.flow.weight = nn.Parameter(nd.sample(self.flow.weight.shape)) 42 | self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape)) 43 | self.batch_norm = getattr(nn, "BatchNorm{0}d".format(dim))(3) 44 | 45 | def conv_block(self, dim, in_channels, out_channels, kernel_size=3, stride=1, padding=1, batchnorm=False): 46 | conv_fn = getattr(nn, "Conv{0}d".format(dim)) 47 | bn_fn = getattr(nn, "BatchNorm{0}d".format(dim)) 48 | if batchnorm: 49 | layer = nn.Sequential( 50 | conv_fn(in_channels, out_channels, kernel_size, stride=stride, padding=padding), 51 | bn_fn(out_channels), 52 | nn.LeakyReLU(0.2)) 53 | else: 54 | layer = nn.Sequential( 55 | conv_fn(in_channels, out_channels, kernel_size, stride=stride, padding=padding), 56 | nn.LeakyReLU(0.2)) 57 | return layer 58 | 59 | def forward(self, src, tgt): 60 | x = torch.cat([src, tgt], dim=1) 61 | # Get encoder activations 62 | x_enc = [x] 63 | for i, l in enumerate(self.enc): 64 | x = l(x_enc[-1]) 65 | x_enc.append(x) 66 | # Three conv + upsample + concatenate series 67 | y = x_enc[-1] 68 | for i in range(3): 69 | y = self.dec[i](y) 70 | y = self.upsample(y) 71 | y = torch.cat([y, x_enc[-(i + 2)]], dim=1) 72 | # Two convs at full_size/2 res 73 | y = self.dec[3](y) 74 | y = self.dec[4](y) 75 | # Upsample to full res, concatenate and conv 76 | if self.full_size: 77 | y = self.upsample(y) 78 | y = torch.cat([y, x_enc[0]], dim=1) 79 | y = self.dec[5](y) 80 | # Extra conv for vm2 81 | if self.vm2: 82 | y = self.vm2_conv(y) 83 | flow = self.flow(y) 84 | if self.bn: 85 | flow = self.batch_norm(flow) 86 | return flow 87 | 88 | 89 | class SpatialTransformer(nn.Module): 90 | def __init__(self, size, mode='bilinear'): 91 | super(SpatialTransformer, self).__init__() 92 | # Create sampling grid 93 | vectors = [torch.arange(0, s) for s in size] 94 | grids = torch.meshgrid(vectors) 95 | grid = torch.stack(grids) # y, x, z 96 | grid = torch.unsqueeze(grid, 0) # add batch 97 | grid = grid.type(torch.FloatTensor) 98 | self.register_buffer('grid', grid) 99 | 100 | self.mode = mode 101 | 102 | def forward(self, src, flow): 103 | new_locs = self.grid + flow 104 | shape = flow.shape[2:] 105 | 106 | # Need to normalize grid values to [-1, 1] for resampler 107 | for i in range(len(shape)): 108 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 109 | 110 | if len(shape) == 2: 111 | new_locs = new_locs.permute(0, 2, 3, 1) 112 | new_locs = new_locs[..., [1, 0]] 113 | elif len(shape) == 3: 114 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 115 | new_locs = new_locs[..., [2, 1, 0]] 116 | 117 | return F.grid_sample(src, new_locs, mode=self.mode) 118 | 119 | if __name__ == '__main__': 120 | unet = U_Network(2, [16, 32, 32, 32], [32, 32, 32, 32, 8, 8]) 121 | src = torch.ones((1, 3, 256,256)) 122 | trg = torch.ones((1, 3, 256, 256)) 123 | mask = torch.ones((1, 1, 256, 256)) 124 | flow = unet(src, trg) 125 | print(flow.shape) 126 | 127 | stn = SpatialTransformer((256, 256)) 128 | print(stn(mask, flow).shape) 129 | 130 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Functions for parsing args 3 | # ----------------------------------------------------------------------------- 4 | import yaml 5 | import os 6 | from ast import literal_eval 7 | import copy 8 | 9 | 10 | class CfgNode(dict): 11 | """ 12 | CfgNode represents an internal node in the configuration tree. It's a simple 13 | dict-like container that allows for attribute-based access to keys. 14 | """ 15 | 16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 17 | # Recursively convert nested dictionaries in init_dict into CfgNodes 18 | init_dict = {} if init_dict is None else init_dict 19 | key_list = [] if key_list is None else key_list 20 | for k, v in init_dict.items(): 21 | if type(v) is dict: 22 | # Convert dict to CfgNode 23 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 24 | super(CfgNode, self).__init__(init_dict) 25 | 26 | def __getattr__(self, name): 27 | if name in self: 28 | return self[name] 29 | else: 30 | raise AttributeError(name) 31 | 32 | def __setattr__(self, name, value): 33 | self[name] = value 34 | 35 | def __str__(self): 36 | def _indent(s_, num_spaces): 37 | s = s_.split("\n") 38 | if len(s) == 1: 39 | return s_ 40 | first = s.pop(0) 41 | s = [(num_spaces * " ") + line for line in s] 42 | s = "\n".join(s) 43 | s = first + "\n" + s 44 | return s 45 | 46 | r = "" 47 | s = [] 48 | for k, v in sorted(self.items()): 49 | seperator = "\n" if isinstance(v, CfgNode) else " " 50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 51 | attr_str = _indent(attr_str, 2) 52 | s.append(attr_str) 53 | r += "\n".join(s) 54 | return r 55 | 56 | def __repr__(self): 57 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 58 | 59 | 60 | def load_cfg_from_cfg_file(file): 61 | cfg = {} 62 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 63 | '{} is not a yaml file'.format(file) 64 | 65 | with open(file, 'r') as f: 66 | cfg_from_file = yaml.safe_load(f) 67 | 68 | # for key in cfg_from_file: 69 | # for k, v in cfg_from_file[key].items(): 70 | # cfg[k] = v 71 | 72 | cfg = CfgNode(cfg_from_file) 73 | return cfg 74 | 75 | # 76 | # def merge_cfg_from_list(cfg, cfg_list): 77 | # new_cfg = copy.deepcopy(cfg) 78 | # assert len(cfg_list) % 2 == 0 79 | # for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 80 | # subkey = full_key.split('.')[-1] 81 | # assert subkey in cfg, 'Non-existent key: {}'.format(full_key) 82 | # value = _decode_cfg_value(v) 83 | # value = _check_and_coerce_cfg_value_type( 84 | # value, cfg[subkey], subkey, full_key 85 | # ) 86 | # setattr(new_cfg, subkey, value) 87 | # 88 | # return new_cfg 89 | 90 | 91 | def _decode_cfg_value(v): 92 | """Decodes a raw config value (e.g., from a yaml config files or command 93 | line argument) into a Python object. 94 | """ 95 | # All remaining processing is only applied to strings 96 | if not isinstance(v, str): 97 | return v 98 | # Try to interpret `v` as a: 99 | # string, number, tuple, list, dict, boolean, or None 100 | try: 101 | v = literal_eval(v) 102 | # The following two excepts allow v to pass through when it represents a 103 | # string. 104 | # 105 | # Longer explanation: 106 | # The type of v is always a string (before calling literal_eval), but 107 | # sometimes it *represents* a string and other times a data structure, like 108 | # a list. In the case that v represents a string, what we got back from the 109 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 110 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 111 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 112 | # will raise a SyntaxError. 113 | except ValueError: 114 | pass 115 | except SyntaxError: 116 | pass 117 | return v 118 | 119 | 120 | # def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 121 | # """Checks that `replacement`, which is intended to replace `original` is of 122 | # the right type. The type is correct if it matches exactly or is one of a few 123 | # cases in which the type can be easily coerced. 124 | # """ 125 | # original_type = type(original) 126 | # replacement_type = type(replacement) 127 | # 128 | # # The types must match (with some exceptions) 129 | # if replacement_type == original_type: 130 | # return replacement 131 | # 132 | # # Cast replacement from from_type to to_type if the replacement and original 133 | # # types match from_type and to_type 134 | # def conditional_cast(from_type, to_type): 135 | # if replacement_type == from_type and original_type == to_type: 136 | # return True, to_type(replacement) 137 | # else: 138 | # return False, None 139 | # 140 | # # Conditionally casts 141 | # # list <-> tuple 142 | # casts = [(tuple, list), (list, tuple)] 143 | # # For py2: allow converting from str (bytes) to a unicode string 144 | # try: 145 | # casts.append((str, unicode)) # noqa: F821 146 | # except Exception: 147 | # pass 148 | # 149 | # for (from_type, to_type) in casts: 150 | # converted, converted_value = conditional_cast(from_type, to_type) 151 | # if converted: 152 | # return converted_value 153 | # 154 | # raise ValueError( 155 | # "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 156 | # "key: {}".format( 157 | # original_type, replacement_type, original, replacement, full_key 158 | # ) 159 | # ) 160 | # 161 | # 162 | # def _assert_with_logging(cond, msg): 163 | # if not cond: 164 | # logger.debug(msg) 165 | # assert cond, msg 166 | # 167 | -------------------------------------------------------------------------------- /utils/convert_h5.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert to h5 utility. 3 | Sample command to create new dataset - python utils/convert_h5.py -dd /home/masterthesis/shayan/nas_drive/Data_Neuro/OASISchallenge/FS -ld /home/masterthesis/shayan/nas_drive/Data_Neuro/OASISchallenge -trv datasets/train_volumes.txt -tev datasets/test_volumes.txt -rc Neo -o COR -df datasets/MALC/coronal 4 | 5 | - python utils/convert_h5.py -dd /home/masterthesis/shayan/nas_drive/Data_Neuro/IXI/IXI_FS -ld /home/masterthesis/shayan/nas_drive/Data_Neuro/IXI/IXI_FS -ds 98,2 -rc FS -o COR -df datasets/IXI/coronal 6 | 7 | - python3.6 utils/convert_h5.py -dd /home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/SilverCorpus -ld /home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/SilverCorpus -trv datasets/test_volumes_silver.txt -tev datasets/test_volumes_silver.txt -rc WholeBody -o AXI -df datasets/silver_corpus 8 | """ 9 | 10 | import argparse 11 | import os 12 | 13 | import h5py 14 | import numpy as np 15 | 16 | import common_utils 17 | import data_utils as du 18 | import preprocessor 19 | 20 | 21 | def apply_split(data_split, data_dir, label_dir): 22 | file_paths = du.load_file_paths(data_dir, label_dir) 23 | print("Total no of volumes to process : %d" % len(file_paths)) 24 | train_ratio, test_ratio = data_split.split(",") 25 | train_len = int((int(train_ratio) / 100) * len(file_paths)) 26 | train_idx = np.random.choice(len(file_paths), train_len, replace=False) 27 | test_idx = np.array([i for i in range(len(file_paths)) if i not in train_idx]) 28 | train_file_paths = [file_paths[i] for i in train_idx] 29 | test_file_paths = [file_paths[i] for i in test_idx] 30 | return train_file_paths, test_file_paths 31 | 32 | 33 | def _write_h5(data, label, class_weights, weights, f, mode): 34 | no_slices, H, W = data[0].shape 35 | with h5py.File(f[mode]['data'], "w") as data_handle: 36 | data_handle.create_dataset("data", data=np.concatenate(data).reshape((-1, H, W))) 37 | with h5py.File(f[mode]['label'], "w") as label_handle: 38 | label_handle.create_dataset("label", data=np.concatenate(label).reshape((-1, H, W))) 39 | with h5py.File(f[mode]['weights'], "w") as weights_handle: 40 | weights_handle.create_dataset("weights", data=np.concatenate(weights)) 41 | with h5py.File(f[mode]['class_weights'], "w") as class_weights_handle: 42 | class_weights_handle.create_dataset("class_weights", data=np.concatenate( 43 | class_weights).reshape((-1, H, W))) 44 | 45 | 46 | def convert_h5(data_dir, label_dir, data_split, train_volumes, test_volumes, f, remap_config='Neo', 47 | orientation=preprocessor.ORIENTATION['coronal']): 48 | # Data splitting 49 | if data_split: 50 | train_file_paths, test_file_paths = apply_split(data_split, data_dir, label_dir) 51 | elif train_volumes and test_volumes: 52 | train_file_paths = du.load_file_paths_brain(data_dir, label_dir, train_volumes) 53 | test_file_paths = du.load_file_paths_brain(data_dir, label_dir, test_volumes) 54 | else: 55 | raise ValueError('You must either provide the split ratio or a train, train dataset list') 56 | 57 | print("Train dataset size: %d, Test dataset size: %d" % (len(train_file_paths), len(test_file_paths))) 58 | # loading,pre-processing and writing train data 59 | print("===Train data===") 60 | data_train, label_train, class_weights_train, weights_train = du.load_dataset(train_file_paths, 61 | orientation, 62 | remap_config=remap_config, 63 | return_weights=True, 64 | reduce_slices=True, 65 | remove_black=True) 66 | 67 | _write_h5(data_train, label_train, class_weights_train, weights_train, f, mode='train') 68 | 69 | # loading,pre-processing and writing test data 70 | print("===Test data===") 71 | data_test, label_test, class_weights_test, weights_test = du.load_dataset(test_file_paths, 72 | orientation, 73 | remap_config=remap_config, 74 | return_weights=True, 75 | reduce_slices=True, 76 | remove_black=True) 77 | 78 | _write_h5(data_test, label_test, class_weights_test, weights_test, f, mode='test') 79 | 80 | 81 | if __name__ == "__main__": 82 | print("* Start *") 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--data_dir', '-dd', required=True, 85 | help='Base directory of the data folder. This folder should contain one folder per volume.') 86 | parser.add_argument('--label_dir', '-ld', required=True, 87 | help='Base directory of all the label files. This folder should have one file per volumn with same name as the corresponding volumn folder name inside data_dir') 88 | parser.add_argument('--data_split', '-ds', required=False, 89 | help='Ratio to split data randomly into train and test. input e.g. 80,20') 90 | parser.add_argument('--train_volumes', '-trv', required=False, 91 | help='Path to a text file containing the list of volumes to be used for training') 92 | parser.add_argument('--test_volumes', '-tev', required=False, 93 | help='Path to a text file containing the list of volumes to be used for testing') 94 | parser.add_argument('--remap_config', '-rc', required=True, help='Valid options are "FS" and "Neo"') 95 | parser.add_argument('--orientation', '-o', required=True, help='Valid options are COR, AXI, SAG') 96 | parser.add_argument('--destination_folder', '-df', required=True, help='Path where to generate the h5 files') 97 | 98 | args = parser.parse_args() 99 | 100 | common_utils.create_if_not(args.destination_folder) 101 | 102 | f = { 103 | 'train': { 104 | "data": os.path.join(args.destination_folder, "Data_train.h5"), 105 | "label": os.path.join(args.destination_folder, "Label_train.h5"), 106 | "weights": os.path.join(args.destination_folder, "Weight_train.h5"), 107 | "class_weights": os.path.join(args.destination_folder, "Class_Weight_train.h5"), 108 | }, 109 | 'test': { 110 | "data": os.path.join(args.destination_folder, "Data_test.h5"), 111 | "label": os.path.join(args.destination_folder, "Label_test.h5"), 112 | "weights": os.path.join(args.destination_folder, "Weight_test.h5"), 113 | "class_weights": os.path.join(args.destination_folder, "Class_Weight_test.h5") 114 | } 115 | } 116 | 117 | convert_h5(args.data_dir, args.label_dir, args.data_split, args.train_volumes, args.test_volumes, f, 118 | args.remap_config, 119 | args.orientation) 120 | print("* Finish *") 121 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | import scipy.io as sio 8 | import utils.preprocessor as preprocessor 9 | import nibabel as nb 10 | import math 11 | from torchvision import transforms 12 | 13 | 14 | # import utils.preprocessor as preprocessor 15 | 16 | 17 | # transform_train = transforms.Compose([ 18 | # transforms.RandomCrop((480, 220), padding=(32, 36)), 19 | # transforms.ToTensor(), 20 | # ]) 21 | 22 | 23 | class ImdbData(data.Dataset): 24 | def __init__(self, X, y, w=None, transforms=None): 25 | # TODO:Improve later 26 | # lung_mask_1 = (y == 4) 27 | # lung_mask_2 = (y == 5) 28 | # lung_mask = 0.5 * (lung_mask_1 + lung_mask_2) 29 | # X = X + lung_mask 30 | 31 | self.X = X if len(X.shape) == 4 else X[:, np.newaxis, :, :] 32 | self.y = y 33 | self.w = w 34 | self.transforms = transforms 35 | 36 | def __getitem__(self, index): 37 | img = torch.from_numpy(self.X[index]) 38 | label = torch.from_numpy(self.y[index]) 39 | if self.w is not None: 40 | weight = torch.from_numpy(self.w[index]) 41 | return img, label, weight 42 | else: 43 | return img, label 44 | 45 | def __len__(self): 46 | return len(self.y) 47 | 48 | 49 | def get_imdb_dataset(data_params): 50 | data_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_data_file']), 'r') 51 | label_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_label_file']), 'r') 52 | class_weight_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_class_weights_file']), 'r') 53 | weight_train = h5py.File(os.path.join(data_params['data_dir'], data_params['train_weights_file']), 'r') 54 | 55 | data_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_data_file']), 'r') 56 | label_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_label_file']), 'r') 57 | class_weight_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_class_weights_file']), 'r') 58 | weight_test = h5py.File(os.path.join(data_params['data_dir'], data_params['test_weights_file']), 'r') 59 | 60 | return (ImdbData(data_train['data'][()], label_train['label'][()], class_weight_train['class_weights'][()]), 61 | ImdbData(data_test['data'][()], label_test['label'][()], class_weight_test['class_weights'][()])) 62 | 63 | 64 | def load_dataset(file_paths, 65 | orientation, 66 | remap_config, 67 | return_weights=False, 68 | reduce_slices=False, 69 | remove_black=False): 70 | print("Loading and preprocessing data...") 71 | volume_list, labelmap_list, headers, class_weights_list, weights_list = [], [], [], [], [] 72 | 73 | for file_path in file_paths: 74 | volume, labelmap, class_weights, weights = load_and_preprocess(file_path, orientation, 75 | remap_config=remap_config, 76 | reduce_slices=reduce_slices, 77 | remove_black=remove_black, 78 | return_weights=return_weights) 79 | 80 | volume_list.append(volume) 81 | labelmap_list.append(labelmap) 82 | 83 | if return_weights: 84 | class_weights_list.append(class_weights) 85 | weights_list.append(weights) 86 | 87 | print("#", end='', flush=True) 88 | print("100%", flush=True) 89 | if return_weights: 90 | return volume_list, labelmap_list, class_weights_list, weights_list 91 | else: 92 | return volume_list, labelmap_list 93 | 94 | 95 | def load_and_preprocess(file_path, orientation, remap_config, reduce_slices=False, 96 | remove_black=False, 97 | return_weights=False): 98 | print(file_path) 99 | volume, labelmap = load_data_mat(file_path, orientation) 100 | 101 | volume, labelmap, class_weights, weights = preprocess(volume, labelmap, remap_config=remap_config, 102 | reduce_slices=reduce_slices, 103 | remove_black=remove_black, 104 | return_weights=return_weights) 105 | return volume, labelmap, class_weights, weights 106 | 107 | 108 | def load_data(file_path, orientation): 109 | print(file_path[0], file_path[1]) 110 | volume_nifty, labelmap_nifty = nb.load(file_path[0]), nb.load(file_path[1]) 111 | volume, labelmap = volume_nifty.get_fdata(), labelmap_nifty.get_fdata() 112 | volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume)) 113 | volume, labelmap = preprocessor.rotate_orientation(volume, labelmap, orientation) 114 | return volume, labelmap, volume_nifty.header 115 | 116 | 117 | def load_data_mat(file_path, orientation): 118 | data = sio.loadmat(file_path) 119 | volume = data['DatVol'] 120 | labelmap = data['LabVol'] 121 | volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume)) 122 | volume, labelmap = preprocessor.rotate_orientation(volume, labelmap, orientation) 123 | return volume, labelmap 124 | 125 | 126 | def preprocess(volume, labelmap, remap_config, reduce_slices=False, remove_black=False, return_weights=False): 127 | if reduce_slices: 128 | volume, labelmap = preprocessor.reduce_slices(volume, labelmap) 129 | 130 | if remap_config: 131 | labelmap = preprocessor.remap_labels(labelmap, remap_config) 132 | if remove_black: 133 | volume, labelmap = preprocessor.remove_black(volume, labelmap) 134 | 135 | if return_weights: 136 | class_weights, weights = preprocessor.estimate_weights_mfb(labelmap) 137 | return volume, labelmap, class_weights, weights 138 | else: 139 | return volume, labelmap, None, None 140 | 141 | 142 | def load_file_paths_brain(data_dir, label_dir, volumes_txt_file=None): 143 | """ 144 | This function returns the file paths combined as a list where each element is a 2 element tuple, 0th being data and 1st being label. 145 | It should be modified to suit the need of the project 146 | :param data_dir: Directory which contains the data files 147 | :param label_dir: Directory which contains the label files 148 | :param volumes_txt_file: (Optional) Path to the a csv file, when provided only these data points will be read 149 | :return: list of file paths as string 150 | """ 151 | 152 | volume_exclude_list = ['IXI290', 'IXI423'] 153 | if volumes_txt_file: 154 | with open(volumes_txt_file) as file_handle: 155 | volumes_to_use = file_handle.read().splitlines() 156 | else: 157 | volumes_to_use = [name for name in os.listdir(data_dir) if name not in volume_exclude_list] 158 | 159 | file_paths = [ 160 | [os.path.join(data_dir, vol, 'mri/orig.mgz'), os.path.join(label_dir, vol+'_glm.mgz')] 161 | for 162 | vol in volumes_to_use] 163 | return file_paths 164 | 165 | 166 | def load_file_paths(data_dir, label_dir, volumes_txt_file=None): 167 | """ 168 | This function returns the file paths combined as a list where each element is a 2 element tuple, 0th being data and 1st being label. 169 | It should be modified to suit the need of the project 170 | :param data_dir: Directory which contains the data files 171 | :param label_dir: Directory which contains the label files 172 | :param volumes_txt_file: (Optional) Path to the a csv file, when provided only these data points will be read 173 | :return: list of file paths as string 174 | """ 175 | 176 | with open(volumes_txt_file) as file_handle: 177 | volumes_to_use = file_handle.read().splitlines() 178 | file_paths = [os.path.join(data_dir, vol) for vol in volumes_to_use] 179 | 180 | return file_paths 181 | 182 | 183 | def split_batch(X, y, query_label): 184 | batch_size = len(X) // 2 185 | input1 = X[0:batch_size, :, :, :] 186 | input2 = X[batch_size:, :, :, :] 187 | y1 = (y[0:batch_size, :, :] == query_label).type(torch.FloatTensor) 188 | y2 = (y[batch_size:, :, :] == query_label).type(torch.LongTensor) 189 | # y2 = (y[batch_size:, :, :] == query_label).type(torch.FloatTensor) 190 | # y2 = y2.unsqueeze(1) 191 | # Why? 192 | # input1 = torch.cat([input1, y1.unsqueeze(1)], dim=1) 193 | 194 | return input1, input2, y1, y2 195 | -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import re 4 | import shutil 5 | from textwrap import wrap 6 | 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from torch.utils.tensorboard import SummaryWriter 11 | # from tensorboard import SummaryWriter 12 | import torch 13 | import utils.evaluator as eu 14 | import logging 15 | 16 | logging.basicConfig( 17 | level=logging.INFO, 18 | format="%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s") 19 | 20 | plt.switch_backend('agg') 21 | plt.axis('scaled') 22 | 23 | 24 | # TODO: Add custom phase names 25 | class LogWriter(object): 26 | def __init__(self, num_class, log_dir_name, exp_name, use_last_checkpoint=False, labels=None, 27 | cm_cmap=plt.cm.Blues): 28 | self.num_class = num_class 29 | train_log_path, val_log_path = os.path.join(log_dir_name, exp_name, "train"), os.path.join(log_dir_name, 30 | exp_name, 31 | "val") 32 | if not use_last_checkpoint: 33 | if os.path.exists(train_log_path): 34 | shutil.rmtree(train_log_path) 35 | if os.path.exists(val_log_path): 36 | shutil.rmtree(val_log_path) 37 | 38 | self.writer = { 39 | 'train': SummaryWriter(log_dir=train_log_path, comment='Train Summary', flush_secs=30), 40 | 'val': SummaryWriter(log_dir=val_log_path, comment='Val Summary', flush_secs=30) 41 | } 42 | self.curr_iter = 1 43 | self.cm_cmap = cm_cmap 44 | self.labels = self.beautify_labels(labels) 45 | self.logger = logging.getLogger() 46 | file_handler = logging.FileHandler("{0}/{1}.log".format(os.path.join(log_dir_name, exp_name), "console_logs")) 47 | # console_handler = logging.StreamHandler() 48 | self.logger.addHandler(file_handler) 49 | # self.logger.addHandler(console_handler) 50 | 51 | def log(self, text, phase='train'): 52 | self.logger.info(text) 53 | 54 | def loss_per_iter(self, loss_value, i_batch, current_iteration): 55 | self.log('[Iteration : ' + str(i_batch) + '] Loss -> ' + str(loss_value)) 56 | self.writer['train'].add_scalar('loss/per_iteration', loss_value, current_iteration) 57 | 58 | def loss_per_epoch(self, loss_arr, phase, epoch): 59 | if phase == 'train': 60 | loss = loss_arr[-1] 61 | else: 62 | loss = np.mean(loss_arr) 63 | self.writer[phase].add_scalar('loss/per_epoch', loss, epoch) 64 | self.log('epoch ' + phase + ' loss = ' + str(loss)) 65 | 66 | return loss 67 | 68 | def cm_per_epoch(self, phase, output, correct_labels, epoch): 69 | 70 | self.log("Confusion Matrix...") 71 | _, cm = eu.dice_confusion_matrix(output, correct_labels, self.num_class, mode='train') 72 | self.plot_cm('confusion_matrix', phase, cm, epoch) 73 | self.log("DONE") 74 | 75 | def plot_cm(self, caption, phase, cm, step=None): 76 | fig = matplotlib.figure.Figure(figsize=(8, 8), dpi=180, facecolor='w', edgecolor='k') 77 | ax = fig.add_subplot(1, 1, 1) 78 | 79 | ax.imshow(cm, interpolation='nearest', cmap=self.cm_cmap) 80 | ax.set_xlabel('Predicted', fontsize=7) 81 | ax.set_xticks(np.arange(self.num_class)) 82 | c = ax.set_xticklabels(self.labels, fontsize=4, rotation=-90, ha='center') 83 | ax.xaxis.set_label_position('bottom') 84 | ax.xaxis.tick_bottom() 85 | 86 | ax.set_ylabel('True Label', fontsize=7) 87 | ax.set_yticks(np.arange(self.num_class)) 88 | ax.set_yticklabels(self.labels, fontsize=4, va='center') 89 | ax.yaxis.set_label_position('left') 90 | ax.yaxis.tick_left() 91 | 92 | thresh = cm.max() / 2. 93 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 94 | ax.text(j, i, format(cm[i, j], '.2f') if cm[i, j] != 0 else '.', horizontalalignment="center", fontsize=6, 95 | verticalalignment='center', color="white" if cm[i, j] > thresh else "black") 96 | 97 | fig.set_tight_layout(True) 98 | np.set_printoptions(precision=2) 99 | if step: 100 | self.writer[phase].add_figure(caption + '/' + phase, fig, step) 101 | else: 102 | self.writer[phase].add_figure(caption + '/' + phase, fig) 103 | 104 | def dice_score_per_epoch(self, phase, output, correct_labels, epoch): 105 | self.log("Dice Score...") 106 | 107 | # TODO: multiclass vs binary 108 | ds = eu.dice_score_binary(output, correct_labels, self.num_class, phase) 109 | self.log('Dice score is ' + str(ds)) 110 | # self.plot_dice_score(phase, 'dice_score_per_epoch', ds, 'Dice Score', epoch) 111 | 112 | self.log("DONE") 113 | return ds 114 | 115 | def dice_score_per_epoch_segmentor(self, phase, output, correct_labels, epoch): 116 | self.log("Dice Score...") 117 | 118 | # TODO: multiclass vs binary 119 | ds = eu.dice_score_perclass(output, correct_labels, self.num_class, mode=phase) 120 | ds_mean = torch.mean(ds[1:]) 121 | self.log('Dice score is ' + str(ds)) 122 | self.log('Dice score mean ' + str(ds_mean)) 123 | self.plot_dice_score(phase, 'dice_score_per_epoch', ds, 'Dice Score', epoch) 124 | self.log("DONE") 125 | return ds_mean 126 | 127 | def plot_dice_score(self, phase, caption, ds, title, step=None): 128 | fig = matplotlib.figure.Figure(figsize=(8, 6), dpi=180, facecolor='w', edgecolor='k') 129 | ax = fig.add_subplot(1, 1, 1) 130 | ax.set_xlabel(title, fontsize=10) 131 | ax.xaxis.set_label_position('top') 132 | ax.bar(np.arange(self.num_class), ds) 133 | ax.set_xticks(np.arange(self.num_class)) 134 | c = ax.set_xticklabels(self.labels, fontsize=6, rotation=-90, ha='center') 135 | ax.xaxis.tick_bottom() 136 | if step: 137 | self.writer[phase].add_figure(caption + '/' + phase, fig, step) 138 | else: 139 | self.writer[phase].add_figure(caption + '/' + phase, fig) 140 | 141 | def plot_eval_box_plot(self, caption, class_dist, title): 142 | fig = matplotlib.figure.Figure(figsize=(8, 6), dpi=180, facecolor='w', edgecolor='k') 143 | ax = fig.add_subplot(1, 1, 1) 144 | ax.set_xlabel(title, fontsize=10) 145 | ax.xaxis.set_label_position('top') 146 | ax.boxplot(class_dist) 147 | ax.set_xticks(np.arange(self.num_class)) 148 | c = ax.set_xticklabels(self.labels, fontsize=6, rotation=-90, ha='center') 149 | ax.xaxis.tick_bottom() 150 | self.writer['val'].add_figure(caption, fig) 151 | 152 | def image_per_epoch_segmentor(self, prediction, ground_truth, phase, epoch): 153 | self.log("Sample Images...") 154 | ncols = 2 155 | nrows = len(prediction) 156 | 157 | fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 20)) 158 | 159 | for i in range(nrows): 160 | ax[i][0].imshow(torch.squeeze(ground_truth[i]), cmap='CMRmap', vmin=0, vmax=self.num_class - 1) 161 | ax[i][0].set_title("Ground Truth", fontsize=10, color="blue") 162 | ax[i][0].axis('off') 163 | ax[i][1].imshow(torch.squeeze(prediction[i]), cmap='CMRmap', vmin=0, vmax=self.num_class - 1) 164 | ax[i][1].set_title("Predicted", fontsize=10, color="blue") 165 | ax[i][1].axis('off') 166 | fig.set_tight_layout(True) 167 | self.writer[phase].add_figure('sample_prediction/' + phase, fig, epoch) 168 | self.log('DONE') 169 | 170 | def image_per_epoch(self, prediction, ground_truth, phase, epoch, additional_image=None): 171 | self.log("Sample Images...") 172 | ncols = 3 173 | nrows = len(prediction) 174 | 175 | fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 20)) 176 | for i in range(nrows): 177 | ax[i][0].imshow(additional_image[i].squeeze().transpose(0,2).cpu(), cmap='gray', vmin=0, vmax=5) 178 | ax[i][0].set_title("Input Image", fontsize=10, color="blue") 179 | ax[i][0].axis('off') 180 | ax[i][1].imshow(ground_truth[i].squeeze(), cmap='jet', vmin=0, vmax=5) 181 | ax[i][1].set_title("Ground Truth", fontsize=10, color="blue") 182 | ax[i][1].axis('off') 183 | ax[i][2].imshow(prediction[i].squeeze(), cmap='jet', vmin=0, vmax=5) 184 | ax[i][2].set_title("Predicted", fontsize=10, color="blue") 185 | ax[i][2].axis('off') 186 | fig.set_tight_layout(True) 187 | self.writer[phase].add_figure('sample_prediction/' + phase, fig, epoch) 188 | self.log('DONE') 189 | 190 | def graph(self, model, X): 191 | self.writer['train'].add_graph(model, X) 192 | 193 | def close(self): 194 | self.writer['train'].close() 195 | self.writer['val'].close() 196 | 197 | def beautify_labels(self, labels): 198 | classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels] 199 | classes = ['\n'.join(wrap(l, 40)) for l in classes] 200 | return classes 201 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | import nrrd 5 | import cv2 6 | import torch.nn.functional as F 7 | import random 8 | import os 9 | import pandas as pd 10 | import numpy 11 | import copy 12 | import RAP as fs 13 | from settings import Settings 14 | import shutil 15 | 16 | 17 | 18 | support_path = r'your dir' 19 | global_path = r'your dir' 20 | 21 | 22 | def MR_normalize(x_in): 23 | # return (x_in - x_in.mean()) / x_in.std() 24 | # return (x_in - np.min(x_in)) / (np.max(x_in) - np.min(x_in)) 25 | return x_in/255 26 | 27 | def ts_main(ckpt_path): 28 | 29 | settings = Settings() # parse .ini 30 | common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], settings[ 31 | 'NETWORK'], settings['TRAINING'], settings['EVAL'] 32 | 33 | model = fs.RAP(net_params) 34 | 35 | model.load_state_dict(torch.load(ckpt_path, map_location='cpu')['state_dict']) 36 | model.cuda() 37 | model.eval() 38 | 39 | # some params 40 | query_root = global_path 41 | shot = 5 42 | size = 256 43 | all_img_path = glob.glob(query_root+'/*_im.nrrd') 44 | all_support_path = glob.glob(support_path+'/*_im.nrrd') 45 | 46 | save_path = './prediction_la_dice_1000' 47 | if not os.path.exists(save_path): 48 | os.mkdir(save_path) 49 | else: 50 | shutil.rmtree(save_path) 51 | os.mkdir(save_path) 52 | 53 | # data flow and pred 54 | with torch.no_grad(): 55 | for pid in all_img_path: 56 | print('qid:', pid) 57 | query_name = pid.split('\\')[-1].split('.')[0] 58 | # if query_name != '08-63 WANGQIAN_im': 59 | # continue 60 | img_query = nrrd.read(pid)[0].transpose(2, 1, 0) 61 | mask_query = nrrd.read(pid.replace('im', 'm'))[0].transpose(2, 1, 0) 62 | 63 | tmp_support_path = copy.deepcopy(all_support_path) 64 | try: 65 | tmp_support_path.remove(pid) 66 | except: 67 | pass 68 | 69 | pred_mask = [] 70 | tmp_sprior = [] 71 | sp_mask = [] 72 | sp_slices = 3 73 | for query_slice in range(img_query.shape[0]): 74 | if sp_slices == 1: 75 | input = cv2.resize(img_query[query_slice], dsize=(size, size), interpolation=cv2.INTER_LINEAR) 76 | input = MR_normalize(input) 77 | # 3 or 1 channel input 78 | # input = torch.from_numpy(np.repeat(input[np.newaxis, np.newaxis, ...], 3, 1)).float().cuda() 79 | query = torch.from_numpy(input[np.newaxis, np.newaxis, ...]).float().cuda() 80 | 81 | else: 82 | # sp_slices == 3 83 | input = cv2.resize(img_query[query_slice], dsize=(size, size), interpolation=cv2.INTER_LINEAR) 84 | input = MR_normalize(input) 85 | query = torch.from_numpy(input[np.newaxis, np.newaxis, ...]).float() 86 | if query_slice == 0: 87 | query_pre = query 88 | else: 89 | input = cv2.resize(img_query[query_slice-1], dsize=(size, size), interpolation=cv2.INTER_LINEAR) 90 | input = MR_normalize(input) 91 | query_pre = torch.from_numpy(input[np.newaxis, np.newaxis, ...]).float() 92 | if query_slice == img_query.shape[0]-1: 93 | query_next = query 94 | else: 95 | input = cv2.resize(img_query[query_slice+1], dsize=(size, size), interpolation=cv2.INTER_LINEAR) 96 | input = MR_normalize(input) 97 | query_next = torch.from_numpy(input[np.newaxis, np.newaxis, ...]).float() 98 | # finish read query img(1 or 3 slices) and mask (1 slice) 99 | query = torch.cat([query_pre, query, query_next], dim=1).cuda() 100 | mask_query = cv2.resize(mask_query[query_slice], dsize=(size, size), interpolation=cv2.INTER_NEAREST) 101 | 102 | 103 | # every slice a support 104 | support_paths = random.sample(tmp_support_path, shot) 105 | print('sids:', support_paths) 106 | sp_imgs = [] 107 | sp_msks = [] 108 | for i in range(shot): 109 | img_support = nrrd.read(support_paths[i])[0].transpose(2, 1, 0) 110 | mask_support = nrrd.read(support_paths[i].replace('_im', '_m'))[0].transpose(2, 1, 0).astype( 111 | np.uint8) 112 | sp_imgs.append(img_support) 113 | sp_msks.append(mask_support) 114 | 115 | # get cur_slice support 116 | s_inputs = [] 117 | s_masks = [] 118 | cond_inputs = [] 119 | for i in range(shot): 120 | img_support = sp_imgs[i] 121 | mask_support = sp_msks[i] 122 | sp_shp0 = img_support.shape[0] 123 | if sp_slices == 1: 124 | sp_index = int(query_slice/img_query.shape[0]*img_support.shape[0]) 125 | img_support = cv2.resize(img_support[sp_index], dsize=(size, size), interpolation=cv2.INTER_LINEAR) 126 | img_support = MR_normalize(img_support) 127 | s_input = torch.from_numpy(img_support[np.newaxis, np.newaxis, np.newaxis,...]).float().cuda() 128 | msk_support = cv2.resize(mask_support[sp_index], dsize=(size, size), interpolation=cv2.INTER_NEAREST) 129 | s_mask = torch.from_numpy(msk_support[np.newaxis, np.newaxis, np.newaxis, ...]).float().cuda() 130 | else: 131 | # S1 132 | # sp_index = sp_shp0//2 133 | 134 | # S2 135 | # bias = sp_shp0 / 3 / 2 136 | # ratio = query_slice / img_query.shape[0] 137 | # if ratio < 1 / 3: 138 | # sp_index = int(bias) 139 | # elif ratio >= 1 / 3 and ratio < 2 / 3: 140 | # sp_index = int(1 / 3 * sp_shp0 + bias) 141 | # else: 142 | # sp_index = int(2 / 3 * sp_shp0 + bias) 143 | 144 | # S3 145 | sp_index = int(query_slice / img_query.shape[0] * sp_shp0) 146 | 147 | 148 | sp_indexes= [max(sp_index-1, 0), sp_index, min(sp_index+1, sp_shp0-1)] 149 | sp_imgs_tmp = [] 150 | sp_masks_tmp =[] 151 | for sp_index in sp_indexes: 152 | img_support_r = cv2.resize(img_support[sp_index], dsize=(size, size), 153 | interpolation=cv2.INTER_LINEAR) 154 | img_support_r = MR_normalize(img_support_r) 155 | s_input = torch.from_numpy(img_support_r[np.newaxis, np.newaxis, np.newaxis, ...]).float().cuda() 156 | sp_imgs_tmp.append(s_input) 157 | msk_support = cv2.resize(mask_support[sp_index], dsize=(size, size), 158 | interpolation=cv2.INTER_NEAREST)==1 159 | s_mask = torch.from_numpy(msk_support[np.newaxis, np.newaxis, np.newaxis, ...]).float().cuda() 160 | sp_masks_tmp.append(s_mask) 161 | 162 | s_input = torch.cat(sp_imgs_tmp, 2) # [1,1,slice,H,W] 163 | s_mask = torch.cat(sp_masks_tmp, 2) # [1,1,slice,H,W] 164 | 165 | s_inputs.append(s_input) 166 | s_masks.append(s_mask) 167 | 168 | # finish read support img and mask 169 | s_input = torch.cat(s_inputs, 1) # 1, Kshot, slice, h, w 170 | s_mask = torch.cat(s_masks, 1) 171 | 172 | # # run model 173 | support = torch.cat([s_input, s_mask], 2) # b, Kshot, slice*2, h, w 174 | cond_inputs_ = support.permute(1,0,2,3,4) # Kshot, b, slice*2, h, w 175 | 176 | # forward 177 | out, sp_pred, max_corr2 = model.segmentor(query, cond_inputs_, s_mask.permute(1, 0, 2, 3, 4)) 178 | 179 | tmp_sprior.append(out.detach().cpu().numpy()) 180 | out = F.interpolate(out, size=img_query.shape[1:], mode='bilinear', align_corners=True) 181 | sp_pred = F.interpolate(sp_pred, size=img_query.shape[1:], mode='bilinear', align_corners=True) 182 | 183 | output = (out >.5).squeeze(1) 184 | sp_pred = (sp_pred > .5).squeeze(1) 185 | 186 | pred_mask.append(output.cpu().numpy()) 187 | sp_mask.append(sp_pred.squeeze(1).cpu().numpy()) 188 | 189 | pred = np.concatenate(pred_mask, 0) 190 | sp = np.concatenate(sp_mask, 0) 191 | 192 | nrrd.write(f'{save_path}/{query_name}_pred.nrrd', pred.transpose(2,1,0).astype(np.uint8)) 193 | nrrd.write(f'{save_path}/{query_name}_sp.nrrd', sp.transpose(2,1,0).astype(np.uint8)) 194 | 195 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description 3 | ++++++++++++++++++++++ 4 | Addition losses module defines classses which are commonly used particularly in segmentation and are not part of standard pytorch library. 5 | 6 | Usage 7 | ++++++++++++++++++++++ 8 | Import the package and Instantiate any loss class you want to you:: 9 | 10 | from nn_common_modules import losses as additional_losses 11 | loss = additional_losses.DiceLoss() 12 | 13 | Note: If you use DiceLoss, insert Softmax layer in the architecture. In case of combined loss, do not put softmax as it is in-built 14 | 15 | Members 16 | ++++++++++++++++++++++ 17 | """ 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.nn.modules.loss import _Loss, _WeightedLoss 22 | import numpy as np 23 | from torch.autograd import Variable 24 | 25 | 26 | class DiceLoss(_WeightedLoss): 27 | """ 28 | Dice Loss for a batch of samples 29 | """ 30 | 31 | def forward(self, output, target, weights=torch.tensor([0.3,0.7]).cuda(), binary=False): 32 | """ 33 | Forward pass 34 | 35 | :param output: NxCxHxW logits 36 | :param target: NxHxW LongTensor 37 | :param weights: C FloatTensor 38 | :param binary: bool for binarized one chaneel(C=1) input 39 | :return: torch.tensor 40 | """ 41 | if binary: 42 | return self._dice_loss_binary(output, target) 43 | # output = F.softmax(output, dim=1) 44 | return self._dice_loss_multichannel(output, target, weights) 45 | 46 | @staticmethod 47 | def _dice_loss_binary(output, target): 48 | """ 49 | Dice loss for one channel binarized input 50 | 51 | :param output: Nx1xHxW logits 52 | :param target: NxHxW LongTensor 53 | :return: 54 | """ 55 | eps = 0.0001 56 | 57 | intersection = output * target 58 | numerator = 2 * intersection.sum(0).sum(1).sum(1) 59 | denominator = output + target 60 | denominator = denominator.sum(0).sum(1).sum(1) + eps 61 | loss_per_channel = 1 - (numerator / denominator) 62 | 63 | return loss_per_channel.sum() / output.size(1) 64 | 65 | @staticmethod 66 | def _dice_loss_multichannel(output, target, weights=None): 67 | """ 68 | Forward pass 69 | 70 | :param output: NxCxHxW Variable 71 | :param target: NxHxW LongTensor 72 | :param weights: C FloatTensor 73 | :param binary: bool for binarized one chaneel(C=1) input 74 | :return: 75 | """ 76 | 77 | # output = F.softmax(output, dim=1) 78 | eps = 0.0001 79 | target = target.unsqueeze(1) 80 | encoded_target = torch.zeros_like(output) 81 | 82 | encoded_target = encoded_target.scatter(1, target, 1) 83 | 84 | intersection = output * encoded_target 85 | intersection = intersection.sum(2).sum(2) 86 | 87 | num_union_pixels = output.sum(2).sum(2) + encoded_target.sum(2).sum(2) 88 | # num_union_pixels = num_union_pixels.sum(2).sum(2) 89 | 90 | loss_per_class = 1 - ((2 * intersection) / (num_union_pixels + eps)) 91 | # loss_per_class = 1 - ((2 * intersection + 1) / (num_union_pixels + 1)) 92 | if weights is None: 93 | weights = torch.ones_like(loss_per_class) 94 | loss_per_class *= weights 95 | 96 | return loss_per_class.sum(1).mean() # / (num_union_pixels != 0).sum(1).float()).mean() 97 | 98 | 99 | class IoULoss(_WeightedLoss): 100 | """ 101 | IoU Loss for a batch of samples 102 | """ 103 | 104 | def forward(self, output, target, weights=None, ignore_index=None): 105 | """Forward pass 106 | 107 | :param output: shape = NxCxHxW 108 | :type output: torch.tensor [FloatTensor] 109 | :param target: shape = NxHxW 110 | :type target: torch.tensor [LongTensor] 111 | :param weights: shape = C, defaults to None 112 | :type weights: torch.tensor [FloatTensor], optional 113 | :param ignore_index: index to ignore from loss, defaults to None 114 | :type ignore_index: int, optional 115 | :return: loss value 116 | :rtype: torch.tensor 117 | """ 118 | 119 | output = F.softmax(output, dim=1) 120 | 121 | eps = 0.0001 122 | encoded_target = output.detach() * 0 123 | 124 | if ignore_index is not None: 125 | mask = target == ignore_index 126 | target = target.clone() 127 | target[mask] = 0 128 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 129 | mask = mask.unsqueeze(1).expand_as(encoded_target) 130 | encoded_target[mask] = 0 131 | else: 132 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 133 | 134 | if weights is None: 135 | weights = 1 136 | 137 | intersection = output * encoded_target 138 | numerator = intersection.sum(0).sum(1).sum(1) 139 | denominator = (output + encoded_target) - (output * encoded_target) 140 | 141 | if ignore_index is not None: 142 | denominator[mask] = 0 143 | denominator = denominator.sum(0).sum(1).sum(1) + eps 144 | loss_per_channel = weights * (1 - (numerator / denominator)) 145 | 146 | return loss_per_channel.sum() / output.size(1) 147 | 148 | 149 | class CrossEntropyLoss2d(_WeightedLoss): 150 | """ 151 | Standard pytorch weighted nn.CrossEntropyLoss 152 | """ 153 | 154 | def __init__(self, weight=None): 155 | super(CrossEntropyLoss2d, self).__init__() 156 | self.nll_loss = nn.CrossEntropyLoss(weight) 157 | 158 | def forward(self, inputs, targets): 159 | """ 160 | Forward pass 161 | 162 | :param inputs: torch.tensor (NxC) 163 | :param targets: torch.tensor (N) 164 | :return: scalar 165 | """ 166 | return self.nll_loss(inputs, targets) 167 | 168 | 169 | class CombinedLoss(_Loss): 170 | """ 171 | A combination of dice and cross entropy loss 172 | """ 173 | 174 | def __init__(self): 175 | super(CombinedLoss, self).__init__() 176 | self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') # CrossEntropyLoss2d() 177 | self.dice_loss = DiceLoss() 178 | self.focal_loss = FocalLoss() 179 | self.l2_loss = nn.MSELoss() 180 | 181 | def forward(self, input, target, weight=None): 182 | """ 183 | Forward pass 184 | 185 | :param input: torch.tensor (NxCxHxW) 186 | :param target: torch.tensor (NxHxW) 187 | :param weight: torch.tensor (NxHxW) 188 | :return: scalar 189 | """ 190 | # input_soft = F.softmax(input, dim=1) 191 | y_2 = torch.mean(self.dice_loss(input, target)) 192 | if weight is None: 193 | y_1 = torch.mean(self.cross_entropy_loss.forward(input, target)) 194 | else: 195 | y_1 = torch.mean( 196 | torch.mul(self.cross_entropy_loss.forward(input, target), weight)) 197 | return y_1 + y_2 198 | 199 | 200 | class CombinedLoss_KLdiv(_Loss): 201 | """ 202 | A combination of dice and cross entropy loss 203 | """ 204 | 205 | def __init__(self): 206 | super(CombinedLoss_KLdiv, self).__init__() 207 | self.cross_entropy_loss = CrossEntropyLoss2d() 208 | self.dice_loss = DiceLoss() 209 | 210 | def forward(self, input, target, weight=None): 211 | """ 212 | Forward pass 213 | 214 | """ 215 | input, kl_div_loss = input 216 | # input_soft = F.softmax(input, dim=1) 217 | y_2 = torch.mean(self.dice_loss(input, target)) 218 | if weight is None: 219 | y_1 = torch.mean(self.cross_entropy_loss.forward(input, target)) 220 | else: 221 | y_1 = torch.mean( 222 | torch.mul(self.cross_entropy_loss.forward(input, target), weight)) 223 | return y_1, y_2, kl_div_loss 224 | 225 | 226 | # Credit to https://github.com/clcarwin/focal_loss_pytorch 227 | class FocalLoss(nn.Module): 228 | """ 229 | Focal Loss for Dense Object Detection 230 | """ 231 | 232 | def __init__(self, gamma=2, alpha=None, size_average=True): 233 | 234 | super(FocalLoss, self).__init__() 235 | self.gamma = gamma 236 | self.alpha = alpha 237 | if isinstance(alpha, (float, int)): 238 | self.alpha = torch.Tensor([alpha, 1 - alpha]) 239 | if isinstance(alpha, list): 240 | self.alpha = torch.Tensor(alpha) 241 | self.size_average = size_average 242 | 243 | def forward(self, input, target): 244 | """Forward pass 245 | 246 | :param input: shape = NxCxHxW 247 | :type input: torch.tensor 248 | :param target: shape = NxHxW 249 | :type target: torch.tensor 250 | :return: loss value 251 | :rtype: torch.tensor 252 | """ 253 | 254 | if input.dim() > 2: 255 | # N,C,H,W => N,C,H*W 256 | input = input.view(input.size(0), input.size(1), -1) 257 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 258 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 259 | target = target.view(-1, 1) 260 | 261 | logpt = F.log_softmax(input, dim=1) 262 | logpt = logpt.gather(1, target) 263 | logpt = logpt.view(-1) 264 | pt = Variable(logpt.data.exp()) 265 | 266 | if self.alpha is not None: 267 | if self.alpha.type() != input.data.type(): 268 | self.alpha = self.alpha.type_as(input.data) 269 | at = self.alpha.gather(0, target.data.view(-1)) 270 | logpt = logpt * Variable(at) 271 | 272 | loss = -1 * (1 - pt) ** self.gamma * logpt 273 | if self.size_average: 274 | return loss.mean() 275 | else: 276 | return loss.sum() 277 | 278 | 279 | from torch import einsum 280 | from torch import Tensor 281 | from scipy.ndimage import distance_transform_edt as distance 282 | from scipy.spatial.distance import directed_hausdorff 283 | 284 | from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union 285 | 286 | 287 | def class2one_hot(seg: Tensor, C: int) -> Tensor: 288 | if len(seg.shape) == 2: # Only w, h, used by the dataloader 289 | seg = seg.unsqueeze(dim=0) 290 | 291 | b, w, h = seg.shape # type: Tuple[int, int, int] 292 | 293 | res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32) 294 | assert res.shape == (b, C, w, h) 295 | 296 | return res 297 | 298 | def one_hot2dist(seg: np.ndarray) -> np.ndarray: 299 | C: int = len(seg) 300 | 301 | res = np.zeros_like(seg) 302 | for c in range(C): 303 | posmask = seg[c].astype(np.bool) 304 | 305 | if posmask.any(): 306 | negmask = ~posmask 307 | # print('negmask:', negmask) 308 | # print('distance(negmask):', distance(negmask)) 309 | res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask 310 | # print('res[c]', res[c]) 311 | return res 312 | 313 | class SurfaceLoss(): 314 | def __init__(self): 315 | # Self.idc is used to filter out some classes of the target mask. Use fancy indexing 316 | self.idc: List[int] = [0, 1] #这里忽略背景类 https://github.com/LIVIAETS/surface-loss/issues/3 317 | 318 | # probs: bcwh, dist_maps: bcwh 319 | def __call__(self, probs: Tensor, dist_maps: Tensor) -> Tensor: 320 | 321 | pc = probs[:, self.idc, ...].type(torch.float32) 322 | dc = dist_maps[:, self.idc, ...].type(torch.float32) 323 | 324 | multipled = einsum("bcwh,bcwh->bcwh", pc, dc) 325 | 326 | loss = multipled.mean() 327 | 328 | return loss 329 | -------------------------------------------------------------------------------- /modules/conv_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules import se_modules as se 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | 7 | class GenericBlock(nn.Module): 8 | """ 9 | Generic parent class for a conv encoder/decoder block. 10 | 11 | :param params: {'kernel_h': 5 12 | 'kernel_w': 5 13 | 'num_channels':64 14 | 'num_filters':64 15 | 'stride_conv':1 16 | } 17 | :type params: dict 18 | :param se_block_type: Squeeze and Excite block type to be included, defaults to None 19 | :type se_block_type: str, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'}, optional 20 | :return: forward passed tensor 21 | :rtype: torch.tensor [FloatTensor] 22 | """ 23 | 24 | def __init__(self, params, se_block_type=None): 25 | super(GenericBlock, self).__init__() 26 | if se_block_type == se.SELayer.CSE.value: 27 | self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) 28 | 29 | elif se_block_type == se.SELayer.SSE.value: 30 | self.SELayer = se.SpatialSELayer(params['num_filters']) 31 | 32 | elif se_block_type == se.SELayer.CSSE.value: 33 | self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) 34 | else: 35 | self.SELayer = None 36 | padding_h = int((params['kernel_h'] - 1) / 2) 37 | padding_w = int((params['kernel_w'] - 1) / 2) 38 | self.out_channel = params['num_filters'] 39 | self.conv = nn.Conv2d(in_channels=params['num_channels'], out_channels=params['num_filters'], 40 | kernel_size=( 41 | params['kernel_h'], params['kernel_w']), 42 | padding=(padding_h, padding_w), 43 | stride=params['stride_conv']) 44 | self.prelu = nn.PReLU() 45 | self.batchnorm = nn.BatchNorm2d(num_features=params['num_filters']) 46 | if params['drop_out'] > 0: 47 | self.drop_out_needed = True 48 | self.drop_out = nn.Dropout2d(params['drop_out']) 49 | else: 50 | self.drop_out_needed = False 51 | 52 | def forward(self, input, weights=None): 53 | """Forward pass 54 | 55 | :param input: Input tensor, shape = (N x C x H x W) 56 | :type input: torch.tensor [FloatTensor] 57 | :param weights: Custom weights for convolution, defaults to None 58 | :type weights: torch.tensor [FloatTensor], optional 59 | :return: [description] 60 | :rtype: [type] 61 | """ 62 | 63 | _, c, h, w = input.shape 64 | if weights is None: 65 | x1 = self.conv(input) 66 | else: 67 | weights, _ = torch.max(weights, dim=0) 68 | weights = weights.view(self.out_channel, c, 1, 1) 69 | x1 = F.conv2d(input, weights) 70 | x2 = self.prelu(x1) 71 | x3 = self.batchnorm(x2) 72 | return x3 73 | 74 | 75 | class SDnetEncoderBlock(GenericBlock): 76 | """ 77 | A standard conv -> prelu -> batchnorm-> maxpool block without dense connections 78 | 79 | :param params: { 80 | 'num_channels':1, 81 | 'num_filters':64, 82 | 'kernel_h':5, 83 | 'kernel_w':5, 84 | 'stride_conv':1, 85 | 'pool':2, 86 | 'stride_pool':2, 87 | 'num_classes':28, 88 | 'se_block': se.SELayer.None, 89 | 'drop_out':0,2} 90 | :type params: dict 91 | :param se_block_type: Squeeze and Excite block type to be included, defaults to None 92 | :type se_block_type: str, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'}, optional 93 | :return: output tensor with maxpool, output tensor without maxpool, indices for unpooling 94 | :rtype: torch.tensor [FloatTensor], torch.tensor [FloatTensor], torch.tensor [LongTensor] 95 | """ 96 | 97 | def __init__(self, params, se_block_type=None): 98 | super(SDnetEncoderBlock, self).__init__(params, se_block_type) 99 | self.maxpool = nn.MaxPool2d( 100 | kernel_size=params['pool'], stride=params['stride_pool'], return_indices=True) 101 | 102 | def forward(self, input, weights=None): 103 | """Forward pass 104 | 105 | :param input: Input tensor, shape = (N x C x H x W) 106 | :type input: torch.tensor [FloatTensor] 107 | :param weights: Weights used for squeeze and excitation, shape depends on the type of SE block, defaults to None 108 | :type weights: torch.tensor, optional 109 | :return: output tensor with maxpool, output tensor without maxpool, indices for unpooling 110 | :rtype: torch.tensor [FloatTensor], torch.tensor [FloatTensor], torch.tensor [LongTensor] 111 | """ 112 | 113 | out_block = super(SDnetEncoderBlock, self).forward(input, weights) 114 | 115 | if self.SELayer: 116 | out_block = self.SELayer(out_block, weights) 117 | if self.drop_out_needed: 118 | out_block = self.drop_out(out_block) 119 | 120 | out_encoder, indices = self.maxpool(out_block) 121 | return out_encoder, out_block, indices 122 | 123 | 124 | class SDnetDecoderBlock(GenericBlock): 125 | """Standard decoder block with maxunpool -> skipconnections -> conv -> prelu -> batchnorm, without dense connections and an optional SE blocks 126 | 127 | :param params: { 128 | 'num_channels':1, 129 | 'num_filters':64, 130 | 'kernel_h':5, 131 | 'kernel_w':5, 132 | 'stride_conv':1, 133 | 'pool':2, 134 | 'stride_pool':2, 135 | 'num_classes':28, 136 | 'se_block': se.SELayer.None, 137 | 'drop_out':0,2} 138 | :type params: dict 139 | :param se_block_type: Squeeze and Excite block type to be included, defaults to None 140 | :type se_block_type: str, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'}, optional 141 | :return: forward passed tensor 142 | :rtype: torch.tensor [FloatTensor] 143 | """ 144 | 145 | def __init__(self, params, se_block_type=None): 146 | super(SDnetDecoderBlock, self).__init__(params, se_block_type) 147 | self.unpool = nn.MaxUnpool2d( 148 | kernel_size=params['pool'], stride=params['stride_pool']) 149 | # self.conv1 = nn.Conv2d(in_channels=params['num_channels'], out_channels=params['num_filters'], 150 | # kernel_size=(1,1), 151 | # stride=params['stride_conv']) 152 | 153 | 154 | def forward(self, input, out_block=None, indices=None, weights=None): 155 | """Forward pass 156 | 157 | :param input: Input tensor, shape = (N x C x H x W) 158 | :type input: torch.tensor [FloatTensor] 159 | :param out_block: Tensor for skip connection, shape = (N x C x H x W), defaults to None 160 | :type out_block: torch.tensor [FloatTensor], optional 161 | :param indices: Indices used for unpooling operation, defaults to None 162 | :type indices: torch.tensor, optional 163 | :param weights: Weights used for squeeze and excitation, shape depends on the type of SE block, defaults to None 164 | :type weights: torch.tensor, optional 165 | :return: Forward pass 166 | :rtype: torch.tensor 167 | """ 168 | 169 | # unpool = self.unpool(input, indices) # , out_block.shape) 170 | unpool = F.interpolate(input, scale_factor=2, mode='nearest') 171 | # unpool = self.conv1(unpool) 172 | if out_block is not None: 173 | concat = torch.cat((out_block, unpool), dim=1) 174 | else: 175 | concat = unpool 176 | out_block = super(SDnetDecoderBlock, self).forward(concat, weights) 177 | if self.SELayer: 178 | out_block = self.SELayer(out_block, weights) 179 | 180 | if self.drop_out_needed: 181 | out_block = self.drop_out(out_block) 182 | return out_block 183 | 184 | class ClassifierBlock(nn.Module): 185 | """ 186 | Last layer 187 | 188 | :param params: { 189 | 'num_channels':1, 190 | 'num_filters':64, 191 | 'kernel_c':5, 192 | 'stride_conv':1, 193 | 'pool':2, 194 | 'stride_pool':2, 195 | 'num_classes':28, 196 | 'se_block': se.SELayer.None, 197 | 'drop_out':0,2} 198 | :type params: dict 199 | :return: forward passed tensor 200 | :rtype: torch.tensor [FloatTensor] 201 | """ 202 | 203 | def __init__(self, params): 204 | super(ClassifierBlock, self).__init__() 205 | self.conv = nn.Conv2d( 206 | params['num_channels'], params['num_class'], params['kernel_c'], params['stride_conv']) 207 | 208 | def forward(self, input, weights=None): 209 | """Forward pass 210 | 211 | :param input: Input tensor, shape = (N x C x H x W) 212 | :type input: torch.tensor [FloatTensor] 213 | :param weights: Weights for classifier regression, defaults to None 214 | :type weights: torch.tensor (N), optional 215 | :return: logits 216 | :rtype: torch.tensor 217 | """ 218 | batch_size, channel, a, b = input.size() 219 | if weights is not None: 220 | weights, _ = torch.max(weights, dim=0) 221 | weights = weights.view(1, channel, 1, 1) 222 | out_conv = F.conv2d(input, weights) 223 | else: 224 | out_conv = self.conv(input) 225 | return out_conv 226 | 227 | 228 | class AddCoords(nn.Module): 229 | 230 | def __init__(self, with_r=False): 231 | super().__init__() 232 | self.with_r = with_r 233 | 234 | def forward(self, input_tensor): 235 | """ 236 | Args: 237 | input_tensor: shape(batch, channel, x_dim, y_dim) 238 | """ 239 | batch_size, _, x_dim, y_dim = input_tensor.size() 240 | 241 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 242 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 243 | 244 | xx_channel = xx_channel.float() / (x_dim - 1) 245 | yy_channel = yy_channel.float() / (y_dim - 1) 246 | 247 | xx_channel = xx_channel * 2 - 1 248 | yy_channel = yy_channel * 2 - 1 249 | 250 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 251 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 252 | 253 | ret = torch.cat([ 254 | input_tensor, 255 | xx_channel.type_as(input_tensor), 256 | yy_channel.type_as(input_tensor)], dim=1) 257 | 258 | if self.with_r: 259 | rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)) 260 | ret = torch.cat([ret, rr], dim=1) 261 | 262 | return ret 263 | 264 | 265 | class CoordConv(nn.Module): 266 | 267 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 268 | super().__init__() 269 | self.addcoords = AddCoords(with_r=with_r) 270 | in_size = in_channels+2 271 | if with_r: 272 | in_size += 1 273 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 274 | 275 | def forward(self, x): 276 | ret = self.addcoords(x) 277 | ret = self.conv(ret) 278 | return ret 279 | 280 | 281 | class AF(nn.Module): 282 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False): 283 | super(AF, self).__init__() 284 | self.out_channels = out_channels 285 | self.kernel_size = kernel_size 286 | self.stride = stride 287 | self.padding = padding 288 | self.groups = groups 289 | 290 | assert self.out_channels % self.groups == 0, "out_channels should be divided by groups. (example: out_channels: 40, groups: 4)" 291 | 292 | self.rel_h = nn.Parameter(torch.randn(out_channels // 2, 1, 1, kernel_size, 1), requires_grad=True) 293 | self.rel_w = nn.Parameter(torch.randn(out_channels // 2, 1, 1, 1, kernel_size), requires_grad=True) 294 | 295 | self.key_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias) 296 | self.query_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias) 297 | self.value_conv = nn.Conv2d(1, 1, kernel_size=5, padding=2) 298 | 299 | self.reset_parameters() 300 | 301 | def forward(self, x, v): 302 | batch, channels, height, width = x.size() 303 | 304 | padded_x = F.pad(x, [self.padding, self.padding, self.padding, self.padding]) 305 | padded_v = F.pad(v, [self.padding, self.padding, self.padding, self.padding]) 306 | q_out = self.query_conv(x) 307 | k_out = self.key_conv(padded_x) 308 | v_out = padded_v 309 | 310 | k_out = k_out.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride) 311 | v_out = v_out.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride) 312 | 313 | 314 | k_out_h, k_out_w = k_out.split(self.out_channels // 2, dim=1) 315 | k_out = torch.cat((k_out_h + self.rel_h, k_out_w + self.rel_w), dim=1) 316 | 317 | k_out = k_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1) # -1 = k*k 318 | v_out = v_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1) 319 | 320 | 321 | q_out = q_out.view(batch, self.groups, self.out_channels // self.groups, height, width, 1) 322 | 323 | out = q_out * k_out 324 | out = F.softmax(out, dim=-1) 325 | out = torch.einsum('bnchwk,bnchwk -> bnchw', out, v_out).view(batch, -1, height, width) 326 | 327 | return out 328 | 329 | def reset_parameters(self): 330 | init.kaiming_normal_(self.key_conv.weight, mode='fan_out', nonlinearity='relu') 331 | init.kaiming_normal_(self.value_conv.weight, mode='fan_out', nonlinearity='relu') 332 | init.kaiming_normal_(self.query_conv.weight, mode='fan_out', nonlinearity='relu') 333 | 334 | init.normal_(self.rel_h, 0, 1) 335 | init.normal_(self.rel_w, 0, 1) 336 | -------------------------------------------------------------------------------- /RAP.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.getcwd()) 3 | import torch 4 | import torch.nn as nn 5 | import modules.conv_blocks as cm 6 | import torch.nn.functional as F 7 | from modules.voxelmorph_us import U_Network, SpatialTransformer 8 | from modules.cre import Correlation 9 | 10 | 11 | class SDnetSegmentor(nn.Module): 12 | """ 13 | Segmentor Code 14 | """ 15 | def __init__(self, params): 16 | super(SDnetSegmentor, self).__init__() 17 | params['num_channels'] = 3 18 | params['num_filters'] = 16 19 | self.encode1 = cm.SDnetEncoderBlock(params) 20 | params['num_channels'] = 16 21 | params['num_filters'] = 32 22 | self.encode2 = cm.SDnetEncoderBlock(params) 23 | params['num_channels'] = 32 24 | params['num_filters'] = 64 25 | self.encode3 = cm.SDnetEncoderBlock(params) 26 | params['num_channels'] = 64 27 | params['num_filters'] = 64 28 | self.encode4 = cm.SDnetEncoderBlock(params) 29 | params['num_channels'] = 64 30 | params['num_filters'] = 64 31 | self.bottleneck = cm.GenericBlock(params) 32 | 33 | params['num_channels'] = 128 34 | params['num_filters'] = 32 35 | self.decode3 = cm.SDnetDecoderBlock(params) 36 | params['num_channels'] = 128 37 | params['num_filters'] = 64 38 | self.decode4 = cm.SDnetDecoderBlock(params) 39 | params['num_channels'] = 64 40 | params['num_class'] = 1 41 | self.classifier = cm.ClassifierBlock(params) 42 | params['num_channels'] = 1 43 | 44 | cood_conv = cm.CoordConv(16*3, 1, kernel_size=3, padding=1) #24 45 | self.soft_max = nn.Softmax2d() 46 | self.sigmoid = nn.Sigmoid() 47 | self.mask_conv = nn.Sequential(cood_conv, nn.Sigmoid()) 48 | 49 | self.unet = U_Network(2, [16, 32, 32, 32], [32, 32, 32, 32, 8, 8]) 50 | self.stn = SpatialTransformer((256, 256)) 51 | 52 | self.conv00 = nn.Sequential( 53 | nn.Conv2d(3, 16, 5, 1, 2), 54 | nn.ReLU(inplace=True) 55 | ) 56 | 57 | self.conv11 = nn.Sequential( 58 | nn.Conv2d(3, 16, 5, 1, 2), 59 | nn.ReLU(inplace=True) 60 | ) 61 | 62 | self.conv22 = nn.Sequential( 63 | nn.Conv2d(3, 16, 7, 1, 3), 64 | nn.ReLU(inplace=True) 65 | ) 66 | 67 | 68 | self.q4 = nn.Sequential( 69 | nn.Conv2d(64 + (3 * 2 + 1) ** 2, 64, 1), 70 | nn.BatchNorm2d(64), 71 | nn.ReLU(inplace=True) 72 | ) 73 | self.q3 = nn.Sequential( 74 | nn.Conv2d(64 + (3 * 2 + 1) ** 2, 64, 1), 75 | nn.BatchNorm2d(64), 76 | nn.ReLU(inplace=True) 77 | ) 78 | self.q2 = nn.Sequential( 79 | nn.Conv2d(32 + (3 * 2 + 1) ** 2, 32, 1), 80 | nn.BatchNorm2d(32), 81 | nn.ReLU(inplace=True) 82 | ) 83 | 84 | self.af_4 = cm.AF(128, 16, 3, 1, 1) 85 | self.af_3 = cm.AF(128, 16, 5, 1, 2) 86 | self.af_2 = cm.AF(64, 16, 7, 1, 3) 87 | 88 | self._init_weights() 89 | 90 | def seg_branch_encoder(self, input): 91 | e1, _, ind1 = self.encode1(input) 92 | e2, _, ind2 = self.encode2(e1) 93 | e3, _, ind3 = self.encode3(e2) 94 | e4, out4, ind4 = self.encode4(e3) 95 | bn = self.bottleneck(e4) 96 | 97 | return bn, ind4, ind3, ind2, ind1, e4, e3, e2, e1 98 | 99 | def _init_weights(self): 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 103 | 104 | # def forward(self, inpt, weights=None, inpt_sp=None, inpt_mask=None, weights_pure=None, base_map=None, sp_prior_test=None, query_mask=None): 105 | def forward(self, inpt, inpt_sp=None, inpt_mask=None): 106 | 107 | 108 | bn, ind4, ind3, ind2, ind1, e4, e3, e2, e1 = self.seg_branch_encoder(inpt) 109 | 110 | 111 | if inpt_sp is not None and len(inpt_sp.shape) == 5: 112 | tmp_sp_poriors = [] 113 | tmp_sp_img_poriors = [] 114 | for i in range(inpt_sp.shape[0]): 115 | 116 | flow = self.unet(inpt_sp[i][:, :3], inpt) 117 | 118 | tmp_sp_porior = self.stn(inpt_mask[i][:, 1:2, ...], flow) 119 | tmp_sp_img_porior = self.stn(inpt_sp[i][:, 1:2, ...], flow) 120 | 121 | tmp_sp_poriors.append(tmp_sp_porior) 122 | tmp_sp_img_poriors.append(tmp_sp_img_porior) 123 | 124 | 125 | sp_prior = torch.cat(tmp_sp_poriors, 1).mean(1, keepdim=True) 126 | sp_img_prior = torch.cat(tmp_sp_img_poriors, 1).mean(1, keepdim=True) 127 | 128 | # sp_prior = inpt_mask[:, :, 1:2, :, :].permute(1, 0, 2, 3, 4).mean(1) 129 | 130 | if inpt_sp is not None and len(inpt_sp.shape) == 5: 131 | # k shot atten 132 | sp_feats_fg_e4 = [] 133 | sp_feats_bg_e4 = [] 134 | sp_feats_fg_e3 = [] 135 | sp_feats_bg_e3 = [] 136 | sp_feats_fg_e2 = [] 137 | sp_feats_bg_e2 = [] 138 | 139 | for i in range(inpt_sp.shape[0]): 140 | bn, ind4, ind3, ind2, ind1, e4_sp, e3_sp, e2_sp, e1_sp = self.seg_branch_encoder(inpt_sp[i][:, :3]) 141 | 142 | sp_level_features = e4_sp 143 | sp_mask = F.interpolate(inpt_mask[i][:, 1:2, ...], size=(sp_level_features.shape[-2:]), mode='nearest') 144 | 145 | corr_fg = Correlation(sp_level_features * sp_mask, sp_level_features * (1 - sp_mask)) 146 | corr_bg = Correlation(sp_level_features * (1 - sp_mask), sp_level_features * sp_mask) 147 | corr_fg = self.q4(torch.cat([corr_fg, sp_level_features * sp_mask], dim=1)) 148 | corr_bg = self.q4(torch.cat([corr_bg, sp_level_features * (1 - sp_mask)], dim=1)) 149 | 150 | 151 | fore_avg_feat = torch.sum(torch.sum(corr_fg * sp_mask, dim=3), dim=2) / (torch.sum(sp_mask) + torch.tensor(1e-10).to(sp_mask.device)) 152 | bg_avg_feat = torch.sum(torch.sum(corr_bg * (1 - sp_mask), dim=3), dim=2) / torch.sum(1 - sp_mask) 153 | 154 | sp_feats_fg_e4.append(fore_avg_feat) 155 | sp_feats_bg_e4.append(bg_avg_feat) 156 | 157 | sp_level_features = e3_sp 158 | sp_mask = F.interpolate(inpt_mask[i][:, 1:2, ...], size=(sp_level_features.shape[-2:]), mode='nearest') 159 | 160 | # aug each level features 161 | corr_fg = Correlation(sp_level_features * sp_mask, sp_level_features * (1 - sp_mask)) 162 | corr_bg = Correlation(sp_level_features * (1 - sp_mask), sp_level_features * sp_mask) 163 | corr_fg = self.q3(torch.cat([corr_fg, sp_level_features * sp_mask], dim=1)) 164 | corr_bg = self.q3(torch.cat([corr_bg, sp_level_features * (1 - sp_mask)], dim=1)) 165 | 166 | fore_avg_feat = torch.sum(torch.sum(corr_fg * sp_mask, dim=3), dim=2) / (torch.sum(sp_mask) + torch.tensor(1e-10).to(sp_mask.device)) 167 | bg_avg_feat = torch.sum(torch.sum(corr_bg * (1 - sp_mask), dim=3), dim=2) / torch.sum(1 - sp_mask) 168 | 169 | sp_feats_fg_e3.append(fore_avg_feat) 170 | sp_feats_bg_e3.append(bg_avg_feat) 171 | 172 | # repeat 173 | sp_level_features = e2_sp 174 | sp_mask = F.interpolate(inpt_mask[i][:,1:2,...], size=(sp_level_features.shape[-2:]), mode='nearest') 175 | 176 | # aug each level features 177 | corr_fg = Correlation(sp_level_features * sp_mask, sp_level_features * (1 - sp_mask)) 178 | corr_bg = Correlation(sp_level_features * (1 - sp_mask), sp_level_features * sp_mask) 179 | corr_fg = self.q2(torch.cat([corr_fg, sp_level_features * sp_mask], dim=1)) 180 | corr_bg = self.q2(torch.cat([corr_bg, sp_level_features * (1 - sp_mask)], dim=1)) 181 | 182 | fore_avg_feat = torch.sum(torch.sum(corr_fg * sp_mask, dim=3), dim=2) / (torch.sum(sp_mask) + torch.tensor(1e-10).to(sp_mask.device)) 183 | bg_avg_feat = torch.sum(torch.sum(corr_bg * (1 - sp_mask), dim=3), dim=2) / torch.sum(1 - sp_mask) 184 | 185 | sp_feats_fg_e2.append(fore_avg_feat) 186 | sp_feats_bg_e2.append(bg_avg_feat) 187 | 188 | 189 | sp_feat_fg_e4 = torch.mean(torch.stack(sp_feats_fg_e4), 0).unsqueeze(-1) 190 | sp_feat_bg_e4 = torch.mean(torch.stack(sp_feats_bg_e4), 0).unsqueeze(-1) 191 | 192 | sp_feat_fg_e3 = torch.mean(torch.stack(sp_feats_fg_e3), 0).unsqueeze(-1) 193 | sp_feat_bg_e3 = torch.mean(torch.stack(sp_feats_bg_e3), 0).unsqueeze(-1) 194 | 195 | sp_feat_fg_e2 = torch.mean(torch.stack(sp_feats_fg_e2), 0).unsqueeze(-1) 196 | sp_feat_bg_e2 = torch.mean(torch.stack(sp_feats_bg_e2), 0).unsqueeze(-1) 197 | 198 | q_level_features = e4 199 | # aug query each level features 200 | sp_prior_r = F.interpolate(sp_prior, size=(q_level_features.shape[-2:]), mode='nearest') 201 | 202 | corr_fg = Correlation(q_level_features *sp_prior_r, q_level_features * (1 - sp_prior_r)) 203 | corr_bg = Correlation(q_level_features * (1 - sp_prior_r), q_level_features * sp_prior_r) 204 | 205 | q_level_features_fg = self.q4(torch.cat([corr_fg, q_level_features * sp_prior_r], dim=1)) 206 | q_level_features_bg = self.q4(torch.cat([corr_bg, q_level_features * (1-sp_prior_r)], dim=1)) 207 | 208 | q_b, q_n, q_h, q_w = q_level_features_fg.shape 209 | q_level_features_fg = q_level_features_fg.view(q_b, q_n, -1) 210 | q_level_features_bg = q_level_features_bg.view(q_b, q_n, -1) 211 | 212 | correlative_map_fg = F.cosine_similarity(q_level_features_fg, sp_feat_fg_e4) 213 | correlative_map_fg_e4 = correlative_map_fg.view(q_b, 1, q_h, q_w) 214 | 215 | correlative_map_bg = F.cosine_similarity(q_level_features_bg, sp_feat_bg_e4) 216 | correlative_map_bg_e4 = correlative_map_bg.view(q_b, 1, q_h, q_w) 217 | 218 | q_level_features = e3 219 | # aug q 220 | sp_prior_r = F.interpolate(sp_prior, size=(q_level_features.shape[-2:]), mode='nearest') 221 | 222 | corr_fg = Correlation(q_level_features *sp_prior_r, q_level_features * (1 - sp_prior_r)) 223 | corr_bg = Correlation(q_level_features * (1 - sp_prior_r), q_level_features * sp_prior_r) 224 | 225 | q_level_features_fg = self.q3(torch.cat([corr_fg, q_level_features * sp_prior_r], dim=1)) 226 | q_level_features_bg = self.q3(torch.cat([corr_bg, q_level_features * (1-sp_prior_r)], dim=1)) 227 | 228 | q_b, q_n, q_h, q_w = q_level_features_fg.shape 229 | q_level_features_fg = q_level_features_fg.view(q_b, q_n, -1) 230 | q_level_features_bg = q_level_features_bg.view(q_b, q_n, -1) 231 | 232 | 233 | correlative_map_fg = F.cosine_similarity(q_level_features_fg, sp_feat_fg_e3) 234 | correlative_map_fg_e3 = correlative_map_fg.view(q_b, 1, q_h, q_w) 235 | 236 | correlative_map_bg = F.cosine_similarity(q_level_features_bg, sp_feat_bg_e3) 237 | correlative_map_bg_e3 = correlative_map_bg.view(q_b, 1, q_h, q_w) 238 | 239 | _, max_corr_e3 = torch.max(self.soft_max(torch.cat([correlative_map_bg_e3, correlative_map_fg_e3], 1)), dim=1, keepdim=True) 240 | 241 | # repeat 242 | q_level_features = e2 243 | # aug q level 244 | sp_prior_r = F.interpolate(sp_prior, size=(q_level_features.shape[-2:]), mode='nearest') 245 | 246 | corr_fg = Correlation(q_level_features *sp_prior_r, q_level_features * (1 - sp_prior_r)) 247 | corr_bg = Correlation(q_level_features * (1 - sp_prior_r), q_level_features * sp_prior_r) 248 | 249 | q_level_features_fg = self.q2(torch.cat([corr_fg, q_level_features * sp_prior_r], dim=1)) 250 | q_level_features_bg = self.q2(torch.cat([corr_bg, q_level_features * (1-sp_prior_r)], dim=1)) 251 | 252 | q_b, q_n, q_h, q_w = q_level_features_fg.shape 253 | q_level_features_fg = q_level_features_fg.view(q_b, q_n, -1) 254 | q_level_features_bg = q_level_features_bg.view(q_b, q_n, -1) 255 | 256 | 257 | correlative_map_fg = F.cosine_similarity(q_level_features_fg, sp_feat_fg_e2) 258 | correlative_map_fg_e2 = correlative_map_fg.view(q_b, 1, q_h, q_w) 259 | 260 | correlative_map_bg = F.cosine_similarity(q_level_features_bg, sp_feat_bg_e2) 261 | correlative_map_bg_e2 = correlative_map_bg.view(q_b, 1, q_h, q_w) 262 | 263 | 264 | sp_prior_e4 = F.interpolate(sp_prior, size=(e4.shape[-2:]), mode='nearest') 265 | attention_e4 = self.conv00(torch.cat([correlative_map_bg_e4, correlative_map_fg_e4, sp_prior_e4], 1)) 266 | 267 | sp_prior_e3 = F.interpolate(sp_prior, size=(e3.shape[-2:]), mode='nearest') 268 | attention_e3 = self.conv11(torch.cat([correlative_map_bg_e3, correlative_map_fg_e3, sp_prior_e3], 1)) 269 | 270 | sp_prior_e2 = F.interpolate(sp_prior, size=(e2.shape[-2:]), mode='nearest') 271 | attention_e2 = self.conv22(torch.cat([correlative_map_bg_e2, correlative_map_fg_e2, sp_prior_e2], 1)) 272 | 273 | bn = torch.cat([bn, e4], 1) 274 | bn_sa = self.af_4(bn, attention_e4) 275 | d4 = self.decode4(bn, None, ind4) 276 | 277 | d4 = torch.cat([d4, e3], 1) 278 | d4_sa = self.af_3(d4, attention_e3) 279 | d3 = self.decode3(torch.cat([d4], 1), None, ind3) 280 | 281 | d3 = torch.cat([d3, e2], 1) 282 | d3_sa = self.af_2(d3, attention_e2) 283 | 284 | d4_sa = F.interpolate(d4_sa, size=(d3_sa.shape[-2:]), mode='nearest') 285 | bn_sa = F.interpolate(bn_sa, size=(d3_sa.shape[-2:]), mode='nearest') 286 | 287 | logit = self.mask_conv(torch.cat([bn_sa, d4_sa, d3_sa], 1)) 288 | 289 | 290 | logit = F.interpolate(logit, size=(inpt.shape[-2:]), mode='nearest') 291 | 292 | 293 | max_corr = F.interpolate(max_corr_e3.float(), size=(inpt.shape[-2:]), mode='nearest').requires_grad_() 294 | 295 | 296 | return logit, sp_prior, max_corr 297 | 298 | 299 | class RAP(nn.Module): 300 | 301 | def __init__(self, params): 302 | super(RAP, self).__init__() 303 | self.segmentor = SDnetSegmentor(params) 304 | 305 | def forward(self, input1, input_sp, input_mask): 306 | ''' 307 | :param input1: 308 | :param input2: 309 | :return: 310 | ''' 311 | segment = self.segmentor(input1, input_sp, input_mask) 312 | return segment 313 | 314 | @property 315 | def is_cuda(self): 316 | """ 317 | Check if modules parameters are allocated on the GPU. 318 | """ 319 | return next(self.parameters()).is_cuda 320 | 321 | def save(self, path): 322 | """ 323 | Save modules with its parameters to the given path. Conventionally the 324 | path should end with "*.modules". 325 | 326 | Inputs: 327 | - path: path string 328 | """ 329 | print('Saving modules... %s' % path) 330 | torch.save(self, path) 331 | 332 | 333 | if __name__ == '__main__': 334 | 335 | import argparse 336 | from utils import config 337 | 338 | def get_configs(): 339 | parser = argparse.ArgumentParser(description='PyTorch Few Shot Semantic Segmentation') 340 | parser.add_argument('--config', type=str, default='./settings.yaml', help='config file') 341 | parser.add_argument('--mode', '-m', default='train', 342 | help='run mode, valid values are train and eval') 343 | parser.add_argument('--device', '-d', default=0, 344 | help='device to run on') 345 | args = parser.parse_args() 346 | assert args.config is not None 347 | cfg = config.load_cfg_from_cfg_file(args.config) 348 | return cfg, args 349 | 350 | cfgs, args = get_configs() 351 | print(cfgs.DATA) 352 | common_params, data_params, net_params, train_params, eval_params = cfgs.COMMON, cfgs.DATA, cfgs.NETWORK, \ 353 | cfgs.TRAINING, cfgs.EVAL 354 | 355 | few_shot_model = SDnetSegmentor(net_params) 356 | 357 | import torchinfo 358 | 359 | batch_size = 2 360 | shot = 1 361 | torchinfo.summary(few_shot_model, input_size=[(batch_size, 3, 256, 256), (shot, batch_size, 6, 256, 256), (shot, batch_size, 3, 256, 256)]) 362 | -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | import numbers 5 | import collections 6 | import cv2 7 | 8 | import torch 9 | 10 | manual_seed = 123 11 | torch.manual_seed(manual_seed) 12 | np.random.seed(manual_seed) 13 | torch.manual_seed(manual_seed) 14 | torch.cuda.manual_seed_all(manual_seed) 15 | random.seed(manual_seed) 16 | 17 | 18 | class Compose(object): 19 | # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()]) 20 | def __init__(self, segtransform): 21 | self.segtransform = segtransform 22 | 23 | def __call__(self, image, label): 24 | for t in self.segtransform: 25 | image, label = t(image, label) 26 | return image, label 27 | 28 | 29 | import time 30 | 31 | 32 | class ToTensor(object): 33 | # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 34 | def __call__(self, image, label): 35 | if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray): 36 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray" 37 | "[eg: data readed by cv2.imread()].\n")) 38 | if len(image.shape) > 3 or len(image.shape) < 2: 39 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n")) 40 | if len(image.shape) == 2: 41 | image = np.expand_dims(image, axis=2) 42 | if not len(label.shape) == 2: 43 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n")) 44 | 45 | image = torch.from_numpy(image.transpose((2, 0, 1))) 46 | if not isinstance(image, torch.FloatTensor): 47 | image = image.float() 48 | label = torch.from_numpy(label) 49 | if not isinstance(label, torch.LongTensor): 50 | label = label.long() 51 | return image, label 52 | 53 | 54 | class Normalize(object): 55 | # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std 56 | def __init__(self, mean, std=None): 57 | if std is None: 58 | assert len(mean) > 0 59 | else: 60 | assert len(mean) == len(std) 61 | self.mean = mean 62 | self.std = std 63 | 64 | def __call__(self, image, label): 65 | if self.std is None: 66 | for t, m in zip(image, self.mean): 67 | t.sub_(m) 68 | else: 69 | for t, m, s in zip(image, self.mean, self.std): 70 | t.sub_(m).div_(s) 71 | return image, label 72 | 73 | 74 | class Scaling(object): 75 | # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std 76 | def __init__(self, ): 77 | pass 78 | 79 | def __call__(self, image, label): 80 | return image / 255, label 81 | 82 | 83 | class Resize(object): 84 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). 85 | def __init__(self, size): 86 | self.size = size 87 | 88 | def __call__(self, image, label): 89 | 90 | value_scale = 255 91 | mean = [0.485, 0.456, 0.406] 92 | mean = [item * value_scale for item in mean] 93 | std = [0.229, 0.224, 0.225] 94 | std = [item * value_scale for item in std] 95 | 96 | def find_new_hw(ori_h, ori_w, test_size): 97 | if ori_h >= ori_w: 98 | ratio = test_size * 1.0 / ori_h 99 | new_h = test_size 100 | new_w = int(ori_w * ratio) 101 | elif ori_w > ori_h: 102 | ratio = test_size * 1.0 / ori_w 103 | new_h = int(ori_h * ratio) 104 | new_w = test_size 105 | 106 | if new_h % 8 != 0: 107 | new_h = (int(new_h / 8)) * 8 108 | else: 109 | new_h = new_h 110 | if new_w % 8 != 0: 111 | new_w = (int(new_w / 8)) * 8 112 | else: 113 | new_w = new_w 114 | return new_h, new_w 115 | 116 | test_size = self.size 117 | new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) 118 | # new_h, new_w = test_size, test_size 119 | image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) 120 | back_crop = np.zeros((test_size, test_size, 3)) 121 | # back_crop[:,:,0] = mean[0] 122 | # back_crop[:,:,1] = mean[1] 123 | # back_crop[:,:,2] = mean[2] 124 | back_crop[:new_h, :new_w, :] = image_crop 125 | image = back_crop 126 | 127 | s_mask = label 128 | new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) 129 | # new_h, new_w = test_size, test_size 130 | s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_NEAREST) 131 | back_crop_s_mask = np.ones((test_size, test_size)) * 255 132 | back_crop_s_mask[:new_h, :new_w] = s_mask 133 | label = back_crop_s_mask 134 | 135 | return image, label 136 | 137 | 138 | class test_Resize(object): 139 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). 140 | def __init__(self, size): 141 | self.size = size 142 | 143 | def __call__(self, image, label): 144 | 145 | value_scale = 255 146 | mean = [0.485, 0.456, 0.406] 147 | mean = [item * value_scale for item in mean] 148 | std = [0.229, 0.224, 0.225] 149 | std = [item * value_scale for item in std] 150 | 151 | def find_new_hw(ori_h, ori_w, test_size): 152 | if max(ori_h, ori_w) > test_size: 153 | if ori_h >= ori_w: 154 | ratio = test_size * 1.0 / ori_h 155 | new_h = test_size 156 | new_w = int(ori_w * ratio) 157 | elif ori_w > ori_h: 158 | ratio = test_size * 1.0 / ori_w 159 | new_h = int(ori_h * ratio) 160 | new_w = test_size 161 | 162 | if new_h % 8 != 0: 163 | new_h = (int(new_h / 8)) * 8 164 | else: 165 | new_h = new_h 166 | if new_w % 8 != 0: 167 | new_w = (int(new_w / 8)) * 8 168 | else: 169 | new_w = new_w 170 | return new_h, new_w 171 | else: 172 | return ori_h, ori_w 173 | 174 | test_size = self.size 175 | new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) 176 | if new_w != image.shape[0] or new_h != image.shape[1]: 177 | image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) 178 | else: 179 | image_crop = image.copy() 180 | back_crop = np.zeros((test_size, test_size, 3)) 181 | back_crop[:new_h, :new_w, :] = image_crop 182 | image = back_crop 183 | 184 | s_mask = label 185 | new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) 186 | if new_w != s_mask.shape[0] or new_h != s_mask.shape[1]: 187 | s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)), 188 | interpolation=cv2.INTER_NEAREST) 189 | back_crop_s_mask = np.ones((test_size, test_size)) * 255 190 | back_crop_s_mask[:new_h, :new_w] = s_mask 191 | label = back_crop_s_mask 192 | 193 | return image, label 194 | 195 | 196 | class RandScale(object): 197 | # Randomly resize image & label with scale factor in [scale_min, scale_max] 198 | def __init__(self, scale, aspect_ratio=None): 199 | assert (isinstance(scale, collections.Iterable) and len(scale) == 2) 200 | if isinstance(scale, collections.Iterable) and len(scale) == 2 \ 201 | and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \ 202 | and 0 < scale[0] < scale[1]: 203 | self.scale = scale 204 | else: 205 | raise (RuntimeError("segtransform.RandScale() scale param error.\n")) 206 | if aspect_ratio is None: 207 | self.aspect_ratio = aspect_ratio 208 | elif isinstance(aspect_ratio, collections.Iterable) and len(aspect_ratio) == 2 \ 209 | and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \ 210 | and 0 < aspect_ratio[0] < aspect_ratio[1]: 211 | self.aspect_ratio = aspect_ratio 212 | else: 213 | raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n")) 214 | 215 | def __call__(self, image, label): 216 | temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random() 217 | temp_aspect_ratio = 1.0 218 | if self.aspect_ratio is not None: 219 | temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random() 220 | temp_aspect_ratio = math.sqrt(temp_aspect_ratio) 221 | scale_factor_x = temp_scale * temp_aspect_ratio 222 | scale_factor_y = temp_scale / temp_aspect_ratio 223 | image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR) 224 | label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST) 225 | return image, label 226 | 227 | 228 | class Crop(object): 229 | """Crops the given ndarray image (H*W*C or H*W). 230 | Args: 231 | size (sequence or int): Desired output size of the crop. If size is an 232 | int instead of sequence like (h, w), a square crop (size, size) is made. 233 | """ 234 | 235 | def __init__(self, size, crop_type='center', padding=None, ignore_label=255): 236 | self.size = size 237 | if isinstance(size, int): 238 | self.crop_h = size 239 | self.crop_w = size 240 | elif isinstance(size, collections.Iterable) and len(size) == 2 \ 241 | and isinstance(size[0], int) and isinstance(size[1], int) \ 242 | and size[0] > 0 and size[1] > 0: 243 | self.crop_h = size[0] 244 | self.crop_w = size[1] 245 | else: 246 | raise (RuntimeError("crop size error.\n")) 247 | if crop_type == 'center' or crop_type == 'rand': 248 | self.crop_type = crop_type 249 | else: 250 | raise (RuntimeError("crop type error: rand | center\n")) 251 | if padding is None: 252 | self.padding = padding 253 | elif isinstance(padding, list): 254 | if all(isinstance(i, numbers.Number) for i in padding): 255 | self.padding = padding 256 | else: 257 | raise (RuntimeError("padding in Crop() should be a number list\n")) 258 | if len(padding) != 3: 259 | raise (RuntimeError("padding channel is not equal with 3\n")) 260 | else: 261 | raise (RuntimeError("padding in Crop() should be a number list\n")) 262 | if isinstance(ignore_label, int): 263 | self.ignore_label = ignore_label 264 | else: 265 | raise (RuntimeError("ignore_label should be an integer number\n")) 266 | 267 | def __call__(self, image, label): 268 | h, w = label.shape 269 | 270 | pad_h = max(self.crop_h - h, 0) 271 | pad_w = max(self.crop_w - w, 0) 272 | pad_h_half = int(pad_h / 2) 273 | pad_w_half = int(pad_w / 2) 274 | if pad_h > 0 or pad_w > 0: 275 | if self.padding is None: 276 | raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n")) 277 | image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, 278 | cv2.BORDER_CONSTANT, value=self.padding) 279 | label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, 280 | cv2.BORDER_CONSTANT, value=self.ignore_label) 281 | h, w = label.shape 282 | raw_label = label 283 | raw_image = image 284 | 285 | if self.crop_type == 'rand': 286 | h_off = random.randint(0, h - self.crop_h) 287 | w_off = random.randint(0, w - self.crop_w) 288 | else: 289 | h_off = int((h - self.crop_h) / 2) 290 | w_off = int((w - self.crop_w) / 2) 291 | image = image[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 292 | label = label[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 293 | raw_pos_num = np.sum(raw_label == 1) 294 | pos_num = np.sum(label == 1) 295 | crop_cnt = 0 296 | while (pos_num < 0.85 * raw_pos_num and crop_cnt <= 30): 297 | image = raw_image 298 | label = raw_label 299 | if self.crop_type == 'rand': 300 | h_off = random.randint(0, h - self.crop_h) 301 | w_off = random.randint(0, w - self.crop_w) 302 | else: 303 | h_off = int((h - self.crop_h) / 2) 304 | w_off = int((w - self.crop_w) / 2) 305 | image = image[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 306 | label = label[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 307 | raw_pos_num = np.sum(raw_label == 1) 308 | pos_num = np.sum(label == 1) 309 | crop_cnt += 1 310 | if crop_cnt >= 50: 311 | image = cv2.resize(raw_image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) 312 | label = cv2.resize(raw_label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) 313 | 314 | if image.shape != (self.size[0], self.size[0], 3): 315 | image = cv2.resize(image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) 316 | label = cv2.resize(label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) 317 | 318 | return image, label 319 | 320 | class Resize_fy(object): 321 | def __init__(self, size): 322 | self.size = size 323 | 324 | def __call__(self, image, label): 325 | image = cv2.resize(image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) 326 | label = cv2.resize(label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) 327 | 328 | return image, label 329 | 330 | class RandRotate(object): 331 | # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max] 332 | def __init__(self, rotate, padding, ignore_label=255, p=0.5): 333 | assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2) 334 | if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]: 335 | self.rotate = rotate 336 | else: 337 | raise (RuntimeError("segtransform.RandRotate() scale param error.\n")) 338 | assert padding is not None 339 | assert isinstance(padding, list) and len(padding) == 3 340 | if all(isinstance(i, numbers.Number) for i in padding): 341 | self.padding = padding 342 | else: 343 | raise (RuntimeError("padding in RandRotate() should be a number list\n")) 344 | assert isinstance(ignore_label, int) 345 | self.ignore_label = ignore_label 346 | self.p = p 347 | 348 | def __call__(self, image, label): 349 | if random.random() < self.p: 350 | angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random() 351 | h, w = label.shape 352 | matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 353 | image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, 354 | borderValue=self.padding) 355 | label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, 356 | borderValue=self.ignore_label) 357 | return image, label 358 | 359 | 360 | class RandomHorizontalFlip(object): 361 | def __init__(self, p=0.5): 362 | self.p = p 363 | 364 | def __call__(self, image, label): 365 | if random.random() < self.p: 366 | image = cv2.flip(image, 1) # 0 vertical 1 horizon -1 v and h 367 | label = cv2.flip(label, 1) 368 | return image, label 369 | 370 | 371 | class RandomVerticalFlip(object): 372 | def __init__(self, p=0.5): 373 | self.p = p 374 | 375 | def __call__(self, image, label): 376 | if random.random() < self.p: 377 | image = cv2.flip(image, 0) 378 | label = cv2.flip(label, 0) 379 | return image, label 380 | 381 | 382 | class RandomGaussianBlur(object): 383 | def __init__(self, radius=5): 384 | self.radius = radius 385 | 386 | def __call__(self, image, label): 387 | if random.random() < 0.5: 388 | image = cv2.GaussianBlur(image, (self.radius, self.radius), 0) 389 | return image, label 390 | 391 | 392 | class RGB2BGR(object): 393 | # Converts image from RGB order to BGR order, for model initialized from Caffe 394 | def __call__(self, image, label): 395 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 396 | return image, label 397 | 398 | 399 | class BGR2RGB(object): 400 | # Converts image from BGR order to RGB order, for model initialized from Pytorch 401 | def __call__(self, image, label): 402 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 403 | return image, label 404 | -------------------------------------------------------------------------------- /dataset/dataset_us.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import torch 6 | import random 7 | import nibabel as nib 8 | import nrrd 9 | import cv2 10 | 11 | 12 | 13 | def make_dataset(sub_list): 14 | ''' 15 | :param sub_list: specify cls index as train or val set , here can be [LV(2,3,4) RV(4,F) LA(2,3,4) LAA] 16 | :return: 17 | ''' 18 | # same view cannot in train and val set 19 | 20 | data_dict = {'LV2': r'D:\Dataset\PH_30_UNET\PH_LV_A2C_TRAIN', 21 | 'LV3': r'D:\Dataset\PH_30_UNET\PH_LV_A3C_TRAIN', 22 | 'LV4': r'D:\Dataset\PH_30_UNET\PH_LV_A4C_TRAIN', 23 | 'LA2': r'D:\Dataset\PH_30_UNET\PH_LA_A2C_TRAIN', 24 | 'LA3': r'D:\Dataset\PH_30_UNET\PH_LA_A3C_TRAIN', 25 | 'LA4': r'D:\Dataset\PH_30_UNET\PH_LA_A4C_TRAIN', 26 | } 27 | 28 | 29 | 30 | # same view cannot in train and val set 31 | # data_dict = {'LV2': '/home/fengyong/Dataset/PH_30/PH_LV_A2C', 32 | # 'LV3': '/home/fengyong/Dataset/PH_30/PH_LV_A3C', 33 | # 'LV4': '/home/fengyong/Dataset/PH_30/PH_LV_A4C', 34 | # 'RV4': '/home/fengyong/Dataset/PH_30/RV', 35 | # 'LA2': '/home/fengyong/Dataset/PH_30/PH_LA_A2C', 36 | # 'LA3': '/home/fengyong/Dataset/PH_30/PH_LA_A3C', 37 | # 'LA4': '/home/fengyong/Dataset/PH_30/PH_LA_A4C', 38 | # 'LAA': '/home/fengyong/Dataset/PH_30/LAA', 39 | # } 40 | 41 | print(f"Processing data...{sub_list}") 42 | 43 | image_label_list = [] 44 | sub_class_file_list = {} 45 | # for sub_c in sub_list: 46 | # sub_class_file_list[sub_c] = [] 47 | global_cls_id = 1 48 | total_cls = [] 49 | for c in sub_list: 50 | root_c = data_dict[c] 51 | pids_pair = sorted(os.listdir(root_c)) 52 | temp_cls = total_cls 53 | for idx in range(0, len(pids_pair), 2): 54 | pid_img = pids_pair[idx] 55 | print(pid_img) 56 | pid_lab = pids_pair[idx+1] 57 | image_name = os.path.join(root_c, pid_img) 58 | # image_name = [name for name in img_paths if os.path.isfile(name)][0] 59 | 60 | label_name = os.path.join(root_c, pid_lab) 61 | # label_name = [name for name in msk_paths if os.path.isfile(name)][0] 62 | item = (image_name, label_name, c) 63 | 64 | if temp_cls == total_cls: 65 | # load label check cls 66 | if os.path.splitext(label_name)[-1] == '.gz': 67 | mask_nib = nib.load(label_name) 68 | label = mask_nib.get_data().transpose(2, 1, 0) 69 | else: 70 | label = nrrd.read(label_name)[0].transpose((2,1,0)) 71 | 72 | label_class = np.unique(label).tolist() 73 | if 0 in label_class: 74 | label_class.remove(0) 75 | temp_cls = [c+'_'+str(lc) for lc in label_class] 76 | 77 | if len(label_class) > 0: 78 | image_label_list.append(item) 79 | 80 | for t_c in temp_cls: 81 | if t_c not in sub_class_file_list.keys(): 82 | sub_class_file_list[t_c] = [] 83 | sub_class_file_list[t_c].append(item) 84 | 85 | total_cls += temp_cls 86 | np.savez('train_la_list.npz', img_lab=image_label_list, cls_file=sub_class_file_list) 87 | print('DOne make data...') 88 | return image_label_list, sub_class_file_list 89 | 90 | 91 | class SemData_US(Dataset): 92 | def __init__(self, shot=1, transform=None, mode='train'): 93 | assert mode in ['train', 'val', 'test'] 94 | 95 | self.mode = mode 96 | self.shot = shot 97 | self.sub_list = ['LV2', 'LV3', 'LV4'] # ['LA2', 'LA3', 'LA4', 'LAA'] 98 | # self.sub_list = ['LA2', 'LA3', 'LA4'] 99 | # self.sub_list = ['LV2', 'LV3', 'LV4', 'RV4', 'LAA'] 100 | self.sub_val_list = ['LV2'] 101 | print('sub_list: ', self.sub_list) 102 | print('sub_val_list: ', self.sub_val_list) 103 | 104 | if self.mode == 'train': 105 | self.data_list, self.sub_class_file_list = make_dataset(self.sub_list) 106 | print('Load Train Set') 107 | # file_list = np.load('./coco_train_list.npz', allow_pickle=True) 108 | # self.data_list = list(file_list['img_lab']) 109 | # self.sub_class_file_list = file_list['cls_file'].item() 110 | 111 | # assert len(self.sub_class_file_list.keys()) == len(self.sub_list) 112 | elif self.mode == 'val': 113 | print('Load Val Set') 114 | self.data_list, self.sub_class_file_list = make_dataset(self.sub_val_list) 115 | # file_list = np.load('./coco_val_list.npz', allow_pickle=True) 116 | # self.data_list = list(file_list['img_lab']) 117 | # self.sub_class_file_list = file_list['cls_file'].item() 118 | # assert len(self.sub_class_file_list.keys()) == len(self.sub_val_list) 119 | self.transform = transform 120 | self.slice_num = 3 121 | self.w_ = (self.slice_num-1)//2 122 | 123 | def __len__(self): 124 | return len(self.data_list) 125 | 126 | def __getitem__(self, index): 127 | label_class = [] 128 | image_path, label_path, mode_name = self.data_list[index] 129 | query_path = image_path 130 | 131 | # if 'LA_A2C' in image_path: 132 | # choosen_cls = 0 133 | # sp_prior = torch.from_numpy(np.load('la2.npy')) 134 | # if 'LA_A3C' in image_path: 135 | # choosen_cls = 0 136 | # sp_prior = torch.from_numpy(np.load('la3.npy')) 137 | # if 'LA_A4C' in image_path: 138 | # choosen_cls = 0 139 | # sp_prior = torch.from_numpy(np.load('la4.npy')) 140 | # if 'LV_A2C' in image_path: 141 | # choosen_cls = 0 142 | # sp_prior = torch.from_numpy(np.load('lv2.npy')) 143 | # if 'LV_A3C' in image_path: 144 | # choosen_cls = 0 145 | # sp_prior = torch.from_numpy(np.load('lv3.npy')) 146 | # if 'LV_A4C' in image_path: 147 | # choosen_cls = 0 148 | # sp_prior = torch.from_numpy(np.load('lv4.npy')) 149 | # if 'RV' in image_path: 150 | # choosen_cls = 1 151 | # sp_prior = torch.from_numpy(np.load('rv4.npy')) 152 | # if 'LAA' in image_path: 153 | # choosen_cls = 2 154 | # sp_prior = torch.from_numpy(np.load('laa.npy')) 155 | 156 | # load label check cls 157 | 158 | if os.path.splitext(label_path)[-1] == '.gz': 159 | mask_nib = nib.load(label_path) 160 | label = mask_nib.get_data().transpose(2, 1, 0) 161 | else: 162 | label = nrrd.read(label_path)[0].transpose((2, 1, 0)) 163 | 164 | image = nrrd.read(image_path)[0].transpose((2,1,0)) 165 | # slice_idx = random.randint(0, image.shape[0]-1) #[a,b] 166 | slice_idx = random.randint(self.w_, image.shape[0] - 1 - self.w_) # [a,b] 167 | # image = image[slice_idx].astype(np.float32) 168 | image = image[slice_idx - self.w_:slice_idx + self.w_ + 1].astype(np.float32) 169 | # image = np.repeat(image[slice_idx][..., np.newaxis], 3, -1) 170 | label = label[slice_idx - self.w_:slice_idx + self.w_ + 1].astype(np.uint8) 171 | raw_label = label.copy() 172 | 173 | if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]: 174 | raise (RuntimeError("Query Image & label shape mismatch: " + image_path + " " + label_path + "\n")) 175 | label_class = np.unique(label).tolist() 176 | # print('cls unique', label_class) 177 | 178 | if 0 in label_class: 179 | label_class.remove(0) 180 | 181 | if len(label_class) == 0: 182 | print('lb = 0,', image_path, slice_idx) 183 | label_class = [1] 184 | # assert len(label_class) > 0 185 | 186 | chosen_idx = random.randint(1, len(label_class)) - 1 187 | class_chosen = label_class[chosen_idx] 188 | # class_chosen = class_chosen 189 | # if self.mode == 'val': 190 | # print('chosen class:', class_chosen, 'from: ', label_class, query_path) 191 | target_pix = np.where(label == class_chosen) 192 | 193 | if target_pix[0].shape[0] > 0 and len(target_pix)==3: 194 | label[:, :, :] = 0 195 | label[target_pix[0], target_pix[1], target_pix[2]] = 1 196 | elif len(target_pix)==2: 197 | label[:, :] = 0 198 | label[target_pix[0], target_pix[1]] = 1 199 | 200 | chosen_mode_name = mode_name + '_' + str(class_chosen) 201 | file_class_chosen = self.sub_class_file_list[chosen_mode_name] 202 | num_file = len(file_class_chosen) 203 | 204 | support_image_path_list = [] 205 | support_label_path_list = [] 206 | support_idx_list = [] 207 | for k in range(self.shot): 208 | support_idx = random.randint(1, num_file) - 1 209 | support_image_path = image_path 210 | support_label_path = label_path 211 | while (( 212 | support_image_path == image_path and support_label_path == label_path) or support_idx in support_idx_list): 213 | support_idx = random.randint(1, num_file) - 1 214 | support_image_path, support_label_path, mode = file_class_chosen[support_idx] 215 | support_idx_list.append(support_idx) 216 | support_image_path_list.append(support_image_path) 217 | support_label_path_list.append(support_label_path) 218 | # print('sup list:', support_image_path_list) 219 | 220 | support_image_list = [] 221 | support_label_list = [] 222 | subcls_list = [] 223 | for k in range(self.shot): 224 | if self.mode == 'train': 225 | # subcls_list.append(self.sub_list.index(class_chosen)) 226 | subcls_list.append(chosen_mode_name) 227 | else: 228 | # subcls_list.append(self.sub_val_list.index(class_chosen)) 229 | subcls_list.append(chosen_mode_name) 230 | support_image_path = support_image_path_list[k] 231 | support_label_path = support_label_path_list[k] 232 | # print(support_image_path, support_label_path) 233 | 234 | support_image = nrrd.read(support_image_path)[0].transpose((2, 1, 0)) 235 | # slice_idx = random.randint(0, support_image.shape[0] - 1) # [a,b] 236 | slice_idx = random.randint(0+self.w_, support_image.shape[0] - 1 -self.w_) # [a,b] 237 | # support_image = np.repeat(support_image[slice_idx][..., np.newaxis], 3, -1).astype(np.float32) 238 | support_image = support_image[slice_idx-self.w_:slice_idx+self.w_+1].astype(np.float32) 239 | 240 | # load label check cls 241 | if os.path.splitext(support_label_path)[-1] == '.gz': 242 | mask_nib = nib.load(support_label_path) 243 | support_label = mask_nib.get_data().transpose(2, 1, 0) 244 | else: 245 | support_label = nrrd.read(support_label_path)[0].transpose((2, 1, 0)) 246 | # support_label = support_label[slice_idx].astype(np.uint8) 247 | support_label = support_label[slice_idx - self.w_: slice_idx + self.w_ + 1].astype(np.uint8) 248 | 249 | target_pix = np.where(support_label == class_chosen) 250 | # ignore_pix = np.where(support_label == 255) 251 | support_label[:, :, :] = 0 252 | support_label[target_pix[0], target_pix[1], target_pix[2]] = 1 253 | # support_label[ignore_pix[0], ignore_pix[1]] = 255 254 | if support_image.shape[0] != support_label.shape[0] or support_image.shape[1] != support_label.shape[1]: 255 | raise (RuntimeError( 256 | "Support Image & label shape mismatch: " + support_image_path + " " + support_label_path + "\n")) 257 | support_image_list.append(support_image) 258 | support_label_list.append(support_label) 259 | assert len(support_label_list) == self.shot and len(support_image_list) == self.shot 260 | 261 | flip = (random.random() > 0.5) 262 | if self.transform is not None: 263 | trans_image = [] 264 | trans_label = [] 265 | if flip: 266 | for i in range(self.slice_num): 267 | image[i] = cv2.flip(image[i], 1) 268 | label[i] = cv2.flip(label[i], 1) 269 | for i in range(self.slice_num): 270 | trans_image_i, trans_label_i = self.transform(image[i], label[i]) 271 | trans_image.append(trans_image_i) 272 | trans_label.append(trans_label_i) 273 | # image, label = self.transform(image, label) 274 | 275 | image = torch.stack(trans_image) 276 | label = torch.stack(trans_label).unsqueeze(1) 277 | 278 | 279 | for k in range(self.shot): 280 | trans_image_si = [] 281 | trans_label_si = [] 282 | if flip: 283 | for i in range(self.slice_num): 284 | support_image_list[k][i] = cv2.flip(support_image_list[k][i], 1) 285 | support_label_list[k][i] = cv2.flip(support_label_list[k][i], 1) 286 | for i in range(self.slice_num): 287 | support_image_ki, support_label_ki = self.transform(support_image_list[k][i], 288 | support_label_list[k][i]) 289 | 290 | trans_image_si.append(support_image_ki) 291 | trans_label_si.append(support_label_ki.unsqueeze(0)) 292 | 293 | 294 | support_image_list[k] = torch.stack(trans_image_si) 295 | support_label_list[k] = torch.stack(trans_label_si) 296 | 297 | s_xs = torch.stack(support_image_list, 0) 298 | s_ys = torch.stack(support_label_list, 0) 299 | 300 | # sp_prior = torch.mean(s_ys.float(), 0) # slice 1 256, 256 301 | 302 | if self.mode == 'train': 303 | return image, label, s_xs, s_ys, subcls_list#, sp_prior, choosen_cls 304 | else: 305 | return image, label, s_xs, s_ys, subcls_list, raw_label, query_path, support_image_path_list 306 | 307 | # model: s_x=torch.FloatTensor(1,1,3,473,473).cuda(), s_y=torch.FloatTensor(1,1,473,473).cuda() 308 | # raw_label: 0-1 2D 309 | # image: 3D cv2 read query image 310 | # label: 2D cv2 read label image 311 | # s_x : [1, n, 3, h, w] 312 | # s_x : [1, n, h, w] 313 | # subcls_liit : [chosen cls] 314 | 315 | 316 | if __name__ == '__main__': 317 | 318 | # l1, d1 = make_dataset(['LV4', 'LAA']) 319 | # print(len(l1)) 320 | # print(d1.keys()) 321 | 322 | 323 | # import shutil 324 | # root = r'D:\Dataset\NRRD\Train' 325 | # tar = r'D:\Dataset\NRRD\BI_V_ALL' 326 | # pids = os.listdir(root) 327 | # 328 | # for pid in pids: 329 | # ori_img = os.path.join(root, pid, 'im.nrrd') 330 | # tar_img = os.path.join(tar, f'{pid}_im.nrrd') 331 | # 332 | # ori_msk = os.path.join(root, pid, 'm.nrrd') 333 | # tar_msk = os.path.join(tar, f'{pid}_m.nrrd') 334 | # shutil.copyfile(ori_img, tar_img) 335 | # shutil.copyfile(ori_msk, tar_msk) 336 | 337 | 338 | import transform 339 | 340 | # value_scale = 255 341 | # mean = [0.485, 0.456, 0.406] 342 | # mean = [item * value_scale for item in mean] 343 | # std = [0.229, 0.224, 0.225] 344 | # std = [item * value_scale for item in std] 345 | # train_transform = [ 346 | # # transform.RandScale([0.8, 1.25]), 347 | # # transform.RandRotate([-10, 10], padding=[0,0,0], ignore_label=0), 348 | # # transform.RandomGaussianBlur(), 349 | # # transform.RandomHorizontalFlip(), 350 | # transform.Crop([641, 641], crop_type='rand', padding=mean, ignore_label=255), 351 | # # transform.Resize([256, 256]), 352 | # transform.ToTensor()] 353 | # # transform.Normalize(mean=mean, std=std)] 354 | # train_transform = transform.Compose(train_transform) 355 | # 356 | # data = SemData_US(shot=5, transform=train_transform, mode='val') 357 | # 358 | train_transform = [ 359 | # transform.RandScale([0.8, 1.25]), 360 | # transform.RandRotate([-10, 10], padding=mean, ignore_label=0), 361 | transform.RandomGaussianBlur(), 362 | # transform.RandomHorizontalFlip(), 363 | # transform.Crop([256, 256], crop_type='rand', padding=mean, ignore_label=0), 364 | transform.Resize_fy([256, 256]), 365 | transform.ToTensor(), 366 | transform.Scaling()] 367 | train_transform = transform.Compose(train_transform) 368 | train_data = SemData_US(shot=1, transform=train_transform, mode='train') 369 | train_sampler = None 370 | trainloader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=(train_sampler is None), 371 | num_workers=1, pin_memory=True, sampler=train_sampler, 372 | drop_last=False) 373 | 374 | for i in trainloader: 375 | image, label, s_xs, s_ys, subcls_list, sp_prior = i 376 | break 377 | # image, label, s_xs, s_ys, subcls_list, raw_label, query_path, support_image_path_list = data[0] 378 | 379 | print(image.shape, label.shape, s_xs.shape, s_ys.shape) 380 | print(subcls_list) 381 | # 382 | nrrd.write('t.nrrd', s_xs.squeeze().cpu().numpy().transpose(2,1,0)) 383 | nrrd.write('tl.nrrd', s_ys.squeeze().cpu().numpy().transpose(2, 1, 0).astype(np.uint8)) 384 | 385 | -------------------------------------------------------------------------------- /solver_us_pro.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pdb 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.optim import lr_scheduler 9 | 10 | import utils.common_utils as common_utils 11 | import losses 12 | # from nn_additional_losses import losses 13 | from utils.data_utils import split_batch 14 | from utils.log_utils import LogWriter 15 | from cus_loss import * 16 | import nrrd 17 | 18 | CHECKPOINT_DIR = 'checkpoints' 19 | CHECKPOINT_EXTENSION = 'pth.tar' 20 | 21 | def gradient_loss(ss, penalty='l2'): 22 | '''s = flow''' 23 | d = 0 24 | for s in ss: 25 | dy = torch.abs(s[:, :, 1:, :] - s[:, :, :-1, :]) 26 | dx = torch.abs(s[:, :, :, 1:] - s[:, :, :, :-1]) 27 | # dz = torch.abs(s[:, :, :, :, 1:] - s[:, :, :, :, :-1]) 28 | 29 | if (penalty == 'l2'): 30 | dy = dy * dy 31 | dx = dx * dx 32 | # dz = dz * dz 33 | 34 | d += (torch.mean(dx) + torch.mean(dy))/2 35 | return d / s.shape[0] 36 | 37 | class Solver(object): 38 | 39 | def __init__(self, 40 | model, 41 | exp_name, 42 | device, 43 | num_class, 44 | optim=torch.optim.Adam, 45 | optim_args={}, 46 | loss_func=losses.DiceLoss(), 47 | model_name='OneShotSegmentor', 48 | labels=None, 49 | num_epochs=10, 50 | log_nth=5, 51 | lr_scheduler_step_size=5, 52 | lr_scheduler_gamma=0.5, 53 | use_last_checkpoint=True, 54 | exp_dir='experiments', 55 | log_dir='logs'): 56 | 57 | self.device = device 58 | self.model = model 59 | 60 | self.model_name = model_name 61 | self.labels = labels 62 | self.num_epochs = num_epochs 63 | if torch.cuda.is_available(): 64 | self.loss_func = loss_func.cuda(device) 65 | else: 66 | self.loss_func = loss_func 67 | 68 | self.optim = optim([{'params': filter(lambda p: p.requires_grad, self.model.parameters()), 'lr': 5e-5, 'weight_decay': 0.00001}]) 69 | 70 | exp_dir_path = os.path.join(exp_dir, exp_name) 71 | common_utils.create_if_not(exp_dir_path) 72 | common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) 73 | self.exp_dir_path = exp_dir_path 74 | 75 | self.log_nth = log_nth 76 | self.logWriter = LogWriter( 77 | num_class, log_dir, exp_name, use_last_checkpoint, labels) 78 | 79 | self.use_last_checkpoint = use_last_checkpoint 80 | self.start_epoch = 1 81 | self.start_iteration = 1 82 | 83 | self.best_ds_mean = 0 84 | self.best_ds_mean_epoch = 0 85 | self.bloss = losses.SurfaceLoss() 86 | self.onedice = onecls_DiceLoss() 87 | self.multidice = multicls_DiceLoss() 88 | 89 | # if use_last_checkpoint: 90 | # self.load_checkpoint('180_lv5hot') # 100_1shot_la 91 | # self.load_stn_checkpoint(1000) 92 | 93 | 94 | def train(self, train_loader, test_loader): 95 | """ 96 | Train a given model with the provided data. 97 | 98 | Inputs: 99 | - train_loader: train data in torch.utils.data.DataLoader 100 | - val_loader: val data in torch.utils.data.DataLoader 101 | """ 102 | model, optim = self.model, self.optim 103 | 104 | data_loader = { 105 | 'train': train_loader, 106 | 'val': test_loader 107 | } 108 | 109 | if torch.cuda.is_available(): 110 | torch.cuda.empty_cache() 111 | model.cuda(self.device) 112 | 113 | self.logWriter.log('START TRAINING. : model name = %s, device = %s' % ( 114 | self.model_name, torch.cuda.get_device_name(self.device))) 115 | current_iteration = self.start_iteration 116 | 117 | for epoch in range(self.start_epoch, self.num_epochs + 1): 118 | self.logWriter.log( 119 | 'train', "\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) 120 | phase = 'train' 121 | if phase == 'train': 122 | model.train() 123 | else: 124 | model.eval() 125 | # print('len data:', len(data_loader[phase])) 126 | print('epoch: ', epoch) 127 | for i_batch, sampled_batch in enumerate(data_loader[phase]): 128 | 129 | s_x = sampled_batch[2] # [B, Support, slice_num, 1, 256, 256] 130 | s_y = sampled_batch[3] # [B, Support, slice_num, 1, 256, 256] 131 | q_x = sampled_batch[0] # [B, slice_num, 1, 256, 256] 132 | q_y = sampled_batch[1].type(torch.LongTensor) # [B, slice_num, 1, 256, 256] 133 | # sp_prior = sampled_batch[5][:,1].cuda(self.device, non_blocking=True) 134 | # sp_prior = sampled_batch[5].unsqueeze(1).float().cuda(self.device, non_blocking=True) 135 | # cls = sampled_batch[6] 136 | 137 | #s_x = sampled_batch['s_x'] # [B, Support, slice_num=1, 1, 256, 256] 138 | X = s_x.squeeze(2) # [B, Support, 1, 256, 256] 139 | #s_y = sampled_batch['s_y'] # [B, Support, slice_num, 1, 256, 256] 140 | Y = s_y.squeeze(2) # [B, Support, 1, 256, 256] 141 | Y = Y.squeeze(2) # [B, Support, 256, 256] 142 | #q_x = sampled_batch['q_x'] # [B, slice_num, 1, 256, 256] 143 | q_x = q_x.squeeze(1) # [B, 1, 256, 256] 144 | #q_y = sampled_batch['q_y'] # [B, slice_num, 1, 256, 256] 145 | q_y = q_y.squeeze(1) # [B, 1, 256, 256] 146 | q_y = q_y.squeeze(1) # [B, 256, 256] 147 | 148 | query_input = q_x 149 | # condition_input = torch.cat((input1, y1.unsqueeze(1)), dim=1) slice =1 150 | # condition_input = torch.cat((input1, y1), dim=2) # slice=3 151 | y2 = q_y 152 | y2 = y2.type(torch.LongTensor) 153 | 154 | temp_out = [] 155 | temp_cond = [] 156 | 157 | # # multi-shot method 1 fuse all pred 158 | # loss_cond = 0 159 | # for sp_index in range(X.shape[1]): 160 | # input1 = X[:, sp_index, ...] # use 1 shot at first 161 | # y1 = Y[:, sp_index, ...] 162 | # y1 = y1.type(torch.LongTensor) 163 | # condition_input = torch.cat((input1, y1), dim=2) # slice=3 164 | # 165 | # if model.is_cuda: 166 | # condition_input, query_input, y2, y1 = condition_input.cuda(self.device, non_blocking=True), \ 167 | # query_input.cuda(self.device, non_blocking=True), \ 168 | # y2.cuda(self.device, non_blocking=True), \ 169 | # y1.cuda(self.device, non_blocking=True) 170 | # 171 | # # sp-se 172 | # # weights = model.conditioner(condition_input) 173 | # # output = model.segmentor(query_input, weights, sp_prior) 174 | # 175 | # # lba-net 176 | # condition_input = condition_input.view(-1, 6, 256, 256) 177 | # attention, cond_output = model.conditioner(condition_input) 178 | # # output = model.segmentor(query_input, attention, condition_input, y1.unsqueeze(1).float()) # slice=1 179 | # query_input = query_input.squeeze(2) 180 | # # output = model.segmentor(query_input, attention, condition_input, y1.unsqueeze(1).float()) # slice=1 181 | # output = model.segmentor(query_input, attention, condition_input, y1.float()) 182 | # temp_out.append(output) 183 | # temp_cond.append(cond_output) 184 | # # cond loss 185 | # loss_cond += self.loss_func(F.softmax(cond_output, dim=1), y1[:, 1].squeeze(1)) 186 | 187 | # method 2 prototype fused 188 | 189 | 190 | # import matplotlib 191 | # matplotlib.use('TkAgg') 192 | # import matplotlib.pyplot as plt 193 | # img = X[0, 0, 1:2, ...].squeeze().detach().cpu().numpy() 194 | # # img2 = tmp_sp_img_porior.squeeze().detach().cpu().numpy() 195 | # # img2 = inpt[:, 1:2, ...].squeeze().detach().cpu().numpy() 196 | # fig, axes = plt.subplots(1, 2) 197 | # axes[0].imshow(img) 198 | # # axes[1].imshow(img2) 199 | # # plt.imshow(img) 200 | # plt.show() 201 | 202 | temp_atten = [] 203 | cond_inputs = [] 204 | cond_y1s = [] 205 | # for sp_index in range(X.shape[1]): 206 | # input1 = X[:, sp_index, ...] # use 1 shot at first 207 | # y1 = Y[:, sp_index, ...] 208 | # y1 = y1.type(torch.LongTensor) 209 | # condition_input = torch.cat((input1, y1), dim=1) # slice=3 210 | # query_input = query_input.squeeze(2) 211 | # 212 | # if model.is_cuda: 213 | # condition_input, query_input, y2, y1 = condition_input.cuda(self.device, non_blocking=True), \ 214 | # query_input.cuda(self.device, non_blocking=True), \ 215 | # y2.cuda(self.device, non_blocking=True), \ 216 | # y1.cuda(self.device, non_blocking=True) 217 | 218 | # import nrrd 219 | # nrrd.write(f'prediction_la_dice_1000/y1.nrrd', 220 | # y1[0, 1].cpu().numpy().transpose(2, 1, 0).astype(np.uint8)) 221 | 222 | # lba-net 223 | # condition_input = condition_input.view(-1, 6, 256, 256) 224 | 225 | # import matplotlib 226 | # matplotlib.use('TkAgg') 227 | # import matplotlib.pyplot as plt 228 | # img = condition_input[0, 2:3,...].squeeze().detach().cpu().numpy() 229 | # img2 = condition_input[0, 1:2, ...].squeeze().detach().cpu().numpy() 230 | # # img2 = tmp_sp_img_porior.squeeze().detach().cpu().numpy() 231 | # # img2 = inpt[:, 1:2, ...].squeeze().detach().cpu().numpy() 232 | # fig, axes = plt.subplots(1, 2) 233 | # axes[0].imshow(img) 234 | # axes[1].imshow(img2) 235 | # # plt.imshow(img) 236 | # plt.show() 237 | 238 | # attention, cond_output = model.conditioner(condition_input) 239 | # temp_atten.append(attention) 240 | # 241 | # cond_inputs.append(condition_input) 242 | # cond_y1s.append(y1) 243 | 244 | # # pseg 245 | # query_input = query_input.squeeze(2) 246 | # attention_p, cond_output_p = model.psegmentor(query_input) 247 | # 248 | # base_map_list = [] 249 | # c_id_array = torch.arange(3, device='cuda') 250 | # for b_id in range(X.shape[0]): 251 | # c_id = cls[b_id] # cat_idx = fore cls index 252 | # # if c_id == 4: 253 | # # c_id = 2 254 | # c_mask = (c_id_array != c_id) 255 | # pseg_fg = torch.prod(1 - cond_output_p[b_id, c_mask, :, :], dim=0) # one channel only fg, bg 256 | # pseg_fg = pseg_fg * cond_output_p[b_id, c_id, :, :].unsqueeze(0).unsqueeze(0) 257 | # base_map_list.append(pseg_fg) 258 | # base_map = torch.cat(base_map_list, 0) 259 | # 260 | # choosen_cond = [] 261 | # cond_softmax = cond_output_p 262 | # for bs, c in enumerate(cls): 263 | # choosen_cond.append(cond_softmax[bs, c].unsqueeze(0).unsqueeze(0)) 264 | # cond_output_choosen = torch.cat(choosen_cond, 0) 265 | # loss_cond = self.onedice(cond_output_choosen, y2[:, 1]) 266 | # # get fusd atten and prototype for cond 267 | # max_attens = [] 268 | # for att_id in range(9): 269 | # atten_i = [] 270 | # for bs_id in range(X.shape[0]): 271 | # temp = [] 272 | # for sp_id in range(X.shape[1]): 273 | # temp.append(temp_atten[sp_id][0][att_id][bs_id]) 274 | # max_atten = torch.mean(torch.stack(temp), 0, keepdim=True)[0] 275 | # atten_i.append(max_atten) 276 | # max_attens.append(torch.stack(atten_i)) 277 | # 278 | # max_attens.append(None) 279 | # attention = list(attention) 280 | # attention[0] = list(attention[0]) 281 | # attention[0] = max_attens 282 | 283 | if model.is_cuda: 284 | y2 = y2.cuda(self.device, non_blocking=True) 285 | query_input = query_input.squeeze(2).cuda(self.device, non_blocking=True) 286 | cond_inputs = torch.cat([s_x.permute(1,0,2,3,4,5), s_y.permute(1,0,2,3,4,5)], 2).squeeze(3).cuda(self.device, non_blocking=True) 287 | cond_y1s = s_y.permute(1,0,2,3,4,5).squeeze(3).cuda(self.device, non_blocking=True) 288 | else: 289 | return -1 290 | # cond_inputs = torch.stack(cond_inputs) 291 | # cond_y1s = torch.stack(cond_y1s) 292 | sp_prior = None 293 | # output, sp_prior_pred, max_corr_e2, sp_img_prior = model.segmentor(query_input, cond_inputs, cond_y1s.float()) 294 | output, sp_prior_pred, max_corr_e2 = model.segmentor(query_input, cond_inputs, cond_y1s.float()) 295 | # sp_prior_pred, sp_img_prior, flow = model.segmentor(query_input, cond_inputs, 296 | # cond_y1s.float()) 297 | 298 | # sp_prior_pred, max_corr_e2, sp_img_prior, flow = model.segmentor(query_input, cond_inputs, 299 | # cond_y1s.float()) 300 | # output, sp_prior_pred, max_corr_e2, sp_img_prior, flow = model.segmentor(query_input, cond_inputs, 301 | # cond_y1s.float()) 302 | # 303 | # loss = self.loss_func(F.softmax(output, dim=1), y2) # slice=3 304 | # loss = self.loss_func(output, y2[:,1].squeeze(1)) # slice=3 305 | # print(output.shape) 306 | 307 | # loss = self.onedice(output, y2[:,1]) 308 | # loss = self.onedice(max_corr_e2, y2[:, 1]) 309 | loss = self.onedice(output, y2[:,1]) 310 | # loss_sp_mse = torch.mean((sp_img_prior - query_input[:, 1:2, ...]) ** 2) 311 | # loss_g = gradient_loss(flow) 312 | # loss += loss_sp_mse 313 | # loss += loss_g 314 | # print(loss_g, loss_sp_mse) 315 | # loss = self.onedice(max_corr_e2, y2[:, 1]) 316 | # loss_cor_e3 = self.onedice(max_corr_e3, y2[:, 1]) 317 | print('seg loss:', loss.item()) 318 | # print('sp loss: ', loss_sp.item()) 319 | # print('sp loss mse: ', loss_sp_mse.item()) 320 | # print('corr loss 2: ', loss_cor_e2.item()) 321 | # print('corr loss 3: ', loss_cor_e3.item()) 322 | print('-'*20) 323 | # loss += 0.3*loss_sp_mse 324 | # loss+= 0.3*loss_sp 325 | # loss+= 0.3*loss_cor_e2# +0.3*loss_cor_e3 326 | #dis_map = torch.tensor(losses.one_hot2dist(losses.class2one_hot(y2[:,1].squeeze(1), 2).cpu().numpy())).cuda(self.device, non_blocking=True) 327 | # loss_b = self.bloss(F.softmax(output, dim=1), dis_map) 328 | # loss +=loss_b 329 | # del condition_input, y2, y1#, dis_map 330 | torch.cuda.empty_cache() 331 | # align loss 332 | ''' 333 | _, q_pred = torch.max(F.softmax(output, dim=1), dim=1, keepdim=True) 334 | align_condition_input = torch.cat((query_input, q_pred.repeat(1,3,1,1)), dim=1) 335 | attention, cond_output = model.conditioner(align_condition_input) 336 | s1 = X[:, 0, ...].squeeze(2).cuda(self.device, non_blocking=True) # use 1 shot at first 337 | y1 = Y[:, 0, ...].squeeze(2).cuda(self.device, non_blocking=True) 338 | align_output = model.segmentor(s1, attention, align_condition_input, q_pred.repeat(1,3,1,1).unsqueeze(2).float()) 339 | align_loss = self.loss_func(F.softmax(align_output, dim=1), y1[:,1].squeeze(1)) 340 | loss += align_loss 341 | ''' 342 | optim.zero_grad() 343 | loss.backward() 344 | if phase == 'train': 345 | optim.step() 346 | 347 | if i_batch % self.log_nth == 0: 348 | self.logWriter.loss_per_iter( 349 | loss.item(), i_batch, current_iteration) 350 | # print('bloss loss: ', loss_b) 351 | # print('cond loss: ', loss_cond/5*0.3) 352 | current_iteration += 1 353 | 354 | if phase == 'val': 355 | if i_batch != len(data_loader[phase]) - 1: 356 | # print("#", end='', flush=True) 357 | pass 358 | else: 359 | print("100%", flush=True) 360 | if phase == 'train' and epoch % 10 ==0: 361 | self.logWriter.log('saving checkpoint ....') 362 | self.save_checkpoint({ 363 | 'epoch': epoch + 1, 364 | 'start_iteration': current_iteration + 1, 365 | 'arch': self.model_name, 366 | 'state_dict': model.state_dict(), 367 | 'optimizer': optim.state_dict(), 368 | #'scheduler_c': scheduler_c.state_dict(), 369 | #'optimizer_s': optim_s.state_dict(), 370 | 'best_ds_mean_epoch': self.best_ds_mean_epoch, 371 | #'scheduler_s': scheduler_s.state_dict() 372 | }, os.path.join(self.exp_dir_path, CHECKPOINT_DIR, 373 | 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) 374 | 375 | self.logWriter.log( 376 | "==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") 377 | self.logWriter.log('FINISH.') 378 | self.logWriter.close() 379 | 380 | def save_checkpoint(self, state, filename): 381 | torch.save(state, filename) 382 | 383 | def save_best_model(self, path): 384 | """ 385 | Save model with its parameters to the given path. Conventionally the 386 | path should end with "*.model". 387 | 388 | Inputs: 389 | - path: path string 390 | """ 391 | print('Saving model... %s' % path) 392 | print("Best Epoch... " + str(self.best_ds_mean_epoch)) 393 | self.load_checkpoint(self.best_ds_mean_epoch) 394 | 395 | torch.save(self.model, path) 396 | 397 | def load_checkpoint(self, epoch=None): 398 | if epoch is not None: 399 | checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, 400 | 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION) 401 | self._load_checkpoint_file(checkpoint_path) 402 | else: 403 | all_files_path = os.path.join( 404 | self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) 405 | list_of_files = glob.glob(all_files_path) 406 | if len(list_of_files) > 0: 407 | checkpoint_path = max(list_of_files, key=os.path.getctime) 408 | self._load_checkpoint_file(checkpoint_path) 409 | else: 410 | self.logWriter.log( 411 | "=> no checkpoint found at '{}' folder".format(os.path.join(self.exp_dir_path, CHECKPOINT_DIR))) 412 | 413 | def load_stn_checkpoint(self, epoch=None): 414 | checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, 415 | 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION) 416 | checkpoint = torch.load(checkpoint_path) 417 | state_dict = checkpoint['state_dict'] 418 | # for k in state_dict.keys(): 419 | # print(k) 420 | part_sd = {k: v for k, v in state_dict.items() if 'segmentor.unet' in k} 421 | # print(part_sd.keys()) 422 | self.model.state_dict().update(part_sd) 423 | for p in self.model.segmentor.unet.parameters(): 424 | # print(p) 425 | p.requires_grad = False 426 | return 0 427 | 428 | 429 | 430 | def _load_checkpoint_file(self, file_path): 431 | self.logWriter.log("=> loading checkpoint '{}'".format(file_path)) 432 | checkpoint = torch.load(file_path) 433 | self.start_epoch = checkpoint['epoch'] 434 | self.start_iteration = checkpoint['start_iteration'] 435 | self.best_ds_mean_epoch = checkpoint['best_ds_mean_epoch'] 436 | self.model.load_state_dict(checkpoint['state_dict']) 437 | self.optim.load_state_dict(checkpoint['optimizer']) 438 | 439 | for state in self.optim.state.values(): 440 | for k, v in state.items(): 441 | if torch.is_tensor(v): 442 | state[k] = v.to(self.device) 443 | 444 | 445 | self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format( 446 | file_path, checkpoint['epoch'])) 447 | -------------------------------------------------------------------------------- /utils/evaluator_slow.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import nibabel as nib 4 | import numpy as np 5 | import torch 6 | 7 | import utils.common_utils as common_utils 8 | import utils.data_utils as du 9 | import torch.nn.functional as F 10 | import shot_batch_sampler as SB 11 | 12 | 13 | def dice_score_binary(vol_output, ground_truth, no_samples=10, phase='train'): 14 | ground_truth = ground_truth.type(torch.FloatTensor) 15 | vol_output = vol_output.type(torch.FloatTensor) 16 | if phase == 'train': 17 | samples = np.random.choice(len(vol_output), no_samples) 18 | vol_output, ground_truth = vol_output[samples], ground_truth[samples] 19 | inter = 2 * torch.sum(torch.mul(ground_truth, vol_output)) 20 | union = torch.sum(ground_truth) + torch.sum(vol_output) + 0.0001 21 | 22 | return torch.div(inter, union) 23 | 24 | 25 | def dice_confusion_matrix(vol_output, ground_truth, num_classes, no_samples=10, mode='train'): 26 | dice_cm = torch.zeros(num_classes, num_classes) 27 | if mode == 'train': 28 | samples = np.random.choice(len(vol_output), no_samples) 29 | vol_output, ground_truth = vol_output[samples], ground_truth[samples] 30 | for i in range(num_classes): 31 | GT = (ground_truth == i).float() 32 | for j in range(num_classes): 33 | Pred = (vol_output == j).float() 34 | inter = torch.sum(torch.mul(GT, Pred)) 35 | union = torch.sum(GT) + torch.sum(Pred) + 0.0001 36 | dice_cm[i, j] = 2 * torch.div(inter, union) 37 | avg_dice = torch.mean(torch.diagflat(dice_cm)) 38 | return avg_dice, dice_cm 39 | 40 | 41 | def get_range(volume): 42 | batch, _, _ = volume.size() 43 | slice_with_class = torch.sum(volume.view(batch, -1), dim=1) > 10 44 | index = slice_with_class[:-1] - slice_with_class[1:] > 0 45 | seq = torch.Tensor(range(batch - 1)) 46 | range_index = seq[index].type(torch.LongTensor) 47 | return range_index 48 | 49 | 50 | def dice_score_perclass(vol_output, ground_truth, num_classes, no_samples=10, mode='train'): 51 | dice_perclass = torch.zeros(num_classes) 52 | if mode == 'train': 53 | samples = np.random.choice(len(vol_output), no_samples) 54 | vol_output, ground_truth = vol_output[samples], ground_truth[samples] 55 | for i in range(num_classes): 56 | GT = (ground_truth == i).float() 57 | Pred = (vol_output == i).float() 58 | inter = torch.sum(torch.mul(GT, Pred)) 59 | union = torch.sum(GT) + torch.sum(Pred) + 0.0001 60 | dice_perclass[i] = (2 * torch.div(inter, union)) 61 | return dice_perclass 62 | 63 | 64 | def binarize_label(volume, groud_truth, class_label): 65 | groud_truth = (groud_truth == class_label).type(torch.FloatTensor) 66 | batch, _, _ = groud_truth.size() 67 | slice_with_class = torch.sum(groud_truth.view(batch, -1), dim=1) > 10 68 | index = slice_with_class[:-1] - slice_with_class[1:] > 0 69 | seq = torch.Tensor(range(batch - 1)) 70 | range_index = seq[index].type(torch.LongTensor) 71 | groud_truth = groud_truth[slice_with_class] 72 | volume = volume[slice_with_class] 73 | condition_input = torch.cat((volume, groud_truth.unsqueeze(1)), dim=1) 74 | return condition_input, range_index.cpu().numpy() 75 | 76 | 77 | def evaluate_dice_score(model_path, 78 | num_classes, 79 | query_labels, 80 | data_dir, 81 | query_txt_file, 82 | support_txt_file, 83 | remap_config, 84 | orientation, 85 | prediction_path, device=0, logWriter=None, mode='eval', fold=None): 86 | print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") 87 | print("Loading model => " + model_path) 88 | batch_size = 20 89 | Num_support = 10 90 | with open(query_txt_file) as file_handle: 91 | volumes_query = file_handle.read().splitlines() 92 | 93 | # with open(support_txt_file) as file_handle: 94 | # volumes_support = file_handle.read().splitlines() 95 | 96 | model = torch.load(model_path) 97 | cuda_available = torch.cuda.is_available() 98 | if cuda_available: 99 | torch.cuda.empty_cache() 100 | model.cuda(device) 101 | 102 | model.eval() 103 | 104 | common_utils.create_if_not(prediction_path) 105 | 106 | print("Evaluating now... " + fold) 107 | query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file) 108 | support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file) 109 | 110 | with torch.no_grad(): 111 | all_query_dice_score_list = [] 112 | for query_label in query_labels: 113 | volume_dice_score_list = [] 114 | 115 | # Loading support 116 | support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0], 117 | orientation=orientation, 118 | remap_config=remap_config) 119 | support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, 120 | :] 121 | support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \ 122 | torch.tensor(support_labelmap).type(torch.LongTensor) 123 | 124 | support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) 125 | 126 | # slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) 127 | # 128 | # support_slice_indexes = [i for i in range(0, len(support_volume), slice_gap_support)] 129 | # 130 | # if len(support_slice_indexes) < Num_support: 131 | # support_slice_indexes.append(len(support_volume) - 1) 132 | 133 | for vol_idx, file_path in enumerate(query_file_paths): 134 | 135 | query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path, 136 | orientation=orientation, 137 | remap_config=remap_config) 138 | 139 | query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :] 140 | query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \ 141 | torch.tensor(query_labelmap).type(torch.LongTensor) 142 | 143 | query_labelmap = query_labelmap == query_label 144 | range_query = get_range(query_labelmap) 145 | query_volume = query_volume[range_query[0]: range_query[1] + 1] 146 | query_labelmap = query_labelmap[range_query[0]: range_query[1] + 1] 147 | 148 | dice_per_slice = [] 149 | vol_output = [] 150 | 151 | for i, query_slice in enumerate(query_volume): 152 | query_batch_x = query_slice.unsqueeze(0) 153 | max_dice = -1.0 154 | max_output = None 155 | for j in range(0, len(support_volume), 10): 156 | support_slice = support_volume[j] 157 | 158 | support_batch_x = support_slice.unsqueeze(0) 159 | if cuda_available: 160 | query_batch_x = query_batch_x.cuda(device) 161 | support_batch_x = support_batch_x.cuda(device) 162 | 163 | weights = model.conditioner(support_batch_x) 164 | out = model.segmentor(query_batch_x, weights) 165 | 166 | _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) 167 | slice_dice_score = dice_score_binary(batch_output, 168 | query_labelmap[i].cuda(device), phase=mode) 169 | dice_per_slice.append(slice_dice_score.item()) 170 | if slice_dice_score.item() >= max_dice: 171 | max_dice = slice_dice_score.item() 172 | max_output = batch_output 173 | # dice_per_slice.append(max_dice) 174 | vol_output.append(max_output) 175 | 176 | vol_output = torch.cat(vol_output) 177 | volume_dice_score = dice_score_binary(vol_output, query_labelmap.cuda(device), phase=mode) 178 | volume_dice_score_list.append(volume_dice_score) 179 | 180 | print(volume_dice_score) 181 | 182 | dice_score_arr = np.asarray(volume_dice_score_list) 183 | avg_dice_score = np.median(dice_score_arr) 184 | print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) 185 | all_query_dice_score_list.append(avg_dice_score) 186 | 187 | print("DONE") 188 | 189 | return np.mean(all_query_dice_score_list) 190 | 191 | 192 | def evaluate_dice_score_2view(model1_path, 193 | model2_path, 194 | num_classes, 195 | query_labels, 196 | data_dir, 197 | query_txt_file, 198 | support_txt_file, 199 | remap_config, 200 | orientation1, 201 | prediction_path, device=0, logWriter=None, mode='eval', fold=None): 202 | print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") 203 | print("Loading model => " + model1_path + " and " + model2_path) 204 | batch_size = 10 205 | 206 | with open(query_txt_file) as file_handle: 207 | volumes_query = file_handle.read().splitlines() 208 | 209 | # with open(support_txt_file) as file_handle: 210 | # volumes_support = file_handle.read().splitlines() 211 | 212 | model1 = torch.load(model1_path) 213 | model2 = torch.load(model2_path) 214 | cuda_available = torch.cuda.is_available() 215 | if cuda_available: 216 | torch.cuda.empty_cache() 217 | model1.cuda(device) 218 | model2.cuda(device) 219 | 220 | model1.eval() 221 | model2.eval() 222 | 223 | common_utils.create_if_not(prediction_path) 224 | 225 | print("Evaluating now... " + fold) 226 | query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file) 227 | support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file) 228 | 229 | with torch.no_grad(): 230 | all_query_dice_score_list = [] 231 | for query_label in query_labels: 232 | volume_dice_score_list = [] 233 | for vol_idx, file_path in enumerate(support_file_paths): 234 | # Loading support 235 | support_volume1, support_labelmap1, _, _ = du.load_and_preprocess(file_path, 236 | orientation=orientation1, 237 | remap_config=remap_config) 238 | support_volume2, support_labelmap2 = support_volume1.transpose((1, 2, 0)), support_labelmap1.transpose( 239 | (1, 2, 0)) 240 | 241 | support_volume1 = support_volume1 if len(support_volume1.shape) == 4 else support_volume1[:, np.newaxis, 242 | :, :] 243 | support_volume2 = support_volume2 if len(support_volume2.shape) == 4 else support_volume2[:, np.newaxis, 244 | :, :] 245 | 246 | support_volume1, support_labelmap1 = torch.tensor(support_volume1).type( 247 | torch.FloatTensor), torch.tensor( 248 | support_labelmap1).type(torch.LongTensor) 249 | support_volume2, support_labelmap2 = torch.tensor(support_volume2).type( 250 | torch.FloatTensor), torch.tensor( 251 | support_labelmap2).type(torch.LongTensor) 252 | support_volume1 = binarize_label(support_volume1, support_labelmap1, query_label) 253 | support_volume2 = binarize_label(support_volume2, support_labelmap2, query_label) 254 | 255 | for vol_idx, file_path in enumerate(query_file_paths): 256 | query_volume1, query_labelmap1, _, _ = du.load_and_preprocess(file_path, 257 | orientation=orientation1, 258 | remap_config=remap_config) 259 | query_volume2, query_labelmap2 = query_volume1.transpose((1, 2, 0)), query_labelmap1.transpose( 260 | (1, 2, 0)) 261 | 262 | query_volume1 = query_volume1 if len(query_volume1.shape) == 4 else query_volume1[:, np.newaxis, :, :] 263 | query_volume2 = query_volume2 if len(query_volume2.shape) == 4 else query_volume2[:, np.newaxis, :, :] 264 | 265 | query_volume1, query_labelmap1 = torch.tensor(query_volume1).type(torch.FloatTensor), torch.tensor( 266 | query_labelmap1).type(torch.LongTensor) 267 | query_volume2, query_labelmap2 = torch.tensor(query_volume2).type(torch.FloatTensor), torch.tensor( 268 | query_labelmap2).type(torch.LongTensor) 269 | 270 | query_labelmap1 = query_labelmap1 == query_label 271 | query_labelmap2 = query_labelmap2 == query_label 272 | 273 | # Evaluate for orientation 1 274 | support_batch_x = [] 275 | k = 2 276 | volume_prediction1 = [] 277 | for i in range(0, len(query_volume1), batch_size): 278 | query_batch_x = query_volume1[i: i + batch_size] 279 | if k % 2 == 0: 280 | support_batch_x = support_volume1[i: i + batch_size] 281 | sz = query_batch_x.size() 282 | support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) 283 | k += 1 284 | if cuda_available: 285 | query_batch_x = query_batch_x.cuda(device) 286 | support_batch_x = support_batch_x.cuda(device) 287 | 288 | weights = model1.conditioner(support_batch_x) 289 | out = model1.segmentor(query_batch_x, weights) 290 | 291 | # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) 292 | volume_prediction1.append(out) 293 | 294 | # Evaluate for orientation 2 295 | support_batch_x = [] 296 | k = 2 297 | volume_prediction2 = [] 298 | for i in range(0, len(query_volume2), batch_size): 299 | query_batch_x = query_volume2[i: i + batch_size] 300 | if k % 2 == 0: 301 | support_batch_x = support_volume2[i: i + batch_size] 302 | sz = query_batch_x.size() 303 | support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) 304 | k += 1 305 | if cuda_available: 306 | query_batch_x = query_batch_x.cuda(device) 307 | support_batch_x = support_batch_x.cuda(device) 308 | 309 | weights = model2.conditioner(support_batch_x) 310 | out = model2.segmentor(query_batch_x, weights) 311 | volume_prediction2.append(out) 312 | 313 | volume_prediction1 = torch.cat(volume_prediction1) 314 | volume_prediction2 = torch.cat(volume_prediction2) 315 | volume_prediction = 0.5 * volume_prediction1 + 0.5 * volume_prediction2.permute(3, 1, 0, 2) 316 | _, batch_output = torch.max(F.softmax(volume_prediction, dim=1), dim=1) 317 | volume_dice_score = dice_score_binary(batch_output, query_labelmap1.cuda(device), phase=mode) 318 | 319 | batch_output = (batch_output.cpu().numpy()).astype('float32') 320 | nifti_img = nib.MGHImage(np.squeeze(batch_output), np.eye(4)) 321 | nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz'))) 322 | 323 | # # Save Input 324 | # nifti_img = nib.MGHImage(np.squeeze(query_volume1.cpu().numpy()), np.eye(4)) 325 | # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz'))) 326 | # # # Condition Input 327 | # nifti_img = nib.MGHImage(np.squeeze(support_volume1.cpu().numpy()), np.eye(4)) 328 | # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz'))) 329 | # # # Cond GT 330 | # nifti_img = nib.MGHImage(np.squeeze(support_labelmap1.cpu().numpy()).astype('float32'), np.eye(4)) 331 | # nib.save(nifti_img, 332 | # os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz'))) 333 | # # # # Save Ground Truth 334 | # nifti_img = nib.MGHImage(np.squeeze(query_labelmap1.cpu().numpy()), np.eye(4)) 335 | # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + str('.mgz'))) 336 | 337 | # if logWriter: 338 | # logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], 339 | # vol_idx) 340 | volume_dice_score = volume_dice_score.cpu().numpy() 341 | volume_dice_score_list.append(volume_dice_score) 342 | 343 | print(volume_dice_score) 344 | 345 | dice_score_arr = np.asarray(volume_dice_score_list) 346 | avg_dice_score = np.median(dice_score_arr) 347 | print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) 348 | all_query_dice_score_list.append(avg_dice_score) 349 | # class_dist = [dice_score_arr[:, c] for c in range(num_classes)] 350 | 351 | # if logWriter: 352 | # logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score') 353 | print("DONE") 354 | 355 | return np.mean(all_query_dice_score_list) 356 | 357 | 358 | def evaluate_dice_score_3view(model1_path, 359 | model2_path, 360 | model3_path, 361 | num_classes, 362 | query_labels, 363 | data_dir, 364 | query_txt_file, 365 | support_txt_file, 366 | remap_config, 367 | orientation1, 368 | prediction_path, device=0, logWriter=None, mode='eval', fold=None): 369 | print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") 370 | print("Loading model => " + model1_path + " and " + model2_path) 371 | batch_size = 10 372 | 373 | with open(query_txt_file) as file_handle: 374 | volumes_query = file_handle.read().splitlines() 375 | 376 | # with open(support_txt_file) as file_handle: 377 | # volumes_support = file_handle.read().splitlines() 378 | 379 | model1 = torch.load(model1_path) 380 | model2 = torch.load(model2_path) 381 | model3 = torch.load(model3_path) 382 | cuda_available = torch.cuda.is_available() 383 | if cuda_available: 384 | torch.cuda.empty_cache() 385 | model1.cuda(device) 386 | model2.cuda(device) 387 | model3.cuda(device) 388 | 389 | model1.eval() 390 | model2.eval() 391 | model3.eval() 392 | 393 | common_utils.create_if_not(prediction_path) 394 | 395 | print("Evaluating now... " + fold) 396 | query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file) 397 | support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file) 398 | 399 | with torch.no_grad(): 400 | all_query_dice_score_list = [] 401 | for query_label in query_labels: 402 | volume_dice_score_list = [] 403 | for vol_idx, file_path in enumerate(support_file_paths): 404 | # Loading support 405 | support_volume1, support_labelmap1, _, _ = du.load_and_preprocess(file_path, 406 | orientation=orientation1, 407 | remap_config=remap_config) 408 | support_volume2, support_labelmap2 = support_volume1.transpose((1, 2, 0)), support_labelmap1.transpose( 409 | (1, 2, 0)) 410 | 411 | support_volume3, support_labelmap3 = support_volume1.transpose((2, 0, 1)), support_labelmap1.transpose( 412 | (2, 0, 1)) 413 | 414 | support_volume1 = support_volume1 if len(support_volume1.shape) == 4 else support_volume1[:, np.newaxis, 415 | :, :] 416 | support_volume2 = support_volume2 if len(support_volume2.shape) == 4 else support_volume2[:, np.newaxis, 417 | :, :] 418 | 419 | support_volume3 = support_volume3 if len(support_volume3.shape) == 4 else support_volume3[:, np.newaxis, 420 | :, :] 421 | 422 | support_volume1, support_labelmap1 = torch.tensor(support_volume1).type( 423 | torch.FloatTensor), torch.tensor( 424 | support_labelmap1).type(torch.LongTensor) 425 | support_volume2, support_labelmap2 = torch.tensor(support_volume2).type( 426 | torch.FloatTensor), torch.tensor( 427 | support_labelmap2).type(torch.LongTensor) 428 | support_volume3, support_labelmap3 = torch.tensor(support_volume3).type( 429 | torch.FloatTensor), torch.tensor( 430 | support_labelmap3).type(torch.LongTensor) 431 | 432 | support_volume1 = binarize_label(support_volume1, support_labelmap1, query_label) 433 | support_volume2 = binarize_label(support_volume2, support_labelmap2, query_label) 434 | support_volume3 = binarize_label(support_volume3, support_labelmap3, query_label) 435 | 436 | for vol_idx, file_path in enumerate(query_file_paths): 437 | query_volume1, query_labelmap1, _, _ = du.load_and_preprocess(file_path, 438 | orientation=orientation1, 439 | remap_config=remap_config) 440 | query_volume2, query_labelmap2 = query_volume1.transpose((1, 2, 0)), query_labelmap1.transpose( 441 | (1, 2, 0)) 442 | query_volume3, query_labelmap3 = query_volume1.transpose((2, 0, 1)), query_labelmap1.transpose( 443 | (2, 0, 1)) 444 | 445 | query_volume1 = query_volume1 if len(query_volume1.shape) == 4 else query_volume1[:, np.newaxis, :, :] 446 | query_volume2 = query_volume2 if len(query_volume2.shape) == 4 else query_volume2[:, np.newaxis, :, :] 447 | query_volume3 = query_volume3 if len(query_volume3.shape) == 4 else query_volume3[:, np.newaxis, :, :] 448 | 449 | query_volume1, query_labelmap1 = torch.tensor(query_volume1).type(torch.FloatTensor), torch.tensor( 450 | query_labelmap1).type(torch.LongTensor) 451 | query_volume2, query_labelmap2 = torch.tensor(query_volume2).type(torch.FloatTensor), torch.tensor( 452 | query_labelmap2).type(torch.LongTensor) 453 | query_volume3, query_labelmap3 = torch.tensor(query_volume3).type(torch.FloatTensor), torch.tensor( 454 | query_labelmap3).type(torch.LongTensor) 455 | 456 | query_labelmap1 = query_labelmap1 == query_label 457 | # query_labelmap2 = query_labelmap2 == query_label 458 | # query_labelmap3 = query_labelmap3 == query_label 459 | 460 | # Evaluate for orientation 1 461 | support_batch_x = [] 462 | k = 2 463 | volume_prediction1 = [] 464 | for i in range(0, len(query_volume1), batch_size): 465 | query_batch_x = query_volume1[i: i + batch_size] 466 | if k % 2 == 0: 467 | support_batch_x = support_volume1[i: i + batch_size] 468 | sz = query_batch_x.size() 469 | support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) 470 | k += 1 471 | if cuda_available: 472 | query_batch_x = query_batch_x.cuda(device) 473 | support_batch_x = support_batch_x.cuda(device) 474 | 475 | weights = model1.conditioner(support_batch_x) 476 | out = model1.segmentor(query_batch_x, weights) 477 | 478 | # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) 479 | volume_prediction1.append(out) 480 | 481 | # Evaluate for orientation 2 482 | support_batch_x = [] 483 | k = 2 484 | volume_prediction2 = [] 485 | for i in range(0, len(query_volume2), batch_size): 486 | query_batch_x = query_volume2[i: i + batch_size] 487 | if k % 2 == 0: 488 | support_batch_x = support_volume2[i: i + batch_size] 489 | sz = query_batch_x.size() 490 | support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) 491 | k += 1 492 | if cuda_available: 493 | query_batch_x = query_batch_x.cuda(device) 494 | support_batch_x = support_batch_x.cuda(device) 495 | 496 | weights = model2.conditioner(support_batch_x) 497 | out = model2.segmentor(query_batch_x, weights) 498 | volume_prediction2.append(out) 499 | 500 | # Evaluate for orientation 3 501 | support_batch_x = [] 502 | k = 2 503 | volume_prediction3 = [] 504 | for i in range(0, len(query_volume3), batch_size): 505 | query_batch_x = query_volume3[i: i + batch_size] 506 | if k % 2 == 0: 507 | support_batch_x = support_volume3[i: i + batch_size] 508 | sz = query_batch_x.size() 509 | support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) 510 | k += 1 511 | if cuda_available: 512 | query_batch_x = query_batch_x.cuda(device) 513 | support_batch_x = support_batch_x.cuda(device) 514 | 515 | weights = model3.conditioner(support_batch_x) 516 | out = model3.segmentor(query_batch_x, weights) 517 | volume_prediction3.append(out) 518 | 519 | volume_prediction1 = torch.cat(volume_prediction1) 520 | volume_prediction2 = torch.cat(volume_prediction2) 521 | volume_prediction3 = torch.cat(volume_prediction3) 522 | volume_prediction = 0.33 * F.softmax(volume_prediction1, dim=1) + 0.33 * F.softmax( 523 | volume_prediction2.permute(3, 1, 0, 2), dim=1) + 0.33 * F.softmax( 524 | volume_prediction3.permute(2, 1, 3, 0), dim=1) 525 | _, batch_output = torch.max(volume_prediction, dim=1) 526 | volume_dice_score = dice_score_binary(batch_output, query_labelmap1.cuda(device), phase=mode) 527 | 528 | batch_output = (batch_output.cpu().numpy()).astype('float32') 529 | nifti_img = nib.MGHImage(np.squeeze(batch_output), np.eye(4)) 530 | nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz'))) 531 | 532 | # # Save Input 533 | # nifti_img = nib.MGHImage(np.squeeze(query_volume1.cpu().numpy()), np.eye(4)) 534 | # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz'))) 535 | # # # Condition Input 536 | # nifti_img = nib.MGHImage(np.squeeze(support_volume1.cpu().numpy()), np.eye(4)) 537 | # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz'))) 538 | # # # Cond GT 539 | # nifti_img = nib.MGHImage(np.squeeze(support_labelmap1.cpu().numpy()).astype('float32'), np.eye(4)) 540 | # nib.save(nifti_img, 541 | # os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz'))) 542 | # # # # Save Ground Truth 543 | # nifti_img = nib.MGHImage(np.squeeze(query_labelmap1.cpu().numpy()), np.eye(4)) 544 | # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + str('.mgz'))) 545 | 546 | # if logWriter: 547 | # logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], 548 | # vol_idx) 549 | volume_dice_score = volume_dice_score.cpu().numpy() 550 | volume_dice_score_list.append(volume_dice_score) 551 | 552 | print(volume_dice_score) 553 | 554 | dice_score_arr = np.asarray(volume_dice_score_list) 555 | avg_dice_score = np.median(dice_score_arr) 556 | print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) 557 | all_query_dice_score_list.append(avg_dice_score) 558 | # class_dist = [dice_score_arr[:, c] for c in range(num_classes)] 559 | 560 | # if logWriter: 561 | # logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score') 562 | print("DONE") 563 | 564 | return np.mean(all_query_dice_score_list) 565 | --------------------------------------------------------------------------------