├── LICENSE ├── README.md ├── demo2d.py ├── demo_dataset ├── 1.png ├── 115_HC.png ├── 3.png ├── 34.png ├── 4.png ├── 7.png ├── CT-left-kidney-case5_img.png ├── Case16_slice5_points6.png ├── Case31_slice6_points6.png ├── axial.png ├── egd_vis.png ├── pancreas.png ├── placenta.png └── test5.png ├── demo_video ├── pancreas.gif └── pancreas.mp4 ├── mideepseg ├── controler.py ├── gui.py ├── iter_15000.pth ├── logo.png ├── main.py ├── network.py └── utils.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 xdluo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MIDeepSeg: Minimally Interactive Segmentation of Unseen Objects from Medical Images Using Deep Learning [[MedIA](https://www.sciencedirect.com/science/article/pii/S1361841521001481) or [Arxiv](https://arxiv.org/pdf/2104.12166.pdf)] and [[Demo]](https://www.youtube.com/watch?v=eq-tqlJnckE) 2 | This repository proivdes a 2D medical image interactive segmentation method for segmentation and annotation. 3 | ![image](https://github.com/HiLab-git/MIDeepSeg/blob/master/demo_video/pancreas.gif) 4 | 5 | * This project was originally developed for our previous work [MIDeepSeg](https://arxiv.org/pdf/2104.12166.pdf), if you find it's useful for your research, please consider to cite the followings: 6 | 7 | @article{luo2021mideepseg, 8 | title={MIDeepSeg: Minimally interactive segmentation of unseen objects from medical images using deep learning}, 9 | author={Luo, Xiangde and Wang, Guotai and Song, Tao and Zhang, Jingyang and Aertsen, Michael and Deprest, Jan and Ourselin, Sebastien and Vercauteren, Tom and Zhang, Shaoting}, 10 | journal={Medical Image Analysis}, 11 | volume={72}, 12 | pages={102102}, 13 | year={2021}, 14 | publisher={Elsevier}} 15 | ![2D example](./demo_dataset/egd_vis.png) 16 | A visualization comparison of different distance transform methods, following [GeodisTK](https://github.com/taigw/GeodisTK). 17 | ## Requirements 18 | Before you can use this package for image segmentation. You should: 19 | * PyTorch version >=1.0.1 20 | * Some common python packages such as Numpy, Pandas, SimpleITK,OpenCV, pyqt5, scipy...... 21 | * Install the [GeodisTK][geos_dis_link] for geodesic distance transformation. 22 | * Install the [SimpleCRF][simplecrf_link] for interactive refinement. 23 | ## How to use 24 | 1, compile the requirement library: 25 | 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | 2. launch the GUI 30 | ```bash 31 | cd mideepseg 32 | python main.py 33 | ``` 34 | 3. load an image for segmentation. Once the image is loaded, Firstly, give some edge points by left mouse to get an initial interactions, click the Segmentation button to obtain an initial segmentation. Then, press left mouse button to give clicks in under-segmented regions, and press right mouse button to give clicks in over-segmented region. Then click the Refinement button, and the segmentation will be updated according to the interactions. 35 | 36 | 4. Note that, the pretrained model is only trained with placenta MR-T2 data. 37 | 38 | ## Acknowledgment and Statement 39 | * We thank the authors of [Deep_Extreme_Cut][dextr_link], [DeepIGeoS][deepigeos_link] and [BIFSeg][bifseg_link] for their elegant and efficient code base ! 40 | 41 | [geos_dis_link]: https://github.com/taigw/GeodisTK 42 | [simplecrf_link]: https://github.com/HiLab-git/SimpleCRF 43 | [dextr_link]: https://openaccess.thecvf.com/content_cvpr_2018/papers/Maninis_Deep_Extreme_Cut_CVPR_2018_paper.pdf 44 | [deepigeos_link]: https://ieeexplore.ieee.org/document/8370732 45 | [bifseg_link]: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8270673 46 | * This project was designed for academic research, not for clinical or commercial use, as it's a protected patent. If you want to use it for commercial, please contact [Prof. Guotai Wang](https://scholar.google.com/citations?user=Z2sFN4EAAAAJ&hl=en). 47 | -------------------------------------------------------------------------------- /demo2d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Xiangde Luo 3 | # Date: 2 Sep., 2021 4 | # Implementation of MIDeepSeg for interactive medical image segmentation and annotation. 5 | # This file was borrowed from [GeodisTK](https://github.com/taigw/GeodisTK) 6 | # Reference: 7 | # [1] X. Luo and G. Wang et al. MIDeepSeg: Minimally interactive segmentation of unseen objects 8 | # from medical images using deep learning. Medical Image Analysis, 2021. DOI:https://doi.org/10.1016/j.media.2021.102102. 9 | # [2] Wang, Guotai, et al. "DeepIGeoS: A deep interactive geodesic framework for medical image segmentation." TPAMI, 2018. 10 | 11 | import GeodisTK 12 | import numpy as np 13 | import time 14 | from PIL import Image 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def geodesic_distance_2d(I, S, lamb, iter): 19 | ''' 20 | get 2d geodesic disntance by raser scanning. 21 | I: input image, can have multiple channels. Type should be np.float32. 22 | S: binary image where non-zero pixels are used as seeds. Type should be np.uint8. 23 | lamb: weighting betwween 0.0 and 1.0 24 | if lamb==0.0, return spatial euclidean distance without considering gradient 25 | if lamb==1.0, the distance is based on gradient only without using spatial distance 26 | iter: number of iteration for raster scanning. 27 | ''' 28 | return GeodisTK.geodesic2d_raster_scan(I, S, lamb, iter) 29 | 30 | 31 | def demo_geodesic_distance2d(img, seed_pos): 32 | I = np.asanyarray(img, np.float32) 33 | S = np.zeros((I.shape[0], I.shape[1]), np.uint8) 34 | S[seed_pos[0]][seed_pos[1]] = 1 35 | t0 = time.time() 36 | D1 = GeodisTK.geodesic2d_fast_marching(I, S) 37 | t1 = time.time() 38 | D2 = geodesic_distance_2d(I, S, 1.0, 2) 39 | dt1 = t1 - t0 40 | dt2 = time.time() - t1 41 | D3 = geodesic_distance_2d(I, S, 0.0, 2) 42 | D4 = geodesic_distance_2d(I, S, 0.5, 2) 43 | print("runtime(s) of fast marching {0:}".format(dt1)) 44 | print("runtime(s) of raster scan {0:}".format(dt2)) 45 | 46 | plt.figure(figsize=(18, 6)) 47 | plt.subplot(1, 6, 1) 48 | plt.imshow(img, "gray") 49 | plt.autoscale(False) 50 | plt.plot([seed_pos[1]], [seed_pos[0]], 'ro') 51 | plt.axis('off') 52 | plt.title('(a) input image \n with a seed point') 53 | 54 | plt.subplot(1, 6, 2) 55 | plt.imshow(D1) 56 | plt.axis('off') 57 | plt.title('(b) Geodesic distance \n based on fast marching') 58 | 59 | plt.subplot(1, 6, 3) 60 | plt.imshow(D2) 61 | plt.axis('off') 62 | plt.title('(c) Geodesic distance \n based on ranster scan') 63 | 64 | plt.subplot(1, 6, 4) 65 | plt.imshow(D3) 66 | plt.axis('off') 67 | plt.title('(d) Euclidean distance') 68 | 69 | plt.subplot(1, 6, 5) 70 | plt.imshow(D4) 71 | plt.axis('off') 72 | plt.title('(e) Mexture of Geodesic \n and Euclidean distance') 73 | 74 | plt.subplot(1, 6, 6) 75 | plt.imshow(np.exp(-D1)) 76 | plt.axis('off') 77 | plt.title('(f) Exponential Geodesic distance') 78 | plt.savefig("demo_dataset/egd_vis.png", 79 | bbox_inches='tight', dpi=500, pad_inches=0.0) 80 | plt.show() 81 | 82 | 83 | def demo_geodesic_distance2d_gray_scale_image(): 84 | img = Image.open('demo_dataset/pancreas.png').convert('L') 85 | img = np.array(img)[100:400, 100:400] 86 | img = (img - img.mean()) / img.std() 87 | seed_position = [121, 182] 88 | demo_geodesic_distance2d(img, seed_position) 89 | 90 | 91 | if __name__ == '__main__': 92 | demo_geodesic_distance2d_gray_scale_image() 93 | -------------------------------------------------------------------------------- /demo_dataset/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/1.png -------------------------------------------------------------------------------- /demo_dataset/115_HC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/115_HC.png -------------------------------------------------------------------------------- /demo_dataset/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/3.png -------------------------------------------------------------------------------- /demo_dataset/34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/34.png -------------------------------------------------------------------------------- /demo_dataset/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/4.png -------------------------------------------------------------------------------- /demo_dataset/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/7.png -------------------------------------------------------------------------------- /demo_dataset/CT-left-kidney-case5_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/CT-left-kidney-case5_img.png -------------------------------------------------------------------------------- /demo_dataset/Case16_slice5_points6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/Case16_slice5_points6.png -------------------------------------------------------------------------------- /demo_dataset/Case31_slice6_points6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/Case31_slice6_points6.png -------------------------------------------------------------------------------- /demo_dataset/axial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/axial.png -------------------------------------------------------------------------------- /demo_dataset/egd_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/egd_vis.png -------------------------------------------------------------------------------- /demo_dataset/pancreas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/pancreas.png -------------------------------------------------------------------------------- /demo_dataset/placenta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/placenta.png -------------------------------------------------------------------------------- /demo_dataset/test5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_dataset/test5.png -------------------------------------------------------------------------------- /demo_video/pancreas.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_video/pancreas.gif -------------------------------------------------------------------------------- /demo_video/pancreas.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/demo_video/pancreas.mp4 -------------------------------------------------------------------------------- /mideepseg/controler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Xiangde Luo 3 | # Date: 2 Sep., 2021 4 | # Implementation of MIDeepSeg for interactive medical image segmentation and annotation. 5 | # Reference: 6 | # X. Luo and G. Wang et al. MIDeepSeg: Minimally interactive segmentation of unseen objects 7 | # from medical images using deep learning. Medical Image Analysis, 2021. DOI:https://doi.org/10.1016/j.media.2021.102102. 8 | 9 | import os 10 | from collections import OrderedDict 11 | from os.path import join as opj 12 | 13 | import cv2 14 | import matplotlib.pyplot as plt 15 | import maxflow 16 | import numpy as np 17 | import torch 18 | from PIL import Image 19 | from scipy import ndimage 20 | from scipy.ndimage import zoom 21 | from skimage import color, measure 22 | from network import UNet 23 | from utils import (add_countor, add_overlay, cropped_image, extends_points, 24 | extreme_points, get_bbox, get_largest_two_component, 25 | get_start_end_points, interaction_euclidean_distance, 26 | interaction_gaussian_distance, 27 | interaction_geodesic_distance, 28 | interaction_refined_geodesic_distance, 29 | itensity_normalization, itensity_normalize_one_volume, 30 | itensity_standardization, softmax, softmax_seg, zoom_image) 31 | 32 | rootPATH = os.path.abspath(".") 33 | 34 | 35 | class Controler(object): 36 | seeds = 0 37 | extreme_points = 5 38 | foreground = 2 39 | background = 3 40 | imageName = "../mideepseg/logo.png" 41 | model_path = "../mideepseg/iter_15000.pth" 42 | 43 | def __init__(self): 44 | self.img = None 45 | self.step = 0 46 | self.image = None 47 | self.mask = None 48 | self.overlay = None 49 | self.seed_overlay = None 50 | self.segment_overlay = None 51 | self.extreme_point_seed = [] 52 | self.background_seeds = [] 53 | self.foreground_seeds = [] 54 | self.current_overlay = self.seeds 55 | self.load_image(self.imageName) 56 | 57 | self.initial_seg = None 58 | self.initial_extreme_seed = None 59 | 60 | def initial_param(self): 61 | self.step = 0 62 | self.img = None 63 | self.image = None 64 | self.mask = None 65 | self.overlay = None 66 | self.seed_overlay = None 67 | self.segment_overlay = None 68 | self.extreme_point_seed = [] 69 | self.background_seeds = [] 70 | self.foreground_seeds = [] 71 | self.current_overlay = self.seeds 72 | self.initial_seg = None 73 | self.initial_extreme_seed = None 74 | 75 | def load_image(self, filename): 76 | self.filename = filename 77 | self.initial_param() 78 | self.init_image = cv2.imread(filename) 79 | self.image = cv2.imread(filename) 80 | self.img = np.array(Image.open(filename).convert('L')) 81 | self.images = cv2.imread(filename) 82 | self.seed_overlay = np.zeros_like(self.image) 83 | self.segment_overlay = np.zeros_like(self.image) 84 | self.mask = None 85 | self.refined_clicks = 0 86 | self.refined_iterations = 0 87 | 88 | def add_seed(self, x, y, type): 89 | if self.image is None: 90 | print('Please load an image before adding seeds.') 91 | if type == self.background: 92 | if not self.background_seeds.__contains__((x, y)): 93 | self.background_seeds.append((x, y)) 94 | cv2.rectangle(self.seed_overlay, (x - 1, y - 1), 95 | (x + 1, y + 1), (255, 0, 255), 2) 96 | elif type == self.foreground: 97 | if not self.foreground_seeds.__contains__((x, y)): 98 | if self.step == 0: 99 | self.extreme_point_seed.append((x, y)) 100 | cv2.rectangle(self.seed_overlay, (x - 1, y - 1), 101 | (x + 1, y + 1), (255, 255, 0), 2) 102 | if self.step == 1: 103 | self.foreground_seeds.append((x, y)) 104 | cv2.rectangle(self.seed_overlay, (x - 1, y - 1), 105 | (x + 1, y + 1), (0, 0, 255), 2) 106 | if len(self.extreme_point_seed) == 1: 107 | import time 108 | self.stage1_begin = time.time() 109 | if len(self.background_seeds) > 0 or len(self.foreground_seeds) > 0: 110 | self.refined_clicks += 1 111 | 112 | if self.refined_clicks == 1: 113 | import time 114 | self.stage2_begin = time.time() 115 | if self.refined_clicks == 0: 116 | import time 117 | self.stage2_begin = None 118 | 119 | def clear_seeds(self): 120 | self.step = 0 121 | self.background_seeds = [] 122 | self.foreground_seeds = [] 123 | self.extreme_point_seed = [] 124 | self.background_superseeds = [] 125 | self.foreground_superseeds = [] 126 | self.seed_overlay = np.zeros_like(self.seed_overlay) 127 | self.image = self.init_image 128 | 129 | def get_image_with_overlay(self, overlayNumber): 130 | return cv2.addWeighted(self.image, 0.9, self.seed_overlay, 0.7, 0.7) 131 | 132 | def segment_show(self): 133 | pass 134 | 135 | def save_image(self, filename): 136 | if self.mask is None: 137 | print('Please segment the image before saving.') 138 | return 139 | self.mask = self.mask * 255 140 | cv2.imwrite(str(filename), self.mask.astype(int)) 141 | 142 | def extreme_segmentation(self): 143 | if self.step == 0: 144 | seed = np.zeros_like(self.img) 145 | for i in self.extreme_point_seed: 146 | seed[i[1], i[0]] = 1 147 | if seed.sum() == 0: 148 | print('Please provide initial seeds for segmentation.') 149 | return 150 | seed = extends_points(seed) 151 | self.initial_extreme_seed = seed 152 | bbox = get_start_end_points(seed) 153 | cropped_img = cropped_image(self.img, bbox) 154 | x, y = cropped_img.shape 155 | normal_img = itensity_normalization(cropped_img) 156 | 157 | cropped_seed = cropped_image(seed, bbox) 158 | cropped_geos = interaction_geodesic_distance( 159 | normal_img, cropped_seed) 160 | # cropped_geos = itensity_normalization(cropped_geos) 161 | 162 | zoomed_img = zoom_image(normal_img) 163 | zoomed_geos = zoom_image(cropped_geos) 164 | 165 | inputs = np.asarray([[zoomed_img, zoomed_geos]]) 166 | if torch.cuda.is_available(): 167 | inputs = torch.from_numpy(inputs).float().cuda() 168 | else: 169 | inputs = torch.from_numpy(inputs).float().cpu() 170 | net = self.initial_model() 171 | net.eval() 172 | output = net(inputs) 173 | output = torch.softmax(output, dim=1) 174 | output = output.squeeze(0) 175 | predict = output.cpu().detach().numpy() 176 | fg_prob = predict[1] 177 | bg_prob = predict[0] 178 | 179 | crf_param = (5.0, 0.1) 180 | Prob = np.asarray([bg_prob, fg_prob]) 181 | Prob = np.transpose(Prob, [1, 2, 0]) 182 | fix_predict = maxflow.maxflow2d(zoomed_img.astype( 183 | np.float32), Prob, crf_param) 184 | 185 | fixed_predict = zoom(fix_predict, (x/96, y/96), output=None, 186 | order=0, mode='constant', cval=0.0, prefilter=True) 187 | # fixed_predict = zoom(fg_prob, (x/96, y/96), output=None, 188 | # order=0, mode='constant', cval=0.0, prefilter=True) 189 | 190 | pred = np.zeros_like(self.img, dtype=np.float) 191 | 192 | pred[bbox[0]:bbox[2], bbox[1]:bbox[3]] = fixed_predict 193 | self.initial_seg = pred 194 | 195 | pred[pred >= 0.5] = 1 196 | pred[pred < 0.5] = 0 197 | 198 | strt = ndimage.generate_binary_structure(2, 1) 199 | seg = np.asarray( 200 | ndimage.morphology.binary_opening(pred, strt), np.uint8) 201 | seg = np.asarray( 202 | ndimage.morphology.binary_closing(pred, strt), np.uint8) 203 | seg = self.largestConnectComponent(seg) 204 | seg = ndimage.binary_fill_holes(seg) 205 | 206 | seg = np.clip(seg, 0, 255) 207 | seg = np.array(seg, np.uint8) 208 | 209 | contours, hierarchy = cv2.findContours( 210 | seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 211 | if len(contours) != 0: 212 | image_data = cv2.drawContours( 213 | self.image, contours, -1, (0, 255, 0), 2) 214 | 215 | self.image = image_data 216 | self.mask = seg 217 | self.step = 1 218 | 219 | def largestConnectComponent(self, img): 220 | binaryimg = img 221 | 222 | label_image, num = measure.label( 223 | binaryimg, background=0, return_num=True) 224 | areas = [r.area for r in measure.regionprops(label_image)] 225 | areas.sort() 226 | if len(areas) > 1: 227 | for region in measure.regionprops(label_image): 228 | if region.area < areas[-1]: 229 | for coordinates in region.coords: 230 | label_image[coordinates[0], coordinates[1]] = 0 231 | label_image = label_image.astype(np.int8) 232 | label_image[np.where(label_image > 0)] = 1 233 | return label_image 234 | 235 | def initial_model(self): 236 | model = UNet(2, 2, 16) 237 | if torch.cuda.is_available(): 238 | model = model.cuda() 239 | else: 240 | model = model.cpu() 241 | model.load_state_dict(torch.load(self.model_path)) 242 | return model 243 | 244 | def refined_seg(self): 245 | fore_seeds = np.zeros_like(self.img) 246 | for i in self.foreground_seeds: 247 | fore_seeds[i[1], i[0]] = 1 248 | back_seeds = np.zeros_like(self.img) 249 | for i1 in self.background_seeds: 250 | back_seeds[i1[1], i1[0]] = 1 251 | 252 | fore_seeds = extends_points(fore_seeds) 253 | back_seeds = extends_points(back_seeds) 254 | 255 | all_refined_seeds = np.maximum(fore_seeds, back_seeds) 256 | all_seeds = np.maximum(all_refined_seeds, self.initial_extreme_seed) 257 | 258 | bbox = get_start_end_points(all_seeds) 259 | cropped_img = cropped_image(self.img, bbox) 260 | 261 | normal_img = itensity_standardization(cropped_img) 262 | init_seg = [self.initial_seg, 1.0-self.initial_seg] 263 | fg_prob = init_seg[0] 264 | bg_prob = init_seg[1] 265 | 266 | cropped_initial_seg = cropped_image(fg_prob, bbox) 267 | cropped_fore_seeds = cropped_image(fore_seeds, bbox) 268 | 269 | cropped_fore_geos = interaction_refined_geodesic_distance( 270 | normal_img, cropped_fore_seeds) 271 | cropped_back_seeds = cropped_image(back_seeds, bbox) 272 | cropped_back_geos = interaction_refined_geodesic_distance( 273 | normal_img, cropped_back_seeds) 274 | 275 | fore_prob = np.maximum(cropped_fore_geos, cropped_initial_seg) 276 | 277 | cropped_back_seg = cropped_image(bg_prob, bbox) 278 | back_prob = np.maximum(cropped_back_geos, cropped_back_seg) 279 | 280 | crf_seeds = np.zeros_like(cropped_fore_seeds, np.uint8) 281 | crf_seeds[cropped_fore_seeds > 0] = 170 282 | crf_seeds[cropped_back_seeds > 0] = 255 283 | crf_param = (5.0, 0.1) 284 | 285 | crf_seeds = np.asarray([crf_seeds == 255, crf_seeds == 170], np.uint8) 286 | crf_seeds = np.transpose(crf_seeds, [1, 2, 0]) 287 | 288 | x, y = fore_prob.shape 289 | prob_feature = np.zeros((2, x, y), dtype=np.float32) 290 | prob_feature[0] = fore_prob 291 | prob_feature[1] = back_prob 292 | softmax_feture = np.exp(prob_feature) / \ 293 | np.sum(np.exp(prob_feature), axis=0) 294 | softmax_feture = np.exp(softmax_feture) / \ 295 | np.sum(np.exp(softmax_feture), axis=0) 296 | fg_prob = softmax_feture[0].astype(np.float32) 297 | bg_prob = softmax_feture[1].astype(np.float32) 298 | 299 | Prob = np.asarray([bg_prob, fg_prob]) 300 | Prob = np.transpose(Prob, [1, 2, 0]) 301 | 302 | refined_pred = maxflow.interactive_maxflow2d( 303 | normal_img, Prob, crf_seeds, crf_param) 304 | 305 | pred = np.zeros_like(self.img, dtype=np.float) 306 | pred[bbox[0]:bbox[2], bbox[1]:bbox[3]] = refined_pred 307 | 308 | pred = self.largestConnectComponent(pred) 309 | strt = ndimage.generate_binary_structure(2, 1) 310 | seg = np.asarray( 311 | ndimage.morphology.binary_opening(pred, strt), np.uint8) 312 | seg = np.asarray( 313 | ndimage.morphology.binary_closing(pred, strt), np.uint8) 314 | contours, hierarchy = cv2.findContours( 315 | seg, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 316 | img = self.images.copy() 317 | image_data = cv2.drawContours( 318 | self.images, contours, -1, (0, 255, 0), 2) 319 | self.images = img 320 | 321 | self.image = image_data 322 | self.mask = seg 323 | -------------------------------------------------------------------------------- /mideepseg/gui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Xiangde Luo 3 | # Date: 2 Sep., 2021 4 | # Implementation of MIDeepSeg for interactive medical image segmentation and annotation. 5 | # Reference: 6 | # X. Luo and G. Wang et al. MIDeepSeg: Minimally interactive segmentation of unseen objects 7 | # from medical images using deep learning. Medical Image Analysis, 2021. DOI:https://doi.org/10.1016/j.media.2021.102102. 8 | 9 | import random 10 | import sys 11 | import time 12 | 13 | import cv2 14 | import numpy as np 15 | from PyQt5.QtCore import * 16 | from PyQt5.QtGui import * 17 | from PyQt5.QtWidgets import * 18 | 19 | from controler import Controler 20 | 21 | 22 | class MIDeepSeg(QWidget): 23 | 24 | def __init__(self): 25 | super().__init__() 26 | 27 | self.graph_maker = Controler() 28 | self.seed_type = 1 # annotation type 29 | self.all_datasets = [] 30 | 31 | self.initUI() 32 | 33 | def initUI(self): 34 | self.a = QApplication(sys.argv) 35 | 36 | self.window = QMainWindow() 37 | # Setup file menu 38 | self.window.setWindowTitle('MIDeepSeg') 39 | mainMenu = self.window.menuBar() 40 | fileMenu = mainMenu.addMenu('&File') 41 | 42 | openButton = QAction(QIcon('exit24.png'), 'Open Image', self.window) 43 | openButton.setShortcut('Ctrl+O') 44 | openButton.setStatusTip('Open a file for segmenting.') 45 | openButton.triggered.connect(self.on_open) 46 | fileMenu.addAction(openButton) 47 | 48 | saveButton = QAction(QIcon('exit24.png'), 'Save Image', self.window) 49 | saveButton.setShortcut('Ctrl+S') 50 | saveButton.setStatusTip('Save file to disk.') 51 | saveButton.triggered.connect(self.on_save) 52 | fileMenu.addAction(saveButton) 53 | 54 | closeButton = QAction(QIcon('exit24.png'), 'Exit', self.window) 55 | closeButton.setShortcut('Ctrl+Q') 56 | closeButton.setStatusTip('Exit application') 57 | closeButton.triggered.connect(self.on_close) 58 | fileMenu.addAction(closeButton) 59 | 60 | mainWidget = QWidget() 61 | 62 | annotationButton = QPushButton("Load Image") 63 | annotationButton.setStyleSheet("background-color:white") 64 | annotationButton.clicked.connect(self.on_open) 65 | 66 | segmentButton = QPushButton("Segment") 67 | segmentButton.setStyleSheet("background-color:white") 68 | segmentButton.clicked.connect(self.on_segment) 69 | 70 | refinementButton = QPushButton("Refinement") 71 | refinementButton.setStyleSheet("background-color:white") 72 | refinementButton.clicked.connect(self.on_refinement) 73 | 74 | CleanButton = QPushButton("Clear all seeds") 75 | CleanButton.setStyleSheet("background-color:white") 76 | CleanButton.clicked.connect(self.on_clean) 77 | 78 | NextButton = QPushButton("Save segmentation") 79 | NextButton.setStyleSheet("background-color:white") 80 | NextButton.clicked.connect(self.on_save) 81 | 82 | StateLine = QLabel() 83 | StateLine.setText("Clicks as user input.") 84 | palette = QPalette() 85 | palette.setColor(StateLine.foregroundRole(), Qt.blue) 86 | StateLine.setPalette(palette) 87 | 88 | MethodLine = QLabel() 89 | MethodLine.setText("Segmentation.") 90 | mpalette = QPalette() 91 | mpalette.setColor(MethodLine.foregroundRole(), Qt.blue) 92 | MethodLine.setPalette(mpalette) 93 | 94 | SaveLine = QLabel() 95 | SaveLine.setText("Clean or Save.") 96 | spalette = QPalette() 97 | spalette.setColor(SaveLine.foregroundRole(), Qt.blue) 98 | SaveLine.setPalette(spalette) 99 | 100 | hbox = QVBoxLayout() 101 | hbox.addWidget(StateLine) 102 | hbox.addWidget(annotationButton) 103 | hbox.addWidget(MethodLine) 104 | hbox.addWidget(segmentButton) 105 | hbox.addWidget(refinementButton) 106 | hbox.addWidget(SaveLine) 107 | hbox.addWidget(CleanButton) 108 | hbox.addWidget(NextButton) 109 | hbox.addStretch() 110 | 111 | tipsFont = StateLine.font() 112 | tipsFont.setPointSize(10) 113 | StateLine.setFixedHeight(30) 114 | StateLine.setWordWrap(True) 115 | StateLine.setFont(tipsFont) 116 | MethodLine.setFixedHeight(30) 117 | MethodLine.setWordWrap(True) 118 | MethodLine.setFont(tipsFont) 119 | SaveLine.setFixedHeight(30) 120 | SaveLine.setWordWrap(True) 121 | SaveLine.setFont(tipsFont) 122 | 123 | self.seedLabel = QLabel() 124 | self.seedLabel.setPixmap(QPixmap.fromImage( 125 | self.get_qimage(self.graph_maker.get_image_with_overlay(self.graph_maker.seeds)))) 126 | self.seedLabel.mousePressEvent = self.mouse_down 127 | self.seedLabel.mouseMoveEvent = self.mouse_drag 128 | 129 | imagebox = QHBoxLayout() 130 | imagebox.addWidget(self.seedLabel) 131 | 132 | vbox = QHBoxLayout() 133 | 134 | vbox.addLayout(imagebox) 135 | vbox.addLayout(hbox) 136 | 137 | mainWidget.setLayout(vbox) 138 | 139 | self.window.setCentralWidget(mainWidget) 140 | self.window.show() 141 | 142 | @staticmethod 143 | def get_qimage(cvimage): 144 | height, width, bytes_per_pix = cvimage.shape 145 | bytes_per_line = width * bytes_per_pix 146 | cv2.cvtColor(cvimage, cv2.COLOR_BGR2RGB, cvimage) 147 | return QImage(cvimage.data, width, height, bytes_per_line, QImage.Format_RGB888) 148 | 149 | def mouse_down(self, event): 150 | if event.button() == Qt.LeftButton: 151 | self.seed_type = 2 152 | elif event.button() == Qt.RightButton: 153 | self.seed_type = 3 154 | self.graph_maker.add_seed(event.x(), event.y(), self.seed_type) 155 | self.seedLabel.setPixmap(QPixmap.fromImage( 156 | self.get_qimage(self.graph_maker.get_image_with_overlay(self.graph_maker.seeds)))) 157 | 158 | def mouse_drag(self, event): 159 | self.graph_maker.add_seed(event.x(), event.y(), self.seed_type) 160 | self.seedLabel.setPixmap(QPixmap.fromImage( 161 | self.get_qimage(self.graph_maker.get_image_with_overlay(self.graph_maker.seeds)))) 162 | 163 | @pyqtSlot() 164 | def on_open(self): 165 | f = QFileDialog.getOpenFileName() 166 | if f[0] is not None and f[0] != "": 167 | f = f[0] 168 | self.graph_maker.load_image(str(f)) 169 | self.seedLabel.setPixmap(QPixmap.fromImage( 170 | self.get_qimage(self.graph_maker.get_image_with_overlay(self.graph_maker.seeds)))) 171 | else: 172 | pass 173 | 174 | @pyqtSlot() 175 | def on_save(self): 176 | f = QFileDialog.getSaveFileName() 177 | print('Saving') 178 | if f is not None and f != "": 179 | f = f[0] 180 | self.graph_maker.save_image(f) 181 | else: 182 | pass 183 | 184 | @pyqtSlot() 185 | def on_close(self): 186 | print('Closing') 187 | self.window.close() 188 | 189 | @pyqtSlot() 190 | def on_segment(self): 191 | self.graph_maker.extreme_segmentation() 192 | self.seedLabel.setPixmap(QPixmap.fromImage( 193 | self.get_qimage(self.graph_maker.get_image_with_overlay(self.graph_maker.extreme_segmentation)))) 194 | 195 | @pyqtSlot() 196 | def on_clean(self): 197 | self.graph_maker.clear_seeds() 198 | self.seedLabel.setPixmap(QPixmap.fromImage( 199 | self.get_qimage(self.graph_maker.get_image_with_overlay(self.graph_maker.clear_seeds)))) 200 | 201 | @pyqtSlot() 202 | def on_refinement(self): 203 | self.graph_maker.refined_seg() 204 | self.seedLabel.setPixmap(QPixmap.fromImage( 205 | self.get_qimage(self.graph_maker.get_image_with_overlay(self.graph_maker.refined_seg)))) 206 | -------------------------------------------------------------------------------- /mideepseg/iter_15000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/mideepseg/iter_15000.pth -------------------------------------------------------------------------------- /mideepseg/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/MIDeepSeg/a6eb18d45b9e7e5ed5a8805306c917c22aa25908/mideepseg/logo.png -------------------------------------------------------------------------------- /mideepseg/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Xiangde Luo 3 | # Date: 2 Sep., 2021 4 | # Implementation of MIDeepSeg for interactive medical image segmentation and annotation. 5 | # Reference: 6 | # X. Luo and G. Wang et al. MIDeepSeg: Minimally interactive segmentation of unseen objects 7 | # from medical images using deep learning. Medical Image Analysis, 2021. DOI:https://doi.org/10.1016/j.media.2021.102102. 8 | 9 | import os 10 | import sys 11 | from PyQt5.QtGui import * 12 | from PyQt5.QtWidgets import * 13 | from PyQt5.QtCore import * 14 | from gui import MIDeepSeg 15 | 16 | 17 | if __name__ == '__main__': 18 | app = QApplication(sys.argv) 19 | ex = MIDeepSeg() 20 | sys.exit(app.exec_()) 21 | -------------------------------------------------------------------------------- /mideepseg/network.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Xiangde Luo 3 | # Date: 2 Sep., 2021 4 | # Implementation of MIDeepSeg for interactive medical image segmentation and annotation. 5 | # Reference: 6 | # X. Luo and G. Wang et al. MIDeepSeg: Minimally interactive segmentation of unseen objects 7 | # from medical images using deep learning. Medical Image Analysis, 2021. DOI:https://doi.org/10.1016/j.media.2021.102102. 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | def conv_block(in_dim, out_dim, act_fn): 14 | model = nn.Sequential( 15 | nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1), 16 | nn.InstanceNorm2d(out_dim), 17 | act_fn, 18 | ) 19 | return model 20 | 21 | 22 | def up_conv(in_dim, out_dim, act_fn): 23 | model = nn.Sequential( 24 | nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, 25 | stride=2, padding=1, output_padding=1), 26 | nn.InstanceNorm2d(out_dim), 27 | act_fn, 28 | ) 29 | return model 30 | 31 | 32 | def double_conv_block(in_dim, out_dim, act_fn): 33 | model = nn.Sequential( 34 | conv_block(in_dim, out_dim, act_fn), 35 | conv_block(out_dim, out_dim, act_fn), 36 | ) 37 | return model 38 | 39 | 40 | def out_block(in_dim, out_dim): 41 | model = nn.Sequential( 42 | nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1, padding=0) 43 | ) 44 | return model 45 | 46 | 47 | class UNet(nn.Module): 48 | 49 | def __init__(self, in_dim, out_dim, num_filter): 50 | super(UNet, self).__init__() 51 | self.in_dim = in_dim 52 | self.out_dim = out_dim 53 | self.num_filter = num_filter 54 | act_fn = nn.ReLU(inplace=True) 55 | 56 | self.down_1 = double_conv_block(self.in_dim, self.num_filter, act_fn) 57 | self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 58 | self.down_2 = double_conv_block( 59 | self.num_filter, self.num_filter * 2, act_fn) 60 | self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 61 | self.down_3 = double_conv_block( 62 | self.num_filter * 2, self.num_filter * 4, act_fn) 63 | self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 64 | 65 | self.down_4 = double_conv_block( 66 | self.num_filter * 4, self.num_filter * 8, act_fn) 67 | self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 68 | 69 | self.bridge = double_conv_block( 70 | self.num_filter * 8, self.num_filter * 16, act_fn) 71 | 72 | self.trans_1 = up_conv(self.num_filter * 16, 73 | self.num_filter * 16, act_fn) 74 | self.up_1 = double_conv_block( 75 | self.num_filter * 24, self.num_filter * 8, act_fn) 76 | 77 | self.trans_2 = up_conv(self.num_filter * 8, 78 | self.num_filter * 8, act_fn) 79 | self.up_2 = double_conv_block( 80 | self.num_filter * 12, self.num_filter * 4, act_fn) 81 | 82 | self.trans_3 = up_conv(self.num_filter * 4, 83 | self.num_filter * 4, act_fn) 84 | self.up_3 = double_conv_block( 85 | self.num_filter * 6, self.num_filter*2, act_fn) 86 | 87 | self.trans_4 = up_conv(self.num_filter * 2, 88 | self.num_filter * 2, act_fn) 89 | self.up_4 = double_conv_block( 90 | self.num_filter * 3, self.num_filter, act_fn) 91 | 92 | self.out = out_block(self.num_filter, out_dim) 93 | 94 | def forward(self, x): 95 | down_1 = self.down_1(x) 96 | pool_1 = self.pool_1(down_1) 97 | down_2 = self.down_2(pool_1) 98 | pool_2 = self.pool_2(down_2) 99 | down_3 = self.down_3(pool_2) 100 | pool_3 = self.pool_3(down_3) 101 | down_4 = self.down_4(pool_3) 102 | pool_4 = self.pool_4(down_4) 103 | 104 | bridge = self.bridge(pool_4) 105 | 106 | trans_1 = self.trans_1(bridge) 107 | concat_1 = torch.cat([trans_1, down_4], dim=1) 108 | up_1 = self.up_1(concat_1) 109 | trans_2 = self.trans_2(up_1) 110 | concat_2 = torch.cat([trans_2, down_3], dim=1) 111 | up_2 = self.up_2(concat_2) 112 | trans_3 = self.trans_3(up_2) 113 | concat_3 = torch.cat([trans_3, down_2], dim=1) 114 | up_3 = self.up_3(concat_3) 115 | trans_4 = self.trans_4(up_3) 116 | concat_4 = torch.cat([trans_4, down_1], dim=1) 117 | up_4 = self.up_4(concat_4) 118 | out = self.out(up_4) 119 | 120 | return out 121 | -------------------------------------------------------------------------------- /mideepseg/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Xiangde Luo 3 | # Date: 2 Sep., 2021 4 | # Implementation of MIDeepSeg for interactive medical image segmentation and annotation. 5 | # Reference: 6 | # X. Luo and G. Wang et al. MIDeepSeg: Minimally interactive segmentation of unseen objects 7 | # from medical images using deep learning. Medical Image Analysis, 2021. DOI:https://doi.org/10.1016/j.media.2021.102102. 8 | 9 | import os 10 | import random 11 | from os.path import join as opj 12 | 13 | import GeodisTK 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import scipy 17 | import torchvision.transforms as ts 18 | import torchvision.transforms.functional as TF 19 | from PIL import Image 20 | from scipy import ndimage 21 | from scipy.ndimage import zoom 22 | from skimage import color, measure 23 | import cv2 24 | 25 | 26 | def itensity_normalize_one_volume(volume): 27 | """ 28 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 29 | inputs: 30 | volume: the input nd volume 31 | outputs: 32 | out: the normalized nd volume 33 | """ 34 | 35 | volume = (volume - volume.min()) / (volume.max() - volume.min()) 36 | pixels = volume[volume > 0] 37 | mean = pixels.mean() 38 | std = pixels.std() 39 | out = (volume - mean)/std 40 | out_random = np.random.normal(0, 1, size=volume.shape) 41 | out[volume == 0] = out_random[volume == 0] 42 | out = out.astype(np.float32) 43 | return out 44 | 45 | 46 | def extreme_points(mask, pert=0): 47 | def find_point(id_x, id_y, ids): 48 | sel_id = ids[0][random.randint(0, len(ids[0]) - 1)] 49 | return [id_x[sel_id], id_y[sel_id]] 50 | 51 | # List of coordinates of the mask 52 | inds_y, inds_x = np.where(mask > 0.5) 53 | 54 | # Find extreme points 55 | return np.array([find_point(inds_x, inds_y, np.where(inds_x <= np.min(inds_x)+pert)), # left 56 | find_point(inds_x, inds_y, np.where( 57 | inds_x >= np.max(inds_x)-pert)), # right 58 | find_point(inds_x, inds_y, np.where( 59 | inds_y <= np.min(inds_y)+pert)), # top 60 | find_point(inds_x, inds_y, np.where( 61 | inds_y >= np.max(inds_y)-pert)) # bottom 62 | ]) 63 | 64 | 65 | def get_bbox(mask, points=None, pad=0, zero_pad=False): 66 | if points is not None: 67 | inds = np.flip(points.transpose(), axis=0) 68 | else: 69 | inds = np.where(mask > 0) 70 | 71 | if inds[0].shape[0] == 0: 72 | return None 73 | 74 | if zero_pad: 75 | x_min_bound = -np.inf 76 | y_min_bound = -np.inf 77 | x_max_bound = np.inf 78 | y_max_bound = np.inf 79 | else: 80 | x_min_bound = 0 81 | y_min_bound = 0 82 | x_max_bound = mask.shape[1] - 1 83 | y_max_bound = mask.shape[0] - 1 84 | 85 | x_min = max(inds[1].min() - pad, x_min_bound) 86 | y_min = max(inds[0].min() - pad, y_min_bound) 87 | x_max = min(inds[1].max() + pad, x_max_bound) 88 | y_max = min(inds[0].max() + pad, y_max_bound) 89 | 90 | return x_min, y_min, x_max, y_max 91 | 92 | 93 | def cropped_image(image, bbox, pixel=0): 94 | random_bbox = [bbox[0] - pixel, bbox[1] - 95 | pixel, bbox[2] + pixel, bbox[3] + pixel] 96 | cropped = image[random_bbox[0]:random_bbox[2], 97 | random_bbox[1]:random_bbox[3]] 98 | return cropped 99 | 100 | 101 | def zoom_image(data): 102 | """ 103 | reshape image to 64*64 pixels 104 | """ 105 | x, y = data.shape 106 | zoomed_image = zoom(data, (96 / x, 96 / y)) 107 | # zoomed_image = zoom(data, (128 / x, 128 / y)) 108 | return zoomed_image 109 | 110 | 111 | def extends_points(seed): 112 | if(seed.sum() > 0): 113 | points = ndimage.distance_transform_edt(seed == 0) 114 | points[points > 2] = 0 115 | points[points > 0] = 1 116 | else: 117 | points = seed 118 | return points.astype(np.uint8) 119 | 120 | 121 | def gaussian_kernel(d, bias=0, sigma=10): 122 | """ 123 | this a gaussian kernel 124 | input: 125 | d: distance between each extreme point to every point in volume 126 | bias: 127 | sigma: is full-width-half-maximum, which can be thought of as an effective radius. 128 | """ 129 | gaus_dis = (1 / (sigma * np.sqrt(2 * np.pi))) * \ 130 | np.exp(- ((d - bias)**2 / (2 * sigma**2))) 131 | return gaus_dis 132 | 133 | 134 | def interaction_euclidean_distance(img, seed): 135 | if seed.sum() > 0: 136 | euc_dis = ndimage.distance_transform_edt(seed == 0) 137 | else: 138 | euc_dis = np.ones_like(seed, dtype=np.float32) 139 | euc_dis = cstm_normalize(euc_dis) 140 | return euc_dis 141 | 142 | 143 | def interaction_gaussian_distance(img, seed, sigma=10, bias=0): 144 | if seed.sum() > 0: 145 | euc_dis = ndimage.distance_transform_edt(seed == 0) 146 | gaus_dis = gaussian_kernel(euc_dis, bias, sigma) 147 | else: 148 | gaus_dis = np.zeros_like(seed, dtype=np.float32) 149 | gaus_dis = cstm_normalize(gaus_dis) 150 | return gaus_dis 151 | 152 | 153 | def interaction_geodesic_distance(img, seed, threshold=0): 154 | if seed.sum() > 0: 155 | # I = itensity_normalize_one_volume(img) 156 | I = np.asanyarray(img, np.float32) 157 | S = seed 158 | geo_dis = GeodisTK.geodesic2d_fast_marching(I, S) 159 | # geo_dis = GeodisTK.geodesic2d_raster_scan(I, S, 1.0, 2.0) 160 | if threshold > 0: 161 | geo_dis[geo_dis > threshold] = threshold 162 | geo_dis = geo_dis / threshold 163 | else: 164 | geo_dis = np.exp(-geo_dis) 165 | else: 166 | geo_dis = np.zeros_like(img, dtype=np.float32) 167 | return cstm_normalize(geo_dis) 168 | 169 | 170 | def interaction_refined_geodesic_distance(img, seed, threshold=0): 171 | if seed.sum() > 0: 172 | # I = itensity_normalize_one_volume(img) 173 | I = np.asanyarray(img, np.float32) 174 | S = seed 175 | geo_dis = GeodisTK.geodesic2d_fast_marching(I, S) 176 | if threshold > 0: 177 | geo_dis[geo_dis > threshold] = threshold 178 | geo_dis = geo_dis / threshold 179 | else: 180 | geo_dis = np.exp(-geo_dis**2) 181 | else: 182 | geo_dis = np.zeros_like(img, dtype=np.float32) 183 | return geo_dis 184 | 185 | 186 | def cstm_normalize(im, max_value=1.0): 187 | """ 188 | Normalize image to range 0 - max_value 189 | """ 190 | imn = max_value*(im - im.min()) / max((im.max() - im.min()), 1e-8) 191 | return imn 192 | 193 | 194 | def get_start_end_points(scribbles): 195 | points = np.where(scribbles != 0) 196 | minZidx = int(np.min(points[0])) 197 | maxZidx = int(np.max(points[0])) 198 | minXidx = int(np.min(points[1])) 199 | maxXidx = int(np.max(points[1])) 200 | start_end_points = [minZidx - 5, minXidx - 5, maxZidx + 5, maxXidx + 5] 201 | return start_end_points 202 | 203 | 204 | def add_countor(In, Seg, Color=(0, 255, 0)): 205 | Out = In.copy() 206 | [H, W] = In.size 207 | for i in range(H): 208 | for j in range(W): 209 | if(i == 0 or i == H-1 or j == 0 or j == W-1): 210 | if(Seg.getpixel((i, j)) != 0): 211 | Out.putpixel((i, j), Color) 212 | elif(Seg.getpixel((i, j)) != 0 and 213 | not(Seg.getpixel((i-1, j)) != 0 and 214 | Seg.getpixel((i+1, j)) != 0 and 215 | Seg.getpixel((i, j-1)) != 0 and 216 | Seg.getpixel((i, j+1)) != 0)): 217 | Out.putpixel((i, j), Color) 218 | return Out 219 | 220 | 221 | def add_overlay(image, seg_name, Color=(0, 255, 0)): 222 | seg = Image.open(seg_name).convert('L') 223 | seg = np.asarray(seg) 224 | if(image.size[1] != seg.shape[0] or image.size[0] != seg.shape[1]): 225 | print('segmentation has been resized') 226 | seg = scipy.misc.imresize( 227 | seg, (image.size[1], image.size[0]), interp='nearest') 228 | strt = ndimage.generate_binary_structure(2, 1) 229 | seg = np.asarray(ndimage.morphology.binary_opening(seg, strt), np.uint8) 230 | seg = np.asarray(ndimage.morphology.binary_closing(seg, strt), np.uint8) 231 | 232 | img_show = add_countor(image, Image.fromarray(seg), Color) 233 | strt = ndimage.generate_binary_structure(2, 1) 234 | seg = np.asarray(ndimage.morphology.binary_dilation(seg, strt), np.uint8) 235 | img_show = add_countor(img_show, Image.fromarray(seg), Color) 236 | return img_show 237 | 238 | 239 | def get_largest_two_component(img, prt=False, threshold=None): 240 | s = ndimage.generate_binary_structure(3, 2) # iterate structure 241 | labeled_array, numpatches = ndimage.label(img, s) # labeling 242 | sizes = ndimage.sum(img, labeled_array, range(1, numpatches+1)) 243 | sizes_list = [sizes[i] for i in range(len(sizes))] 244 | sizes_list.sort() 245 | if(prt): 246 | print("component size", sizes_list) 247 | if(len(sizes) == 1): 248 | return img 249 | else: 250 | if(threshold): 251 | out_img = np.zeros_like(img) 252 | for temp_size in sizes_list: 253 | if(temp_size > threshold): 254 | temp_lab = np.where(sizes == temp_size)[0] + 1 255 | temp_cmp = labeled_array == temp_lab 256 | out_img = (out_img + temp_cmp) > 0 257 | return out_img 258 | else: 259 | max_size1 = sizes_list[-1] 260 | max_size2 = sizes_list[-2] 261 | max_label1 = np.where(sizes == max_size1)[0] + 1 262 | max_label2 = np.where(sizes == max_size2)[0] + 1 263 | component1 = labeled_array == max_label1 264 | component2 = labeled_array == max_label2 265 | if(prt): 266 | print(max_size2, max_size1, max_size2/max_size1) 267 | if(max_size2*10 > max_size1): 268 | component1 = (component1 + component2) > 0 269 | 270 | return component1 271 | 272 | 273 | def softmax_seg(seg): 274 | m, n = seg.shape 275 | prob_feature = np.zeros((2, m, n), dtype=np.float32) 276 | prob_feature[0] = seg 277 | prob_feature[1] = 1.0 - seg 278 | softmax_feture = np.exp(prob_feature)/np.sum(np.exp(prob_feature), axis=0) 279 | fg_prob = softmax_feture[0].astype(np.float32) 280 | return fg_prob 281 | 282 | 283 | def softmax(x): 284 | return np.exp(x)/np.sum(np.exp(x), axis=0).astype(np.float32) 285 | 286 | 287 | def itensity_standardization(image): 288 | """ 289 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 290 | inputs: 291 | volume: the input nd volume 292 | outputs: 293 | out: the normalized nd volume 294 | """ 295 | pixels = image[image > 0] 296 | mean = pixels.mean() 297 | std = pixels.std() 298 | out = (image - mean)/std 299 | out = out.astype(np.float32) 300 | return out 301 | 302 | 303 | def itensity_normalization(image): 304 | out = (image - image.min()) / (image.max() - image.min()) 305 | out = out.astype(np.float32) 306 | return out 307 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit_image==0.17.2 2 | matplotlib==3.3.4 3 | torchvision==0.4.0a0+6b959ee 4 | numpy==1.19.2 5 | SimpleCRF==0.1.0 6 | scipy==1.5.2 7 | torch==1.2.0 8 | GeodisTK==0.1.7 9 | maxflow==0.0.1 10 | Pillow==8.3.1 11 | PyQt5==5.15.4 12 | skimage==0.0 13 | --------------------------------------------------------------------------------