├── .gitignore ├── LICENSE ├── README.md ├── pictures ├── i-drlse.png ├── mgnet.png ├── refinement.png └── uncertainty.png ├── requirements.txt ├── uncertainty_demo ├── config │ ├── image_test.csv │ └── mgnet.cfg ├── data │ ├── a26_12.nii.gz │ └── b15_05.nii.gz ├── model │ └── mgnet_20000.pt ├── result │ ├── a26_12.nii.gz │ ├── a26_12_var.nii.gz │ ├── b15_05.nii.gz │ ├── b15_05_var.nii.gz │ └── uncertainty.png └── show_uncertanty.py └── util ├── custom_net_run.py ├── level_set ├── __init__.py ├── data │ ├── a03_04_11img.png │ ├── a03_04_11scrb.png │ ├── a03_04_11seg.png │ ├── a10_12_22img.png │ ├── a10_12_22scrb.png │ ├── a10_12_22seg.png │ └── gourd.bmp ├── demo │ └── demo_idrlse.py └── ls_util │ ├── __init__.py │ ├── drlse_reion.py │ ├── get_gradient.py │ └── interactive_ls.py └── network ├── MGNet.py └── unet2dres.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, University of Electronic Science and Technology of China. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### UGIR: Uncertainty-Guided Interactive Refinement for Segmentation 2 | This repository provides the code for the following MICCAI 2020 paper ([Arxiv link][arxiv_link], [Demo][demo_link]). If you use some modules of our repository, please cite this paper. 3 | * Guotai Wang, Michael Aertsen, Jan Deprest, Sébastien Ourselin, Tom Vercauteren, Shaoting Zhang: 4 | Uncertainty-Guided Efficient Interactive Refinement of Fetal Brain Segmentation from Stacks of MRI Slices. MICCAI (4) 2020: 279-288. 5 | 6 | The code contains two modules: 1), a novel CNN based on convolution in Multiple Groups (MG-Net) that simultaneously obtains an intial segmentation and its uncertainty estimation. 2), Interaction-based level set for fast refinement, which is an extention of the DRLSE algorithm and named as I-DRLSE. 7 | 8 | ![mg_net](./pictures/mgnet.png) 9 | Fig. 1. Structure of MG-Net. 10 | 11 | ![uncertainty](./pictures/uncertainty.png) 12 | Fig. 2. Segmentation with uncertainty estimation. 13 | 14 | ![refinement](./pictures/refinement.png) 15 | 16 | Fig. 3. Using I-DRLSE for interactive refinement. 17 | 18 | ### Requirements 19 | Some important required packages include: 20 | * [Pytorch][torch_link] version >=1.0.1. 21 | * [PyMIC][pymic_link], a pytorch-based toolkit for medical image computing. Version 0.2.3 is required. 22 | * [GeodisTK][geodistk_link], geodesic distance transform toolkit for 2D and 3D images. 23 | 24 | Follow official guidance to install [Pytorch][torch_link]. Install the other required packages by: 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | [arxiv_link]:https://arxiv.org/abs/2007.00833 29 | [demo_link]:https://www.youtube.com/watch?v=H_-XQlWsRvM 30 | [torch_link]:https://pytorch.org/ 31 | [pymic_link]:https://github.com/HiLab-git/PyMIC 32 | [geodistk_link]:https://github.com/taigw/GeodisTK 33 | 34 | ### How to use 35 | After installing the required packages, add the path of `UGIR` to the PYTHONPATH environment variable. 36 | ### Demo of MG-Net 37 | 1. Run the following commands to use MG-Net for simultanuous segmentation and uncertainty estimation. 38 | ``` 39 | cd uncertainty_demo 40 | python ../util/custom_net_run.py test config/mgnet.cfg 41 | ``` 42 | 2. The results will be saved to `uncertainty_demo/result`. To get a visualization of the uncertainty estimation in an example slice, run: 43 | ``` 44 | python show_uncertanty.py 45 | ``` 46 | 47 | ### Demo of I-DRLSE 48 | To see a demo of I-DRLSE, run the following commands: 49 | ``` 50 | cd util/level_set 51 | python demo/demo_idrlse.py 52 | ``` 53 | The result should look like the following. 54 | ![i-drlse](./pictures/i-drlse.png) 55 | 56 | ### Copyright and License 57 | Copyright (c) 2020, University of Electronic Science and Technology of China. 58 | All rights reserved. This code is made available as open-source software under the BSD-3-Clause License. 59 | -------------------------------------------------------------------------------- /pictures/i-drlse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/pictures/i-drlse.png -------------------------------------------------------------------------------- /pictures/mgnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/pictures/mgnet.png -------------------------------------------------------------------------------- /pictures/refinement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/pictures/refinement.png -------------------------------------------------------------------------------- /pictures/uncertainty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/pictures/uncertainty.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython>=0.29.14 2 | GeodisTK>=0.1.6 3 | matplotlib>=3.1.2 4 | numpy>=1.17.4 5 | pandas>=0.25.3 6 | scikit-image>=0.16.2 7 | scikit-learn>=0.22 8 | scipy>=1.3.3 9 | SimpleITK>=1.2.4 10 | tensorboard>=2.1.0 11 | tensorboardX>=1.9 12 | torchvision>=0.4.2 13 | PYMIC==0.2.3 14 | -------------------------------------------------------------------------------- /uncertainty_demo/config/image_test.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | a26_12.nii.gz, 3 | b15_05.nii.gz, 4 | -------------------------------------------------------------------------------- /uncertainty_demo/config/mgnet.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | tensor_type = float 3 | task_type = seg 4 | 5 | #root_dir = /home/disk2t/data/fetalMR/DLLS_brain 6 | root_dir = ./data 7 | train_csv = config/image_train.csv 8 | valid_csv = config/image_valid.csv 9 | test_csv = config/image_test.csv 10 | 11 | load_pixelwise_weight = False 12 | # modality number 13 | modal_num = 1 14 | 15 | # data transforms 16 | train_transform = [ChannelWiseThresholdWithNormalize, RandomFlip, RandomRotate, Pad, RandomCrop, LabelToProbability] 17 | test_transform = [ChannelWiseThresholdWithNormalize, Pad] 18 | 19 | ChannelWiseThresholdWithNormalize_threshold_lower = [0] 20 | ChannelWiseThresholdWithNormalize_threshold_upper = [None] 21 | ChannelWiseThresholdWithNormalize_mean_std_mode = True 22 | ChannelWiseThresholdWithNormalize_inverse = False 23 | 24 | RandomFlip_flip_depth = True 25 | RandomFlip_flip_height = True 26 | RandomFlip_flip_width = True 27 | RandomFlip_inverse = True 28 | 29 | RandomRotate_angle_range_d = [-180, 180] 30 | RandomRotate_angle_range_h = None 31 | RandomRotate_angle_range_w = None 32 | RandomRotate_inverse = True 33 | 34 | Pad_output_size = [16, 192, 192] 35 | Pad_ceil_mode = False 36 | Pad_inverse = True 37 | 38 | RandomCrop_output_size = [12, 144, 144] 39 | RandomCrop_foreground_focus = False 40 | RandomCrop_foreground_ratio = None 41 | RandomCrop_mask_label = None 42 | RandomCrop_inverse = False 43 | 44 | LabelToProbability_class_num = 2 45 | LabelToProbability_inverse = False 46 | 47 | [network] 48 | # this section gives parameters for network 49 | # the keys may be different for different networks 50 | 51 | # type of network 52 | net_type = MGNet 53 | 54 | # number of class, required for segmentation task 55 | class_num = 2 56 | in_chns = 1 57 | block_type = UNetBlock 58 | feature_chns = [64, 128, 256, 512, 512] 59 | feature_grps = [ 4, 4, 4, 4, 1] 60 | norm_type = group_norm 61 | acti_func = leakyrelu 62 | leakyrelu_negative_slope = 0.01 63 | dropout = True 64 | depth_sep_deconv = False 65 | deep_supervision = False 66 | 67 | [training] 68 | gpus = [0] 69 | 70 | batch_size = 1 71 | loss_type = MultiScaleDiceLoss 72 | MultiScaleDiceLoss_Enable_Pixel_Weight = False 73 | MultiScaleDiceLoss_Enable_Class_Weight = False 74 | MultiScaleDiceLoss_Scale_Weight = [1.0, 1.0, 1.0, 1.0] 75 | 76 | # for optimizers 77 | optimizer = Adam 78 | learning_rate = 1e-4 79 | momentum = 0.9 80 | weight_decay = 1e-5 81 | 82 | # for lr schedular (MultiStepLR) 83 | lr_gamma = 0.5 84 | lr_milestones = [5000, 10000, 15000, 20000, 25000, 30000] 85 | 86 | ckpt_save_dir = exp_uncertain/model/unet2d_mg 87 | ckpt_save_prefix = mgnet 88 | 89 | # start iter 90 | iter_start = 0 91 | iter_max = 20000 92 | iter_valid = 100 93 | iter_save = 5000 94 | 95 | [testing] 96 | gpus = [0] 97 | 98 | ckpt_mode = 2 99 | ckpt_name = model/mgnet_20000.pt 100 | evaluation_mode = False 101 | multi_pred_avg = True 102 | output_num = 4 103 | 104 | # use test time augmentation 105 | tta_mode = 0 106 | infer_sliding_window = False 107 | sliding_window_size = None 108 | sliding_window_stride = None 109 | 110 | label_source = None 111 | label_target = None 112 | 113 | filename_replace_source = None 114 | filename_replace_target = None 115 | 116 | #output_dir = exp_uncertain/result/unet2d_mg/predict 117 | output_dir = result 118 | save_probability = False 119 | save_multi_pred_var = True 120 | 121 | -------------------------------------------------------------------------------- /uncertainty_demo/data/a26_12.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/uncertainty_demo/data/a26_12.nii.gz -------------------------------------------------------------------------------- /uncertainty_demo/data/b15_05.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/uncertainty_demo/data/b15_05.nii.gz -------------------------------------------------------------------------------- /uncertainty_demo/model/mgnet_20000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/uncertainty_demo/model/mgnet_20000.pt -------------------------------------------------------------------------------- /uncertainty_demo/result/a26_12.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/uncertainty_demo/result/a26_12.nii.gz -------------------------------------------------------------------------------- /uncertainty_demo/result/a26_12_var.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/uncertainty_demo/result/a26_12_var.nii.gz -------------------------------------------------------------------------------- /uncertainty_demo/result/b15_05.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/uncertainty_demo/result/b15_05.nii.gz -------------------------------------------------------------------------------- /uncertainty_demo/result/b15_05_var.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/uncertainty_demo/result/b15_05_var.nii.gz -------------------------------------------------------------------------------- /uncertainty_demo/result/uncertainty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/uncertainty_demo/result/uncertainty.png -------------------------------------------------------------------------------- /uncertainty_demo/show_uncertanty.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SimpleITK as sitk 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from scipy.ndimage import gaussian_filter 6 | from PIL import Image 7 | from PIL import ImageFilter 8 | 9 | def add_countor(In, Seg, Color=(0, 255, 0)): 10 | Out = In.copy() 11 | [H, W] = In.size 12 | for i in range(H): 13 | for j in range(W): 14 | if(i==0 or i==H-1 or j==0 or j == W-1): 15 | if(Seg.getpixel((i,j))!=0): 16 | Out.putpixel((i,j), Color) 17 | elif(Seg.getpixel((i,j))!=0 and \ 18 | not(Seg.getpixel((i-1,j))!=0 and \ 19 | Seg.getpixel((i+1,j))!=0 and \ 20 | Seg.getpixel((i,j-1))!=0 and \ 21 | Seg.getpixel((i,j+1))!=0)): 22 | Out.putpixel((i,j), Color) 23 | return Out 24 | 25 | def gray_to_rgb(image): 26 | image_cat = np.asarray([image, image, image]) 27 | image_cat = np.transpose(image_cat, [1, 2, 0]) 28 | return image_cat 29 | 30 | 31 | def map_scalar_to_color(x): 32 | x_list = [0.0, 0.25, 0.5, 0.75, 1.0] 33 | c_list = [[0, 0, 255], 34 | [0, 255, 255], 35 | [0, 255, 0], 36 | [255, 255, 0], 37 | [255, 0, 0]] 38 | for i in range(len(x_list)): 39 | if(x <= x_list[i + 1]): 40 | x0 = x_list[i] 41 | x1 = x_list[i + 1] 42 | c0 = c_list[i] 43 | c1 = c_list[i + 1] 44 | alpha = (x - x0)/(x1 - x0) 45 | c = [c0[j]*(1 - alpha) + c1[j] * alpha for j in range(3)] 46 | c = [int(item) for item in c] 47 | return tuple(c) 48 | 49 | def get_attention_map(image, att): 50 | [H, W] = image.size 51 | img = Image.new('RGB', image.size, (255, 0, 0)) 52 | 53 | for i in range(H): 54 | for j in range(W): 55 | p0 = image.getpixel((i,j)) 56 | alpha = att.getpixel((i,j)) 57 | p1 = map_scalar_to_color(alpha) 58 | # alpha = 0.1 + alpha*0.9 59 | p = [int(p0[c] * (1 - alpha) + p1[c]*alpha) for c in range(3)] 60 | p = tuple(p) 61 | img.putpixel((i,j), p) 62 | return img 63 | 64 | 65 | def show_seg_uncertainty(): 66 | img_folder = "data" 67 | seg_folder = "result" 68 | uncertain_folder = "result" 69 | patient_id = "a26_12" 70 | slice_id = 14 71 | img_name = img_folder + '/' + patient_id + ".nii.gz" 72 | seg_name = seg_folder + '/' + patient_id + ".nii.gz" 73 | uncertain_name = uncertain_folder + '/' + patient_id + "_var.nii.gz" 74 | img_obj = sitk.ReadImage(img_name) 75 | seg_obj = sitk.ReadImage(seg_name) 76 | uct_obj = sitk.ReadImage(uncertain_name) 77 | img3d = sitk.GetArrayFromImage(img_obj) 78 | seg3d = sitk.GetArrayFromImage(seg_obj) 79 | uct3d = sitk.GetArrayFromImage(uct_obj) 80 | 81 | img3d = (img3d - img3d.min()) * 255.0 / (img3d.max() - img3d.min()) 82 | img3d = np.asarray(img3d, np.uint8) 83 | uct3d = uct3d / uct3d.max() 84 | 85 | img = img3d[slice_id] 86 | seg = seg3d[slice_id] 87 | uct = uct3d[slice_id] 88 | img_show_raw = gray_to_rgb(img) 89 | img_show_raw = Image.fromarray(img_show_raw) 90 | seg = Image.fromarray(seg) 91 | img_show_seg = add_countor(img_show_raw, seg) 92 | 93 | uct = Image.fromarray(uct) 94 | img_show_uct = get_attention_map(img_show_raw, uct) 95 | fig = plt.figure(figsize=(6, 3)) 96 | plt.subplot(1, 2, 1); plt.axis('off'); plt.title('segmentation result') 97 | plt.imshow(img_show_seg) 98 | plt.subplot(1, 2, 2); plt.axis('off'); plt.title('uncertainty') 99 | plt.imshow(img_show_uct) 100 | plt.show() 101 | fig.savefig('./result/uncertainty.png') 102 | 103 | if __name__ == "__main__": 104 | show_seg_uncertainty() 105 | -------------------------------------------------------------------------------- /util/custom_net_run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import os 5 | import scipy 6 | import sys 7 | import time 8 | import torch 9 | import numpy as np 10 | from pymic.io.image_read_write import save_nd_array_as_image 11 | from pymic.util.parse_config import parse_config 12 | from pymic.net_run.agent_seg import SegmentationAgent 13 | from pymic.net_run.infer_func import Inferer 14 | from network.unet2dres import UNet2DRes 15 | from network.MGNet import MGNet 16 | 17 | net_dict = { 18 | 'UNet2DRes': UNet2DRes, 19 | 'MGNet': MGNet 20 | } 21 | 22 | class CustomSegAgent(SegmentationAgent): 23 | def __init__(self, config, stage = 'train'): 24 | super(CustomSegAgent, self).__init__(config, stage) 25 | 26 | def infer(self): 27 | device_ids = self.config['testing']['gpus'] 28 | device = torch.device("cuda:{0:}".format(device_ids[0])) 29 | self.net.to(device) 30 | # load network parameters and set the network as evaluation mode 31 | checkpoint_name = self.get_checkpoint_name() 32 | checkpoint = torch.load(checkpoint_name, map_location = device) 33 | self.net.load_state_dict(checkpoint['model_state_dict']) 34 | 35 | if(self.config['testing']['evaluation_mode'] == True): 36 | self.net.eval() 37 | if(self.config['testing']['test_time_dropout'] == True): 38 | def test_time_dropout(m): 39 | if(type(m) == nn.Dropout): 40 | print('dropout layer') 41 | m.train() 42 | self.net.apply(test_time_dropout) 43 | 44 | infer_cfg = self.config['testing'] 45 | infer_cfg['class_num'] = self.config['network']['class_num'] 46 | infer_obj = Inferer(self.net, infer_cfg) 47 | infer_time_list = [] 48 | with torch.no_grad(): 49 | for data in self.test_loder: 50 | images = self.convert_tensor_type(data['image']) 51 | images = images.to(device) 52 | 53 | start_time = time.time() 54 | pred = infer_obj.run(images) 55 | if(isinstance(pred, (tuple, list))): 56 | pred = [item.cpu().numpy() for item in pred] 57 | else: 58 | pred = pred.cpu().numpy() 59 | data['predict'] = pred 60 | # inverse transform 61 | for transform in self.transform_list[::-1]: 62 | if (transform.inverse): 63 | data = transform.inverse_transform_for_prediction(data) 64 | 65 | infer_time = time.time() - start_time 66 | infer_time_list.append(infer_time) 67 | self.save_ouputs(data) 68 | infer_time_list = np.asarray(infer_time_list) 69 | time_avg, time_std = infer_time_list.mean(), infer_time_list.std() 70 | print("testing time {0:} +/- {1:}".format(time_avg, time_std)) 71 | 72 | def save_ouputs(self, data): 73 | output_dir = self.config['testing']['output_dir'] 74 | ignore_dir = self.config['testing'].get('filename_ignore_dir', True) 75 | save_prob = self.config['testing'].get('save_probability', False) 76 | save_var = self.config['testing'].get('save_multi_pred_var', False) 77 | multi_pred_avg = self.config['testing'].get('multi_pred_avg', False) 78 | label_source = self.config['testing'].get('label_source', None) 79 | label_target = self.config['testing'].get('label_target', None) 80 | filename_replace_source = self.config['testing'].get('filename_replace_source', None) 81 | filename_replace_target = self.config['testing'].get('filename_replace_target', None) 82 | if(not os.path.exists(output_dir)): 83 | os.mkdir(output_dir) 84 | 85 | names, pred = data['names'], data['predict'] 86 | if(isinstance(pred, (tuple, list))): 87 | prob_list = [scipy.special.softmax(predi,axis=1) for predi in pred] 88 | prob_stack = np.asarray(prob_list, np.float32) 89 | var = np.var(prob_stack, axis = 0) 90 | if(multi_pred_avg): 91 | prob = np.mean(prob_stack, axis = 0) 92 | else: 93 | prob = prob_list[0] 94 | else: 95 | prob = scipy.special.softmax(pred, axis = 1) 96 | output = np.asarray(np.argmax(prob, axis = 1), np.uint8) 97 | if((label_source is not None) and (label_target is not None)): 98 | output = convert_label(output, label_source, label_target) 99 | # save the output and (optionally) probability predictions 100 | root_dir = self.config['dataset']['root_dir'] 101 | for i in range(len(names)): 102 | save_name = names[i].split('/')[-1] if ignore_dir else \ 103 | names[i].replace('/', '_') 104 | if((filename_replace_source is not None) and (filename_replace_target is not None)): 105 | save_name = save_name.replace(filename_replace_source, filename_replace_target) 106 | print(save_name) 107 | save_name = "{0:}/{1:}".format(output_dir, save_name) 108 | save_nd_array_as_image(output[i], save_name, root_dir + '/' + names[i]) 109 | save_name_split = save_name.split('.') 110 | 111 | if('.nii.gz' in save_name): 112 | save_prefix = '.'.join(save_name_split[:-2]) 113 | save_format = 'nii.gz' 114 | else: 115 | save_prefix = '.'.join(save_name_split[:-1]) 116 | save_format = save_name_split[-1] 117 | 118 | if(save_prob): 119 | class_num = prob.shape[1] 120 | for c in range(0, class_num): 121 | temp_prob = prob[i][c] 122 | prob_save_name = "{0:}_prob_{1:}.{2:}".format(save_prefix, c, save_format) 123 | if(len(temp_prob.shape) == 2): 124 | temp_prob = np.asarray(temp_prob * 255, np.uint8) 125 | save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[i]) 126 | 127 | if(save_var): 128 | var = var[i][1] 129 | var_save_name = "{0:}_var.{1:}".format(save_prefix, save_format) 130 | save_nd_array_as_image(var, var_save_name, root_dir + '/' + names[0]) 131 | 132 | def main(): 133 | if(len(sys.argv) < 3): 134 | print('Number of arguments should be 3. e.g.') 135 | print(' python custom_net_run.py train config.cfg') 136 | exit() 137 | stage = str(sys.argv[1]) 138 | cfg_file = str(sys.argv[2]) 139 | config = parse_config(cfg_file) 140 | 141 | # use custormized CNN 142 | agent = CustomSegAgent(config, stage) 143 | net_name = config['network']['net_type'] 144 | if(net_name in net_dict): 145 | net = net_dict[net_name](config['network']) 146 | agent.set_network(net) 147 | agent.run() 148 | else: 149 | raise ValueError("undefined network {0:}".format(net_name)) 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /util/level_set/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/__init__.py -------------------------------------------------------------------------------- /util/level_set/data/a03_04_11img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/data/a03_04_11img.png -------------------------------------------------------------------------------- /util/level_set/data/a03_04_11scrb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/data/a03_04_11scrb.png -------------------------------------------------------------------------------- /util/level_set/data/a03_04_11seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/data/a03_04_11seg.png -------------------------------------------------------------------------------- /util/level_set/data/a10_12_22img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/data/a10_12_22img.png -------------------------------------------------------------------------------- /util/level_set/data/a10_12_22scrb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/data/a10_12_22scrb.png -------------------------------------------------------------------------------- /util/level_set/data/a10_12_22seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/data/a10_12_22seg.png -------------------------------------------------------------------------------- /util/level_set/data/gourd.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/data/gourd.bmp -------------------------------------------------------------------------------- /util/level_set/demo/demo_idrlse.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import scipy.ndimage.filters as filters 6 | from PIL import Image 7 | from scipy import ndimage 8 | from skimage import measure 9 | from mpl_toolkits.mplot3d import Axes3D 10 | from level_set.ls_util.interactive_ls import * 11 | 12 | def refine_dlls(image_name, seg_name, seed_name, display = True, intensity = False): 13 | # read images as gray cale, and normalize the input image 14 | img = Image.open(image_name).convert('L') 15 | seg = Image.open(seg_name).convert('L') 16 | seg = np.asarray(seg, np.float32)/255.0 17 | seed = Image.open(seed_name).convert('L') 18 | seed = np.asarray(seed) 19 | seed_f = seed == 127 20 | seed_b = seed == 255 21 | 22 | params = {} 23 | params['mu'] = 0.003 24 | params['lambda'] = 0.3 25 | params['alpha'] = 0.1 26 | params['beta'] = 0.5 27 | new_seg, runtime = interactive_level_set(img, seg, seed_f, seed_b, params, display, intensity) 28 | 29 | return new_seg, runtime 30 | 31 | def get_result_for_one_case(): 32 | data_root = 'data/' 33 | img_name = 'a03_04_11' #'a10_12_22' 34 | img_full_name = data_root + "{0:}img.png".format(img_name) 35 | seg_full_name = data_root + "{0:}seg.png".format(img_name) 36 | scrb_full_name = data_root + "{0:}scrb.png".format(img_name) 37 | refine_dlls(img_full_name, seg_full_name, scrb_full_name, intensity = False) 38 | 39 | if __name__ == "__main__": 40 | get_result_for_one_case() -------------------------------------------------------------------------------- /util/level_set/ls_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGIR/d4dfd0a92750a60351dd2e34a4f0a926fb6a757e/util/level_set/ls_util/__init__.py -------------------------------------------------------------------------------- /util/level_set/ls_util/drlse_reion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage.filters as filters 3 | 4 | 5 | def del2(M): 6 | dx = 1 7 | dy = 1 8 | rows, cols = M.shape 9 | dx = dx * np.ones((1, cols - 1)) 10 | dy = dy * np.ones((rows - 1, 1)) 11 | 12 | mr, mc = M.shape 13 | D = np.zeros((mr, mc)) 14 | 15 | if (mr >= 3): 16 | ## x direction 17 | ## left and right boundary 18 | D[:, 0] = (M[:, 0] - 2 * M[:, 1] + M[:, 2]) / (dx[:, 0] * dx[:, 1]) 19 | D[:, mc - 1] = (M[:, mc - 3] - 2 * M[:, mc - 2] + M[:, mc - 1]) \ 20 | / (dx[:, mc - 3] * dx[:, mc - 2]) 21 | 22 | ## interior points 23 | tmp1 = D[:, 1:mc - 1] 24 | tmp2 = (M[:, 2:mc] - 2 * M[:, 1:mc - 1] + M[:, 0:mc - 2]) 25 | tmp3 = np.kron(dx[:, 0:mc - 2] * dx[:, 1:mc - 1], np.ones((mr, 1))) 26 | D[:, 1:mc - 1] = tmp1 + tmp2 / tmp3 27 | 28 | if (mr >= 3): 29 | ## y direction 30 | ## top and bottom boundary 31 | D[0, :] = D[0, :] + \ 32 | (M[0, :] - 2 * M[1, :] + M[2, :]) / (dy[0, :] * dy[1, :]) 33 | 34 | D[mr - 1, :] = D[mr - 1, :] \ 35 | + (M[mr - 3, :] - 2 * M[mr - 2, :] + M[mr - 1, :]) \ 36 | / (dy[mr - 3, :] * dx[:, mr - 2]) 37 | 38 | ## interior points 39 | tmp1 = D[1:mr - 1, :] 40 | tmp2 = (M[2:mr, :] - 2 * M[1:mr - 1, :] + M[0:mr - 2, :]) 41 | tmp3 = np.kron(dy[0:mr - 2, :] * dy[1:mr - 1, :], np.ones((1, mc))) 42 | D[1:mr - 1, :] = tmp1 + tmp2 / tmp3 43 | 44 | return D / 4 45 | 46 | 47 | def drlse_edge(phi_0, g, lmda, mu, alfa, epsilon, timestep, iters, potentialFunction): # Updated Level Set Function 48 | """ 49 | 50 | :param phi_0: level set function to be updated by level set evolution 51 | :param g: edge indicator function 52 | :param lmda: weight of the weighted length term 53 | :param mu: weight of distance regularization term 54 | :param alfa: weight of the weighted area term 55 | :param epsilon: width of Dirac Delta function 56 | :param timestep: time step 57 | :param iters: number of iterations 58 | :param potentialFunction: choice of potential function in distance regularization term. 59 | % As mentioned in the above paper, two choices are provided: potentialFunction='single-well' or 60 | % potentialFunction='double-well', which correspond to the potential functions p1 (single-well) 61 | % and p2 (double-well), respectively. 62 | """ 63 | phi = phi_0.copy() 64 | [vy, vx] = np.gradient(g) 65 | for k in range(iters): 66 | phi = NeumannBoundCond(phi) 67 | [phi_y, phi_x] = np.gradient(phi) 68 | s = np.sqrt(np.square(phi_x) + np.square(phi_y)) 69 | smallNumber = 1e-10 70 | Nx = phi_x / (s + smallNumber) # add a small positive number to avoid division by zero 71 | Ny = phi_y / (s + smallNumber) 72 | curvature = div(Nx, Ny) 73 | if potentialFunction == 'single-well': 74 | distRegTerm = filters.laplace(phi, mode='wrap') - curvature # compute distance regularization term in equation (13) with the single-well potential p1. 75 | elif potentialFunction == 'double-well': 76 | distRegTerm = distReg_p2(phi) # compute the distance regularization term in eqaution (13) with the double-well potential p2. 77 | else: 78 | print('Error: Wrong choice of potential function. Please input the string "single-well" or "double-well" in the drlse_edge function.') 79 | diracPhi = Dirac(phi, epsilon) 80 | areaTerm = diracPhi * g # balloon/pressure force 81 | edgeTerm = diracPhi * (vx * Nx + vy * Ny) + diracPhi * g * curvature 82 | phi = phi + timestep * (mu * distRegTerm + lmda * edgeTerm + alfa * areaTerm) 83 | return phi 84 | 85 | def drlse_region(phi_0, I, lmda, mu, alfa, epsilon, timestep, iters, potentialFunction): # Updated Level Set Function 86 | """ 87 | 88 | :param phi_0: level set function to be updated by level set evolution 89 | :param g: edge indicator function 90 | :param lmda: weight of the weighted length term 91 | :param mu: weight of distance regularization term 92 | :param alfa: weight of the weighted area term 93 | :param epsilon: width of Dirac Delta function 94 | :param timestep: time step 95 | :param iters: number of iterations 96 | :param potentialFunction: choice of potential function in distance regularization term. 97 | % As mentioned in the above paper, two choices are provided: potentialFunction='single-well' or 98 | % potentialFunction='double-well', which correspond to the potential functions p1 (single-well) 99 | % and p2 (double-well), respectively. 100 | """ 101 | phi = phi_0.copy() 102 | # [vy, vx] = np.gradient(g) 103 | for k in range(iters): 104 | phi = NeumannBoundCond(phi) 105 | [phi_y, phi_x] = np.gradient(phi) 106 | s = np.sqrt(np.square(phi_x) + np.square(phi_y)) 107 | smallNumber = 1e-10 108 | Nx = phi_x / (s + smallNumber) # add a small positive number to avoid division by zero 109 | Ny = phi_y / (s + smallNumber) 110 | curvature = div(Nx, Ny) 111 | if potentialFunction == 'single-well': 112 | distRegTerm = filters.laplace(phi, mode='wrap') - curvature # compute distance regularization term in equation (13) with the single-well potential p1. 113 | elif potentialFunction == 'double-well': 114 | distRegTerm = distReg_p2(phi) # compute the distance regularization term in eqaution (13) with the double-well potential p2. 115 | else: 116 | print('Error: Wrong choice of potential function. Please input the string "single-well" or "double-well" in the drlse_edge function.') 117 | diracPhi = Dirac(phi, epsilon) 118 | # areaTerm = diracPhi * g # balloon/pressure force 119 | # edgeTerm = diracPhi * (vx * Nx + vy * Ny) + diracPhi * g * curvature 120 | mean_in = I[phi>0].mean() 121 | mean_out = I[phi<0].mean() 122 | areaTerm = np.square(I - mean_in) - np.square(I - mean_out) 123 | areaTerm = diracPhi * (areaTerm/abs(areaTerm.max())) 124 | edgeTerm = diracPhi * curvature 125 | phi = phi + timestep * (mu * distRegTerm + lmda * edgeTerm + alfa * areaTerm) 126 | return phi 127 | 128 | def drlse_region_interaction(phi_0, I, P,lmda, mu, alfa, beta, 129 | epsilon, timestep, iters, potentialFunction): # Updated Level Set Function 130 | """ 131 | 132 | :param phi_0: level set function to be updated by level set evolution 133 | :param I: the input image to be segmented 134 | :param P: the probability of being foreground, based on user interaction 135 | :param lmda: weight of the weighted length term 136 | :param mu: weight of distance regularization term 137 | :param alfa: weight of the weighted area term 138 | :param beta: weight of user interaction term 139 | :param epsilon: width of Dirac Delta function 140 | :param timestep: time step 141 | :param iters: number of iterations 142 | :param potentialFunction: choice of potential function in distance regularization term. 143 | % As mentioned in the above paper, two choices are provided: potentialFunction='single-well' or 144 | % potentialFunction='double-well', which correspond to the potential functions p1 (single-well) 145 | % and p2 (double-well), respectively. 146 | """ 147 | phi = phi_0.copy() 148 | # lu = (P - 0.5)*(P - 0.5)*(P - ) 149 | lu = np.log(P) - np.log(1.0-P) 150 | # [vy, vx] = np.gradient(g) 151 | for k in range(iters): 152 | phi = NeumannBoundCond(phi) 153 | [phi_y, phi_x] = np.gradient(phi) 154 | s = np.sqrt(np.square(phi_x) + np.square(phi_y)) 155 | smallNumber = 1e-10 156 | Nx = phi_x / (s + smallNumber) # add a small positive number to avoid division by zero 157 | Ny = phi_y / (s + smallNumber) 158 | curvature = div(Nx, Ny) 159 | if potentialFunction == 'single-well': 160 | distRegTerm = filters.laplace(phi, mode='wrap') - curvature # compute distance regularization term in equation (13) with the single-well potential p1. 161 | elif potentialFunction == 'double-well': 162 | distRegTerm = distReg_p2(phi) # compute the distance regularization term in eqaution (13) with the double-well potential p2. 163 | else: 164 | print('Error: Wrong choice of potential function. Please input the string "single-well" or "double-well" in the drlse_edge function.') 165 | diracPhi = Dirac(phi, epsilon) 166 | # areaTerm = diracPhi * g # balloon/pressure force 167 | # edgeTerm = diracPhi * (vx * Nx + vy * Ny) + diracPhi * g * curvature 168 | mean_in = I[phi>0].mean() 169 | mean_out = I[phi<0].mean() 170 | 171 | areaTerm = np.square(I - mean_in) - np.square(I - mean_out) 172 | areaTerm = diracPhi * (areaTerm/abs(areaTerm.max())) 173 | 174 | edgeTerm = diracPhi * curvature 175 | phi = phi + timestep * (mu * distRegTerm + lmda * edgeTerm - alfa * areaTerm + beta * lu) 176 | return phi 177 | 178 | def drlse_region_interaction2(phi_0, I, D,lmda, mu, alfa, 179 | epsilon, timestep, iters, potentialFunction): # Updated Level Set Function 180 | """ 181 | 182 | :param phi_0: level set function to be updated by level set evolution 183 | :param I: the input image to be segmented 184 | :param D: the distance transform of seed points 185 | :param lmda: weight of the weighted length term 186 | :param mu: weight of distance regularization term 187 | :param alfa: weight of the weighted area term 188 | :param beta: weight of user interaction term 189 | :param epsilon: width of Dirac Delta function 190 | :param timestep: time step 191 | :param iters: number of iterations 192 | :param potentialFunction: choice of potential function in distance regularization term. 193 | % As mentioned in the above paper, two choices are provided: potentialFunction='single-well' or 194 | % potentialFunction='double-well', which correspond to the potential functions p1 (single-well) 195 | % and p2 (double-well), respectively. 196 | """ 197 | phi = phi_0.copy() 198 | g = D 199 | # lu = (P - 0.5)*(P - 0.5)*(P - ) 200 | # lu = np.log(P) - np.log(1-P) 201 | [vy, vx] = np.gradient(g) 202 | for k in range(iters): 203 | phi = NeumannBoundCond(phi) 204 | [phi_y, phi_x] = np.gradient(phi) 205 | s = np.sqrt(np.square(phi_x) + np.square(phi_y)) 206 | smallNumber = 1e-10 207 | Nx = phi_x / (s + smallNumber) # add a small positive number to avoid division by zero 208 | Ny = phi_y / (s + smallNumber) 209 | curvature = div(Nx, Ny) 210 | if potentialFunction == 'single-well': 211 | distRegTerm = filters.laplace(phi, mode='wrap') - curvature # compute distance regularization term in equation (13) with the single-well potential p1. 212 | elif potentialFunction == 'double-well': 213 | distRegTerm = distReg_p2(phi) # compute the distance regularization term in eqaution (13) with the double-well potential p2. 214 | else: 215 | print('Error: Wrong choice of potential function. Please input the string "single-well" or "double-well" in the drlse_edge function.') 216 | diracPhi = Dirac(phi, epsilon) 217 | # areaTerm = diracPhi * g # balloon/pressure force 218 | edgeTerm = diracPhi * (vx * Nx + vy * Ny) + diracPhi * g * curvature 219 | mean_in = I[phi>0].mean() 220 | mean_out = I[phi<0].mean() 221 | areaTerm = np.square(I - mean_in) - np.square(I - mean_out) 222 | areaTerm = diracPhi * (areaTerm/abs(areaTerm.max())) 223 | # edgeTerm = diracPhi * curvature 224 | phi = phi + timestep * (mu * distRegTerm + lmda * edgeTerm - alfa * areaTerm ) 225 | return phi 226 | 227 | def drlse_region_edge_interaction(phi_0, I, G, P, lmda, mu, alfa, beta, 228 | epsilon, timestep, iters, potentialFunction): # Updated Level Set Function 229 | """ 230 | 231 | :param phi_0: level set function to be updated by level set evolution 232 | :param I: the input image to be segmented 233 | :param G: the gradient of image 234 | :param P: the probability of being foreground, based on user interaction 235 | :param lmda: weight of the weighted length term (edge term) 236 | :param mu: weight of distance regularization term 237 | :param alfa: weight of the weighted area term 238 | :param beta: weight of user interaction term 239 | :param epsilon: width of Dirac Delta function 240 | :param timestep: time step 241 | :param iters: number of iterations 242 | :param potentialFunction: choice of potential function in distance regularization term. 243 | % As mentioned in the above paper, two choices are provided: potentialFunction='single-well' or 244 | % potentialFunction='double-well', which correspond to the potential functions p1 (single-well) 245 | % and p2 (double-well), respectively. 246 | """ 247 | phi = phi_0.copy() 248 | lu = np.log(P) - np.log(1-P) 249 | [vy, vx] = np.gradient(G) 250 | for k in range(iters): 251 | phi = NeumannBoundCond(phi) 252 | [phi_y, phi_x] = np.gradient(phi) 253 | s = np.sqrt(np.square(phi_x) + np.square(phi_y)) 254 | smallNumber = 1e-10 255 | Nx = phi_x / (s + smallNumber) # add a small positive number to avoid division by zero 256 | Ny = phi_y / (s + smallNumber) 257 | curvature = div(Nx, Ny) 258 | if potentialFunction == 'single-well': 259 | distRegTerm = filters.laplace(phi, mode='wrap') - curvature # compute distance regularization term in equation (13) with the single-well potential p1. 260 | elif potentialFunction == 'double-well': 261 | distRegTerm = distReg_p2(phi) # compute the distance regularization term in eqaution (13) with the double-well potential p2. 262 | else: 263 | print('Error: Wrong choice of potential function. Please input the string "single-well" or "double-well" in the drlse_edge function.') 264 | diracPhi = Dirac(phi, epsilon) 265 | # areaTerm = diracPhi * g # balloon/pressure force 266 | edgeTerm = diracPhi * (vx * Nx + vy * Ny) + diracPhi * G * curvature 267 | # edgeTerm = diracPhi * curvature 268 | mean_in = I[phi>0].mean() 269 | mean_out = I[phi<0].mean() 270 | areaTerm = np.square(I - mean_in) - np.square(I - mean_out) 271 | areaTerm = diracPhi * (areaTerm/abs(areaTerm.max())) 272 | # areaTerm = diracPhi * G 273 | phi = phi + timestep * (mu * distRegTerm + lmda * edgeTerm - alfa * areaTerm + beta * lu) 274 | return phi 275 | 276 | def distReg_p2(phi): 277 | """ 278 | compute the distance regularization term with the double-well potential p2 in equation (16) 279 | """ 280 | [phi_y, phi_x] = np.gradient(phi) 281 | s = np.sqrt(np.square(phi_x) + np.square(phi_y)) 282 | a = (s >= 0) & (s <= 1) 283 | b = (s > 1) 284 | ps = a * np.sin(2 * np.pi * s) / (2 * np.pi) + b * (s - 1) # compute first order derivative of the double-well potential p2 in equation (16) 285 | dps = ((ps != 0) * ps + (ps == 0)) / ((s != 0) * s + (s == 0)) # compute d_p(s)=p'(s)/s in equation (10). As s-->0, we have d_p(s)-->1 according to equation (18) 286 | return div(dps * phi_x - phi_x, dps * phi_y - phi_y) + filters.laplace(phi, mode='wrap') 287 | 288 | 289 | def div(nx, ny): 290 | [junk, nxx] = np.gradient(nx) 291 | [nyy, junk] = np.gradient(ny) 292 | return nxx + nyy 293 | 294 | 295 | def Dirac(x, sigma): 296 | f = (1.0 / 2 / sigma) * (1 + np.cos(np.pi * x / sigma)) 297 | b = (x <= sigma) & (x >= -sigma) 298 | return f * b 299 | 300 | 301 | def NeumannBoundCond(f): 302 | """ 303 | Make a function satisfy Neumann boundary condition 304 | """ 305 | [ny, nx] = f.shape 306 | g = f.copy() 307 | 308 | g[0, 0] = g[2, 2] 309 | g[0, nx-1] = g[2, nx-3] 310 | g[ny-1, 0] = g[ny-3, 2] 311 | g[ny-1, nx-1] = g[ny-3, nx-3] 312 | 313 | g[0, 1:-1] = g[2, 1:-1] 314 | g[ny-1, 1:-1] = g[ny-3, 1:-1] 315 | 316 | g[1:-1, 0] = g[1:-1, 2] 317 | g[1:-1, nx-1] = g[1:-1, nx-3] 318 | 319 | return g 320 | -------------------------------------------------------------------------------- /util/level_set/ls_util/get_gradient.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | import scipy.ndimage.filters as filters 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | img_name = 'drlse_region/a13_08_35seg_cut_edge.png' 8 | img = Image.open(img_name).convert('L') 9 | img = np.asarray(img, np.float32) 10 | img = (img - img.mean())/img.std() 11 | 12 | sigma = 1.5 # scale parameter in Gaussian kernel 13 | img_smooth = filters.gaussian_filter(img, sigma) # smooth image by Gaussian convolution 14 | [Iy, Ix] = np.gradient(img_smooth) 15 | f = np.square(Ix) + np.square(Iy) 16 | g = 1 / (1+f) # edge indicator function. 17 | plt.imshow(g) 18 | plt.show() -------------------------------------------------------------------------------- /util/level_set/ls_util/interactive_ls.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import GeodisTK 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import scipy.ndimage.filters as filters 7 | from PIL import Image 8 | from scipy import ndimage 9 | from skimage import measure 10 | from mpl_toolkits.mplot3d import Axes3D 11 | from level_set.ls_util.drlse_reion import * 12 | 13 | def show_leve_set(fig, phi): 14 | ax1 = fig.add_subplot(111, projection='3d') 15 | y, x = phi.shape 16 | x = np.arange(0, x, 1) 17 | y = np.arange(0, y, 1) 18 | X, Y = np.meshgrid(x, y) 19 | ax1.plot_surface(X, Y, phi, rstride=2, cstride=2, color='r', linewidth=0, alpha=0.6, antialiased=True) 20 | ax1.contour(X, Y, phi, 0, colors='g', linewidths=2) 21 | 22 | def show_image_and_segmentation(fig, img, contours, seeds = None): 23 | ax2 = fig.add_subplot(111) 24 | ax2.imshow(img, interpolation='nearest', cmap=plt.cm.gray) 25 | for n, contour in enumerate(contours): 26 | ax2.plot(contour[:, 1], contour[:, 0], linewidth=2, color='green') 27 | if(seeds is not None): 28 | h_idx, w_idx = np.where(seeds[0] > 0) 29 | ax2.plot(w_idx, h_idx, linewidth=2, color='red') 30 | h_idx, w_idx = np.where(seeds[1] > 0) 31 | ax2.plot(w_idx, h_idx, linewidth=2, color='blue') 32 | ax2.axis('off') 33 | 34 | def get_distance_based_likelihood(img, seed, D): 35 | if(seed.sum() > 0): 36 | geoD = GeodisTK.geodesic2d_raster_scan(img, seed, 0.1, 2) 37 | geoD[geoD > D] = D 38 | else: 39 | geoD = np.ones_like(img)*D 40 | geoD = np.exp(-geoD) 41 | return geoD 42 | 43 | def interactive_level_set(img, seg, seed_f, seed_b, param, display = True, intensity = False): 44 | """ 45 | Refine an initial segmentation with interaction based level set 46 | Params: 47 | img: a 2D image array 48 | sed: a 2D image array representing the intial binary segmentation 49 | seed_f: a binary array representing the existence of foreground scribbles 50 | seed_b: a binary array representing the existence of background scribbles 51 | display: a bool value, whether display the segmentation result 52 | intensity: a bool value, whether define the region term based on intensity 53 | """ 54 | img = np.asarray(img, np.float32) 55 | img = (img - img.mean())/img.std() 56 | seg = np.asarray(seg, np.float32) 57 | Df = get_distance_based_likelihood(img, seed_f, 4) 58 | Db = get_distance_based_likelihood(img, seed_b, 4) 59 | 60 | Pfexp = np.exp(Df); Pbexp = np.exp(Db) 61 | Pf = Pfexp / (Pfexp + Pbexp) 62 | # if(display): 63 | # plt.subplot(1,3,1) 64 | # plt.imshow(Df) 65 | # plt.subplot(1,3,2) 66 | # plt.imshow(Db) 67 | # plt.subplot(1,3,3) 68 | # plt.imshow(Pf) 69 | # plt.show() 70 | 71 | [H, D] = img.shape 72 | zoom = [64.0/H, 64.0/D] 73 | img_d = ndimage.interpolation.zoom(img, zoom) 74 | seg_d = ndimage.interpolation.zoom(seg, zoom) 75 | Pf_d = ndimage.interpolation.zoom(Pf, zoom) 76 | if(intensity is True): 77 | print("use intensity") 78 | ls_img = img_d 79 | else: 80 | print("use segmentation") 81 | ls_img = seg_d 82 | 83 | # parameters 84 | timestep = 1 # time step 85 | iter_inner = 50 86 | iter_outer_max = 10 87 | mu = param['mu']/timestep # coefficient of the distance regularization term R(phi) 88 | lmda = param['lambda'] # coefficient of the weighted length term L(phi) 89 | alfa = param['alpha'] # coefficient of the weighted area term A(phi) 90 | beta = param['beta'] # coefficient for user interactin term 91 | epsilon = 1.5 # parameter that specifies the width of the DiracDelta function 92 | # initialize LSF as binary step function 93 | # the level set has positive value inside the contour and negative value outside 94 | # this is opposite to DRLSE 95 | c0 = 20 96 | initialLSF = -c0 * np.ones(seg_d.shape) 97 | initialLSF[seg_d > 0.5] = c0 98 | phi = initialLSF.copy() 99 | 100 | t0 = time.time() 101 | # start level set evolution 102 | seg_size0 = np.asarray(phi > 0).sum() 103 | for n in range(iter_outer_max): 104 | phi = drlse_region_interaction(phi, ls_img, Pf_d, lmda, mu, alfa, beta, epsilon, timestep, iter_inner, 'double-well') 105 | seg_size = np.asarray(phi > 0).sum() 106 | ratio = (seg_size - seg_size0)/float(seg_size0) 107 | if(abs(ratio) < 1e-3): 108 | print('iteration', n*iter_inner, ratio) 109 | break 110 | else: 111 | seg_size0 = seg_size 112 | runtime = time.time() - t0 113 | print('iteration', (n + 1)*iter_inner) 114 | print('running time', runtime) 115 | 116 | 117 | finalLSF = phi.copy() 118 | finalLSF = ndimage.interpolation.zoom(finalLSF, [1.0/item for item in zoom]) 119 | if(display): 120 | plt.ion() 121 | fig1 = plt.figure(1) 122 | fig2 = plt.figure(2) 123 | fig3 = plt.figure(3) 124 | 125 | fig1.clf() 126 | init_contours = measure.find_contours(seg, 0.5) 127 | show_image_and_segmentation(fig1, img, init_contours, [seed_f, seed_b]) 128 | fig1.suptitle("(a) Initial Segmentation") 129 | # fig1.savefig("init_seg.png") 130 | 131 | fig2.clf() 132 | final_contours = measure.find_contours(finalLSF, 0) 133 | show_image_and_segmentation(fig2, img, final_contours) 134 | fig2.suptitle("(b) Refined Result") 135 | # fig2.savefig("refine_seg.png") 136 | 137 | fig3.clf() 138 | show_leve_set(fig3, finalLSF) 139 | fig3.suptitle("(c) Final Level Set Function") 140 | # fig3.savefig("levelset_func.png") 141 | plt.pause(10) 142 | plt.show() 143 | return finalLSF > 0, runtime 144 | 145 | -------------------------------------------------------------------------------- /util/network/MGNet.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from pymic.layer.activation import get_acti_func 6 | from pymic.layer.convolution import ConvolutionLayer, DepthSeperableConvolutionLayer 7 | from pymic.layer.deconvolution import DeconvolutionLayer, DepthSeperableDeconvolutionLayer 8 | from network.unet2dres import get_acti_func, get_deconv_layer, get_unet_block 9 | 10 | def interleaved_concate(f1, f2): 11 | f1_shape = list(f1.shape) 12 | f2_shape = list(f2.shape) 13 | c1 = f1_shape[1] 14 | c2 = f2_shape[1] 15 | 16 | f1_shape_new = f1_shape[:1] + [c1, 1] + f1_shape[2:] 17 | f2_shape_new = f2_shape[:1] + [c2, 1] + f2_shape[2:] 18 | 19 | f1_reshape = torch.reshape(f1, f1_shape_new) 20 | f2_reshape = torch.reshape(f2, f2_shape_new) 21 | output = torch.cat((f1_reshape, f2_reshape), dim = 2) 22 | out_shape = f1_shape[:1] + [c1 + c2] + f1_shape[2:] 23 | output = torch.reshape(output, out_shape) 24 | return output 25 | 26 | class MGNet(nn.Module): 27 | def __init__(self, params): 28 | super(MGNet, self).__init__() 29 | self.params = params 30 | self.in_chns = self.params['in_chns'] 31 | self.ft_chns = self.params['feature_chns'] 32 | self.ft_groups = self.params['feature_grps'] 33 | self.norm_type = self.params['norm_type'] 34 | self.block_type= self.params['block_type'] 35 | self.n_class = self.params['class_num'] 36 | self.acti_func = self.params['acti_func'] 37 | self.dropout = self.params['dropout'] 38 | self.depth_sep_deconv= self.params['depth_sep_deconv'] 39 | self.deep_spv = self.params['deep_supervision'] 40 | self.resolution_level = len(self.ft_chns) 41 | assert(self.resolution_level == 5 or self.resolution_level == 4) 42 | 43 | Block = get_unet_block(self.block_type) 44 | self.block1 = Block(self.in_chns, self.ft_chns[0], self.norm_type, self.ft_groups[0], 45 | self.acti_func, self.params) 46 | 47 | self.block2 = Block(self.ft_chns[0], self.ft_chns[1], self.norm_type, self.ft_groups[1], 48 | self.acti_func, self.params) 49 | 50 | self.block3 = Block(self.ft_chns[1], self.ft_chns[2], self.norm_type, self.ft_groups[2], 51 | self.acti_func, self.params) 52 | 53 | self.block4 = Block(self.ft_chns[2], self.ft_chns[3], self.norm_type, self.ft_groups[3], 54 | self.acti_func, self.params) 55 | 56 | if(self.resolution_level == 5): 57 | self.block5 = Block(self.ft_chns[3], self.ft_chns[4], self.norm_type, self.ft_groups[4], 58 | self.acti_func, self.params) 59 | 60 | self.block6 = Block(self.ft_chns[3] * 2, self.ft_chns[3], self.norm_type, self.ft_groups[3], 61 | self.acti_func, self.params) 62 | 63 | self.block7 = Block(self.ft_chns[2] * 2, self.ft_chns[2], self.norm_type, self.ft_groups[2], 64 | self.acti_func, self.params) 65 | 66 | self.block8 = Block(self.ft_chns[1] * 2, self.ft_chns[1], self.norm_type, self.ft_groups[1], 67 | self.acti_func, self.params) 68 | 69 | self.block9= Block(self.ft_chns[0] * 2, self.ft_chns[0], self.norm_type, self.ft_groups[0], 70 | self.acti_func, self.params) 71 | 72 | 73 | self.down1 = nn.MaxPool2d(kernel_size = 2) 74 | self.down2 = nn.MaxPool2d(kernel_size = 2) 75 | self.down3 = nn.MaxPool2d(kernel_size = 2) 76 | 77 | DeconvLayer = get_deconv_layer(self.depth_sep_deconv) 78 | if(self.resolution_level == 5): 79 | self.down4 = nn.MaxPool2d(kernel_size = 2) 80 | self.up1 = DeconvLayer(self.ft_chns[4], self.ft_chns[3], kernel_size = 2, 81 | dim = 2, stride = 2, groups = self.ft_groups[3], acti_func = get_acti_func(self.acti_func, self.params)) 82 | self.up2 = DeconvLayer(self.ft_chns[3], self.ft_chns[2], kernel_size = 2, 83 | dim = 2, stride = 2, groups = self.ft_groups[2], acti_func = get_acti_func(self.acti_func, self.params)) 84 | self.up3 = DeconvLayer(self.ft_chns[2], self.ft_chns[1], kernel_size = 2, 85 | dim = 2, stride = 2, groups = self.ft_groups[1], acti_func = get_acti_func(self.acti_func, self.params)) 86 | self.up4 = DeconvLayer(self.ft_chns[1], self.ft_chns[0], kernel_size = 2, 87 | dim = 2, stride = 2, groups = self.ft_groups[0], acti_func = get_acti_func(self.acti_func, self.params)) 88 | 89 | if(self.dropout): 90 | self.drop1 = nn.Dropout(p=0.1) 91 | self.drop2 = nn.Dropout(p=0.2) 92 | self.drop3 = nn.Dropout(p=0.3) 93 | self.drop4 = nn.Dropout(p=0.4) 94 | if(self.resolution_level == 5): 95 | self.drop5 = nn.Dropout(p=0.5) 96 | 97 | self.conv9= nn.Conv2d(self.ft_chns[0], self.n_class * self.ft_groups[0], 98 | kernel_size = 3, padding = 1, groups = self.ft_groups[0]) 99 | 100 | def forward(self, x): 101 | x_shape = list(x.shape) 102 | if(len(x_shape)==5): 103 | [N, C, D, H, W] = x_shape 104 | new_shape = [N*D, C, H, W] 105 | x = torch.transpose(x, 1, 2) 106 | x = torch.reshape(x, new_shape) 107 | f1 = self.block1(x) 108 | if(self.dropout): 109 | f1 = self.drop1(f1) 110 | d1 = self.down1(f1) 111 | 112 | f2 = self.block2(d1) 113 | if(self.dropout): 114 | f2 = self.drop2(f2) 115 | d2 = self.down2(f2) 116 | 117 | f3 = self.block3(d2) 118 | if(self.dropout): 119 | f3 = self.drop3(f3) 120 | d3 = self.down3(f3) 121 | 122 | f4 = self.block4(d3) 123 | if(self.dropout): 124 | f4 = self.drop4(f4) 125 | 126 | if(self.resolution_level == 5): 127 | d4 = self.down4(f4) 128 | f5 = self.block5(d4) 129 | if(self.dropout): 130 | f5 = self.drop5(f5) 131 | 132 | f5up = self.up1(f5) 133 | f4cat = interleaved_concate(f4, f5up) 134 | f6 = self.block6(f4cat) 135 | f6up = self.up2(f6) 136 | f3cat = interleaved_concate(f3, f6up) 137 | else: 138 | f4up = self.up2(f4) 139 | f3cat = interleaved_concate(f3, f4up) 140 | f7 = self.block7(f3cat) 141 | f7up = self.up3(f7) 142 | 143 | f2cat = interleaved_concate(f2, f7up) 144 | f8 = self.block8(f2cat) 145 | f8up = self.up4(f8) 146 | 147 | f1cat = interleaved_concate(f1, f8up) 148 | f9 = self.block9(f1cat) 149 | 150 | output = self.conv9(f9) 151 | 152 | if(len(x_shape)==5): 153 | new_shape = [N, D] + list(output.shape)[1:] 154 | output = torch.reshape(output, new_shape) 155 | output = torch.transpose(output, 1, 2) 156 | 157 | output_list = torch.chunk(output, self.ft_groups[0], dim = 1) 158 | return output_list 159 | -------------------------------------------------------------------------------- /util/network/unet2dres.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from pymic.layer.activation import get_acti_func 9 | from pymic.layer.convolution import ConvolutionLayer, DepthSeperableConvolutionLayer 10 | from pymic.layer.deconvolution import DeconvolutionLayer, DepthSeperableDeconvolutionLayer 11 | 12 | def channel_shuffle(x, groups): 13 | B, C, H, W = x.data.size() 14 | channels_per_group = C // groups 15 | 16 | # reshape 17 | x = x.view(B, groups, channels_per_group, H, W) 18 | x = torch.transpose(x, 1, 2).contiguous() 19 | 20 | # flatten 21 | x = x.view(B, -1, H, W) 22 | return x 23 | 24 | class UNetBlock(nn.Module): 25 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 26 | super(UNetBlock, self).__init__() 27 | 28 | self.in_chns = in_channels 29 | self.out_chns = out_channels 30 | self.acti_func = acti_func 31 | 32 | group1 = 1 if (in_channels < 8) else groups 33 | self.conv1 = ConvolutionLayer(in_channels, out_channels, 1, 34 | dim = 2, padding = 0, conv_group = group1, norm_type = norm_type, norm_group = group1, 35 | acti_func=get_acti_func(acti_func, acti_func_param)) 36 | self.conv2 = ConvolutionLayer(out_channels, out_channels, 3, 37 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 38 | acti_func=get_acti_func(acti_func, acti_func_param)) 39 | 40 | def forward(self, x): 41 | f1 = self.conv1(x) 42 | f2 = self.conv2(f1) 43 | return f2 44 | 45 | class UNetBlock_DW(nn.Module): 46 | """UNet block with depthwise seperable convolution 47 | """ 48 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 49 | super(UNetBlock_DW, self).__init__() 50 | self.in_chns = in_channels 51 | self.out_chns = out_channels 52 | self.acti_func = acti_func 53 | self.groups = groups 54 | 55 | self.conv1 = DepthSeperableConvolutionLayer(in_channels, out_channels, 3, 56 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 57 | acti_func=get_acti_func(acti_func, acti_func_param)) 58 | self.conv2 = DepthSeperableConvolutionLayer(out_channels, out_channels, 3, 59 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 60 | acti_func=get_acti_func(acti_func, acti_func_param)) 61 | 62 | def forward(self, x): 63 | f1 = self.conv1(x) 64 | f2 = self.conv2(f1) 65 | return f2 66 | 67 | class UNetBlock_DW_CF(UNetBlock_DW): 68 | """UNet block with depthwise seperable convolution 69 | """ 70 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 71 | super(UNetBlock_DW_CF, self).__init__(in_channels, out_channels, norm_type, groups, acti_func, acti_func_param) 72 | 73 | def forward(self, x): 74 | f1 = self.conv1(x) 75 | if(self.groups > 1): 76 | f1 = channel_shuffle(f1, groups = self.groups) 77 | f2 = self.conv2(f1) 78 | if(self.groups > 1): 79 | f2 = channel_shuffle(f2, groups = int(self.out_chns / self.groups)) 80 | return f2 81 | 82 | class UNetBlock_DW_CF_Res(UNetBlock_DW): 83 | """UNet block with depthwise seperable convolution 84 | """ 85 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 86 | super(UNetBlock_DW_CF_Res, self).__init__(in_channels, out_channels, norm_type, groups, acti_func, acti_func_param) 87 | 88 | def forward(self, x): 89 | f1 = self.conv1(x) 90 | if(self.groups > 1): 91 | f1 = channel_shuffle(f1, groups = self.groups) 92 | f2 = self.conv2(f1) 93 | if(self.groups > 1): 94 | f2 = channel_shuffle(f2, groups = int(self.out_chns / self.groups)) 95 | return f1 + f2 96 | 97 | class VanillaBlock(nn.Module): 98 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 99 | super(VanillaBlock, self).__init__() 100 | 101 | self.in_chns = in_channels 102 | self.out_chns = out_channels 103 | self.acti_func = acti_func 104 | 105 | group1 = 1 if (in_channels < 8) else groups 106 | self.conv1 = ConvolutionLayer(in_channels, out_channels, 1, 107 | dim = 2, padding = 0, conv_group = group1, norm_type = norm_type, norm_group = group1, 108 | acti_func=get_acti_func(acti_func, acti_func_param)) 109 | self.conv2 = ConvolutionLayer(out_channels, out_channels, 3, 110 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 111 | acti_func=get_acti_func(acti_func, acti_func_param)) 112 | self.conv3 = ConvolutionLayer(out_channels, out_channels, 3, 113 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 114 | acti_func=get_acti_func(acti_func, acti_func_param)) 115 | 116 | def forward(self, x): 117 | f1 = self.conv1(x) 118 | f2 = self.conv2(f1) 119 | f3 = self.conv3(f2) 120 | return f3 121 | 122 | class ResBlock(VanillaBlock): 123 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 124 | super(ResBlock, self).__init__(in_channels, out_channels, norm_type, groups, acti_func, acti_func_param) 125 | 126 | def forward(self, x): 127 | f1 = self.conv1(x) 128 | f2 = self.conv2(f1) 129 | f3 = self.conv3(f2) 130 | return f1 + f3 131 | 132 | class ResBlock_DW(nn.Module): 133 | """UNet block with depthwise seperable convolution 134 | """ 135 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 136 | super(ResBlock_DW, self).__init__() 137 | self.in_chns = in_channels 138 | self.out_chns = out_channels 139 | self.acti_func = acti_func 140 | self.groups = groups 141 | 142 | self.conv1 = ConvolutionLayer(in_channels, out_channels, 1, 143 | dim = 2, padding = 0, conv_group = 1, norm_type = norm_type, norm_group = 1, 144 | acti_func=get_acti_func(acti_func, acti_func_param)) 145 | self.conv2 = DepthSeperableConvolutionLayer(out_channels, out_channels, 3, 146 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 147 | acti_func=get_acti_func(acti_func, acti_func_param)) 148 | self.conv3 = DepthSeperableConvolutionLayer(out_channels, out_channels, 3, 149 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 150 | acti_func=get_acti_func(acti_func, acti_func_param)) 151 | 152 | def forward(self, x): 153 | f1 = self.conv1(x) 154 | f2 = self.conv2(f1) 155 | f3 = self.conv3(f2) 156 | return f1 + f3 157 | 158 | class ResBlock_DWGC_CF(nn.Module): 159 | """UNet block with depthwise seperable convolution and group convolution + channel shuffle 160 | """ 161 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 162 | super(ResBlock_DWGC_CF, self).__init__() 163 | self.in_chns = in_channels 164 | self.out_chns = out_channels 165 | self.acti_func = acti_func 166 | self.groups = groups 167 | groups2 = int(out_channels / groups) 168 | self.conv1 = ConvolutionLayer(in_channels, out_channels, 1, 169 | dim = 2, padding = 0, conv_group = 1, norm_type = norm_type, norm_group = 1, 170 | acti_func=get_acti_func(acti_func, acti_func_param)) 171 | self.conv2 = DepthSeperableConvolutionLayer(out_channels, out_channels, 3, 172 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 173 | acti_func=get_acti_func(acti_func, acti_func_param)) 174 | self.conv3 = DepthSeperableConvolutionLayer(out_channels, out_channels, 3, 175 | dim = 2, padding = 1, conv_group = groups2, norm_type = norm_type, norm_group = groups2, 176 | acti_func=get_acti_func(acti_func, acti_func_param)) 177 | 178 | def forward(self, x): 179 | f1 = self.conv1(x) 180 | f2 = self.conv2(f1) 181 | if(self.groups > 1): 182 | f2 = channel_shuffle(f2, groups = self.groups) 183 | f3 = self.conv3(f2) 184 | if(self.groups > 1): 185 | f3 = channel_shuffle(f3, groups = int(self.out_chns / self.groups)) 186 | return f1 + f3 187 | 188 | class PEBlock(nn.Module): 189 | def __init__(self, channels, acti_func, acti_func_param): 190 | super(PEBlock, self).__init__() 191 | 192 | self.channels = channels 193 | self.acti_func = acti_func 194 | 195 | self.conv1 = ConvolutionLayer(channels, int(channels / 2), 1, 196 | dim = 2, padding = 0, conv_group = 1, norm_type = None, norm_group = 1, 197 | acti_func=get_acti_func(acti_func, acti_func_param)) 198 | self.conv2 = ConvolutionLayer(int(channels / 2), channels, 1, 199 | dim = 2, padding = 0, conv_group = 1, norm_type = None, norm_group = 1, 200 | acti_func=nn.Sigmoid()) 201 | 202 | def forward(self, x): 203 | # projection along each dimension 204 | x_shape = list(x.shape) 205 | [N, C, H, W] = x_shape 206 | p_w = torch.sum(x, dim = -1, keepdim = True) / W # the shape becomes [N, C, H, 1] 207 | p_h = torch.sum(x, dim = -2, keepdim = True) / H # the shape becomes [N, C, 1, W] 208 | p_w_repeat = p_w.repeat(1, 1, 1, W) # the shape is [N, C, H, W] 209 | p_h_repeat = p_h.repeat(1, 1, H, 1) # the shape is [N, C, H, W] 210 | f = p_w_repeat + p_h_repeat 211 | f = self.conv1(f) 212 | f = self.conv2(f) # get attention coefficient 213 | out = f*x + x # use a residual connection 214 | return out 215 | 216 | class ResBlock_DWGC_CF_PE(ResBlock_DW): 217 | """UNet block with depthwise seperable convolution and group convolution + channel shuffle 218 | """ 219 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 220 | super(ResBlock_DWGC_CF_PE, self).__init__(in_channels, out_channels, 221 | norm_type, groups, acti_func, acti_func_param) 222 | self.pe_block = PEBlock(out_channels, acti_func, acti_func_param) 223 | 224 | def forward(self, x): 225 | f1 = self.conv1(x) 226 | if(self.groups > 1): 227 | f1 = channel_shuffle(f1, groups = self.groups) 228 | f2 = self.conv2(f1) 229 | f3 = self.conv3(f2) 230 | if(self.groups > 1): 231 | f3 = channel_shuffle(f3, groups = self.groups) 232 | out = f1 + f3 233 | out = self.pe_block(out) 234 | return out 235 | 236 | class ResBlock_DWGC_CF_BE(nn.Module): 237 | """UNet block with depthwise seperable convolution and group convolution + channel shuffle 238 | """ 239 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 240 | super(ResBlock_DWGC_CF_BE, self).__init__() 241 | 242 | self.in_chns = in_channels 243 | self.out_chns = out_channels 244 | self.acti_func = acti_func 245 | self.groups = groups 246 | 247 | self.conv1 = ConvolutionLayer(in_channels, out_channels, 1, 248 | dim = 2, padding = 0, conv_group = groups, norm_type = norm_type, norm_group = groups, 249 | acti_func=get_acti_func(acti_func, acti_func_param)) 250 | self.conv2 = DepthSeperableConvolutionLayer(out_channels, out_channels*2, 3, 251 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 252 | acti_func=get_acti_func(acti_func, acti_func_param)) 253 | self.conv3 = DepthSeperableConvolutionLayer(out_channels*2, out_channels, 3, 254 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 255 | acti_func=get_acti_func(acti_func, acti_func_param)) 256 | 257 | def forward(self, x): 258 | f1 = self.conv1(x) 259 | if(self.groups > 1): 260 | f1 = channel_shuffle(f1, groups = self.groups) 261 | f2 = self.conv2(f1) 262 | f3 = self.conv3(f2) 263 | if(self.groups > 1): 264 | f3 = channel_shuffle(f3, groups = self.groups) 265 | return f1 + f3 266 | 267 | 268 | 269 | class ResBlock_DWGC_BE_CPF(nn.Module): 270 | """UNet block with depthwise seperable convolution and group convolution + bottleneck with expansion layer 271 | + channel shuffle and channel split 272 | """ 273 | def __init__(self,in_channels, out_channels, norm_type, groups, acti_func, acti_func_param): 274 | super(ResBlock_DWGC_BE_CPF, self).__init__() 275 | 276 | self.in_chns = in_channels 277 | self.out_chns = out_channels 278 | self.acti_func = acti_func 279 | 280 | chns_half = int(out_channels / 2) 281 | group1 = 1 if (in_channels < 8) else groups 282 | self.conv1 = ConvolutionLayer(in_channels, out_channels, 1, 283 | dim = 2, padding = 0,conv_group = group1, norm_type = norm_type, norm_group = group1, 284 | acti_func=get_acti_func(acti_func, acti_func_param)) 285 | self.conv2 = DepthSeperableConvolutionLayer(chns_half, self.out_chns, 3, 286 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 287 | acti_func=get_acti_func(acti_func, acti_func_param)) 288 | self.conv3 = DepthSeperableConvolutionLayer(self.out_chns, chns_half, 3, 289 | dim = 2, padding = 1, conv_group = groups, norm_type = norm_type, norm_group = groups, 290 | acti_func=get_acti_func(acti_func, acti_func_param)) 291 | 292 | def forward(self, x): 293 | chns_half = int(self.out_chns / 2) 294 | 295 | f1 = self.conv1(x) 296 | f1_shuffle = channel_shuffle(f1, groups = 2) 297 | f1_a = f1_shuffle[:,0:chns_half, :, :] 298 | f1_b = f1_shuffle[:,chns_half:, :, :] 299 | 300 | f2 = self.conv2(f1_b) 301 | f3 = self.conv3(f2) 302 | 303 | f3cat = torch.cat([f1_a, f3], dim = 1) 304 | out = channel_shuffle(f3cat, groups = 2) 305 | return out 306 | 307 | 308 | def get_unet_block(block_type): 309 | if(block_type == "UNetBlock"): 310 | return UNetBlock 311 | elif(block_type == "UNetBlock_DW"): 312 | return UNetBlock_DW 313 | elif(block_type == "UNetBlock_DW_CF"): 314 | return UNetBlock_DW_CF 315 | elif(block_type == "UNetBlock_DW_CF_Res"): 316 | return UNetBlock_DW_CF_Res 317 | elif(block_type == "VanillaBlock"): 318 | return VanillaBlock 319 | elif(block_type == "ResBlock"): 320 | return ResBlock 321 | elif(block_type == "ResBlock_DW"): 322 | return ResBlock_DW 323 | elif(block_type == "ResBlock_DWGC_CF"): 324 | return ResBlock_DWGC_CF 325 | elif(block_type == "ResBlock_DWGC_CF_BE"): 326 | return ResBlock_DWGC_CF_BE 327 | elif(block_type == "ResBlock_DWGC_CF_PE"): 328 | return ResBlock_DWGC_CF_PE 329 | else: 330 | raise ValueError("undefined type name {0:}".format(block_type)) 331 | 332 | def get_deconv_layer(depth_sep_deconv): 333 | if(depth_sep_deconv): 334 | return DepthSeperableDeconvolutionLayer 335 | else: 336 | return DeconvolutionLayer 337 | 338 | class UNet2DRes(nn.Module): 339 | def __init__(self, params): 340 | super(UNet2DRes, self).__init__() 341 | self.params = params 342 | self.in_chns = self.params['in_chns'] 343 | self.ft_chns = self.params['feature_chns'] 344 | self.ft_groups = self.params['feature_grps'] 345 | self.norm_type = self.params['norm_type'] 346 | self.block_type= self.params['block_type'] 347 | self.n_class = self.params['class_num'] 348 | self.acti_func = self.params['acti_func'] 349 | self.dropout = self.params['dropout'] 350 | self.depth_sep_deconv= self.params['depth_sep_deconv'] 351 | self.deep_spv = self.params['deep_supervision'] 352 | self.pe_block = self.params.get('pe_block', False) 353 | self.resolution_level = len(self.ft_chns) 354 | assert(self.resolution_level == 5 or self.resolution_level == 4) 355 | 356 | Block = get_unet_block(self.block_type) 357 | self.block1 = Block(self.in_chns, self.ft_chns[0], self.norm_type, self.ft_groups[0], 358 | self.acti_func, self.params) 359 | 360 | self.block2 = Block(self.ft_chns[0], self.ft_chns[1], self.norm_type, self.ft_groups[1], 361 | self.acti_func, self.params) 362 | 363 | self.block3 = Block(self.ft_chns[1], self.ft_chns[2], self.norm_type, self.ft_groups[2], 364 | self.acti_func, self.params) 365 | 366 | self.block4 = Block(self.ft_chns[2], self.ft_chns[3], self.norm_type, self.ft_groups[3], 367 | self.acti_func, self.params) 368 | 369 | if(self.resolution_level == 5): 370 | self.block5 = Block(self.ft_chns[3], self.ft_chns[4], self.norm_type, self.ft_groups[4], 371 | self.acti_func, self.params) 372 | 373 | self.block6 = Block(self.ft_chns[3] * 2, self.ft_chns[3], self.norm_type, self.ft_groups[3], 374 | self.acti_func, self.params) 375 | 376 | self.block7 = Block(self.ft_chns[2] * 2, self.ft_chns[2], self.norm_type, self.ft_groups[2], 377 | self.acti_func, self.params) 378 | 379 | self.block8 = Block(self.ft_chns[1] * 2, self.ft_chns[1], self.norm_type, self.ft_groups[1], 380 | self.acti_func, self.params) 381 | 382 | self.block9 = Block(self.ft_chns[0] * 2, self.ft_chns[0], self.norm_type, self.ft_groups[0], 383 | self.acti_func, self.params) 384 | 385 | if(self.pe_block): 386 | self.pe1 = PEBlock(self.ft_chns[0], self.acti_func, self.params) 387 | self.pe2 = PEBlock(self.ft_chns[1], self.acti_func, self.params) 388 | self.pe3 = PEBlock(self.ft_chns[2], self.acti_func, self.params) 389 | self.pe4 = PEBlock(self.ft_chns[3], self.acti_func, self.params) 390 | self.pe7 = PEBlock(self.ft_chns[2], self.acti_func, self.params) 391 | self.pe8 = PEBlock(self.ft_chns[1], self.acti_func, self.params) 392 | self.pe9 = PEBlock(self.ft_chns[0], self.acti_func, self.params) 393 | if(self.resolution_level == 5): 394 | self.pe5 = PEBlock(self.ft_chns[4], self.acti_func, self.params) 395 | self.pe6 = PEBlock(self.ft_chns[3], self.acti_func, self.params) 396 | 397 | self.down1 = nn.MaxPool2d(kernel_size = 2) 398 | self.down2 = nn.MaxPool2d(kernel_size = 2) 399 | self.down3 = nn.MaxPool2d(kernel_size = 2) 400 | 401 | DeconvLayer = get_deconv_layer(self.depth_sep_deconv) 402 | if(self.resolution_level == 5): 403 | self.down4 = nn.MaxPool2d(kernel_size = 2) 404 | self.up1 = DeconvLayer(self.ft_chns[4], self.ft_chns[3], kernel_size = 2, 405 | dim = 2, stride = 2, groups = 1, acti_func = get_acti_func(self.acti_func, self.params)) 406 | self.up2 = DeconvLayer(self.ft_chns[3], self.ft_chns[2], kernel_size = 2, 407 | dim = 2, stride = 2, groups = 1, acti_func = get_acti_func(self.acti_func, self.params)) 408 | self.up3 = DeconvLayer(self.ft_chns[2], self.ft_chns[1], kernel_size = 2, 409 | dim = 2, stride = 2, groups = 1, acti_func = get_acti_func(self.acti_func, self.params)) 410 | self.up4 = DeconvLayer(self.ft_chns[1], self.ft_chns[0], kernel_size = 2, 411 | dim = 2, stride = 2, groups = 1, acti_func = get_acti_func(self.acti_func, self.params)) 412 | 413 | if(self.dropout): 414 | self.drop1 = nn.Dropout(p=0.1) 415 | self.drop2 = nn.Dropout(p=0.2) 416 | self.drop3 = nn.Dropout(p=0.3) 417 | self.drop4 = nn.Dropout(p=0.4) 418 | if(self.resolution_level == 5): 419 | self.drop5 = nn.Dropout(p=0.5) 420 | 421 | if(self.deep_spv): 422 | self.conv7 = nn.Conv2d(self.ft_chns[2], self.n_class, 423 | kernel_size = 3, padding = 1) 424 | self.conv8 = nn.Conv2d(self.ft_chns[1], self.n_class, 425 | kernel_size = 3, padding = 1) 426 | 427 | self.conv9 = nn.Conv2d(self.ft_chns[0], self.n_class, 428 | kernel_size = 3, padding = 1) 429 | 430 | 431 | def forward(self, x): 432 | x_shape = list(x.shape) 433 | if(len(x_shape)==5): 434 | [N, C, D, H, W] = x_shape 435 | new_shape = [N*D, C, H, W] 436 | x = torch.transpose(x, 1, 2) 437 | x = torch.reshape(x, new_shape) 438 | f1 = self.block1(x) 439 | if(self.pe_block): 440 | f1 = self.pe1(f1) 441 | if(self.dropout): 442 | f1 = self.drop1(f1) 443 | d1 = self.down1(f1) 444 | 445 | f2 = self.block2(d1) 446 | if(self.pe_block): 447 | f2 = self.pe2(f2) 448 | if(self.dropout): 449 | f2 = self.drop2(f2) 450 | d2 = self.down2(f2) 451 | 452 | f3 = self.block3(d2) 453 | if(self.pe_block): 454 | f3 = self.pe3(f3) 455 | if(self.dropout): 456 | f3 = self.drop3(f3) 457 | d3 = self.down3(f3) 458 | 459 | f4 = self.block4(d3) 460 | if(self.pe_block): 461 | f4 = self.pe4(f4) 462 | if(self.dropout): 463 | f4 = self.drop4(f4) 464 | 465 | if(self.resolution_level == 5): 466 | d4 = self.down4(f4) 467 | f5 = self.block5(d4) 468 | if(self.pe_block): 469 | f5 = self.pe5(f5) 470 | if(self.dropout): 471 | f5 = self.drop5(f5) 472 | 473 | f5up = self.up1(f5) 474 | f4cat = torch.cat((f4, f5up), dim = 1) 475 | f6 = self.block6(f4cat) 476 | if(self.pe_block): 477 | f6 = self.pe6(f6) 478 | f6up = self.up2(f6) 479 | f3cat = torch.cat((f3, f6up), dim = 1) 480 | else: 481 | f4up = self.up2(f4) 482 | f3cat = torch.cat((f3, f4up), dim = 1) 483 | f7 = self.block7(f3cat) 484 | if(self.pe_block): 485 | f7 = self.pe7(f7) 486 | f7up = self.up3(f7) 487 | if(self.deep_spv): 488 | f7pred = self.conv7(f7) 489 | f7predup_out = nn.functional.interpolate(f7pred, 490 | size = list(x.shape)[2:], mode = 'bilinear') 491 | 492 | f2cat = torch.cat((f2, f7up), dim = 1) 493 | f8 = self.block8(f2cat) 494 | if(self.pe_block): 495 | f8 = self.pe8(f8) 496 | f8up = self.up4(f8) 497 | if(self.deep_spv): 498 | f8pred = self.conv8(f8) 499 | f8predup_out = nn.functional.interpolate(f8pred, 500 | size = list(x.shape)[2:], mode = 'bilinear') 501 | 502 | f1cat = torch.cat((f1, f8up), dim = 1) 503 | f9 = self.block9(f1cat) 504 | if(self.pe_block): 505 | f9 = self.pe9(f9) 506 | output = self.conv9(f9) 507 | 508 | if(len(x_shape)==5): 509 | new_shape = [N, D] + list(output.shape)[1:] 510 | output = torch.reshape(output, new_shape) 511 | output = torch.transpose(output, 1, 2) 512 | 513 | if(self.deep_spv): 514 | f7predup_out = torch.reshape(f7predup_out, new_shape) 515 | f7predup_out = torch.transpose(f7predup_out, 1, 2) 516 | f8predup_out = torch.reshape(f8predup_out, new_shape) 517 | f8predup_out = torch.transpose(f8predup_out, 1, 2) 518 | if(self.deep_spv): 519 | return output, f8predup_out, f7predup_out 520 | else: 521 | return output 522 | 523 | if __name__ == "__main__": 524 | methods = ["ResBlock", 525 | "ResBlock_DW", 526 | "ResBlock_DW", # GC 527 | "ResBlock_DWGC_CF", # GC 528 | "ResBlock_DWGC_CF_BE", 529 | "ResBlock_DWGC_CF_PE"] 530 | method_id = 5 531 | if(method_id > 1): 532 | feature_grps = [1, 2, 2, 4, 4] 533 | else: 534 | feature_grps = [1, 1, 1, 1, 1] 535 | 536 | params = {'in_chns':1, 537 | 'feature_chns':[32, 64, 128, 256, 512], 538 | 'feature_grps':feature_grps, 539 | 'class_num' : 2, 540 | 'block_type' : methods[method_id], 541 | 'norm_type' : 'batch_norm', 542 | 'acti_func': 'relu', 543 | 'dropout' : True, 544 | 'depth_sep_deconv' : True, 545 | 'deep_supervision': True} 546 | Net = UNet2DRes(params) 547 | Net = Net.double() 548 | device = torch.device('cuda:1') 549 | Net.to(device) 550 | x = np.random.rand(1, 1, 12, 144, 144) # N, C, H, W 551 | xt = torch.from_numpy(x) 552 | xt = torch.tensor(xt) 553 | xt = xt.to(device) 554 | t_list = [] 555 | for i in range(10): 556 | t0 = time.time() 557 | y, y1 = Net(xt) 558 | t = time.time() - t0 559 | t_list.append(t) 560 | t_array = np.asarray(t_list) 561 | print('time', t_array.mean()) 562 | print(len(y.size())) 563 | y = y.detach().cpu().numpy() 564 | print(y.shape) 565 | 566 | # device = torch.device('cpu') 567 | 568 | # param = {'acti_func':'relu'} 569 | # Net = PEBlock(12, 'relu', param) 570 | # Net = Net.double() 571 | # Net.to(device) 572 | # x = np.random.rand(1, 12, 144, 144) # N, C, H, W 573 | # xt = torch.from_numpy(x) 574 | # xt = torch.tensor(xt) 575 | # xt = xt.to(device) 576 | # y = Net(xt) 577 | # y = y.detach().numpy() 578 | # print(y.shape) --------------------------------------------------------------------------------