├── bowel_coarseseg ├── tools │ ├── __init__.py │ └── loss.py ├── datasets │ ├── __init__.py │ └── BowelDatasetCoarseSeg.py ├── crop_ROI.py ├── loc_model.py ├── preprocessing.py ├── infer.py └── train.py ├── bowel_fineseg ├── tools │ ├── __init__.py │ ├── loss.py │ └── utils.py ├── datasets │ ├── __init__.py │ └── BowelDatasetFineSeg.py ├── arch │ ├── data.png │ ├── example.png │ ├── pipeline.png │ ├── seg_demo.png │ └── segmentors.png ├── create_boundary.py ├── create_skeleflux.py ├── infer.py ├── seg_model.py └── train.py └── README.md /bowel_coarseseg/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bowel_fineseg/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bowel_coarseseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bowel_fineseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bowel_fineseg/arch/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/data.png -------------------------------------------------------------------------------- /bowel_fineseg/arch/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/example.png -------------------------------------------------------------------------------- /bowel_fineseg/arch/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/pipeline.png -------------------------------------------------------------------------------- /bowel_fineseg/arch/seg_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/seg_demo.png -------------------------------------------------------------------------------- /bowel_fineseg/arch/segmentors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/segmentors.png -------------------------------------------------------------------------------- /bowel_fineseg/create_boundary.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pylab as plt 2 | import nibabel as nib 3 | from glob import glob 4 | from cc3d import connected_components 5 | import numpy as np 6 | import os, cv2 7 | import shutil 8 | import scipy 9 | from scipy import ndimage 10 | from skimage import morphology 11 | from scipy.ndimage import gaussian_filter 12 | from shutil import copy 13 | 14 | 15 | 16 | def mkdir(folder_path): 17 | if not os.path.exists(folder_path): 18 | os.makedirs(folder_path) 19 | 20 | 21 | 22 | files = glob('/mnt/c/chong/data/Bowel/crop_stage1_ROI_small/*/*/*/masks_crop.nii.gz') 23 | 24 | 25 | 26 | for iii, file in enumerate(files): 27 | 28 | 29 | # file = '/home/chong/Desktop/small bowel skeleton/masks_crop.nii.gz' 30 | 31 | 32 | mask_info = nib.load(file) 33 | mask_arr_ori = np.round(mask_info.get_fdata()).astype(np.uint8) 34 | 35 | mask_arr = mask_arr_ori.copy() 36 | 37 | erosion_bin_map = ndimage.binary_erosion(mask_arr.astype(int), structure=np.ones((3, 3, 1))) 38 | # erosion_bin_map = ndimage.binary_erosion(mask_arr.astype(int), structure=np.ones((3, 3, 3))) 39 | 40 | edge_map = mask_arr - erosion_bin_map.astype(int) 41 | 42 | edge_heatmap = np.zeros_like(mask_arr).astype(np.float32) 43 | for z in range(mask_arr.shape[2]): 44 | gray = edge_map[:, :, z].astype(np.float32).copy() 45 | edge_heatmap[:, :, z] = gaussian_filter(gray, sigma=1.0, truncate=2.0) # larger sigma, more blur 46 | 47 | save_path = file.replace('.nii.gz', '_edge_heatmap.nii.gz') 48 | nib.save(nib.Nifti1Image(edge_heatmap, header=mask_info.header, affine=mask_info.affine), save_path) 49 | 50 | save_path = file.replace('.nii.gz', '_edge.nii.gz') 51 | nib.save(nib.Nifti1Image(edge_map, header=mask_info.header, affine=mask_info.affine), save_path) 52 | 53 | break 54 | -------------------------------------------------------------------------------- /bowel_fineseg/tools/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def dice_loss(prob, target, epsilon=1e-7): 6 | """Computes the Sørensen–Dice loss. 7 | Args: 8 | target: a tensor of shape [B, 1, D, H, W]. 9 | prob: a tensor of shape [B, C, D, H, W]. Corresponds to 10 | the prob output of the model. 11 | epsilon: added to the denominator for numerical stability. 12 | Returns: 13 | dice: the Sørensen–Dice loss. 14 | """ 15 | # target_ = torch.cat((1 - target, target), dim=1) 16 | num_classes = prob.shape[1] 17 | target = F.one_hot(target.squeeze(1).long(), num_classes).permute(0, 4, 1, 2, 3) # [B, c, D, H, W] 18 | target = target.type(prob.type()) 19 | 20 | assert prob.size() == target.size(), "the size of predict and target must be equal." 21 | 22 | intersection = torch.sum(prob * target, dim=[0, 2, 3, 4]) 23 | union = torch.sum(prob + target, dim=[0, 2, 3, 4]) 24 | dice = (2. * intersection / (union + epsilon)).mean() # average over classes 25 | 26 | return 1 - dice 27 | 28 | 29 | def dice_loss_PL(prob, target, epsilon=1e-7): 30 | """Computes the Sørensen–Dice loss. 31 | Args: 32 | target: a tensor of shape [B, C, D, H, W]. 33 | prob: a tensor of shape [B, C, D, H, W]. Corresponds to 34 | the prob output of the model. 35 | epsilon: added to the denominator for numerical stability. 36 | Returns: 37 | dice: the Sørensen–Dice loss. 38 | """ 39 | assert prob.size() == target.size(), "the size of predict and target must be equal." 40 | 41 | intersection = torch.sum(prob * target, dim=[0, 2, 3, 4]) 42 | union = torch.sum(prob + target, dim=[0, 2, 3, 4]) 43 | dice = (2. * intersection / (union + epsilon)).mean() # average over classes 44 | 45 | return 1 - dice 46 | 47 | 48 | def dice_score_metric(input, target, epsilon=1e-7): 49 | result = torch.argmax(input, dim=1).unsqueeze(1) 50 | intersection = torch.sum(result * target) 51 | union = torch.sum(result) + torch.sum(target) 52 | dice = 2. * intersection / (union + epsilon) 53 | return dice 54 | -------------------------------------------------------------------------------- /bowel_coarseseg/tools/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | 7 | def dice_loss_PL(prob, target, epsilon=1e-7): 8 | """Computes the Sørensen–Dice loss. 9 | Args: 10 | target: a tensor of shape [B, C, D, H, W]. 11 | prob: a tensor of shape [B, C, D, H, W]. Corresponds to 12 | the prob output of the model. 13 | epsilon: added to the denominator for numerical stability. 14 | Returns: 15 | dice: the Sørensen–Dice loss. 16 | """ 17 | assert prob.size() == target.size(), "the size of predict and target must be equal." 18 | 19 | intersection = torch.sum(prob * target, dim=[0, 2, 3, 4]) 20 | union = torch.sum(prob + target, dim=[0, 2, 3, 4]) 21 | dice = (2. * intersection / (union + epsilon)).mean() # average over classes 22 | 23 | return 1 - dice 24 | 25 | 26 | def dice_similarity(output, target, smooth=1e-7): 27 | """Computes the Dice similarity""" 28 | 29 | output = output.float() 30 | target = target.float() 31 | 32 | seg_channel = output.view(output.size(0), -1) # (batch, D*H*W) 33 | target_channel = target.view(target.size(0), -1) 34 | 35 | intersection = (seg_channel * target_channel).sum(-1) 36 | union = (seg_channel + target_channel).sum(-1) 37 | dice = (2. * intersection) / (union + smooth) 38 | 39 | return torch.mean(dice) 40 | 41 | 42 | def dice_score_partial(output, target, data_type): 43 | """Computes the Dice scores, given foreground classes""" 44 | 45 | assert data_type in ['fully_labeled', 'smallbowel', 'colon_sigmoid'] 46 | 47 | result = torch.argmax(output, dim=1, keepdim=True) 48 | if data_type == 'fully_labeled': 49 | valid_class = [1, 2, 3, 4, 5] 50 | if data_type == 'smallbowel': 51 | valid_class = [4] 52 | if data_type == 'colon_sigmoid': 53 | valid_class = [2, 3] 54 | 55 | total_dice = [] 56 | for c in valid_class: 57 | target_c = (target == c).long() 58 | output_c = (result == c).long() 59 | dice_c = dice_similarity(output_c, target_c) 60 | total_dice.append(dice_c.item()) 61 | return total_dice 62 | 63 | 64 | -------------------------------------------------------------------------------- /bowel_fineseg/create_skeleflux.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pylab as plt 2 | import nibabel as nib 3 | from glob import glob 4 | from cc3d import connected_components 5 | import numpy as np 6 | import os, cv2 7 | import scipy 8 | from scipy import ndimage 9 | from skimage import morphology 10 | from scipy.ndimage import gaussian_filter 11 | from shutil import copy 12 | import scipy.ndimage as ndi 13 | 14 | 15 | 16 | def mkdir(folder_path): 17 | if not os.path.exists(folder_path): 18 | os.makedirs(folder_path) 19 | 20 | 21 | 22 | files = glob('/mnt/c/chong/data/Bowel/crop_stage1_ROI_small/*/*/*/masks_crop.nii.gz') 23 | 24 | 25 | 26 | for iii, file in enumerate(files): 27 | 28 | 29 | # file = '/home/chong/Desktop/small bowel skeleton/masks_crop.nii.gz' 30 | 31 | 32 | mask_info = nib.load(file) 33 | mask_arr_ori = np.round(mask_info.get_fdata()).astype(np.uint8) 34 | 35 | skeleton = mask_arr_ori.copy() 36 | 37 | #################### 38 | # filtering for smoothing, larger sigma, more blur 39 | skeleton = ndi.gaussian_filter(skeleton.astype(np.float), sigma=(1.5, 1.5, 1.5), order=0, truncate=2.0) 40 | skeleton = (skeleton > 0.5).astype(np.uint8) 41 | ##################### 42 | 43 | skeleton = morphology.skeletonize_3d(skeleton).astype(np.uint8) 44 | 45 | save_path = file.replace('.nii.gz', '_skele.nii.gz') 46 | skeleton_dilate = ndimage.binary_dilation(skeleton, structure=np.ones((3, 3, 5))) 47 | nib.save(nib.Nifti1Image(skeleton_dilate, header=mask_info.header, affine=mask_info.affine), save_path) 48 | 49 | 50 | pseudo_one_mask = np.ones_like(mask_arr_ori).astype(np.uint8) 51 | pseudo_one_mask[skeleton == 1] = 0 52 | skeleton_dist, skeleton_index = ndi.distance_transform_edt(pseudo_one_mask, return_indices=True) # distance to the nearest zero point 53 | 54 | grid = np.indices(skeleton_dist.shape).astype(float) 55 | diff = grid - skeleton_index 56 | dist = np.sqrt(np.sum(diff ** 2, axis=0)) 57 | 58 | direction_0 = np.divide(diff[0, ...], dist + 1e-7) # avoid divide zero 59 | direction_1 = np.divide(diff[1, ...], dist + 1e-7) 60 | direction_2 = np.divide(diff[2, ...], dist + 1e-7) 61 | 62 | skeleton_dist[mask_arr_ori != 1] = 0 63 | direction_0[mask_arr_ori != 1] = 0 64 | direction_1[mask_arr_ori != 1] = 0 65 | direction_2[mask_arr_ori != 1] = 0 66 | 67 | save_path = file.replace('.nii.gz', '_skele_dist.nii.gz') 68 | nib.save(nib.Nifti1Image(skeleton_dist, header=mask_info.header, affine=mask_info.affine), save_path) 69 | 70 | save_path = file.replace('.nii.gz', '_skele_fluxx.nii.gz') 71 | nib.save(nib.Nifti1Image(direction_0, header=mask_info.header, affine=mask_info.affine), save_path) 72 | 73 | save_path = file.replace('.nii.gz', '_skele_fluxy.nii.gz') 74 | nib.save(nib.Nifti1Image(direction_1, header=mask_info.header, affine=mask_info.affine), save_path) 75 | 76 | save_path = file.replace('.nii.gz', '_skele_fluxz.nii.gz') 77 | nib.save(nib.Nifti1Image(direction_2, header=mask_info.header, affine=mask_info.affine), save_path) 78 | 79 | np_save = np.sqrt(direction_0 ** 2 + direction_1 ** 2 + direction_2 ** 2) 80 | nib.save(nib.Nifti1Image(np_save, header=mask_info.header, affine=mask_info.affine), file.replace('.nii.gz', '_skele_fluxmag.nii.gz')) 81 | 82 | 83 | break 84 | 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BowelNet 2 | 3 | 4 | ### Pytorch code for the paper "BowelNet: Joint Semantic-Geometric Ensemble Learning for Bowel Segmentation from Both Partially and Fully Labeled CT Images" at IEEE TMI 2022. 5 | Email: chong.wang@adelaide.edu.au 6 | 7 | 8 | 9 | 10 | 11 | ## Introduction: 12 | 13 | The BowelNet is a two-stage coarse-to-fine framework for the sgmentation of entire bowel (i.e., duodenum, jejunum-ileum, colon, sigmoid, and rectum) from abdominal CT images. The first stage jointly localizes all types of the bowel, trained robustly on both partially and fully labeled samples (see examples below). The second stage finely segments each type of localized the bowels using geometric bowel representations and hybrid psuedo labels: 14 | 15 | [(1) Joint localzation of the five bowel parts using both partially- and fully-labeled images](https://github.com/runningcw/BowelNet/tree/master/bowel_coarseseg) 16 | 17 | [(2) Fine segmentation of each part using geometric (i.e., boundary and skeleton) guidance](https://github.com/runningcw/BowelNet/tree/master/bowel_fineseg) 18 | 19 | 20 | Examples of fully (a) and partially (b, c) labeled training data used in our work: 21 | 22 |
23 | 24 |
25 | 26 | 27 | ## Dataset: 28 | 29 | We utilize a large private abdominal CT dataset that includes both partially and fully labeled segmentation annotations. The dataset is structured as follows: 30 | 31 | ``` 32 | BowelSegData 33 | ├── Fully_labeled_5C 34 | │ ├── abdomen 35 | │ │ ├── .nii.gz 36 | │ │ ... 37 | │ ├── male 38 | │ │ ├── .nii.gz 39 | │ │ ... 40 | │ └── female 41 | │ ├── .nii.gz 42 | │ ... 43 | ├── Colon_Sigmoid 44 | │ ├── abdomen 45 | │ │ ├── .nii.gz 46 | │ │ ... 47 | │ ├── male 48 | │ │ ├── .nii.gz 49 | │ │ ... 50 | │ └── female 51 | │ ├── .nii.gz 52 | │ ... 53 | └── Smallbowel 54 | ├── abdomen 55 | │ ├── .nii.gz 56 | │ ... 57 | ├── male 58 | │ ├── .nii.gz 59 | │ ... 60 | └── female 61 | ├── .nii.gz 62 | ... 63 | ``` 64 | 65 | ## Data Preprocessing: 66 | 67 | [Preprocessing](https://github.com/runningcw/BowelNet/blob/master/bowel_coarseseg/preprocessing.py) includes cropping abdominal body region. We average all 2D CT slices of a volume to form a mean image and then apply a thresholding on it to obtain the abdominal body region (excluding CT bed). 68 | 69 | 70 | ## Demo segmentation: 71 | 72 | Our BowelNet demonstrates improved performance over prior approaches in the entire bowel segmentation. 73 | 74 | ![image](https://github.com/cwangrun/BowelNet/blob/master/bowel_fineseg/arch/seg_demo.png) 75 | 76 | 77 | ## Citation: 78 | ``` 79 | @article{wang2022bowelnet, 80 | title={BowelNet: Joint Semantic-Geometric Ensemble Learning for Bowel Segmentation From Both Partially and Fully Labeled CT Images}, 81 | author={Wang, Chong and Cui, Zhiming and Yang, Junwei and Han, Miaofei and Carneiro, Gustavo and Shen, Dinggang}, 82 | journal={IEEE Transactions on Medical Imaging}, 83 | volume={42}, 84 | number={4}, 85 | pages={1225--1236}, 86 | year={2022}, 87 | publisher={IEEE} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /bowel_coarseseg/crop_ROI.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | from glob import glob 4 | import os.path 5 | import nibabel as nib 6 | import scipy 7 | from cc3d import connected_components 8 | 9 | 10 | name_to_label = {'rectum': 1, 'sigmoid': 2, 'colon': 3, 'small': 4, 'duodenum': 5} 11 | 12 | 13 | def img2box(img): 14 | r = np.any(img, axis=(1, 2)) 15 | c = np.any(img, axis=(0, 2)) 16 | z = np.any(img, axis=(0, 1)) 17 | rmin, rmax = np.where(r)[0][[0, -1]] 18 | cmin, cmax = np.where(c)[0][[0, -1]] 19 | zmin, zmax = np.where(z)[0][[0, -1]] 20 | return rmin, rmax, cmin, cmax, zmin, zmax 21 | 22 | 23 | def normalize_volume(img): 24 | img_array = (img - img.min()) / (img.max() - img.min()) 25 | return img_array 26 | 27 | 28 | def crop_ROI_using_coarse_seg_mask(mask_ori, bowel_name): 29 | assert bowel_name in name_to_label.keys() 30 | 31 | pos_label = name_to_label[bowel_name] 32 | mask = mask_ori.copy() 33 | mask[mask != pos_label] = 0 34 | mask[mask == pos_label] = 1 35 | 36 | # use the largest connected region to eliminate small noisy points, or use the mask directly 37 | labels_out, N = connected_components(mask, connectivity=26, return_N=True) 38 | numPix = [] 39 | for segid in range(1, N + 1): 40 | numPix.append([segid, (labels_out == segid).astype(np.int8).sum()]) 41 | numPix = np.array(numPix) 42 | 43 | if len(numPix) != 0: 44 | max_connected_image = np.int8(labels_out == numPix[np.argmax(numPix[:, 1]), 0]) 45 | min_x, max_x, min_y, max_y, min_z, max_z = img2box(max_connected_image) 46 | else: 47 | print('coarse stage does not detect the organ, will skip this case') 48 | return (None, None), (None, None), (None, None), None 49 | 50 | # extend 51 | x_extend, y_extend, z_extend = (30, 30, 30) 52 | 53 | max_x = min(max_x + x_extend, mask.shape[0]) 54 | max_y = min(max_y + y_extend, mask.shape[1]) 55 | max_z = min(max_z + z_extend, mask.shape[2]) 56 | min_x = max(min_x - x_extend, 0) 57 | min_y = max(min_y - y_extend, 0) 58 | min_z = max(min_z - z_extend, 0) 59 | 60 | return (min_x, max_x), (min_y, max_y), (min_z, max_z), mask 61 | 62 | 63 | if __name__ == '__main__': 64 | 65 | imgs = glob('/mnt/c/chong/data/Bowel/crop_ori_res/*/*.nii.gz') 66 | 67 | bowel_name = 'rectum' 68 | # bowel_name = 'sigmoid' 69 | # bowel_name = 'colon' 70 | # bowel_name = 'small' 71 | # bowel_name = 'duodenum' 72 | 73 | for iii, img_file in enumerate(imgs): 74 | 75 | img_info = nib.load(img_file) 76 | img = img_info.get_fdata() 77 | img = np.clip(img, -1024, 1000) 78 | img = normalize_volume(img) 79 | 80 | coarse_mask_file = img_file.replace('image.nii.gz', 'masks_partial_C5.nii.gz') 81 | assert os.path.isfile(coarse_mask_file) 82 | coarse_mask = nib.load(coarse_mask_file).get_fdata() 83 | coarse_mask = scipy.ndimage.interpolation.zoom((coarse_mask).astype(np.float32), 84 | zoom=np.array(img.shape) / np.array(coarse_mask.shape), 85 | mode='nearest', 86 | order=0) # order = 0 nearest interpolation 87 | coarse_mask = np.round(coarse_mask).astype(np.int32) 88 | 89 | crop_x, crop_y, crop_z, mask_return = crop_ROI_using_coarse_seg_mask(coarse_mask, bowel_name) 90 | 91 | if mask_return is None: 92 | continue 93 | 94 | img = img[crop_x[0]:crop_x[1], crop_y[0]:crop_y[1], crop_z[0]:crop_z[1]] 95 | mask_return = mask_return[crop_x[0]:crop_x[1], crop_y[0]:crop_y[1], crop_z[0]:crop_z[1]] 96 | 97 | save_dir = os.path.dirname(img_file.replace('crop_ori_res', 'crop_ori_res_' + bowel_name)) 98 | os.makedirs(save_dir, exist_ok=True) 99 | ############################################################################################ 100 | crop_coords_patient = [crop_x[0], ',', crop_x[1], ',', crop_y[0], ',', crop_y[1], ',', crop_z[0], ',', crop_z[1]] 101 | with open(save_dir + '/crop_info_' + bowel_name + '.txt', 'w') as file: 102 | file.writelines(map(lambda x: str(x), crop_coords_patient)) 103 | file.close() 104 | nib.save(nib.Nifti1Image(img, header=img_info.header, affine=img_info.affine), os.path.join(save_dir, 'image_crop.nii.gz')) 105 | nib.save(nib.Nifti1Image(mask_return, header=img_info.header, affine=img_info.affine), os.path.join(save_dir, 'masks_partial_C5_crop.nii.gz')) 106 | ############################################################################################ 107 | 108 | # print(iii, crop_x[1] - crop_x[0], crop_y[1] - crop_y[0], crop_z[1] - crop_z[0], img_file) 109 | 110 | -------------------------------------------------------------------------------- /bowel_coarseseg/loc_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def passthrough(x, **kwargs): 8 | return x 9 | 10 | 11 | def ELUCons(elu, nchan): 12 | if elu: 13 | return nn.ELU(inplace=True) 14 | else: 15 | return nn.PReLU(nchan) 16 | 17 | # normalization between sub-volumes is necessary for good performance 18 | class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): 19 | def _check_input_dim(self, input): 20 | if input.dim() != 5: 21 | raise ValueError('expected 5D input (got {}D input)' 22 | .format(input.dim())) 23 | # super(ContBatchNorm3d, self)._check_input_dim(input) 24 | 25 | def forward(self, input): 26 | self._check_input_dim(input) 27 | return F.batch_norm( 28 | input, self.running_mean, self.running_var, self.weight, self.bias, 29 | True, self.momentum, self.eps) 30 | 31 | 32 | class LUConv(nn.Module): 33 | def __init__(self, nchan, elu): 34 | super(LUConv, self).__init__() 35 | self.relu1 = ELUCons(elu, nchan) 36 | self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) 37 | self.bn1 = ContBatchNorm3d(nchan) 38 | 39 | def forward(self, x): 40 | out = self.relu1(self.bn1(self.conv1(x))) 41 | return out 42 | 43 | 44 | def _make_nConv(nchan, depth, elu): 45 | layers = [] 46 | for _ in range(depth): 47 | layers.append(LUConv(nchan, elu)) 48 | return nn.Sequential(*layers) 49 | 50 | 51 | class InputTransition(nn.Module): 52 | def __init__(self, inChans, outChans, elu): 53 | super(InputTransition, self).__init__() 54 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=5, padding=2) 55 | self.bn1 = ContBatchNorm3d(outChans) 56 | self.relu1 = ELUCons(elu, outChans) 57 | 58 | def forward(self, x): 59 | out = self.bn1(self.conv1(x)) 60 | x16 = torch.cat((x, x, x, x, x, x, x, x, x, x, x, x), 1) 61 | out = self.relu1(torch.add(out, x16)) 62 | return out 63 | 64 | 65 | class DownTransition(nn.Module): 66 | def __init__(self, inChans, nConvs, elu, dropout=False): 67 | super(DownTransition, self).__init__() 68 | outChans = 2*inChans 69 | self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2) 70 | self.bn1 = ContBatchNorm3d(outChans) 71 | self.do1 = passthrough 72 | self.relu1 = ELUCons(elu, outChans) 73 | self.relu2 = ELUCons(elu, outChans) 74 | if dropout: 75 | self.do1 = nn.Dropout3d(p=0.2) 76 | self.ops = _make_nConv(outChans, nConvs, elu) 77 | 78 | def forward(self, x): 79 | down = self.relu1(self.bn1(self.down_conv(x))) 80 | out = self.do1(down) 81 | out = self.ops(out) 82 | out = self.relu2(torch.add(out, down)) 83 | return out 84 | 85 | 86 | class UpTransition(nn.Module): 87 | def __init__(self, inChans, outChans, nConvs, elu, dropout=False): 88 | super(UpTransition, self).__init__() 89 | self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2) 90 | self.bn1 = ContBatchNorm3d(outChans // 2) 91 | self.do1 = passthrough 92 | self.do2 = nn.Dropout3d(p=0.2) 93 | self.relu1 = ELUCons(elu, outChans // 2) 94 | self.relu2 = ELUCons(elu, outChans) 95 | if dropout: 96 | self.do1 = nn.Dropout3d(p=0.2) 97 | self.ops = _make_nConv(outChans, nConvs, elu) 98 | 99 | def forward(self, x, skipx): 100 | out = self.do1(x) 101 | skipxdo = self.do2(skipx) 102 | out = self.relu1(self.bn1(self.up_conv(out))) 103 | xcat = torch.cat((out, skipxdo), 1) 104 | out = self.ops(xcat) 105 | out = self.relu2(torch.add(out, xcat)) 106 | return out 107 | 108 | 109 | class OutputTransition(nn.Module): 110 | def __init__(self, inChans, outChans, elu): 111 | super(OutputTransition, self).__init__() 112 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=5, padding=2) 113 | self.bn1 = ContBatchNorm3d(outChans) 114 | self.relu1 = ELUCons(elu, outChans) 115 | self.conv2 = nn.Conv3d(outChans, outChans, kernel_size=1) 116 | 117 | def forward(self, x): 118 | out = self.relu1(self.bn1(self.conv1(x))) 119 | out = self.conv2(out) 120 | return out 121 | 122 | 123 | class BowelLocNet(nn.Module): 124 | 125 | def __init__(self, elu=True): 126 | super(BowelLocNet, self).__init__() 127 | num_baseC = 12 # 16 128 | self.in_tr = InputTransition(1, 1*num_baseC, elu) 129 | self.down_tr32 = DownTransition(1*num_baseC, 1, elu) 130 | self.down_tr64 = DownTransition(2*num_baseC, 2, elu) 131 | self.down_tr128 = DownTransition(4*num_baseC, 2, elu, dropout=True) 132 | self.down_tr256 = DownTransition(8*num_baseC, 2, elu, dropout=True) 133 | self.up_tr256 = UpTransition(16*num_baseC, 16*num_baseC, 2, elu, dropout=True) 134 | self.up_tr128 = UpTransition(16*num_baseC, 8*num_baseC, 2, elu, dropout=True) 135 | self.up_tr64 = UpTransition(8*num_baseC, 4*num_baseC, 1, elu) 136 | self.up_tr32 = UpTransition(4*num_baseC, 2*num_baseC, 1, elu) 137 | self.out_tr = OutputTransition(2*num_baseC, 6, elu) 138 | 139 | def forward(self, x): 140 | out16 = self.in_tr(x) 141 | out32 = self.down_tr32(out16) 142 | out64 = self.down_tr64(out32) 143 | out128 = self.down_tr128(out64) 144 | out256 = self.down_tr256(out128) 145 | 146 | out = self.up_tr256(out256, out128) 147 | out = self.up_tr128(out, out64) 148 | out = self.up_tr64(out, out32) 149 | out = self.up_tr32(out, out16) 150 | out = self.out_tr(out) 151 | return F.softmax(out, dim=1) -------------------------------------------------------------------------------- /bowel_coarseseg/preprocessing.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | from matplotlib import pylab as plt 3 | import nibabel as nib 4 | import SimpleITK as sitk 5 | from glob import glob 6 | from cc3d import connected_components 7 | import numpy as np 8 | import os, cv2 9 | from scipy import ndimage 10 | from skimage import morphology 11 | from skimage.morphology import convex_hull_image, disk, binary_closing 12 | 13 | 14 | def mkdir(folder_path): 15 | if not os.path.exists(folder_path): 16 | os.makedirs(folder_path) 17 | 18 | 19 | def resample_img(image_file, out_spacing=[0.97, 0.97, 1.5], is_label=False): 20 | 21 | itk_image = sitk.ReadImage(image_file) 22 | 23 | original_spacing = itk_image.GetSpacing() 24 | original_size = itk_image.GetSize() 25 | 26 | out_size = [ 27 | int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))), 28 | int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))), 29 | int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))] 30 | 31 | resample = sitk.ResampleImageFilter() 32 | resample.SetOutputSpacing(out_spacing) 33 | resample.SetSize(out_size) 34 | resample.SetOutputDirection(itk_image.GetDirection()) 35 | resample.SetOutputOrigin(itk_image.GetOrigin()) 36 | resample.SetTransform(sitk.Transform()) 37 | resample.SetDefaultPixelValue(itk_image.GetPixelIDValue()) 38 | 39 | if is_label: 40 | resample.SetInterpolator(sitk.sitkNearestNeighbor) 41 | else: 42 | resample.SetInterpolator(sitk.sitkBSpline) 43 | 44 | itk_new = resample.Execute(itk_image) 45 | 46 | return sitk.GetArrayFromImage(itk_new).transpose(2, 1, 0) 47 | 48 | 49 | def select_largest_region(img_bin): 50 | 51 | # N is the number of connected components 52 | labels_out, N = connected_components(img_bin, connectivity=26, return_N=True) 53 | num_labels = [] 54 | for segid in range(1, N + 1): 55 | extracted_image = labels_out * (labels_out == segid) 56 | num_labels.append(np.array([segid, extracted_image.sum()])) 57 | 58 | num_labels = np.array(num_labels) 59 | topk = np.argsort(num_labels, axis=0)[::-1] 60 | top1 = num_labels[topk[0][1]][0] 61 | 62 | largest_mask = np.zeros(img_bin.shape, dtype=np.uint8) 63 | largest_mask[labels_out == top1] = 255 64 | 65 | return largest_mask 66 | 67 | 68 | def crop_body(image): 69 | 70 | image = ((image - image.min()) / (image.max() - image.min())) * 255 71 | image = image.astype("uint8") 72 | blur = cv2.GaussianBlur(image, (5, 5), 0) 73 | thresh, mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 74 | mask_body = select_largest_region(mask) # remove CT bed 75 | 76 | # plt.imshow(image) 77 | # plt.show() 78 | # plt.imshow(mask) 79 | # plt.show() 80 | # plt.imshow(mask_body) 81 | # plt.show() 82 | 83 | x_range, y_range = np.where(mask_body != 0) 84 | x_min, y_min = x_range.min(), y_range.min() 85 | x_max, y_max = x_range.max(), y_range.max() 86 | 87 | return x_min, x_max, y_min, y_max, mask_body 88 | 89 | 90 | files = glob('/mnt/c/chong/data/Bowel/alldata/small_bowel/*/*/image.nii.gz') # resampled data 91 | 92 | for iii, image_file in enumerate(files): 93 | 94 | label_file = image_file.replace('image.nii.gz', 'masks.nii.gz') 95 | 96 | image_info = nib.load(image_file) 97 | image_arr = image_info.get_fdata() 98 | 99 | label_info = nib.load(label_file) 100 | label_arr = np.round(label_info.get_fdata()).astype(int) 101 | 102 | x, y, z = image_arr.shape 103 | 104 | image = image_arr[:, :, :].copy() 105 | image = np.clip(image, -1024, 1000) # remove abnormal intensity 106 | 107 | image_mean = np.mean(image, axis=2) 108 | body_x_min, body_x_max, body_y_min, body_y_max, mask_body = crop_body(image_mean) 109 | 110 | xs = body_x_min - 0 111 | xe = body_x_max + 0 112 | ys = body_y_min - 20 113 | ye = body_y_max + 0 114 | 115 | xs = 0 if xs < 0 else xs 116 | xe = x if xe > x else xe 117 | 118 | ys = 0 if ys < 0 else ys 119 | ye = y if ye > y else ye 120 | 121 | image_body = image_arr[xs:xe, ys:ye, :] 122 | masks_body = label_arr[xs:xe, ys:ye, :] 123 | 124 | print(iii, 'crop_shape', image_body.shape, 'ori_shape', image_arr.shape, image_file) 125 | 126 | # save preprocessed data 127 | ####################################################################################################### 128 | if not os.path.exists(os.path.dirname(image_file.replace('/alldata', '/crop_preprocessed'))): 129 | os.makedirs(os.path.dirname(image_file.replace('/alldata', '/crop_preprocessed'))) 130 | 131 | mask_save_path = os.path.dirname(os.path.dirname(image_file)).replace('/alldata', '/crop_preprocessed') 132 | cv2.imwrite(mask_save_path + '/' + image_file.split('/')[-2] + '.jpg', mask_body.transpose()) 133 | 134 | nib.save(nib.Nifti1Image(image_body, header=image_info.header, affine=image_info.affine), image_file.replace('/alldata', '/crop_preprocessed')) 135 | nib.save(nib.Nifti1Image(masks_body, header=label_info.header, affine=label_info.affine), label_file.replace('/alldata', '/crop_preprocessed')) 136 | ####################################################################################################### 137 | 138 | # save downsampled data 139 | ####################################################################################################### 140 | image_downsampled = image_body[::2, ::2, ::2] 141 | masks_downsampled = masks_body[::2, ::2, ::2] 142 | 143 | if not os.path.exists(os.path.dirname(image_file.replace('/alldata', '/crop_downsample'))): 144 | os.makedirs(os.path.dirname(image_file.replace('/alldata', '/crop_downsample'))) 145 | 146 | nib.save(nib.Nifti1Image(image_downsampled, header=image_info.header, affine=image_info.affine), image_file.replace('/alldata', '/crop_downsample')) 147 | nib.save(nib.Nifti1Image(masks_downsampled, header=label_info.header, affine=label_info.affine), label_file.replace('/alldata', '/crop_downsample')) 148 | ####################################################################################################### 149 | 150 | # break 151 | -------------------------------------------------------------------------------- /bowel_coarseseg/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import division 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from scipy.ndimage.filters import gaussian_filter 9 | from datasets.BowelDatasetCoarseSeg import BowelCoarseSeg 10 | from tools.loss import * 11 | import os, math 12 | import numpy as np 13 | import nibabel as nib 14 | import loc_model 15 | 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # "1, 2" 18 | print("GPU ID:", os.environ['CUDA_VISIBLE_DEVICES']) 19 | 20 | 21 | name_to_label = {'rectum': 1, 'sigmoid': 2, 'colon': 3, 'small': 4, 'duodenum': 5} 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--ngpu', type=int, default=1) 27 | parser.add_argument('--batch_size', type=int, default=8) # 8 (2 GPU) 28 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 29 | parser.add_argument('--nEpochs', type=int, default=1500, help='total training epoch') 30 | 31 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 32 | parser.add_argument('--weight_decay', default=1e-8, type=float, metavar='W', help='weight decay (default: 1e-8)') 33 | parser.add_argument('--eval_interval', default=5, type=int, help='evaluate interval on validation set') 34 | parser.add_argument('--save_dir', type=str, default=None) 35 | parser.add_argument('--seed', type=int, default=1) 36 | parser.add_argument('--deterministic', type=bool, default=True) 37 | args = parser.parse_args() 38 | 39 | 40 | 41 | print("build Bowel Localisation Network") 42 | model = loc_model.BowelLocNet(elu=False) 43 | 44 | model_path = "./exp/BowelLocNet.20230208_1536/partial_5C_dict_1300.pth" 45 | print('load checkpoint:', model_path) 46 | model.load_state_dict(torch.load(model_path)) 47 | 48 | model = model.cuda() 49 | model = nn.parallel.DataParallel(model) 50 | model.eval() 51 | 52 | 53 | print("loading fully_labeled dataset") 54 | fully_labeled_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Fully_labeled_5C' 55 | testSet_fully_labeled = BowelCoarseSeg(fully_labeled_dir, mode="test", transform=False, dataset_name="fully_labeled", save_dir=None) 56 | 57 | print("loading small dataset") 58 | smallbowel_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Smallbowel' 59 | testSet_small = BowelCoarseSeg(smallbowel_dir, mode="test", transform=False, dataset_name="smallbowel", save_dir=None) 60 | 61 | print("loading colon_sigmoid dataset") 62 | colon_sigmoid_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Colon_Sigmoid/colon_sigmoid' 63 | testSet_colon_sigmoid = BowelCoarseSeg(colon_sigmoid_dir, mode="test", transform=False, dataset_name="colon_sigmoid", save_dir=None) 64 | 65 | 66 | patch_size = (64, 128, 128) 67 | overlap = (32, 64, 64) 68 | 69 | 70 | n_test_samples = 0. 71 | dice_score = 0.0 72 | with torch.no_grad(): 73 | # for data_name in testSet_fully_labeled.imgs: 74 | for data_name in testSet_small.imgs: 75 | # for data_name in testSet_colon_sigmoid.imgs: 76 | 77 | n_test_samples = n_test_samples + 1 78 | 79 | img_info = nib.load(data_name) 80 | img = img_info.get_fdata() 81 | img = np.clip(img, -1024, 1000) 82 | img = normalize_volume(img) 83 | 84 | label = nib.load(data_name.replace('image.nii.gz', 'masks.nii.gz')).get_fdata() 85 | label = np.round(label) 86 | 87 | img_arr, label_arr, zs_pad, ze_pad = padding_z(img, label, min_z=100) 88 | img_arr, label_arr, xs_pad, xe_pad = padding_x(img_arr, label_arr, min_x=160) 89 | img_arr, label_arr, ys_pad, ye_pad = padding_y(img_arr, label_arr, min_y=160) 90 | 91 | img_arr = img_arr.transpose((2, 0, 1)) 92 | label_arr = label_arr.transpose((2, 0, 1)) 93 | 94 | data = torch.from_numpy(img_arr[np.newaxis, np.newaxis, :].astype(np.float32)) 95 | target = torch.from_numpy(label_arr[np.newaxis, np.newaxis, :].astype(np.int32)) 96 | 97 | data, target = Variable(data.cuda()), Variable(target.cuda()) 98 | 99 | b, _, z, x, y = data.shape 100 | 101 | zs = list(range(0, z, patch_size[0] - overlap[0])) 102 | xs = list(range(0, x, patch_size[1] - overlap[1])) 103 | ys = list(range(0, y, patch_size[2] - overlap[2])) 104 | 105 | gaussian_map = torch.from_numpy(_get_gaussian(patch_size, sigma_scale=1. / 4)).cuda() 106 | output_all = torch.zeros((b, 6, z, x, y)).cuda() 107 | for zzz in zs: 108 | for xxx in xs: 109 | for yyy in ys: 110 | if xxx + patch_size[1] > x: 111 | xxx = x - patch_size[1] 112 | if yyy + patch_size[2] > y: 113 | yyy = y - patch_size[2] 114 | if zzz + patch_size[0] > z: 115 | zzz = z - patch_size[0] 116 | candidate_patch = data[:, :, zzz:zzz+patch_size[0], xxx:xxx+patch_size[1], yyy:yyy+patch_size[2]] 117 | output = model(candidate_patch) # b, 6, z, x, y prob output 118 | output_all[:, :, zzz:zzz+patch_size[0], xxx:xxx+patch_size[1], yyy:yyy+patch_size[2]] += output * gaussian_map 119 | # remove padded slice 120 | output_all = output_all[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad] 121 | target = target[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad] 122 | result = torch.argmax(output_all, dim=1, keepdim=True) 123 | 124 | 125 | # pos_cls = [1, 2, 3, 4, 5] # fully labeled dataset 126 | pos_cls = [4] # small bowel dataset 127 | # pos_cls = [2, 3] # colon_sigmoid dataset 128 | 129 | 130 | dice_c = [] 131 | for cls in pos_cls: 132 | dice_c.append(dice_similarity((result == cls).contiguous(), (target == cls).contiguous()).item()) 133 | dice_c = np.array(dice_c) 134 | dice_score += dice_c 135 | 136 | print('Name: {}, Dice: {}'.format(data_name, dice_c)) 137 | 138 | # np_save = result.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.int32) 139 | # nib.save(nib.Nifti1Image(np_save, header=img_info.header, affine=img_info.affine), data_name.replace('image.nii.gz', 'masks_partial_C5.nii.gz')) 140 | 141 | dice_score /= n_test_samples 142 | 143 | print('\nMean test: Dice: {}\n'.format(dice_score)) 144 | 145 | 146 | def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray: 147 | tmp = np.zeros(patch_size) 148 | center_coords = [i // 2 for i in patch_size] 149 | sigmas = [i * sigma_scale for i in patch_size] 150 | tmp[tuple(center_coords)] = 1 151 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) 152 | gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1 153 | gaussian_importance_map = gaussian_importance_map.astype(np.float32) 154 | 155 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans! 156 | gaussian_importance_map[gaussian_importance_map == 0] = np.min( 157 | gaussian_importance_map[gaussian_importance_map != 0]) 158 | 159 | return gaussian_importance_map 160 | 161 | 162 | def normalize_volume(img): 163 | img_array = (img - img.min()) / (img.max() - img.min()) 164 | return img_array 165 | 166 | 167 | def padding_z(img, label, min_z): 168 | x, y, z = img.shape 169 | if z >= min_z: 170 | return img, label, 0, z 171 | else: 172 | num_pad = min_z - z 173 | num_top_pad = num_pad // 2 174 | top_pad = np.zeros((x, y, num_top_pad), dtype=np.float64) 175 | bottom_pad = np.zeros((x, y, num_pad - num_top_pad), dtype=np.float64) 176 | img = np.concatenate((bottom_pad, img, top_pad), axis=2) 177 | label = np.concatenate((bottom_pad, label, top_pad), axis=2) 178 | return img, label, bottom_pad.shape[2], img.shape[2] - top_pad.shape[2] 179 | 180 | 181 | def padding_x(img, label, min_x): 182 | x, y, z = img.shape 183 | if x >= min_x: 184 | return img, label, 0, x 185 | else: 186 | num_pad = min_x - x 187 | num_top_pad = num_pad // 2 188 | top_pad = np.zeros((num_top_pad, y, z), dtype=np.float64) 189 | bottom_pad = np.zeros((num_pad - num_top_pad, y, z), dtype=np.float64) 190 | img = np.concatenate((bottom_pad, img, top_pad), axis=0) 191 | label = np.concatenate((bottom_pad, label, top_pad), axis=0) 192 | return img, label, bottom_pad.shape[0], img.shape[0] - top_pad.shape[0] 193 | 194 | 195 | def padding_y(img, label, min_y): 196 | x, y, z = img.shape 197 | if y >= min_y: 198 | return img, label, 0, y 199 | else: 200 | num_pad = min_y - y 201 | num_top_pad = num_pad // 2 202 | top_pad = np.zeros((x, num_top_pad, z), dtype=np.float64) 203 | bottom_pad = np.zeros((x, num_pad - num_top_pad, z), dtype=np.float64) 204 | img = np.concatenate((bottom_pad, img, top_pad), axis=1) 205 | label = np.concatenate((bottom_pad, label, top_pad), axis=1) 206 | return img, label, bottom_pad.shape[1], img.shape[1] - top_pad.shape[1] 207 | 208 | 209 | if __name__ == '__main__': 210 | main() 211 | -------------------------------------------------------------------------------- /bowel_fineseg/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import division 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from torch.utils.data import DataLoader 10 | from scipy.ndimage.filters import gaussian_filter 11 | from tools.loss import * 12 | from tools import utils 13 | import os, math 14 | import nibabel as nib 15 | import seg_model 16 | from datasets.BowelDatasetFineSeg import BowelFineSeg 17 | 18 | 19 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' # "1, 2" 20 | print("GPU ID:", os.environ['CUDA_VISIBLE_DEVICES']) 21 | 22 | 23 | name_to_label = {'rectum': 1, 'sigmoid': 2, 'colon': 3, 'small': 4, 'duodenum': 5} 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--ngpu', type=int, default=1) 29 | parser.add_argument('--batch_size', type=int, default=4) # 4 (single GPU) 30 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 31 | parser.add_argument('--nEpochs_base', type=int, default=501, help='total epoch number for base segmentor') 32 | parser.add_argument('--nEpochs_meta', type=int, default=201, help='total epoch number for meta segmentor') 33 | 34 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 35 | parser.add_argument('--weight-decay', '--wd', default=1e-8, type=float, metavar='W', help='weight decay') 36 | parser.add_argument('--eval_interval', default=1, type=int, help='evaluate interval on validation set') 37 | parser.add_argument('--temp', default=0.7, type=float, help='temperature for meta segmentor') 38 | parser.add_argument('--save_dir', type=str, default=None) 39 | parser.add_argument('--seed', type=int, default=1) 40 | parser.add_argument('--deterministic', type=bool, default=True) 41 | args = parser.parse_args() 42 | 43 | 44 | print("build BowelNet") 45 | model = seg_model.BowelNet(elu=False) 46 | 47 | model_path = "./BowelNet.20230207_1744/rectum_meta_195.pth" 48 | print('load checkpoint:', model_path) 49 | model.load_state_dict(torch.load(model_path)) 50 | 51 | model = model.cuda() 52 | model = nn.parallel.DataParallel(model) 53 | model.eval() 54 | 55 | 56 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_small/' 57 | # bowel_name = 'small' 58 | 59 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_colon/' 60 | # bowel_name = 'colon' 61 | 62 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_sigmoid/' 63 | # bowel_name = 'sigmoid' 64 | 65 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_duodenum/' 66 | # bowel_name = 'duodenum' 67 | 68 | data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_rectum/' 69 | bowel_name = 'rectum' 70 | 71 | 72 | print("loading test set") 73 | testSet = BowelFineSeg(root=data_dir, mode="test", transform=False, bowel_name=bowel_name, save_dir=None) 74 | 75 | if bowel_name == 'small': 76 | patch_size = (64, 192, 192) 77 | overlap = (32, 96, 96) 78 | if bowel_name == 'colon': 79 | patch_size = (64, 192, 192) 80 | overlap = (32, 96, 96) 81 | if bowel_name == 'sigmoid': 82 | patch_size = (64, 160, 160) 83 | overlap = (32, 80, 80) 84 | if bowel_name == 'duodenum': 85 | patch_size = (64, 160, 160) 86 | overlap = (32, 80, 80) 87 | if bowel_name == 'rectum': 88 | patch_size = (64, 96, 96) 89 | overlap = (32, 48, 48) 90 | 91 | n_test_samples = 0. 92 | dice_score_meta = 0.0 93 | dice_score_edge = 0.0 94 | dice_score_skele = 0.0 95 | assd_score_meta = 0.0 96 | assd_score_edge = 0.0 97 | assd_score_skele = 0.0 98 | with torch.no_grad(): 99 | for iii, data_name in enumerate(testSet.imgs): 100 | 101 | n_test_samples = n_test_samples + 1 102 | 103 | img_info = nib.load(data_name) 104 | img = img_info.get_fdata() 105 | # img = np.clip(img, -1024, 1000) 106 | # img = normalize_volume(img) 107 | 108 | label = nib.load(data_name.replace('image_crop.nii.gz', 'masks_crop.nii.gz')).get_fdata() 109 | label = np.round(label) 110 | 111 | img_arr, label_arr, zs_pad, ze_pad = padding_z(img, label, min_z=80) 112 | img_arr, label_arr, xs_pad, xe_pad = padding_x(img_arr, label_arr, min_x=120) 113 | img_arr, label_arr, ys_pad, ye_pad = padding_y(img_arr, label_arr, min_y=120) 114 | 115 | img_arr = img_arr.transpose((2, 0, 1)) 116 | label_arr = label_arr.transpose((2, 0, 1)) 117 | 118 | data = torch.from_numpy(img_arr[np.newaxis, np.newaxis, :].astype(np.float32)) 119 | target = torch.from_numpy(label_arr[np.newaxis, np.newaxis, :].astype(np.int32)) 120 | data, target = Variable(data.cuda()), Variable(target.cuda()) 121 | 122 | b, _, z, x, y = data.shape 123 | 124 | zs = list(range(0, z, patch_size[0] - overlap[0])) 125 | xs = list(range(0, x, patch_size[1] - overlap[1])) 126 | ys = list(range(0, y, patch_size[2] - overlap[2])) 127 | 128 | gaussian_map = torch.from_numpy(_get_gaussian(patch_size, sigma_scale=1. / 4)).cuda() 129 | out_meta_all = torch.zeros((b, 2, z, x, y)).cuda() 130 | out_edge_all = torch.zeros((b, 2, z, x, y)).cuda() 131 | out_skele_all = torch.zeros((b, 2, z, x, y)).cuda() 132 | for zzz in zs: 133 | for xxx in xs: 134 | for yyy in ys: 135 | if xxx + patch_size[1] > x: 136 | xxx = x - patch_size[1] 137 | if yyy + patch_size[2] > y: 138 | yyy = y - patch_size[2] 139 | if zzz + patch_size[0] > z: 140 | zzz = z - patch_size[0] 141 | candidate_patch = data[:, :, zzz:zzz + patch_size[0], xxx:xxx + patch_size[1], yyy:yyy + patch_size[2]] 142 | out_meta, out_edge, out_skele = model(candidate_patch, 'meta') # b, 2, z, x, y 143 | prob_meta = F.softmax(out_meta, dim=1) 144 | prob_edge = F.softmax(out_edge, dim=1) 145 | prob_skele = F.softmax(out_skele, dim=1) 146 | out_meta_all[:, :, zzz:zzz + patch_size[0], xxx:xxx + patch_size[1], yyy:yyy + patch_size[2]] += prob_meta * gaussian_map 147 | out_edge_all[:, :, zzz:zzz + patch_size[0], xxx:xxx + patch_size[1], yyy:yyy + patch_size[2]] += prob_edge * gaussian_map 148 | out_skele_all[:, :, zzz:zzz + patch_size[0], xxx:xxx + patch_size[1], yyy:yyy + patch_size[2]] += prob_skele * gaussian_map 149 | # remove padded slice 150 | out_meta_all = out_meta_all[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad] 151 | out_edge_all = out_edge_all[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad] 152 | out_skele_all = out_skele_all[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad] 153 | target = target[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad] 154 | 155 | dice_meta = dice_score_metric(out_meta_all, target) 156 | dice_edge = dice_score_metric(out_edge_all, target) 157 | dice_skele = dice_score_metric(out_skele_all, target) 158 | dice_score_meta += dice_meta 159 | dice_score_edge += dice_edge 160 | dice_score_skele += dice_skele 161 | 162 | assd_meta = utils.assd(torch.argmax(out_meta_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy(), target.squeeze().permute(1, 2, 0).cpu().numpy()) # x, y, z 163 | assd_edge = utils.assd(torch.argmax(out_edge_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy(), target.squeeze().permute(1, 2, 0).cpu().numpy()) 164 | assd_skele = utils.assd(torch.argmax(out_skele_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy(), target.squeeze().permute(1, 2, 0).cpu().numpy()) 165 | assd_score_meta += assd_meta 166 | assd_score_edge += assd_edge 167 | assd_score_skele += assd_skele 168 | 169 | print('Name: {}, Dice_meta:{:.4f}, Dice_edge:{:.4f}, Dice_skele:{:.4f}, ASSD_meta:{:.4f}, ASSD_edge:{:.4f}, ASSD_skele:{:.4f}'. 170 | format(data_name, dice_meta, dice_edge, dice_skele, assd_meta, assd_edge, assd_skele)) 171 | 172 | # np_save = torch.argmax(out_meta_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.int32) 173 | # np_save[np_save == 1] = name_to_label[bowel_name] 174 | # nib.save(nib.Nifti1Image(np_save, header=img_info.header, affine=img_info.affine), data_name.replace('image_crop.nii.gz', 'masks_crop_meta.nii.gz')) 175 | # 176 | # np_save = torch.argmax(out_edge_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.int32) 177 | # np_save[np_save == 1] = name_to_label[bowel_name] 178 | # nib.save(nib.Nifti1Image(np_save, header=img_info.header, affine=img_info.affine), data_name.replace('image_crop.nii.gz', 'masks_crop_edge.nii.gz')) 179 | # 180 | # np_save = torch.argmax(out_skele_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.int32) 181 | # np_save[np_save == 1] = name_to_label[bowel_name] 182 | # nib.save(nib.Nifti1Image(np_save, header=img_info.header, affine=img_info.affine), data_name.replace('image_crop.nii.gz', 'masks_crop_skele.nii.gz')) 183 | 184 | # break 185 | 186 | dice_score_meta /= n_test_samples 187 | dice_score_edge /= n_test_samples 188 | dice_score_skele /= n_test_samples 189 | assd_score_meta /= n_test_samples 190 | assd_score_edge /= n_test_samples 191 | assd_score_skele /= n_test_samples 192 | 193 | print('\nMean test: Dice_meta:{:.4f}, Dice_edge:{:.4f}, Dice_skele:{:.4f}, ' 194 | 'ASSD_meta: {:.4f}, ASSD_edge: {:.4f}, ASSD_skele: {:.4f}\n'. 195 | format(dice_score_meta, dice_score_edge, dice_score_skele, 196 | assd_score_meta, assd_score_edge, assd_score_skele)) 197 | 198 | 199 | def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray: 200 | tmp = np.zeros(patch_size) 201 | center_coords = [i // 2 for i in patch_size] 202 | sigmas = [i * sigma_scale for i in patch_size] 203 | tmp[tuple(center_coords)] = 1 204 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) 205 | gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1 206 | gaussian_importance_map = gaussian_importance_map.astype(np.float32) 207 | 208 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans! 209 | gaussian_importance_map[gaussian_importance_map == 0] = np.min( 210 | gaussian_importance_map[gaussian_importance_map != 0]) 211 | 212 | return gaussian_importance_map 213 | 214 | 215 | def normalize_volume(img): 216 | img_array = (img - img.min()) / (img.max() - img.min()) 217 | return img_array 218 | 219 | 220 | def padding_z(img, label, min_z): 221 | x, y, z = img.shape 222 | if z >= min_z: 223 | return img, label, 0, z 224 | else: 225 | num_pad = min_z - z 226 | num_top_pad = num_pad // 2 227 | top_pad = np.zeros((x, y, num_top_pad), dtype=np.float64) 228 | bottom_pad = np.zeros((x, y, num_pad - num_top_pad), dtype=np.float64) 229 | img = np.concatenate((bottom_pad, img, top_pad), axis=2) 230 | label = np.concatenate((bottom_pad, label, top_pad), axis=2) 231 | return img, label, bottom_pad.shape[2], img.shape[2] - top_pad.shape[2] 232 | 233 | 234 | def padding_x(img, label, min_x): 235 | x, y, z = img.shape 236 | if x >= min_x: 237 | return img, label, 0, x 238 | else: 239 | num_pad = min_x - x 240 | num_top_pad = num_pad // 2 241 | top_pad = np.zeros((num_top_pad, y, z), dtype=np.float64) 242 | bottom_pad = np.zeros((num_pad - num_top_pad, y, z), dtype=np.float64) 243 | img = np.concatenate((bottom_pad, img, top_pad), axis=0) 244 | label = np.concatenate((bottom_pad, label, top_pad), axis=0) 245 | return img, label, bottom_pad.shape[0], img.shape[0] - top_pad.shape[0] 246 | 247 | 248 | def padding_y(img, label, min_y): 249 | x, y, z = img.shape 250 | if y >= min_y: 251 | return img, label, 0, y 252 | else: 253 | num_pad = min_y - y 254 | num_top_pad = num_pad // 2 255 | top_pad = np.zeros((x, num_top_pad, z), dtype=np.float64) 256 | bottom_pad = np.zeros((x, num_pad - num_top_pad, z), dtype=np.float64) 257 | img = np.concatenate((bottom_pad, img, top_pad), axis=1) 258 | label = np.concatenate((bottom_pad, label, top_pad), axis=1) 259 | return img, label, bottom_pad.shape[1], img.shape[1] - top_pad.shape[1] 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /bowel_fineseg/seg_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def passthrough(x, **kwargs): 8 | return x 9 | 10 | 11 | def ELUCons(elu, nchan): 12 | if elu: 13 | return nn.ELU(inplace=True) 14 | else: 15 | return nn.PReLU(nchan) 16 | 17 | 18 | # normalization between sub-volumes is necessary 19 | # for good performance 20 | class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): 21 | def _check_input_dim(self, input): 22 | if input.dim() != 5: 23 | raise ValueError('expected 5D input (got {}D input)' 24 | .format(input.dim())) 25 | # super(ContBatchNorm3d, self)._check_input_dim(input) 26 | 27 | def forward(self, input): 28 | self._check_input_dim(input) 29 | return F.batch_norm( 30 | input, self.running_mean, self.running_var, self.weight, self.bias, 31 | True, self.momentum, self.eps) 32 | 33 | 34 | class LUConv(nn.Module): 35 | def __init__(self, nchan, elu): 36 | super(LUConv, self).__init__() 37 | self.relu1 = ELUCons(elu, nchan) 38 | self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) 39 | self.bn1 = ContBatchNorm3d(nchan) 40 | 41 | def forward(self, x): 42 | out = self.relu1(self.bn1(self.conv1(x))) 43 | return out 44 | 45 | 46 | def _make_nConv(nchan, depth, elu): 47 | layers = [] 48 | for _ in range(depth): 49 | layers.append(LUConv(nchan, elu)) 50 | return nn.Sequential(*layers) 51 | 52 | 53 | class InputTransition(nn.Module): 54 | def __init__(self, inChans, outChans, elu): 55 | super(InputTransition, self).__init__() 56 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=5, padding=2) 57 | self.bn1 = ContBatchNorm3d(outChans) 58 | self.relu1 = ELUCons(elu, outChans) 59 | 60 | def forward(self, x): 61 | out = self.relu1(self.bn1(self.conv1(x))) 62 | return out 63 | 64 | 65 | class DownTransition(nn.Module): 66 | def __init__(self, inChans, nConvs, elu, dropout=False): 67 | super(DownTransition, self).__init__() 68 | outChans = 2*inChans 69 | self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2) 70 | self.bn1 = ContBatchNorm3d(outChans) 71 | self.do1 = passthrough 72 | self.relu1 = ELUCons(elu, outChans) 73 | self.relu2 = ELUCons(elu, outChans) 74 | if dropout: 75 | self.do1 = nn.Dropout3d(p=0.2) # 0.5 76 | self.ops = _make_nConv(outChans, nConvs, elu) 77 | 78 | def forward(self, x): 79 | down = self.relu1(self.bn1(self.down_conv(x))) 80 | out = self.do1(down) 81 | out = self.ops(out) 82 | out = self.relu2(torch.add(out, down)) 83 | return out 84 | 85 | 86 | class UpTransition(nn.Module): 87 | def __init__(self, inChans, outChans, nConvs, elu, dropout=False): 88 | super(UpTransition, self).__init__() 89 | self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2) 90 | self.bn1 = ContBatchNorm3d(outChans // 2) 91 | self.do1 = passthrough 92 | self.do2 = nn.Dropout3d(p=0.2) # 0.5 93 | self.relu1 = ELUCons(elu, outChans // 2) 94 | self.relu2 = ELUCons(elu, outChans) 95 | if dropout: 96 | self.do1 = nn.Dropout3d(p=0.2) # 0.5 97 | self.ops = _make_nConv(outChans, nConvs, elu) 98 | 99 | def forward(self, x, skipx): 100 | out = self.do1(x) 101 | skipxdo = self.do2(skipx) 102 | out = self.relu1(self.bn1(self.up_conv(out))) 103 | xcat = torch.cat((out, skipxdo), 1) 104 | out = self.ops(xcat) 105 | out = self.relu2(torch.add(out, xcat)) 106 | return out 107 | 108 | 109 | class OutputTransition(nn.Module): 110 | def __init__(self, inChans, outChans, elu): 111 | super(OutputTransition, self).__init__() 112 | self.conv1 = nn.Conv3d(16, 16, kernel_size=3, padding=1) 113 | self.bn1 = ContBatchNorm3d(16) 114 | self.relu1 = ELUCons(elu, 16) 115 | self.conv2 = nn.Conv3d(16, 16, kernel_size=1) 116 | 117 | self.conv3 = nn.Conv3d(inChans, 16, kernel_size=1) 118 | self.relu3 = ELUCons(elu, 16) 119 | 120 | self.conv5 = nn.Conv3d(16, 16, kernel_size=3, padding=1) 121 | self.bn5 = ContBatchNorm3d(16) 122 | self.relu5 = ELUCons(elu, 16) 123 | 124 | self.conv6 = nn.Conv3d(16, outChans, kernel_size=1) 125 | 126 | self.dropout = nn.Dropout3d(p=0.2) 127 | 128 | def forward(self, x): 129 | 130 | out_2 = self.conv2(self.relu1(self.bn1(self.conv1(x)))) 131 | out_2 = self.dropout(out_2) 132 | 133 | out_3 = self.relu3(self.conv3(out_2)) 134 | 135 | out_6 = self.conv6(self.relu5(self.bn5(self.conv5(out_3)))) 136 | 137 | return out_6 138 | 139 | 140 | 141 | class OutputTransition_Edge(nn.Module): 142 | def __init__(self, inChans, outChans, elu): 143 | super(OutputTransition_Edge, self).__init__() 144 | self.conv1 = nn.Conv3d(16, 16, kernel_size=3, padding=1) 145 | self.bn1 = ContBatchNorm3d(16) 146 | self.relu1 = ELUCons(elu, 16) 147 | self.conv2 = nn.Conv3d(16, 16, kernel_size=1) 148 | 149 | self.conv3 = nn.Conv3d(inChans, 16, kernel_size=1) 150 | self.relu3 = ELUCons(elu, 16) 151 | 152 | self.conv5 = nn.Conv3d(16, 16, kernel_size=3, padding=1) 153 | self.bn5 = ContBatchNorm3d(16) 154 | self.relu5 = ELUCons(elu, 16) 155 | 156 | self.conv6 = nn.Conv3d(16, outChans, kernel_size=1) 157 | 158 | self.dropout = nn.Dropout3d(p=0.2) 159 | 160 | def forward(self, x): 161 | 162 | out_2 = self.conv2(self.relu1(self.bn1(self.conv1(x)))) 163 | out_2 = self.dropout(out_2) # added 164 | 165 | out_3 = self.relu3(self.conv3(out_2)) 166 | 167 | out_6 = self.conv6(self.relu5(self.bn5(self.conv5(out_3)))) 168 | 169 | return out_6 170 | 171 | 172 | 173 | class OutputTransition_Skele(nn.Module): 174 | def __init__(self, inChans, outChans, elu): 175 | super(OutputTransition_Skele, self).__init__() 176 | self.conv1 = nn.Conv3d(16, 16, kernel_size=3, padding=1) 177 | self.bn1 = ContBatchNorm3d(16) 178 | self.relu1 = ELUCons(elu, 16) 179 | self.conv2 = nn.Conv3d(16, 16, kernel_size=1) 180 | 181 | self.conv3 = nn.Conv3d(inChans, 16, kernel_size=1) 182 | 183 | self.conv5 = nn.Conv3d(16, 16, kernel_size=3, padding=1) 184 | self.bn5 = ContBatchNorm3d(16) 185 | 186 | self.conv6 = nn.Conv3d(16, outChans, kernel_size=1) 187 | 188 | self.dropout = nn.Dropout3d(p=0.2) 189 | 190 | def forward(self, x): 191 | 192 | out_2 = self.conv2(self.relu1(self.bn1(self.conv1(x)))) 193 | out_2 = self.dropout(out_2) # added 194 | 195 | out_3 = self.conv3(out_2) 196 | 197 | out_6 = self.conv6(self.bn5(self.conv5(out_3))) 198 | 199 | return out_6 200 | 201 | 202 | class MetaTransition_In(nn.Module): 203 | def __init__(self, inChans, outChans, elu): 204 | super(MetaTransition_In, self).__init__() 205 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=3, padding=1) 206 | self.bn1 = ContBatchNorm3d(outChans) 207 | self.relu1 = ELUCons(elu, outChans) 208 | 209 | self.conv2 = nn.Conv3d(outChans, outChans // 2, kernel_size=3, padding=1) 210 | self.bn2 = ContBatchNorm3d(outChans // 2) 211 | self.relu2 = ELUCons(elu, outChans // 2) 212 | 213 | self.conv3 = nn.Conv3d(outChans//2, outChans // 2, kernel_size=1) 214 | 215 | self.dropout = nn.Dropout3d(p=0.2) 216 | 217 | def forward(self, x): 218 | 219 | out_1 = self.relu1(self.bn1(self.conv1(x))) 220 | out_2 = self.dropout(self.relu2(self.bn2(self.conv2(out_1)))) 221 | out_3 = self.conv3(out_2) 222 | 223 | return out_3 224 | 225 | 226 | class MetaTransition_Fused(nn.Module): 227 | def __init__(self, inChans, outChans, elu): 228 | super(MetaTransition_Fused, self).__init__() 229 | self.conv1 = nn.Conv3d(inChans, inChans, kernel_size=3, padding=1) 230 | self.bn1 = ContBatchNorm3d(inChans) 231 | self.relu1 = ELUCons(elu, inChans) 232 | 233 | self.conv2 = nn.Conv3d(inChans, inChans // 2, kernel_size=3, padding=1) 234 | self.bn2 = ContBatchNorm3d(inChans // 2) 235 | self.relu2 = ELUCons(elu, inChans // 2) 236 | 237 | self.conv3 = nn.Conv3d(inChans // 2, outChans, kernel_size=1) 238 | 239 | self.dropout = nn.Dropout3d(p=0.2) 240 | 241 | def forward(self, x): 242 | 243 | out_1 = self.relu1(self.bn1(self.conv1(x))) 244 | out_1 = self.dropout(out_1) 245 | out_2 = self.relu2(self.bn2(self.conv2(out_1))) 246 | out_3 = self.conv3(out_2) 247 | 248 | return out_3 249 | 250 | 251 | class BowelNet(nn.Module): 252 | def __init__(self, elu=True): 253 | super(BowelNet, self).__init__() 254 | 255 | num_baseC = 4 # 8 256 | 257 | self.in_tr = InputTransition(1, 2 * num_baseC, elu) 258 | 259 | # skeleton segmentor 260 | ######################################################################### 261 | self.down_tr32_skele = DownTransition(2 * num_baseC, 1, elu) 262 | self.down_tr64_skele = DownTransition(4 * num_baseC, 2, elu) 263 | self.down_tr128_skele = DownTransition(8 * num_baseC, 2, elu, dropout=True) 264 | self.down_tr256_skele = DownTransition(16 * num_baseC, 2, elu, dropout=True) 265 | self.up_tr256_skele = UpTransition(32 * num_baseC, 32 * num_baseC, 2, elu, dropout=True) 266 | self.up_tr128_skele = UpTransition(32 * num_baseC, 16 * num_baseC, 2, elu, dropout=True) 267 | self.up_tr64_skele = UpTransition(16 * num_baseC, 8 * num_baseC, 1, elu) 268 | self.up_tr32_skele = UpTransition(8 * num_baseC, 4 * num_baseC, 1, elu) 269 | self.out_tr_skele = OutputTransition_Skele(4 * num_baseC, 3, elu) # no relu for negative skeleton flux value 270 | self.out_tr_skele_mask = OutputTransition(4 * num_baseC, 2, elu) 271 | 272 | # boundary segmentor 273 | ######################################################################### 274 | self.down_tr32_edge = DownTransition(2 * num_baseC, 1, elu) 275 | self.down_tr64_edge = DownTransition(4 * num_baseC, 2, elu) 276 | self.down_tr128_edge = DownTransition(8 * num_baseC, 2, elu, dropout=True) 277 | self.down_tr256_edge = DownTransition(16 * num_baseC, 2, elu, dropout=True) 278 | self.up_tr256_edge = UpTransition(32 * num_baseC, 32 * num_baseC, 2, elu, dropout=True) 279 | self.up_tr128_edge = UpTransition(32 * num_baseC, 16 * num_baseC, 2, elu, dropout=True) 280 | self.up_tr64_edge = UpTransition(16 * num_baseC, 8 * num_baseC, 1, elu) 281 | self.up_tr32_edge = UpTransition(8 * num_baseC, 4 * num_baseC, 1, elu) 282 | 283 | self.out_tr_edge = OutputTransition_Edge(4 * num_baseC, 1, elu) 284 | self.out_tr_edge_mask = OutputTransition(4 * num_baseC, 2, elu) 285 | 286 | # meta segmentor 287 | ######################################################################### 288 | self.in_seg_meta = MetaTransition_In(1, 4 * num_baseC, elu) 289 | self.in_img_meta = MetaTransition_In(1, 4 * num_baseC, elu) 290 | self.fused_mask_meta = MetaTransition_Fused(4 * num_baseC, 2, elu) 291 | 292 | 293 | def forward(self, x, train_segmentor): 294 | 295 | assert train_segmentor in ['base', 'meta'] 296 | 297 | out16 = self.in_tr(x) 298 | 299 | out32_skele = self.down_tr32_skele(out16) 300 | out64_skele = self.down_tr64_skele(out32_skele) 301 | out128_skele = self.down_tr128_skele(out64_skele) 302 | out256_skele = self.down_tr256_skele(out128_skele) 303 | out_skele = self.up_tr256_skele(out256_skele, out128_skele) 304 | out_skele = self.up_tr128_skele(out_skele, out64_skele) 305 | out_skele = self.up_tr64_skele(out_skele, out32_skele) 306 | out_skele = self.up_tr32_skele(out_skele, out16) # same resolution as input 307 | out_skele_seg = self.out_tr_skele(out_skele) 308 | out_skele_mask_seg = self.out_tr_skele_mask(out_skele) 309 | 310 | out32_edge = self.down_tr32_edge(out16) 311 | out64_edge = self.down_tr64_edge(out32_edge) 312 | out128_edge = self.down_tr128_edge(out64_edge) 313 | out256_edge = self.down_tr256_edge(out128_edge) 314 | out_edge = self.up_tr256_edge(out256_edge, out128_edge) 315 | out_edge = self.up_tr128_edge(out_edge, out64_edge) 316 | out_edge = self.up_tr64_edge(out_edge, out32_edge) 317 | out_edge = self.up_tr32_edge(out_edge, out16) # same resolution as input 318 | out_edge_seg = self.out_tr_edge(out_edge) 319 | out_edge_mask_seg = self.out_tr_edge_mask(out_edge) 320 | 321 | if train_segmentor == 'base': 322 | return out_skele_seg, out_skele_mask_seg, out_edge_seg, out_edge_mask_seg 323 | else: 324 | # in_seg_avg = 1.0*torch.argmax(0.5 * (F.softmax(out_edge_mask_seg, dim=1) + F.softmax(out_skele_mask_seg, dim=1)), dim=1).unsqueeze(1) 325 | in_seg_avg = 0.5 * (torch.argmax(out_edge_mask_seg, dim=1).unsqueeze(1) + torch.argmax(out_skele_mask_seg, dim=1).unsqueeze(1)) 326 | seg_feat = self.in_seg_meta(in_seg_avg) 327 | img_feat = self.in_img_meta(x) 328 | comb_feat = torch.cat((img_feat, seg_feat), dim=1) 329 | out = self.fused_mask_meta(comb_feat) 330 | return out, out_edge_mask_seg, out_skele_mask_seg -------------------------------------------------------------------------------- /bowel_coarseseg/datasets/BowelDatasetCoarseSeg.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data 5 | from glob import glob 6 | import os 7 | import os.path 8 | import nibabel as nib 9 | import random 10 | import cv2 11 | 12 | 13 | name_to_label = {1: 'rectum', 2: 'sigmoid', 3: 'colon', 4: 'small', 5: 'duodenum'} 14 | 15 | 16 | def split_dataset(dir, current_test, test_fraction, dataset_name, save_dir): 17 | 18 | test_split = [] 19 | train_split = [] 20 | 21 | sub_folder = ['Fully_labeled_5C', 'Colon_Sigmoid', 'Smallbowel'] 22 | 23 | if dataset_name == 'fully_labeled': 24 | all_volumes = sorted(glob(os.path.join(dir, "*/*/image.nii.gz"))) 25 | test_num = int(len(all_volumes) * test_fraction) 26 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 27 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 28 | test_split.extend(test_volumes) 29 | train_split.extend(train_volumes) 30 | 31 | if dataset_name == 'smallbowel': 32 | all_volumes = sorted(glob(os.path.join(dir, "*/*/image.nii.gz"))) 33 | test_num = int(len(all_volumes) * test_fraction) 34 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 35 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 36 | test_split.extend(test_volumes) 37 | train_split.extend(train_volumes) 38 | 39 | if dataset_name == 'colon_sigmoid': 40 | all_volumes = sorted(glob(os.path.join(dir, "*/image.nii.gz"))) 41 | test_num = int(len(all_volumes) * test_fraction) 42 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 43 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 44 | test_split.extend(test_volumes) 45 | train_split.extend(train_volumes) 46 | 47 | if save_dir is not None: 48 | with open(os.path.join(save_dir, dataset_name + '_train_' + str(current_test) + ".txt"), 'w') as f: 49 | for i in train_split: 50 | f.write(i + '\n') 51 | 52 | with open(os.path.join(save_dir, dataset_name + '_test_' + str(current_test) + ".txt"), 'w') as f: 53 | for i in test_split: 54 | f.write(i + '\n') 55 | 56 | return train_split, test_split 57 | 58 | 59 | def load_image_and_label(img_file, transform=False): 60 | 61 | # img_file = '/mnt/c/chong/data/Bowel/crop_downsample/Colon_Sigmoid/colon_sigmoid/0005_male/image.nii.gz' 62 | 63 | img = nib.load(img_file).get_fdata() 64 | img = np.clip(img, -1024, 1000) # remove abnormal intensity 65 | img = normalize_volume(img) 66 | 67 | label = nib.load(img_file.replace('image.nii.gz', 'masks.nii.gz')).get_fdata() 68 | label = np.round(label) 69 | 70 | if transform: 71 | op = random.choice(['ori', 'rotate', 'crop']) 72 | if op == 'rotate': 73 | img, label = rotate(img, label, degree=random.uniform(-10, 10)) 74 | if op == 'crop': 75 | img, label = crop_resize(img, label, shift_size_x=10, shift_size_y=10) 76 | 77 | # zero padding 78 | img_arr, label_arr = padding_z(img, label, min_z=100) 79 | img_arr, label_arr = padding_x(img_arr, label_arr, min_x=160) 80 | img_arr, label_arr = padding_y(img_arr, label_arr, min_y=160) 81 | 82 | return img_arr.transpose((2, 0, 1)), label_arr.transpose((2, 0, 1)) 83 | 84 | 85 | def padding_z(img, label, min_z): 86 | x, y, z = img.shape 87 | if z >= min_z: 88 | return img, label 89 | else: 90 | num_pad = min_z - z 91 | num_top_pad = num_pad // 2 92 | top_pad = np.zeros((x, y, num_top_pad), dtype=np.float64) 93 | bottom_pad = np.zeros((x, y, num_pad - num_top_pad), dtype=np.float64) 94 | img = np.concatenate((bottom_pad, img, top_pad), axis=2) 95 | label = np.concatenate((bottom_pad, label, top_pad), axis=2) 96 | return img, label 97 | 98 | 99 | def padding_x(img, label, min_x): 100 | x, y, z = img.shape 101 | if x >= min_x: 102 | return img, label 103 | else: 104 | num_pad = min_x - x 105 | num_top_pad = num_pad // 2 106 | top_pad = np.zeros((num_top_pad, y, z), dtype=np.float64) 107 | bottom_pad = np.zeros((num_pad - num_top_pad, y, z), dtype=np.float64) 108 | img = np.concatenate((bottom_pad, img, top_pad), axis=0) 109 | label = np.concatenate((bottom_pad, label, top_pad), axis=0) 110 | return img, label 111 | 112 | 113 | def padding_y(img, label, min_y): 114 | x, y, z = img.shape 115 | if y >= min_y: 116 | return img, label 117 | else: 118 | num_pad = min_y - y 119 | num_top_pad = num_pad // 2 120 | top_pad = np.zeros((x, num_top_pad, z), dtype=np.float64) 121 | bottom_pad = np.zeros((x, num_pad - num_top_pad, z), dtype=np.float64) 122 | img = np.concatenate((bottom_pad, img, top_pad), axis=1) 123 | label = np.concatenate((bottom_pad, label, top_pad), axis=1) 124 | return img, label 125 | 126 | 127 | def rotate(img_ori, label_ori, degree): 128 | height, width, depth = img_ori.shape 129 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1) 130 | 131 | imgRotation = np.zeros_like(img_ori) 132 | labelRotation = np.zeros_like(label_ori) 133 | for z in range(depth): 134 | imgRotation[:, :, z] = cv2.warpAffine(img_ori[:, :, z], matRotation, (width, height), flags=cv2.INTER_NEAREST, borderValue=0) 135 | temp = label_ori[:, :, z] 136 | unique_labels = np.unique(temp) 137 | result = np.zeros_like(temp, temp.dtype) 138 | for i, c in enumerate(unique_labels): 139 | res_new = cv2.warpAffine((temp == c).astype(float), matRotation, (width, height), flags=cv2.INTER_NEAREST) 140 | result[res_new > 0.5] = c 141 | labelRotation[:, :, z] = result 142 | return imgRotation, labelRotation 143 | 144 | 145 | def crop_resize(img_ori, label_ori, shift_size_x, shift_size_y): 146 | H, W, C = img_ori.shape 147 | x_small = np.random.randint(0, shift_size_x) 148 | x_large = np.random.randint(H - shift_size_x, H) 149 | y_small = np.random.randint(0, shift_size_y) 150 | y_large = np.random.randint(W - shift_size_y, W) 151 | 152 | imgCropresize = np.zeros_like(img_ori) 153 | labelCropresize = np.zeros_like(label_ori) 154 | for z in range(C): 155 | imgCropresize[:, :, z] = cv2.resize(img_ori[x_small:x_large, y_small:y_large, z], (W, H), interpolation=cv2.INTER_NEAREST) 156 | temp = label_ori[x_small:x_large, y_small:y_large, z] 157 | unique_labels = np.unique(temp) 158 | result = np.zeros((H, W), label_ori.dtype) 159 | for i, c in enumerate(unique_labels): 160 | res_new = cv2.resize((temp == c).astype(float), (W, H), interpolation=cv2.INTER_NEAREST) 161 | result[res_new > 0.5] = c 162 | labelCropresize[:, :, z] = result 163 | return imgCropresize, labelCropresize 164 | 165 | 166 | def normalize_volume(img): 167 | img_array = (img - img.min()) / (img.max() - img.min()) 168 | return img_array 169 | 170 | 171 | class BowelCoarseSeg(data.Dataset): 172 | def __init__(self, root='', transform=None, mode="train", test_fraction=0.2, dataset_name='', save_dir=''): 173 | 174 | assert dataset_name in ["fully_labeled", "smallbowel", "colon_sigmoid"] 175 | 176 | current_test = 5 177 | train_split, test_split = split_dataset(root, current_test, test_fraction, dataset_name, save_dir) 178 | 179 | if mode == "infer" or mode == "test": 180 | self.imgs = test_split 181 | else: 182 | self.imgs = train_split 183 | 184 | self.mode = mode 185 | self.root = root 186 | self.patch_size = (64, 128, 128) # z, x, y 187 | self.transform = transform 188 | 189 | def __len__(self): 190 | return len(self.imgs) 191 | 192 | def __getitem__(self, index): 193 | patch_size = self.patch_size 194 | image_name = self.imgs[index] 195 | 196 | # image_name = '/mnt/c/chong/data/Bowel/crop_downsample/Smallbowel/abdomen/260719/image.nii.gz' 197 | # image_name = '/mnt/c/chong/data/Bowel/crop_downsample/Fully_labeled_5C/rectum_sigmoid_colon_small_duodenum/20210310-092019-464/image.nii.gz' 198 | # image_name = '/mnt/c/chong/data/Bowel/crop_downsample/Colon_Sigmoid/colon_sigmoid/0013_female/image.nii.gz' 199 | 200 | image, label = load_image_and_label(image_name, transform=self.transform) 201 | 202 | positive_labels = sorted(np.unique(label)[1:].astype(int)) 203 | for pos in positive_labels: 204 | assert pos in [1, 2, 3, 4, 5], image_name + ", wrong class label!!!" 205 | 206 | # name_to_label = {1: 'rectum', 2: 'sigmoid', 3: 'colon', 4: 'small', 5: 'duodenum'} 207 | 208 | z, x, y = image.shape 209 | # fully labeled and Colon_Sigmoid dataset 210 | if 'rectum' in image_name or 'duodenum' in image_name or 'sigmoid' in image_name: 211 | 212 | if np.random.choice([False, False, False, True]): 213 | force_fg = True 214 | else: 215 | force_fg = False 216 | 217 | if force_fg: 218 | 219 | pos_labels = sorted(np.unique(label)[1:].astype(int)) 220 | label_will_be_chosen = pos_labels + list(set(pos_labels) - set([3, 4])) # double chance to sample the 3 bowels above 221 | label_chosen = random.choice(label_will_be_chosen) 222 | 223 | z_range, x_range, y_range = np.where(label == label_chosen) 224 | z_min, x_min, y_min = z_range.min(), x_range.min(), y_range.min() 225 | z_max, x_max, y_max = z_range.max(), x_range.max(), y_range.max() 226 | center_z = random.randint(z_min, z_max) 227 | center_x = random.randint(x_min, x_max) 228 | center_y = random.randint(y_min, y_max) 229 | 230 | zs = center_z - patch_size[0] // 2 # start 231 | ze = center_z + patch_size[0] // 2 # end 232 | xs = center_x - patch_size[1] // 2 233 | xe = center_x + patch_size[1] // 2 234 | ys = center_y - patch_size[2] // 2 235 | ye = center_y + patch_size[2] // 2 236 | 237 | if zs < 0: 238 | zs = 0 239 | ze = zs + patch_size[0] 240 | if ze > z: 241 | ze = z 242 | zs = ze - patch_size[0] 243 | 244 | if xs < 0: 245 | xs = 0 246 | xe = xs + patch_size[1] 247 | if xe > x: 248 | xe = x 249 | xs = xe - patch_size[1] 250 | 251 | if ys < 0: 252 | ys = 0 253 | ye = ys + patch_size[2] 254 | if ye > y: 255 | ye = y 256 | ys = ye - patch_size[2] 257 | else: 258 | zs = random.randint(0, z - patch_size[0]) 259 | ze = zs + self.patch_size[0] 260 | xs = random.randint(0, x - patch_size[1]) 261 | xe = xs + self.patch_size[1] 262 | ys = random.randint(0, y - patch_size[2]) 263 | ye = ys + self.patch_size[2] 264 | 265 | else: # small bowel dataset 266 | 267 | if np.random.choice([False, False, False, True]): 268 | force_fg = True 269 | else: 270 | force_fg = False 271 | 272 | if force_fg: 273 | z_range, x_range, y_range = np.where(label != 0) 274 | z_min, x_min, y_min = z_range.min(), x_range.min(), y_range.min() 275 | z_max, x_max, y_max = z_range.max(), x_range.max(), y_range.max() 276 | center_z = random.randint(z_min, z_max) 277 | center_x = random.randint(x_min, x_max) 278 | center_y = random.randint(y_min, y_max) 279 | 280 | zs = center_z - patch_size[0] // 2 281 | ze = center_z + patch_size[0] // 2 282 | xs = center_x - patch_size[1] // 2 283 | xe = center_x + patch_size[1] // 2 284 | ys = center_y - patch_size[2] // 2 285 | ye = center_y + patch_size[2] // 2 286 | 287 | if zs < 0: 288 | zs = 0 289 | ze = zs + patch_size[0] 290 | if ze > z: 291 | ze = z 292 | zs = ze - patch_size[0] 293 | 294 | if xs < 0: 295 | xs = 0 296 | xe = xs + patch_size[1] 297 | if xe > x: 298 | xe = x 299 | xs = xe - patch_size[1] 300 | 301 | if ys < 0: 302 | ys = 0 303 | ye = ys + patch_size[2] 304 | if ye > y: 305 | ye = y 306 | ys = ye - patch_size[2] 307 | else: 308 | zs = random.randint(0, z - patch_size[0]) 309 | ze = zs + self.patch_size[0] 310 | xs = random.randint(0, x - patch_size[1]) 311 | xe = xs + self.patch_size[1] 312 | ys = random.randint(0, y - patch_size[2]) 313 | ye = ys + self.patch_size[2] 314 | 315 | data = torch.from_numpy(image[zs:ze, xs:xe, ys:ye][np.newaxis, :].astype(np.float32)) 316 | mask = torch.from_numpy(label[zs:ze, xs:xe, ys:ye][np.newaxis, :].astype(np.int32)) 317 | 318 | return data, mask, image_name 319 | 320 | -------------------------------------------------------------------------------- /bowel_fineseg/datasets/BowelDatasetFineSeg.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data 5 | from glob import glob 6 | import os 7 | import os.path 8 | import nibabel as nib 9 | import random 10 | import cv2 11 | 12 | 13 | name_to_label = {'rectum': 1, 'sigmoid': 2, 'colon': 3, 'small': 4, 'duodenum': 5} 14 | 15 | 16 | def split_dataset(dir, current_test, test_fraction, bowel_name, save_dir): 17 | 18 | test_split = [] 19 | train_split = [] 20 | 21 | sub_folder = ['Fully_labeled_5C', 'Colon_Sigmoid', 'Smallbowel'] 22 | 23 | assert bowel_name in name_to_label.keys() 24 | 25 | if bowel_name == 'rectum': 26 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz"))) 27 | test_num = int(len(all_volumes) * test_fraction) 28 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 29 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 30 | test_split.extend(test_volumes) 31 | train_split.extend(train_volumes) 32 | 33 | if bowel_name == 'sigmoid': 34 | all_volumes = sorted(glob(os.path.join(dir, "Colon_Sigmoid/*/*/image_crop.nii.gz"))) 35 | test_num = int(len(all_volumes) * test_fraction) 36 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 37 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 38 | test_split.extend(test_volumes) 39 | train_split.extend(train_volumes) 40 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz"))) 41 | test_num = int(len(all_volumes) * test_fraction) 42 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 43 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 44 | test_split.extend(test_volumes) 45 | train_split.extend(train_volumes) 46 | 47 | if bowel_name == 'colon': 48 | all_volumes = sorted(glob(os.path.join(dir, "Colon_Sigmoid/*/*/image_crop.nii.gz"))) 49 | test_num = int(len(all_volumes) * test_fraction) 50 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 51 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 52 | test_split.extend(test_volumes) 53 | train_split.extend(train_volumes) 54 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz"))) 55 | test_num = int(len(all_volumes) * test_fraction) 56 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 57 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 58 | test_split.extend(test_volumes) 59 | train_split.extend(train_volumes) 60 | 61 | if bowel_name == 'small': 62 | all_volumes = sorted(glob(os.path.join(dir, "Smallbowel/*/*/image_crop.nii.gz"))) 63 | test_num = int(len(all_volumes) * test_fraction) 64 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 65 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 66 | test_split.extend(test_volumes) 67 | train_split.extend(train_volumes) 68 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz"))) 69 | test_num = int(len(all_volumes) * test_fraction) 70 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 71 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 72 | test_split.extend(test_volumes) 73 | train_split.extend(train_volumes) 74 | 75 | if bowel_name == 'duodenum': 76 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz"))) 77 | test_num = int(len(all_volumes) * test_fraction) 78 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test] 79 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes))) 80 | test_split.extend(test_volumes) 81 | train_split.extend(train_volumes) 82 | 83 | if save_dir is not None: 84 | with open(os.path.join(save_dir, bowel_name + '_train_' + str(current_test) + ".txt"), 'w') as f: 85 | for i in train_split: 86 | f.write(i + '\n') 87 | 88 | with open(os.path.join(save_dir, bowel_name + '_test_' + str(current_test) + ".txt"), 'w') as f: 89 | for i in test_split: 90 | f.write(i + '\n') 91 | 92 | return train_split, test_split 93 | 94 | 95 | def load_image_and_label(img_file, transform=False): 96 | 97 | img_info = nib.load(img_file) 98 | img = img_info.get_fdata() 99 | # img = np.clip(img, -1024, 1000) # remove abnormal intensity 100 | # img = normalize_volume(img) 101 | 102 | label = nib.load(img_file.replace('image_crop.nii.gz', 'masks_crop.nii.gz')).get_fdata() 103 | label = np.round(label) 104 | 105 | skele_x = nib.load(img_file.replace('image_crop.nii.gz', 'skele_crop_fluxx.nii.gz')).get_fdata() # skele_crop_fluxx.nii.gz 106 | skele_y = nib.load(img_file.replace('image_crop.nii.gz', 'skele_crop_fluxy.nii.gz')).get_fdata() # skele_crop_fluxy.nii.gz 107 | skele_z = nib.load(img_file.replace('image_crop.nii.gz', 'skele_crop_fluxz.nii.gz')).get_fdata() # skele_crop_fluxz.nii.gz 108 | edge = nib.load(img_file.replace('image_crop.nii.gz', 'edge_reg_crop.nii.gz')).get_fdata() # edge_heatmap.nii.gz 109 | 110 | if transform: 111 | op = random.choice(['ori', 'rotate', 'crop']) 112 | if op == 'rotate': 113 | img, label, edge, skele_x, skele_y, skele_z = rotate(img, label, edge, skele_x, skele_y, skele_z, degree=random.randint(-10, 10)) 114 | if op == 'crop': 115 | img, label, edge, skele_x, skele_y, skele_z = crop_resize(img, label, edge, skele_x, skele_y, skele_z, shift_size_x=10, shift_size_y=10) 116 | 117 | # zero padding in case that cropped ROI is smaller than patch size 118 | img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr = padding_z(img, label, edge, skele_x, skele_y, skele_z, min_z=80) 119 | img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr = padding_x(img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr, min_x=120) 120 | img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr = padding_y(img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr, min_y=120) 121 | 122 | return img_arr.transpose((2, 0, 1)), label_arr.transpose((2, 0, 1)), edge_arr.transpose((2, 0, 1)), skele_x_arr.transpose((2, 0, 1)), skele_y_arr.transpose((2, 0, 1)), skele_z_arr.transpose((2, 0, 1)) 123 | 124 | 125 | def padding_z(img, label, edge, skele_x, skele_y, skele_z, min_z): 126 | x, y, z = img.shape 127 | if z >= min_z: 128 | return img, label, edge, skele_x, skele_y, skele_z 129 | else: 130 | num_pad = min_z - z 131 | num_top_pad = num_pad // 2 132 | top_pad = np.zeros((x, y, num_top_pad), dtype=np.float64) 133 | bottom_pad = np.zeros((x, y, num_pad - num_top_pad), dtype=np.float64) 134 | img = np.concatenate((bottom_pad, img, top_pad), axis=2) 135 | label = np.concatenate((bottom_pad, label, top_pad), axis=2) 136 | edge = np.concatenate((bottom_pad, edge, top_pad), axis=2) 137 | skele_x = np.concatenate((bottom_pad, skele_x, top_pad), axis=2) 138 | skele_y = np.concatenate((bottom_pad, skele_y, top_pad), axis=2) 139 | skele_z = np.concatenate((bottom_pad, skele_z, top_pad), axis=2) 140 | return img, label, edge, skele_x, skele_y, skele_z 141 | 142 | 143 | def padding_x(img, label, edge, skele_x, skele_y, skele_z, min_x): 144 | x, y, z = img.shape 145 | if x >= min_x: 146 | return img, label, edge, skele_x, skele_y, skele_z 147 | else: 148 | num_pad = min_x - x 149 | num_top_pad = num_pad // 2 150 | top_pad = np.zeros((num_top_pad, y, z), dtype=np.float64) 151 | bottom_pad = np.zeros((num_pad - num_top_pad, y, z), dtype=np.float64) 152 | img = np.concatenate((bottom_pad, img, top_pad), axis=0) 153 | label = np.concatenate((bottom_pad, label, top_pad), axis=0) 154 | edge = np.concatenate((bottom_pad, edge, top_pad), axis=0) 155 | skele_x = np.concatenate((bottom_pad, skele_x, top_pad), axis=0) 156 | skele_y = np.concatenate((bottom_pad, skele_y, top_pad), axis=0) 157 | skele_z = np.concatenate((bottom_pad, skele_z, top_pad), axis=0) 158 | return img, label, edge, skele_x, skele_y, skele_z 159 | 160 | 161 | def padding_y(img, label, edge, skele_x, skele_y, skele_z, min_y): 162 | x, y, z = img.shape 163 | if y >= min_y: 164 | return img, label, edge, skele_x, skele_y, skele_z 165 | else: 166 | num_pad = min_y - y 167 | num_top_pad = num_pad // 2 168 | top_pad = np.zeros((x, num_top_pad, z), dtype=np.float64) 169 | bottom_pad = np.zeros((x, num_pad - num_top_pad, z), dtype=np.float64) 170 | img = np.concatenate((bottom_pad, img, top_pad), axis=1) 171 | label = np.concatenate((bottom_pad, label, top_pad), axis=1) 172 | edge = np.concatenate((bottom_pad, edge, top_pad), axis=1) 173 | skele_x = np.concatenate((bottom_pad, skele_x, top_pad), axis=1) 174 | skele_y = np.concatenate((bottom_pad, skele_y, top_pad), axis=1) 175 | skele_z = np.concatenate((bottom_pad, skele_z, top_pad), axis=1) 176 | return img, label, edge, skele_x, skele_y, skele_z 177 | 178 | 179 | def rotate(img_ori, label_ori, edge_ori, skele_x_ori, skele_y_ori, skele_z_ori, degree): 180 | height, width, depth = img_ori.shape 181 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1) 182 | 183 | imgRotation = np.zeros_like(img_ori) 184 | labelRotation = np.zeros_like(label_ori) 185 | skele_x_Rotation = np.zeros_like(skele_x_ori) 186 | skele_y_Rotation = np.zeros_like(skele_y_ori) 187 | skele_z_Rotation = np.zeros_like(skele_z_ori) 188 | edgeRotation = np.zeros_like(edge_ori) 189 | for z in range(depth): 190 | imgRotation[:, :, z] = cv2.warpAffine(img_ori[:, :, z], matRotation, (width, height), borderValue=0) 191 | labelRotation[:, :, z] = cv2.warpAffine(label_ori[:, :, z], matRotation, (width, height)) 192 | edgeRotation[:, :, z] = cv2.warpAffine(edge_ori[:, :, z], matRotation, (width, height)) 193 | skele_x_Rotation[:, :, z] = cv2.warpAffine(skele_x_ori[:, :, z], matRotation, (width, height)) 194 | skele_y_Rotation[:, :, z] = cv2.warpAffine(skele_y_ori[:, :, z], matRotation, (width, height)) 195 | skele_z_Rotation[:, :, z] = cv2.warpAffine(skele_z_ori[:, :, z], matRotation, (width, height)) 196 | labelRotation = np.round(labelRotation) 197 | return imgRotation, labelRotation, edgeRotation, skele_x_Rotation, skele_y_Rotation, skele_z_Rotation 198 | 199 | 200 | def crop_resize(img_ori, label_ori, edge_ori, skele_x_ori, skele_y_ori, skele_z_ori, shift_size_x, shift_size_y): 201 | H, W, C = img_ori.shape 202 | x_small = np.random.randint(0, shift_size_x) 203 | x_large = np.random.randint(H - shift_size_x, H) 204 | y_small = np.random.randint(0, shift_size_y) 205 | y_large = np.random.randint(W - shift_size_y, W) 206 | 207 | imgCropresize = np.zeros_like(img_ori) 208 | labelCropresize = np.zeros_like(label_ori) 209 | skele_x_Cropresize = np.zeros_like(skele_x_ori) 210 | skele_y_Cropresize = np.zeros_like(skele_y_ori) 211 | skele_z_Cropresize = np.zeros_like(skele_z_ori) 212 | edgeCropresize = np.zeros_like(edge_ori) 213 | for z in range(C): 214 | imgCropresize[:, :, z] = cv2.resize(img_ori[x_small:x_large, y_small:y_large, z], (W, H)) 215 | labelCropresize[:, :, z] = cv2.resize(label_ori[x_small:x_large, y_small:y_large, z], (W, H)) 216 | edgeCropresize[:, :, z] = cv2.resize(edge_ori[x_small:x_large, y_small:y_large, z], (W, H)) 217 | skele_x_Cropresize[:, :, z] = cv2.resize(skele_x_ori[x_small:x_large, y_small:y_large, z], (W, H)) 218 | skele_y_Cropresize[:, :, z] = cv2.resize(skele_y_ori[x_small:x_large, y_small:y_large, z], (W, H)) 219 | skele_z_Cropresize[:, :, z] = cv2.resize(skele_z_ori[x_small:x_large, y_small:y_large, z], (W, H)) 220 | labelCropresize = np.round(labelCropresize) 221 | return imgCropresize, labelCropresize, edgeCropresize, skele_x_Cropresize, skele_y_Cropresize, skele_z_Cropresize 222 | 223 | 224 | def normalize_volume(img): 225 | img_array = (img - img.min()) / (img.max() - img.min()) 226 | return img_array 227 | 228 | 229 | class BowelFineSeg(data.Dataset): 230 | def __init__(self, root='', transform=None, mode="train", test_fraction=0.2, bowel_name='', save_dir=''): 231 | 232 | assert bowel_name in name_to_label.keys() 233 | 234 | current_test = 5 235 | train_split, test_split = split_dataset(root, current_test, test_fraction, bowel_name, save_dir) 236 | 237 | if mode == "infer" or mode == "test": 238 | self.imgs = test_split 239 | else: 240 | self.imgs = train_split 241 | 242 | self.bowel_name = bowel_name 243 | self.mode = mode 244 | self.root = root 245 | self.patch_size = (64, 192, 192) 246 | 247 | if bowel_name == 'colon': 248 | self.patch_size = (64, 192, 192) 249 | 250 | if bowel_name == 'sigmoid': 251 | self.patch_size = (64, 160, 160) 252 | 253 | if bowel_name == 'duodenum': 254 | self.patch_size = (64, 160, 160) 255 | 256 | if bowel_name == 'rectum': 257 | self.patch_size = (64, 96, 96) 258 | 259 | self.transform = transform 260 | 261 | def __len__(self): 262 | return len(self.imgs) 263 | 264 | def __getitem__(self, index): 265 | 266 | patch_size = self.patch_size 267 | img_name = self.imgs[index] 268 | 269 | img, label, edge, skele_x, skele_y, skele_z = load_image_and_label(img_name, transform=self.transform) 270 | z, x, y = img.shape 271 | 272 | if np.random.choice([False, False, False, True]): 273 | force_fg = True 274 | else: 275 | force_fg = False 276 | 277 | if force_fg: 278 | # sample a foreground region 279 | z_range, x_range, y_range = np.where(label != 0) 280 | z_min, x_min, y_min = z_range.min(), x_range.min(), y_range.min() 281 | z_max, x_max, y_max = z_range.max(), x_range.max(), y_range.max() 282 | center_z = random.randint(z_min, z_max) 283 | center_x = random.randint(x_min, x_max) 284 | center_y = random.randint(y_min, y_max) 285 | 286 | zs = center_z - patch_size[0] // 2 # start 287 | ze = center_z + patch_size[0] // 2 # end 288 | xs = center_x - patch_size[1] // 2 289 | xe = center_x + patch_size[1] // 2 290 | ys = center_y - patch_size[2] // 2 291 | ye = center_y + patch_size[2] // 2 292 | 293 | if zs < 0: 294 | zs = 0 295 | ze = zs + patch_size[0] 296 | if ze > z: 297 | ze = z 298 | zs = ze - patch_size[0] 299 | 300 | if xs < 0: 301 | xs = 0 302 | xe = xs + patch_size[1] 303 | if xe > x: 304 | xe = x 305 | xs = xe - patch_size[1] 306 | 307 | if ys < 0: 308 | ys = 0 309 | ye = ys + patch_size[2] 310 | if ye > y: 311 | ye = y 312 | ys = ye - patch_size[2] 313 | else: 314 | zs = random.randint(0, z - patch_size[0]) 315 | ze = zs + self.patch_size[0] 316 | xs = random.randint(0, x - patch_size[1]) 317 | xe = xs + self.patch_size[1] 318 | ys = random.randint(0, y - patch_size[2]) 319 | ye = ys + self.patch_size[2] 320 | 321 | img_p = img[zs:ze, xs:xe, ys:ye][np.newaxis, :] 322 | label_p = label[zs:ze, xs:xe, ys:ye][np.newaxis, :] 323 | edge_p = edge[zs:ze, xs:xe, ys:ye][np.newaxis, :] 324 | skele_x_p = skele_x[zs:ze, xs:xe, ys:ye][np.newaxis, :] 325 | skele_y_p = skele_y[zs:ze, xs:xe, ys:ye][np.newaxis, :] 326 | skele_z_p = skele_z[zs:ze, xs:xe, ys:ye][np.newaxis, :] 327 | 328 | img_p = torch.from_numpy(img_p.astype(np.float32)) 329 | label_p = torch.from_numpy(label_p.astype(np.int32)) 330 | edge_p = torch.from_numpy(edge_p.astype(np.float32)) 331 | skele_x_p = torch.from_numpy(skele_x_p.astype(np.float32)) 332 | skele_y_p = torch.from_numpy(skele_y_p.astype(np.float32)) 333 | skele_z_p = torch.from_numpy(skele_z_p.astype(np.float32)) 334 | 335 | return img_p, label_p, edge_p, skele_x_p, skele_y_p, skele_z_p 336 | 337 | -------------------------------------------------------------------------------- /bowel_coarseseg/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import division 3 | import time 4 | import argparse 5 | import torch 6 | import random 7 | import numpy as np 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torch.utils.data import DataLoader 13 | from datasets.BowelDatasetCoarseSeg import BowelCoarseSeg 14 | from tools.loss import * 15 | import os, math 16 | import shutil 17 | import loc_model 18 | import wandb 19 | 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' # "1, 2" 22 | print("GPU ID:", os.environ['CUDA_VISIBLE_DEVICES']) 23 | 24 | 25 | def weights_init(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Conv3d') != -1: 28 | # nn.init.kaiming_normal_(m.weight) 29 | nn.init.xavier_normal_(m.weight, gain=0.02) 30 | m.bias.data.zero_() 31 | 32 | 33 | def datestr(): 34 | now = time.gmtime() 35 | return '{}{:02}{:02}_{:02}{:02}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min) 36 | 37 | 38 | def save_checkpoint(state, is_best, path, prefix, filename='checkpoint.pth.tar'): 39 | prefix_save = os.path.join(path, prefix) 40 | name = prefix_save + '_' + filename 41 | torch.save(state, name) 42 | if is_best: 43 | shutil.copyfile(name, prefix_save + '_model_best.pth.tar') 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--ngpu', type=int, default=1) 49 | parser.add_argument('--batch_size', type=int, default=8) # 8 (2 GPU) 50 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 51 | parser.add_argument('--nEpochs', type=int, default=1500, help='total training epoch') 52 | 53 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 54 | parser.add_argument('--weight_decay', default=1e-8, type=float, metavar='W', help='weight decay (default: 1e-8)') 55 | parser.add_argument('--eval_interval', default=5, type=int, help='evaluate interval on validation set') 56 | parser.add_argument('--save_dir', type=str, default=None) 57 | parser.add_argument('--seed', type=int, default=1) 58 | parser.add_argument('--deterministic', type=bool, default=True) 59 | args = parser.parse_args() 60 | 61 | 62 | args.save_dir = 'exp/BowelLocNet.{}'.format(datestr()) 63 | if os.path.exists(args.save_dir): 64 | shutil.rmtree(args.save_dir) 65 | os.makedirs(args.save_dir, exist_ok=True) 66 | 67 | 68 | 69 | shutil.copy(src=os.path.join(os.getcwd(), 'train.py'), dst=args.save_dir) 70 | shutil.copy(src=os.path.join(os.getcwd(), 'infer.py'), dst=args.save_dir) 71 | shutil.copy(src=os.path.join(os.getcwd(), 'crop_ROI.py'), dst=args.save_dir) 72 | shutil.copy(src=os.path.join(os.getcwd(), 'loc_model.py'), dst=args.save_dir) 73 | shutil.copy(src=os.path.join(os.getcwd(), 'tools/loss.py'), dst=args.save_dir) 74 | shutil.copy(src=os.path.join(os.getcwd(), 'datasets/BowelDatasetCoarseSeg.py'), dst=args.save_dir) 75 | 76 | 77 | # if args.seed is not None: 78 | # random.seed(args.seed) 79 | # np.random.seed(args.seed) 80 | # torch.manual_seed(args.seed) 81 | # torch.cuda.manual_seed(args.seed) 82 | # torch.backends.cudnn.deterministic = args.deterministic 83 | 84 | 85 | print("build Bowel Localisation Network") 86 | model = loc_model.BowelLocNet(elu=False) 87 | model.apply(weights_init) 88 | 89 | 90 | # model.load_state_dict(torch.load("./exp/BowelLocNet.20230208_1536/partial_5C_dict_1300.pth")) 91 | 92 | 93 | model = model.cuda() 94 | model = nn.parallel.DataParallel(model) 95 | 96 | print(' + Number of params: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 97 | 98 | 99 | 100 | # WandB – Initialize a new run 101 | wandb.init(project='BowelNet', mode='disabled') # mode='disabled' 102 | wandb.run.name = 'BowelLoc_' + wandb.run.id 103 | 104 | 105 | 106 | print("loading fully_labeled dataset") 107 | fully_labeled_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Fully_labeled_5C' 108 | trainSet_fully_labeled = BowelCoarseSeg(fully_labeled_dir, mode="train", transform=True, dataset_name="fully_labeled", save_dir=args.save_dir) 109 | trainLoader_fully_labeled = DataLoader(trainSet_fully_labeled, batch_size=args.batch_size, shuffle=True, num_workers=6, pin_memory=False) 110 | testSet_fully_labeled = BowelCoarseSeg(fully_labeled_dir, mode="test", transform=False, dataset_name="fully_labeled", save_dir=args.save_dir) 111 | testLoader_fully_labeled = DataLoader(testSet_fully_labeled, batch_size=args.batch_size, shuffle=False, num_workers=6, pin_memory=False) 112 | 113 | print("loading smallbowel dataset") 114 | smallbowel_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Smallbowel' 115 | trainSet_small = BowelCoarseSeg(smallbowel_dir, mode="train", transform=True, dataset_name="smallbowel", save_dir=args.save_dir) 116 | trainLoader_small = DataLoader(trainSet_small, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False) 117 | testSet_small = BowelCoarseSeg(smallbowel_dir, mode="test", transform=False, dataset_name="smallbowel", save_dir=args.save_dir) 118 | testLoader_small = DataLoader(testSet_small, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False) 119 | 120 | print("loading colon_sigmoid dataset") 121 | colon_sigmoid_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Colon_Sigmoid/colon_sigmoid' 122 | trainSet_colon_sigmoid = BowelCoarseSeg(colon_sigmoid_dir, mode="train", transform=True, dataset_name="colon_sigmoid", save_dir=args.save_dir) 123 | trainLoader_colon_sigmoid = DataLoader(trainSet_colon_sigmoid, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False) 124 | testSet_colon_sigmoid = BowelCoarseSeg(colon_sigmoid_dir, mode="test", transform=False, dataset_name="colon_sigmoid", save_dir=args.save_dir) 125 | testLoader_colon_sigmoid = DataLoader(testSet_colon_sigmoid, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False) 126 | 127 | 128 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 129 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.85) # 0.90 130 | 131 | 132 | best_dice = 0.0 133 | trainF = open(os.path.join(args.save_dir, 'train.csv'), 'w') 134 | testF = open(os.path.join(args.save_dir, 'test.csv'), 'w') 135 | for epoch in range(0, args.nEpochs + 1): 136 | if epoch < 0: # 200 137 | train_loc(epoch, model, trainLoader_fully_labeled, optimizer, trainF, data_type='fully_labeled') 138 | # train_loc(epoch, model, trainLoader_small, optimizer, trainF, data_type='smallbowel') 139 | # train_loc(epoch, model, trainLoader_colon_sigmoid, optimizer, trainF, data_type='colon_sigmoid') 140 | else: 141 | if epoch % 3 == 0: 142 | train_loc(epoch, model, trainLoader_fully_labeled, optimizer, trainF, data_type='fully_labeled') 143 | elif epoch % 3 == 1: 144 | train_loc(epoch, model, trainLoader_small, optimizer, trainF, data_type='smallbowel') 145 | else: 146 | train_loc(epoch, model, trainLoader_colon_sigmoid, optimizer, trainF, data_type='colon_sigmoid') 147 | 148 | scheduler.step() 149 | wandb.log({ 150 | "LR": optimizer.param_groups[0]['lr'], 151 | }) 152 | 153 | if epoch % args.eval_interval == 0: # 5 154 | dice_fully_labeled = test_loc(epoch, model, testLoader_fully_labeled, testF, data_type='fully_labeled') 155 | dice_small = test_loc(epoch, model, testLoader_small, testF, data_type='smallbowel') 156 | dice_colon_sigmoid = test_loc(epoch, model, testLoader_colon_sigmoid, testF, data_type='colon_sigmoid') 157 | dice_mean = (np.mean(dice_fully_labeled) + np.mean(dice_small) + np.mean(dice_colon_sigmoid)) / 3.0 158 | 159 | wandb.log({ 160 | "Test Mean Dice Score": dice_mean, 161 | }) 162 | 163 | is_best = False 164 | if dice_mean > best_dice: 165 | is_best = True 166 | best_dice = dice_mean 167 | save_checkpoint({'epoch': epoch, 'state_dict': model.module.state_dict(), 'best_acc': best_dice}, 168 | is_best, args.save_dir, "partial_5C") 169 | torch.save(model.module.state_dict(), args.save_dir + '/partial_5C_dict_' + str(epoch) + '.pth') 170 | 171 | trainF.close() 172 | testF.close() 173 | 174 | 175 | def train_loc(epoch, model, trainLoader, optimizer, trainF, data_type): 176 | 177 | assert data_type in ['fully_labeled', 'smallbowel', 'colon_sigmoid'] 178 | 179 | model.train() 180 | nProcessed = 0 181 | nTrain = len(trainLoader) 182 | 183 | for batch_idx, (data, target, image_name) in enumerate(trainLoader): 184 | 185 | data, target = Variable(data.cuda()), Variable(target.cuda()) 186 | 187 | optimizer.zero_grad() 188 | 189 | out = model(data) 190 | out = torch.clamp(out, min=1e-10, max=1) # prevent overflow 191 | 192 | # {1: 'rectum', 2: 'sigmoid', 3: 'colon', 4: 'small', 5: 'duodenum'} 193 | if data_type == 'fully_labeled': 194 | output_p = out 195 | target_ce = target.long() 196 | target_dice = torch.cat((target_ce == 0, target_ce == 1, target_ce == 2, 197 | target_ce == 3, target_ce == 4, target_ce == 5), dim=1).long() 198 | loss_entropy = torch.tensor(0.0).cuda() 199 | 200 | if data_type == 'smallbowel': 201 | output_p1 = out[:, 4, :, :, :] # smallbowel 202 | output_p0 = out[:, 0, :, :, :] + \ 203 | out[:, 1, :, :, :] + \ 204 | out[:, 2, :, :, :] + \ 205 | out[:, 3, :, :, :] + \ 206 | out[:, 5, :, :, :] 207 | output_p = torch.stack((output_p0, output_p1), dim=1) 208 | target_ce = (target == 4).long() 209 | target_dice = torch.cat((target_ce == 0, target_ce == 1), dim=1).long() 210 | prob_neg = torch.stack((out[:, 0], 211 | out[:, 1], 212 | out[:, 2], 213 | out[:, 3], 214 | out[:, 5]), dim=1) 215 | entropy = (-prob_neg * torch.log(prob_neg)).sum(dim=1) 216 | neg_mask = (target_ce == 0).squeeze(1) 217 | loss_entropy = (entropy * neg_mask).sum() / ((neg_mask).sum() + 1e-7) 218 | 219 | if data_type == 'colon_sigmoid': 220 | output_p1 = out[:, 2, :, :, :] # sigmoid 221 | output_p2 = out[:, 3, :, :, :] # colon 222 | output_p0 = out[:, 0, :, :, :] + \ 223 | out[:, 1, :, :, :] + \ 224 | out[:, 4, :, :, :] + \ 225 | out[:, 5, :, :, :] 226 | output_p = torch.stack((output_p0, output_p1, output_p2), dim=1) 227 | target_ce = torch.zeros(target.shape).long().cuda() 228 | target_ce[target == 2] = 1 229 | target_ce[target == 3] = 2 230 | target_dice = torch.cat((target_ce == 0, target_ce == 1, target_ce == 2), dim=1).long() 231 | prob_neg = torch.stack((out[:, 0], 232 | out[:, 1], 233 | out[:, 4], 234 | out[:, 5]), dim=1) 235 | entropy = (-prob_neg * torch.log(prob_neg)).sum(dim=1) 236 | neg_mask = (target_ce == 0).squeeze(1) 237 | loss_entropy = (entropy * neg_mask).sum() / ((neg_mask).sum() + 1e-7) 238 | 239 | loss_ce = F.nll_loss(output_p.log(), target_ce.squeeze(1)) 240 | loss_dice = dice_loss_PL(output_p, target_dice) 241 | loss = loss_dice + loss_ce + loss_entropy 242 | 243 | loss.backward() 244 | optimizer.step() 245 | 246 | dice_score = dice_score_partial(out, target, data_type) 247 | 248 | pred = torch.argmax(output_p, dim=1, keepdim=True) 249 | correct = pred.eq(target_ce.data).cpu().sum() 250 | acc = correct / target_ce.numel() 251 | 252 | correct_pos = torch.logical_and((pred != 0), (target_ce != 0)).sum().item() 253 | sen_pos = round(correct_pos / ((target_ce != 0).sum().item() + 0.0001), 3) 254 | 255 | correct_neg = torch.logical_and((pred == 0), (target_ce == 0)).sum().item() 256 | sen_neg = round(correct_neg / ((target_ce == 0).sum().item() + 0.0001), 3) 257 | 258 | nProcessed += len(data) 259 | partialEpoch = epoch + batch_idx / nTrain 260 | print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tData_type: {}\tLoss: {:.4f}\tLoss_Dice: {:.4f}\tLoss_CE: {:.4f} ' 261 | '\tLoss_Ent: {:.4f}' 262 | '\tAcc: {:.3f}\tSen_pos: {:.3f}\tSen_neg: {:.3f}\tDice: {}'.format( 263 | partialEpoch, nProcessed, len(trainLoader.dataset), 100. * batch_idx / nTrain, data_type, 264 | loss.data, loss_dice.data, loss_ce.data, loss_entropy.data, acc, sen_pos, sen_neg, dice_score)) 265 | 266 | wandb.log({ 267 | "Train Loss": loss, 268 | "Train CE Loss": loss_ce, 269 | "Train Dice Loss": loss_dice, 270 | "Train Entropy Loss": loss_entropy, 271 | }) 272 | 273 | trainF.write( 274 | '{},{},{},{},{},{},{},{},{}, {}\n'.format(partialEpoch, data_type, loss.data, loss_dice.data, loss_ce.data, loss_entropy.data, acc, sen_pos, sen_neg, dice_score)) 275 | trainF.flush() 276 | 277 | 278 | def test_loc(epoch, model, testLoader, testF, data_type): 279 | 280 | assert data_type in ['fully_labeled', 'smallbowel', 'colon_sigmoid'] 281 | 282 | model.eval() 283 | test_loss = 0 284 | dice_score = 0 285 | acc_all = 0 286 | sen_pos_all = 0 287 | sen_neg_all = 0 288 | with torch.no_grad(): 289 | for data, target, image_name in testLoader: 290 | data, target = Variable(data.cuda()), Variable(target.cuda()) 291 | 292 | out = model(data) 293 | 294 | # {1: 'rectum', 2: 'sigmoid', 3: 'colon', 4: 'small', 5: 'duodenum'} 295 | 296 | if data_type == 'fully_labeled': 297 | output_p = out 298 | target_ce = target.long() 299 | target_dice = torch.cat((target_ce == 0, target_ce == 1, target_ce == 2, 300 | target_ce == 3, target_ce == 4, target_ce == 5), dim=1).long() 301 | 302 | if data_type == 'smallbowel': 303 | output_p1 = out[:, 4, :, :, :] # smallbowel 304 | output_p0 = out[:, 0, :, :, :] + \ 305 | out[:, 1, :, :, :] + \ 306 | out[:, 2, :, :, :] + \ 307 | out[:, 3, :, :, :] + \ 308 | out[:, 5, :, :, :] 309 | output_p = torch.stack((output_p0, output_p1), dim=1) 310 | target_ce = (target == 4).long() 311 | target_dice = torch.cat((target_ce == 0, target_ce == 1), dim=1).long() 312 | 313 | if data_type == 'colon_sigmoid': 314 | output_p1 = out[:, 2, :, :, :] # sigmoid 315 | output_p2 = out[:, 3, :, :, :] # colon 316 | output_p0 = out[:, 0, :, :, :] + \ 317 | out[:, 1, :, :, :] + \ 318 | out[:, 4, :, :, :] + \ 319 | out[:, 5, :, :, :] 320 | output_p = torch.stack((output_p0, output_p1, output_p2), dim=1) 321 | target_ce = torch.zeros(target.shape).long().cuda() 322 | target_ce[target == 2] = 1 323 | target_ce[target == 3] = 2 324 | target_dice = torch.cat((target_ce == 0, target_ce == 1, target_ce == 2), dim=1).long() 325 | 326 | 327 | loss_dice = dice_loss_PL(output_p, target_dice) 328 | test_loss += loss_dice 329 | 330 | dice_score += np.array(dice_score_partial(out, target, data_type)) 331 | 332 | pred = torch.argmax(output_p, dim=1, keepdim=True) 333 | correct = pred.eq(target_ce.data).cpu().sum() 334 | acc_all += correct / target_ce.numel() 335 | 336 | correct_pos = torch.logical_and((pred != 0), (target_ce != 0)).sum().item() 337 | sen_pos = round(correct_pos / ((target_ce != 0).sum().item() + 0.0001), 3) 338 | sen_pos_all += sen_pos 339 | 340 | correct_neg = torch.logical_and((pred == 0), (target_ce == 0)).sum().item() 341 | sen_neg = round(correct_neg / ((target_ce == 0).sum().item() + 0.0001), 3) 342 | sen_neg_all += sen_neg 343 | 344 | test_loss /= len(testLoader) 345 | dice_score /= len(testLoader) 346 | acc_all /= len(testLoader) 347 | sen_pos_all /= len(testLoader) 348 | sen_neg_all /= len(testLoader) 349 | print('\nTest online: Data_type: {}, Dice loss: {:.4f}, Acc:{:.3f}, Sen_pos:{:.3f}, Sen_neg:{:.3f}, Dice: {}\n'. 350 | format(data_type, test_loss, acc_all, sen_pos_all, sen_neg_all, dice_score)) 351 | 352 | testF.write('{},{},{},{},{},{},{}\n'.format(epoch, data_type, test_loss, acc_all, sen_pos_all, sen_neg_all, dice_score)) 353 | testF.flush() 354 | 355 | wandb.log({ 356 | "Test Loss": test_loss, 357 | "Test Dice Score": dice_score.mean(), 358 | }) 359 | 360 | return dice_score 361 | 362 | 363 | if __name__ == '__main__': 364 | main() 365 | -------------------------------------------------------------------------------- /bowel_fineseg/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import division 3 | import time 4 | import argparse 5 | import torch 6 | import random 7 | import numpy as np 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import torchvision.transforms as transforms 13 | from torch.utils.data import DataLoader 14 | from datasets.BowelDatasetFineSeg import BowelFineSeg 15 | from tools.loss import * 16 | import os, math 17 | import shutil 18 | import seg_model 19 | 20 | import wandb 21 | 22 | 23 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' # "0, 1" 24 | print("GPU ID:", os.environ['CUDA_VISIBLE_DEVICES']) 25 | 26 | 27 | def weights_init(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv3d') != -1: 30 | # nn.init.kaiming_normal_(m.weight) 31 | nn.init.xavier_normal_(m.weight, gain=0.02) 32 | m.bias.data.zero_() 33 | 34 | 35 | def datestr(): 36 | now = time.gmtime() 37 | return '{}{:02}{:02}_{:02}{:02}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min) 38 | 39 | 40 | def save_checkpoint(state, is_best, path, prefix, filename='checkpoint.pth.tar'): 41 | prefix_save = os.path.join(path, prefix) 42 | name = prefix_save + '_' + filename 43 | torch.save(state, name) 44 | if is_best: 45 | shutil.copyfile(name, prefix_save + '_model_best.pth.tar') 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--ngpu', type=int, default=1) 51 | parser.add_argument('--batch_size', type=int, default=4) # 4 (single GPU) 52 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 53 | parser.add_argument('--nEpochs_base', type=int, default=501, help='total epoch number for base segmentor') 54 | parser.add_argument('--nEpochs_meta', type=int, default=301, help='total epoch number for meta segmentor') 55 | 56 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 57 | parser.add_argument('--weight_decay', default=1e-8, type=float, metavar='W', help='weight decay (default: 1e-8)') 58 | parser.add_argument('--eval_interval', default=5, type=int, help='evaluate interval on validation set') 59 | parser.add_argument('--temp', default=0.7, type=float, help='temperature for meta segmentor') 60 | parser.add_argument('--save_dir', type=str, default=None) 61 | parser.add_argument('--seed', type=int, default=1) 62 | parser.add_argument('--deterministic', type=bool, default=True) 63 | args = parser.parse_args() 64 | 65 | 66 | args.save_dir = 'exp/BowelNet.{}'.format(datestr()) 67 | if os.path.exists(args.save_dir): 68 | shutil.rmtree(args.save_dir) 69 | os.makedirs(args.save_dir, exist_ok=True) 70 | 71 | shutil.copy(src=os.path.join(os.getcwd(), 'train.py'), dst=args.save_dir) 72 | shutil.copy(src=os.path.join(os.getcwd(), 'infer.py'), dst=args.save_dir) 73 | shutil.copy(src=os.path.join(os.getcwd(), 'seg_model.py'), dst=args.save_dir) 74 | shutil.copy(src=os.path.join(os.getcwd(), 'tools/loss.py'), dst=args.save_dir) 75 | shutil.copy(src=os.path.join(os.getcwd(), 'tools/utils.py'), dst=args.save_dir) 76 | shutil.copy(src=os.path.join(os.getcwd(), 'datasets/BowelDatasetFineSeg.py'), dst=args.save_dir) 77 | 78 | 79 | # if args.seed is not None: 80 | # random.seed(args.seed) 81 | # np.random.seed(args.seed) 82 | # torch.manual_seed(args.seed) 83 | # torch.cuda.manual_seed(args.seed) 84 | # torch.backends.cudnn.deterministic = args.deterministic 85 | 86 | 87 | print("build BowelNet") 88 | model = seg_model.BowelNet(elu=False) 89 | model.apply(weights_init) 90 | 91 | 92 | 93 | # model.load_state_dict(torch.load("./BowelNet.20230207_1744/rectum_meta_195.pth")) 94 | 95 | 96 | 97 | 98 | model = model.cuda() 99 | model = nn.parallel.DataParallel(model) 100 | 101 | 102 | print(' + Number of params: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 103 | 104 | 105 | 106 | 107 | 108 | 109 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_small/' 110 | # bowel_name = 'small' 111 | 112 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_colon/' 113 | # bowel_name = 'colon' 114 | 115 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_sigmoid/' 116 | # bowel_name = 'sigmoid' 117 | 118 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_duodenum/' 119 | # bowel_name = 'duodenum' 120 | 121 | data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_rectum/' 122 | bowel_name = 'rectum' 123 | 124 | 125 | 126 | # WandB – Initialize a new run 127 | wandb.init(project='BowelNet', mode='disabled') # mode='disabled' 128 | wandb.run.name = bowel_name + '_' + wandb.run.id 129 | 130 | 131 | 132 | print("loading training set") 133 | trainSet = BowelFineSeg(root=data_dir, mode="train", transform=True, bowel_name=bowel_name, save_dir=args.save_dir) 134 | trainLoader = DataLoader(trainSet, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=False) # 8 135 | 136 | print("loading test set") 137 | testSet = BowelFineSeg(root=data_dir, mode="test", transform=False, bowel_name=bowel_name, save_dir=args.save_dir) 138 | testLoader = DataLoader(testSet, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=False) # 8 139 | 140 | 141 | 142 | # base segmentor training 143 | ########################################################################## 144 | params_base = [param for name, param in model.named_parameters() if 'meta' not in name] 145 | optimizer_base = optim.Adam(params_base, lr=args.lr, weight_decay=args.weight_decay) 146 | base_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_base, step_size=50, gamma=0.85) 147 | best_dice = 0.0 148 | trainF_base = open(os.path.join(args.save_dir, 'train_base.csv'), 'w') 149 | testF_base = open(os.path.join(args.save_dir, 'test_base.csv'), 'w') 150 | for epoch in range(1, args.nEpochs_base + 1): 151 | train_base(epoch, model, trainLoader, optimizer_base, trainF_base, train_segmentor='base') 152 | # base_scheduler.step() 153 | wandb.log({ 154 | "Base LR": optimizer_base.param_groups[0]['lr'], 155 | }) 156 | if epoch % args.eval_interval == 0: # 5 157 | dice = test_base(epoch, model, testLoader, testF_base, train_segmentor='base') 158 | is_best = False 159 | if dice > best_dice: 160 | is_best = True 161 | best_dice = dice 162 | save_checkpoint({'epoch': epoch, 'state_dict': model.module.state_dict(), 'best_acc': best_dice}, 163 | is_best, args.save_dir, bowel_name + '_base') 164 | torch.save(model.module.state_dict(), args.save_dir + '/' + bowel_name + '_base_' + str(epoch) + '.pth') 165 | trainF_base.close() 166 | testF_base.close() 167 | 168 | # best_base_model_path = os.path.join(args.save_dir, bowel_name + '_base' + '_model_best.pth.tar') 169 | # model.load_state_dict(torch.load(best_base_model_path)['state_dict']) 170 | 171 | # meta segmentor training 172 | ########################################################################## 173 | params_meta = [param for name, param in model.named_parameters() if 'meta' in name] 174 | optimizer_meta = optim.Adam(params_meta, lr=args.lr, weight_decay=args.weight_decay) 175 | meta_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_meta, step_size=30, gamma=0.85) 176 | best_dice = 0.0 177 | trainF_meta = open(os.path.join(args.save_dir, 'train_meta.csv'), 'w') 178 | testF_meta = open(os.path.join(args.save_dir, 'test_meta.csv'), 'w') 179 | for epoch in range(1, args.nEpochs_meta + 1): 180 | train_meta(epoch, model, trainLoader, optimizer_meta, trainF_meta, train_segmentor='meta', temperature=args.temp) 181 | # meta_scheduler.step() 182 | wandb.log({ 183 | "Meta LR": optimizer_meta.param_groups[0]['lr'], 184 | }) 185 | if epoch % args.eval_interval == 0: # 5 186 | dice = test_meta(epoch, model, testLoader, testF_meta, train_segmentor='meta') 187 | is_best = False 188 | if dice > best_dice: 189 | is_best = True 190 | best_dice = dice 191 | save_checkpoint({'epoch': epoch, 'state_dict': model.module.state_dict(), 'best_acc': best_dice}, 192 | is_best, args.save_dir, bowel_name + '_meta') 193 | torch.save(model.module.state_dict(), args.save_dir + '/' + bowel_name + '_meta_' + str(epoch) + '.pth') 194 | trainF_meta.close() 195 | testF_meta.close() 196 | 197 | 198 | 199 | def train_base(epoch, model, trainLoader, optimizer, trainF, train_segmentor='base'): 200 | model.train() 201 | nProcessed = 0 202 | nTrain = len(trainLoader.dataset) 203 | for batch_idx, (data, target, edge, skele_x, skele_y, skele_z) in enumerate(trainLoader): 204 | data, target, edge, skele_x, skele_y, skele_z = Variable(data.cuda()), Variable(target.cuda()), \ 205 | Variable(edge.cuda()), Variable(skele_x.cuda()), \ 206 | Variable(skele_y.cuda()), Variable(skele_z.cuda()) 207 | optimizer.zero_grad() 208 | out_flux_reg, out_skele_mask_seg, out_edge_reg, out_edge_mask_seg = model(data, train_segmentor) 209 | 210 | ##### boundary segmentor loss 211 | loss_ce_edge_mask = F.cross_entropy(out_edge_mask_seg, target.squeeze(1).long()) 212 | loss_dice_edge_mask = dice_loss(F.softmax(out_edge_mask_seg, dim=1), target) 213 | 214 | edge_temp_mask = (edge != 0).float() 215 | if edge_temp_mask.sum() == 0: 216 | weight_matrix = torch.ones_like(edge_temp_mask) * 1.0 217 | else: 218 | pos_sum = (edge_temp_mask != 0).sum() 219 | neg_sum = (edge_temp_mask == 0).sum() 220 | pos_matrix = (neg_sum / pos_sum * 0.5) * edge_temp_mask 221 | neg_matrix = 1.0 * (1 - edge_temp_mask) 222 | weight_matrix = pos_matrix + neg_matrix 223 | loss_edge = (weight_matrix * torch.square(out_edge_reg - edge)).mean() 224 | # print('edge re-weight:', weight_matrix.max().item(), weight_matrix.min().item()) 225 | 226 | ##### skeleton segmentor loss 227 | loss_ce_skele_mask = F.cross_entropy(out_skele_mask_seg, target.squeeze(1).long()) 228 | loss_dice_skele_mask = dice_loss(F.softmax(out_skele_mask_seg, dim=1), target) # notice 229 | 230 | skele_xyz = torch.cat((skele_x, skele_y, skele_z), dim=1) 231 | loss_flux = F.mse_loss(out_flux_reg, skele_xyz) 232 | skele_square = skele_xyz[:, 0, ...] ** 2 + skele_xyz[:, 1, ...] ** 2 + skele_xyz[:, 2, ...] ** 2 233 | out_flux_square = out_flux_reg[:, 0, ...] ** 2 + out_flux_reg[:, 1, ...] ** 2 + out_flux_reg[:, 2, ...] ** 2 234 | loss_flux_square = F.mse_loss(out_flux_square, skele_square) 235 | loss_flux = loss_flux + loss_flux_square 236 | 237 | 238 | loss = (loss_dice_skele_mask + loss_ce_skele_mask) + \ 239 | (loss_dice_edge_mask + loss_ce_edge_mask) + \ 240 | 0.8 * loss_flux + \ 241 | 0.8 * loss_edge 242 | 243 | loss.backward() 244 | optimizer.step() 245 | 246 | dice_loss_edge = dice_score_metric(out_edge_mask_seg, target) 247 | dice_loss_skele = dice_score_metric(out_skele_mask_seg, target) 248 | pred = torch.argmax(out_edge_mask_seg, dim=1).unsqueeze(1) 249 | correct = pred.eq(target.data).cpu().sum() 250 | acc = correct / target.numel() 251 | 252 | correct_pos = torch.logical_and((pred == 1), (target == 1)).sum().item() 253 | sen_pos = round(correct_pos / ((target == 1).sum().item() + 0.0001), 3) 254 | 255 | correct_neg = torch.logical_and((pred == 0), (target == 0)).sum().item() 256 | sen_neg = round(correct_neg / ((target == 0).sum().item() + 0.0001), 3) 257 | 258 | nProcessed += len(data) 259 | partialEpoch = epoch + batch_idx / len(trainLoader) - 1 260 | print('Base Train, Epoch: {:.2f} [{}/{} ({:.0f}%)]\t' 261 | 'Loss: {:.5f}\t' 262 | 'L_Dice_skele: {:.5f}\tL_CE_skele: {:.5f}\t L_flux_reg: {:.5f}\t' 263 | 'L_Dice_edge: {:.5f}\tL_CE_edge: {:.5f}\t L_edge_seg: {:.5f}\t' 264 | 'Acc: {:.3f}\tSen_pos: {:.3f}\tSen_neg: {:.3f}\tDice: {:.5f}\tDice_skele: {:.5f}'.format( 265 | partialEpoch, nProcessed, nTrain, 100. * batch_idx / len(trainLoader), 266 | loss.data, 267 | loss_dice_skele_mask.data, loss_ce_skele_mask.data, loss_flux.data, 268 | loss_dice_edge_mask.data, loss_ce_edge_mask.data, loss_edge.data, 269 | acc, sen_pos, sen_neg, dice_loss_edge, dice_loss_skele)) 270 | 271 | trainF.write('{},{},{},{},{},{},{},{},{}\n'.format(partialEpoch, loss.data, loss_dice_skele_mask.data, 272 | loss_dice_edge_mask.data, acc, sen_pos, sen_neg, 273 | dice_loss_edge, dice_loss_skele)) 274 | trainF.flush() 275 | 276 | wandb.log({ 277 | "Train Dice Skele Mask Loss": loss_dice_skele_mask.item(), 278 | "Train Dice Edge Mask Loss": loss_dice_edge_mask.item(), 279 | "Train Skele Loss": loss_flux.item(), 280 | "Train Edge Loss": loss_edge.item(), 281 | }) 282 | 283 | 284 | def test_base(epoch, model, testLoader, testF, train_segmentor='base'): 285 | model.eval() 286 | test_loss = 0 287 | edge_error = 0 288 | skele_error = 0 289 | dice_score_edge = 0 290 | dice_score_skele = 0 291 | acc_all = 0 292 | sen_pos_all = 0 293 | sen_neg_all = 0 294 | with torch.no_grad(): 295 | for data, target, edge, skele_x, skele_y, skele_z in testLoader: 296 | data, target, edge, skele_x, skele_y, skele_z = Variable(data.cuda()), Variable(target.cuda()), Variable( 297 | edge.cuda()), Variable(skele_x.cuda()), Variable(skele_y.cuda()), Variable(skele_z.cuda()) 298 | 299 | out_flux_reg, out_skele_mask_seg, out_edge_reg, out_edge_mask_seg = model(data, train_segmentor) 300 | 301 | loss_dice_skele = dice_loss(F.softmax(out_skele_mask_seg, dim=1), target).data 302 | loss_dice_edge = dice_loss(F.softmax(out_edge_mask_seg, dim=1), target).data 303 | test_loss += loss_dice_skele + loss_dice_edge 304 | 305 | loss_edge = F.mse_loss(out_edge_reg, edge) 306 | edge_error += loss_edge 307 | 308 | skele_xyz = torch.cat((skele_x, skele_y, skele_z), dim=1) 309 | loss_flux = F.mse_loss(out_flux_reg, skele_xyz) 310 | skele_error += loss_flux 311 | 312 | dice_score_edge += dice_score_metric(out_edge_mask_seg, target) 313 | dice_score_skele += dice_score_metric(out_skele_mask_seg, target) 314 | 315 | pred = torch.argmax(out_edge_mask_seg, dim=1).unsqueeze(1) 316 | correct = pred.eq(target.data).cpu().sum() 317 | acc_all += correct / target.numel() 318 | 319 | correct_pos = torch.logical_and((pred == 1), (target == 1)).sum().item() 320 | sen_pos = round(correct_pos / ((target == 1).sum().item() + 0.0001), 3) 321 | sen_pos_all += sen_pos 322 | 323 | correct_neg = torch.logical_and((pred == 0), (target == 0)).sum().item() 324 | sen_neg = round(correct_neg / ((target == 0).sum().item() + 0.0001), 3) 325 | sen_neg_all += sen_neg 326 | 327 | test_loss /= len(testLoader) 328 | dice_score_edge /= len(testLoader) 329 | dice_score_skele /= len(testLoader) 330 | edge_error /= len(testLoader) 331 | skele_error /= len(testLoader) 332 | 333 | acc_all /= len(testLoader) 334 | sen_pos_all /= len(testLoader) 335 | sen_neg_all /= len(testLoader) 336 | 337 | print('\nTest online: Dice loss: {:.6f}, edge_error: {:.6f}, flux_error: {:.6f}, ' 338 | '\tAcc:{:.3f}, Sen_pos:{:.3f}, Sen_neg:{:.3f}, ' 339 | '\tDice_edge: {:.6f}, Dice_skele: {:.6f}\n'.format( 340 | test_loss, edge_error, skele_error, 341 | acc_all, sen_pos_all, sen_neg_all, 342 | dice_score_edge, dice_score_skele)) 343 | 344 | testF.write( 345 | '{},{},{},{},{},{},{},{},{}\n'.format(epoch, test_loss, edge_error, skele_error, acc_all, sen_pos_all, 346 | sen_neg_all, dice_score_edge, dice_score_skele)) 347 | testF.flush() 348 | 349 | wandb.log({ 350 | "Test Loss": test_loss, 351 | "Test Skele Error": skele_error, 352 | "Test Edge Error": edge_error, 353 | "Test Skele Dice Score": dice_score_skele, 354 | "Test Edge Dice Score": dice_score_edge, 355 | }) 356 | 357 | return (dice_score_skele + dice_score_edge) * 0.5 358 | 359 | 360 | def train_meta(epoch, model, trainLoader, optimizer, trainF, train_segmentor='meta', temperature=0.7): 361 | model.train() 362 | 363 | ################################################# 364 | for name, para in model.module.named_parameters(): 365 | if 'meta' in name: 366 | para.requires_grad = True 367 | else: 368 | para.requires_grad = False 369 | ################################################# 370 | 371 | for batch_idx, (data, target, edge, skele_x, skele_y, skele_z) in enumerate(trainLoader): 372 | data, target, edge, skele_x, skele_y, skele_z = Variable(data.cuda()), Variable(target.cuda()), \ 373 | Variable(edge.cuda()), Variable(skele_x.cuda()), \ 374 | Variable(skele_y.cuda()), Variable(skele_z.cuda()) 375 | optimizer.zero_grad() 376 | 377 | out_meta_mask_seg, out_edge, out_skele = model(data, train_segmentor) 378 | 379 | prob_edge = F.softmax(out_edge / temperature, dim=1) 380 | prob_skele = F.softmax(out_skele / temperature, dim=1) 381 | 382 | alpha = np.random.uniform(0.0, 1.0, 1)[0] 383 | soft_pseudo_target = alpha * prob_edge + (1 - alpha) * prob_skele 384 | loss_dice = dice_loss_PL(F.softmax(out_meta_mask_seg / temperature, dim=1), soft_pseudo_target) 385 | loss_ce = torch.mean(-(soft_pseudo_target * F.log_softmax(out_meta_mask_seg / temperature, dim=1)).sum(dim=1)) 386 | loss_kl = F.kl_div(F.log_softmax(out_meta_mask_seg / temperature, dim=1), soft_pseudo_target, reduction='none').sum(dim=1).mean() 387 | 388 | loss = loss_ce + loss_dice 389 | 390 | loss.backward() 391 | optimizer.step() 392 | 393 | dice_score_meta = dice_score_metric(out_meta_mask_seg, target) 394 | 395 | pred = torch.argmax(out_meta_mask_seg, dim=1).unsqueeze(1) 396 | correct = pred.eq(target.data).cpu().sum() 397 | acc = correct / target.numel() 398 | 399 | correct_pos = torch.logical_and((pred == 1), (target == 1)).sum().item() 400 | sen_pos = round(correct_pos / ((target == 1).sum().item() + 0.0001), 3) 401 | 402 | correct_neg = torch.logical_and((pred == 0), (target == 0)).sum().item() 403 | sen_neg = round(correct_neg / ((target == 0).sum().item() + 0.0001), 3) 404 | 405 | partialEpoch = epoch + batch_idx / len(trainLoader) - 1 406 | 407 | print('Meta Train, Epoch: {:.2f}\tloss: {:.5f}\tLoss_dice: {:.5f}\tLoss_ce: {:.5f}\tAcc: {:.5f}\tSen: {:.5f}\tSpe: {:.5f}\tDice_score: {:.5f}'. 408 | format(partialEpoch, loss.item(), loss_dice.item(), loss_ce.item(), acc, sen_pos, sen_neg, dice_score_meta)) 409 | 410 | trainF.write('{},{},{},{},{},{},{},{}\n'. 411 | format(partialEpoch, loss.item(), loss_dice.item(), loss_ce.item(), acc, sen_pos, sen_neg, dice_score_meta)) 412 | 413 | wandb.log({ 414 | "Train Meta Mask Loss": loss.item(), 415 | "Train Meta CE Loss": loss_ce.item(), 416 | "Train Meta Dice Loss": loss_dice.item(), 417 | }) 418 | 419 | trainF.flush() 420 | 421 | 422 | def test_meta(epoch, model, testLoader, testF, train_segmentor='meta'): 423 | model.eval() 424 | test_loss = 0 425 | dice_score_meta = 0 426 | acc_all = 0 427 | sen_pos_all = 0 428 | sen_neg_all = 0 429 | with torch.no_grad(): 430 | for data, target, edge, skele_x, skele_y, skele_z in testLoader: 431 | data, target, edge, skele_x, skele_y, skele_z = Variable(data.cuda()), Variable(target.cuda()), Variable( 432 | edge.cuda()), Variable(skele_x.cuda()), Variable(skele_y.cuda()), Variable(skele_z.cuda()) 433 | 434 | out_meta_mask_seg, prob_edge, prob_skele = model(data, train_segmentor) 435 | 436 | loss_dice_meta = dice_loss(F.softmax(out_meta_mask_seg, dim=1), target).data 437 | test_loss += loss_dice_meta 438 | 439 | dice_score_meta += dice_score_metric(out_meta_mask_seg, target) 440 | 441 | pred = torch.argmax(out_meta_mask_seg, dim=1).unsqueeze(1) 442 | correct = pred.eq(target.data).cpu().sum() 443 | acc_all += correct / target.numel() 444 | 445 | correct_pos = torch.logical_and((pred == 1), (target == 1)).sum().item() 446 | sen_pos = round(correct_pos / ((target == 1).sum().item() + 0.0001), 3) 447 | sen_pos_all += sen_pos 448 | 449 | correct_neg = torch.logical_and((pred == 0), (target == 0)).sum().item() 450 | sen_neg = round(correct_neg / ((target == 0).sum().item() + 0.0001), 3) 451 | sen_neg_all += sen_neg 452 | 453 | test_loss /= len(testLoader) 454 | dice_score_meta /= len(testLoader) 455 | 456 | acc_all /= len(testLoader) 457 | sen_pos_all /= len(testLoader) 458 | sen_neg_all /= len(testLoader) 459 | print('\nTest online: Dice loss: {:.5f}, Acc:{:.3f}, Sen:{:.3f}, Spe:{:.3f}, Dice_meta: {:.5f}\n'.format( 460 | test_loss, acc_all, sen_pos_all, sen_neg_all, dice_score_meta)) 461 | 462 | testF.write( 463 | '{},{},{},{},{},{}\n'.format(epoch, test_loss, acc_all, sen_pos_all, sen_neg_all, dice_score_meta)) 464 | 465 | 466 | wandb.log({ 467 | "Test Meta Loss": test_loss, 468 | "Test Meta Dice Score": dice_score_meta, 469 | }) 470 | 471 | 472 | testF.flush() 473 | 474 | return dice_score_meta 475 | 476 | 477 | if __name__ == '__main__': 478 | main() 479 | -------------------------------------------------------------------------------- /bowel_fineseg/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2013 Oskar Maier 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | # 16 | # author Oskar Maier 17 | # version r0.1.1 18 | # since 2014-03-13 19 | # status Release 20 | 21 | # build-in modules 22 | 23 | # third-party modules 24 | import numpy 25 | from scipy.ndimage import _ni_support 26 | from scipy.ndimage.morphology import distance_transform_edt, binary_erosion, \ 27 | generate_binary_structure 28 | from scipy.ndimage.measurements import label, find_objects 29 | from scipy.stats import pearsonr 30 | 31 | 32 | # own modules 33 | 34 | # code 35 | def dc(result, reference): 36 | r""" 37 | Dice coefficient 38 | 39 | Computes the Dice coefficient (also known as Sorensen index) between the binary 40 | objects in two images. 41 | 42 | The metric is defined as 43 | 44 | .. math:: 45 | 46 | DC=\frac{2|A\cap B|}{|A|+|B|} 47 | 48 | , where :math:`A` is the first and :math:`B` the second set of samples (here: binary objects). 49 | 50 | Parameters 51 | ---------- 52 | result : array_like 53 | Input data containing objects. Can be any type but will be converted 54 | into binary: background where 0, object everywhere else. 55 | reference : array_like 56 | Input data containing objects. Can be any type but will be converted 57 | into binary: background where 0, object everywhere else. 58 | 59 | Returns 60 | ------- 61 | dc : float 62 | The Dice coefficient between the object(s) in ```result``` and the 63 | object(s) in ```reference```. It ranges from 0 (no overlap) to 1 (perfect overlap). 64 | 65 | Notes 66 | ----- 67 | This is a real metric. The binary images can therefore be supplied in any order. 68 | """ 69 | result = numpy.atleast_1d(result.astype(numpy.bool)) 70 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 71 | 72 | intersection = numpy.count_nonzero(result & reference) 73 | 74 | size_i1 = numpy.count_nonzero(result) 75 | size_i2 = numpy.count_nonzero(reference) 76 | 77 | try: 78 | dc = 2. * intersection / float(size_i1 + size_i2) 79 | except ZeroDivisionError: 80 | dc = 0.0 81 | 82 | return dc 83 | 84 | 85 | def jc(result, reference): 86 | """ 87 | Jaccard coefficient 88 | 89 | Computes the Jaccard coefficient between the binary objects in two images. 90 | 91 | Parameters 92 | ---------- 93 | result: array_like 94 | Input data containing objects. Can be any type but will be converted 95 | into binary: background where 0, object everywhere else. 96 | reference: array_like 97 | Input data containing objects. Can be any type but will be converted 98 | into binary: background where 0, object everywhere else. 99 | 100 | Returns 101 | ------- 102 | jc: float 103 | The Jaccard coefficient between the object(s) in `result` and the 104 | object(s) in `reference`. It ranges from 0 (no overlap) to 1 (perfect overlap). 105 | 106 | Notes 107 | ----- 108 | This is a real metric. The binary images can therefore be supplied in any order. 109 | """ 110 | result = numpy.atleast_1d(result.astype(numpy.bool)) 111 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 112 | 113 | intersection = numpy.count_nonzero(result & reference) 114 | union = numpy.count_nonzero(result | reference) 115 | 116 | jc = float(intersection) / float(union) 117 | 118 | return jc 119 | 120 | 121 | def precision(result, reference): 122 | """ 123 | Precison. 124 | 125 | Parameters 126 | ---------- 127 | result : array_like 128 | Input data containing objects. Can be any type but will be converted 129 | into binary: background where 0, object everywhere else. 130 | reference : array_like 131 | Input data containing objects. Can be any type but will be converted 132 | into binary: background where 0, object everywhere else. 133 | 134 | Returns 135 | ------- 136 | precision : float 137 | The precision between two binary datasets, here mostly binary objects in images, 138 | which is defined as the fraction of retrieved instances that are relevant. The 139 | precision is not symmetric. 140 | 141 | See also 142 | -------- 143 | :func:`recall` 144 | 145 | Notes 146 | ----- 147 | Not symmetric. The inverse of the precision is :func:`recall`. 148 | High precision means that an algorithm returned substantially more relevant results than irrelevant. 149 | 150 | References 151 | ---------- 152 | .. [1] http://en.wikipedia.org/wiki/Precision_and_recall 153 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 154 | """ 155 | result = numpy.atleast_1d(result.astype(numpy.bool)) 156 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 157 | 158 | tp = numpy.count_nonzero(result & reference) 159 | fp = numpy.count_nonzero(result & ~reference) 160 | 161 | try: 162 | precision = tp / float(tp + fp) 163 | except ZeroDivisionError: 164 | precision = 0.0 165 | 166 | return precision 167 | 168 | 169 | def recall(result, reference): 170 | """ 171 | Recall. 172 | 173 | Parameters 174 | ---------- 175 | result : array_like 176 | Input data containing objects. Can be any type but will be converted 177 | into binary: background where 0, object everywhere else. 178 | reference : array_like 179 | Input data containing objects. Can be any type but will be converted 180 | into binary: background where 0, object everywhere else. 181 | 182 | Returns 183 | ------- 184 | recall : float 185 | The recall between two binary datasets, here mostly binary objects in images, 186 | which is defined as the fraction of relevant instances that are retrieved. The 187 | recall is not symmetric. 188 | 189 | See also 190 | -------- 191 | :func:`precision` 192 | 193 | Notes 194 | ----- 195 | Not symmetric. The inverse of the recall is :func:`precision`. 196 | High recall means that an algorithm returned most of the relevant results. 197 | 198 | References 199 | ---------- 200 | .. [1] http://en.wikipedia.org/wiki/Precision_and_recall 201 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 202 | """ 203 | result = numpy.atleast_1d(result.astype(numpy.bool)) 204 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 205 | 206 | tp = numpy.count_nonzero(result & reference) 207 | fn = numpy.count_nonzero(~result & reference) 208 | 209 | try: 210 | recall = tp / float(tp + fn) 211 | except ZeroDivisionError: 212 | recall = 0.0 213 | 214 | return recall 215 | 216 | 217 | def sensitivity(result, reference): 218 | """ 219 | Sensitivity. 220 | Same as :func:`recall`, see there for a detailed description. 221 | 222 | See also 223 | -------- 224 | :func:`specificity` 225 | """ 226 | return recall(result, reference) 227 | 228 | 229 | def specificity(result, reference): 230 | """ 231 | Specificity. 232 | 233 | Parameters 234 | ---------- 235 | result : array_like 236 | Input data containing objects. Can be any type but will be converted 237 | into binary: background where 0, object everywhere else. 238 | reference : array_like 239 | Input data containing objects. Can be any type but will be converted 240 | into binary: background where 0, object everywhere else. 241 | 242 | Returns 243 | ------- 244 | specificity : float 245 | The specificity between two binary datasets, here mostly binary objects in images, 246 | which denotes the fraction of correctly returned negatives. The 247 | specificity is not symmetric. 248 | 249 | See also 250 | -------- 251 | :func:`sensitivity` 252 | 253 | Notes 254 | ----- 255 | Not symmetric. The completment of the specificity is :func:`sensitivity`. 256 | High recall means that an algorithm returned most of the irrelevant results. 257 | 258 | References 259 | ---------- 260 | .. [1] https://en.wikipedia.org/wiki/Sensitivity_and_specificity 261 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 262 | """ 263 | result = numpy.atleast_1d(result.astype(numpy.bool)) 264 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 265 | 266 | tn = numpy.count_nonzero(~result & ~reference) 267 | fp = numpy.count_nonzero(result & ~reference) 268 | 269 | try: 270 | specificity = tn / float(tn + fp) 271 | except ZeroDivisionError: 272 | specificity = 0.0 273 | 274 | return specificity 275 | 276 | 277 | def true_negative_rate(result, reference): 278 | """ 279 | True negative rate. 280 | Same as :func:`specificity`, see there for a detailed description. 281 | 282 | See also 283 | -------- 284 | :func:`true_positive_rate` 285 | :func:`positive_predictive_value` 286 | """ 287 | return specificity(result, reference) 288 | 289 | 290 | def true_positive_rate(result, reference): 291 | """ 292 | True positive rate. 293 | Same as :func:`recall` and :func:`sensitivity`, see there for a detailed description. 294 | 295 | See also 296 | -------- 297 | :func:`positive_predictive_value` 298 | :func:`true_negative_rate` 299 | """ 300 | return recall(result, reference) 301 | 302 | 303 | def positive_predictive_value(result, reference): 304 | """ 305 | Positive predictive value. 306 | Same as :func:`precision`, see there for a detailed description. 307 | 308 | See also 309 | -------- 310 | :func:`true_positive_rate` 311 | :func:`true_negative_rate` 312 | """ 313 | return precision(result, reference) 314 | 315 | 316 | def hd(result, reference, voxelspacing=None, connectivity=1): 317 | """ 318 | Hausdorff Distance. 319 | 320 | Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two 321 | images. It is defined as the maximum surface distance between the objects. 322 | 323 | Parameters 324 | ---------- 325 | result : array_like 326 | Input data containing objects. Can be any type but will be converted 327 | into binary: background where 0, object everywhere else. 328 | reference : array_like 329 | Input data containing objects. Can be any type but will be converted 330 | into binary: background where 0, object everywhere else. 331 | voxelspacing : float or sequence of floats, optional 332 | The voxelspacing in a distance unit i.e. spacing of elements 333 | along each dimension. If a sequence, must be of length equal to 334 | the input rank; if a single number, this is used for all axes. If 335 | not specified, a grid spacing of unity is implied. 336 | connectivity : int 337 | The neighbourhood/connectivity considered when determining the surface 338 | of the binary objects. This value is passed to 339 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 340 | Note that the connectivity influences the result in the case of the Hausdorff distance. 341 | 342 | Returns 343 | ------- 344 | hd : float 345 | The symmetric Hausdorff Distance between the object(s) in ```result``` and the 346 | object(s) in ```reference```. The distance unit is the same as for the spacing of 347 | elements along each dimension, which is usually given in mm. 348 | 349 | See also 350 | -------- 351 | :func:`assd` 352 | :func:`asd` 353 | 354 | Notes 355 | ----- 356 | This is a real metric. The binary images can therefore be supplied in any order. 357 | """ 358 | hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max() 359 | hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max() 360 | hd = max(hd1, hd2) 361 | return hd 362 | 363 | 364 | def hd95(result, reference, voxelspacing=None, connectivity=1): 365 | """ 366 | 95th percentile of the Hausdorff Distance. 367 | 368 | Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two 369 | images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is 370 | commonly used in Biomedical Segmentation challenges. 371 | 372 | Parameters 373 | ---------- 374 | result : array_like 375 | Input data containing objects. Can be any type but will be converted 376 | into binary: background where 0, object everywhere else. 377 | reference : array_like 378 | Input data containing objects. Can be any type but will be converted 379 | into binary: background where 0, object everywhere else. 380 | voxelspacing : float or sequence of floats, optional 381 | The voxelspacing in a distance unit i.e. spacing of elements 382 | along each dimension. If a sequence, must be of length equal to 383 | the input rank; if a single number, this is used for all axes. If 384 | not specified, a grid spacing of unity is implied. 385 | connectivity : int 386 | The neighbourhood/connectivity considered when determining the surface 387 | of the binary objects. This value is passed to 388 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 389 | Note that the connectivity influences the result in the case of the Hausdorff distance. 390 | 391 | Returns 392 | ------- 393 | hd : float 394 | The symmetric Hausdorff Distance between the object(s) in ```result``` and the 395 | object(s) in ```reference```. The distance unit is the same as for the spacing of 396 | elements along each dimension, which is usually given in mm. 397 | 398 | See also 399 | -------- 400 | :func:`hd` 401 | 402 | Notes 403 | ----- 404 | This is a real metric. The binary images can therefore be supplied in any order. 405 | """ 406 | hd1 = __surface_distances(result, reference, voxelspacing, connectivity) 407 | hd2 = __surface_distances(reference, result, voxelspacing, connectivity) 408 | hd95 = numpy.percentile(numpy.hstack((hd1, hd2)), 95) 409 | return hd95 410 | 411 | 412 | def assd(result, reference, voxelspacing=None, connectivity=1): 413 | """ 414 | Average symmetric surface distance. 415 | 416 | Computes the average symmetric surface distance (ASD) between the binary objects in 417 | two images. 418 | 419 | Parameters 420 | ---------- 421 | result : array_like 422 | Input data containing objects. Can be any type but will be converted 423 | into binary: background where 0, object everywhere else. 424 | reference : array_like 425 | Input data containing objects. Can be any type but will be converted 426 | into binary: background where 0, object everywhere else. 427 | voxelspacing : float or sequence of floats, optional 428 | The voxelspacing in a distance unit i.e. spacing of elements 429 | along each dimension. If a sequence, must be of length equal to 430 | the input rank; if a single number, this is used for all axes. If 431 | not specified, a grid spacing of unity is implied. 432 | connectivity : int 433 | The neighbourhood/connectivity considered when determining the surface 434 | of the binary objects. This value is passed to 435 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 436 | The decision on the connectivity is important, as it can influence the results 437 | strongly. If in doubt, leave it as it is. 438 | 439 | Returns 440 | ------- 441 | assd : float 442 | The average symmetric surface distance between the object(s) in ``result`` and the 443 | object(s) in ``reference``. The distance unit is the same as for the spacing of 444 | elements along each dimension, which is usually given in mm. 445 | 446 | See also 447 | -------- 448 | :func:`asd` 449 | :func:`hd` 450 | 451 | Notes 452 | ----- 453 | This is a real metric, obtained by calling and averaging 454 | 455 | >>> asd(result, reference) 456 | 457 | and 458 | 459 | >>> asd(reference, result) 460 | 461 | The binary images can therefore be supplied in any order. 462 | """ 463 | assd = numpy.mean( 464 | (asd(result, reference, voxelspacing, connectivity), asd(reference, result, voxelspacing, connectivity))) 465 | return assd 466 | 467 | 468 | def asd(result, reference, voxelspacing=None, connectivity=1): 469 | """ 470 | Average surface distance metric. 471 | 472 | Computes the average surface distance (ASD) between the binary objects in two images. 473 | 474 | Parameters 475 | ---------- 476 | result : array_like 477 | Input data containing objects. Can be any type but will be converted 478 | into binary: background where 0, object everywhere else. 479 | reference : array_like 480 | Input data containing objects. Can be any type but will be converted 481 | into binary: background where 0, object everywhere else. 482 | voxelspacing : float or sequence of floats, optional 483 | The voxelspacing in a distance unit i.e. spacing of elements 484 | along each dimension. If a sequence, must be of length equal to 485 | the input rank; if a single number, this is used for all axes. If 486 | not specified, a grid spacing of unity is implied. 487 | connectivity : int 488 | The neighbourhood/connectivity considered when determining the surface 489 | of the binary objects. This value is passed to 490 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 491 | The decision on the connectivity is important, as it can influence the results 492 | strongly. If in doubt, leave it as it is. 493 | 494 | Returns 495 | ------- 496 | asd : float 497 | The average surface distance between the object(s) in ``result`` and the 498 | object(s) in ``reference``. The distance unit is the same as for the spacing 499 | of elements along each dimension, which is usually given in mm. 500 | 501 | See also 502 | -------- 503 | :func:`assd` 504 | :func:`hd` 505 | 506 | 507 | Notes 508 | ----- 509 | This is not a real metric, as it is directed. See `assd` for a real metric of this. 510 | 511 | The method is implemented making use of distance images and simple binary morphology 512 | to achieve high computational speed. 513 | 514 | Examples 515 | -------- 516 | The `connectivity` determines what pixels/voxels are considered the surface of a 517 | binary object. Take the following binary image showing a cross 518 | 519 | >>> from scipy.ndimage.morphology import generate_binary_structure 520 | >>> cross = generate_binary_structure(2, 1) 521 | array([[0, 1, 0], 522 | [1, 1, 1], 523 | [0, 1, 0]]) 524 | 525 | With `connectivity` set to `1` a 4-neighbourhood is considered when determining the 526 | object surface, resulting in the surface 527 | 528 | .. code-block:: python 529 | 530 | array([[0, 1, 0], 531 | [1, 0, 1], 532 | [0, 1, 0]]) 533 | 534 | Changing `connectivity` to `2`, a 8-neighbourhood is considered and we get: 535 | 536 | .. code-block:: python 537 | 538 | array([[0, 1, 0], 539 | [1, 1, 1], 540 | [0, 1, 0]]) 541 | 542 | , as a diagonal connection does no longer qualifies as valid object surface. 543 | 544 | This influences the results `asd` returns. Imagine we want to compute the surface 545 | distance of our cross to a cube-like object: 546 | 547 | >>> cube = generate_binary_structure(2, 1) 548 | array([[1, 1, 1], 549 | [1, 1, 1], 550 | [1, 1, 1]]) 551 | 552 | , which surface is, independent of the `connectivity` value set, always 553 | 554 | .. code-block:: python 555 | 556 | array([[1, 1, 1], 557 | [1, 0, 1], 558 | [1, 1, 1]]) 559 | 560 | Using a `connectivity` of `1` we get 561 | 562 | >>> asd(cross, cube, connectivity=1) 563 | 0.0 564 | 565 | while a value of `2` returns us 566 | 567 | >>> asd(cross, cube, connectivity=2) 568 | 0.20000000000000001 569 | 570 | due to the center of the cross being considered surface as well. 571 | 572 | """ 573 | sds = __surface_distances(result, reference, voxelspacing, connectivity) 574 | asd = sds.mean() 575 | return asd 576 | 577 | 578 | def ravd(result, reference): 579 | """ 580 | Relative absolute volume difference. 581 | 582 | Compute the relative absolute volume difference between the (joined) binary objects 583 | in the two images. 584 | 585 | Parameters 586 | ---------- 587 | result : array_like 588 | Input data containing objects. Can be any type but will be converted 589 | into binary: background where 0, object everywhere else. 590 | reference : array_like 591 | Input data containing objects. Can be any type but will be converted 592 | into binary: background where 0, object everywhere else. 593 | 594 | Returns 595 | ------- 596 | ravd : float 597 | The relative absolute volume difference between the object(s) in ``result`` 598 | and the object(s) in ``reference``. This is a percentage value in the range 599 | :math:`[-1.0, +inf]` for which a :math:`0` denotes an ideal score. 600 | 601 | Raises 602 | ------ 603 | RuntimeError 604 | If the reference object is empty. 605 | 606 | See also 607 | -------- 608 | :func:`dc` 609 | :func:`precision` 610 | :func:`recall` 611 | 612 | Notes 613 | ----- 614 | This is not a real metric, as it is directed. Negative values denote a smaller 615 | and positive values a larger volume than the reference. 616 | This implementation does not check, whether the two supplied arrays are of the same 617 | size. 618 | 619 | Examples 620 | -------- 621 | Considering the following inputs 622 | 623 | >>> import numpy 624 | >>> arr1 = numpy.asarray([[0,1,0],[1,1,1],[0,1,0]]) 625 | >>> arr1 626 | array([[0, 1, 0], 627 | [1, 1, 1], 628 | [0, 1, 0]]) 629 | >>> arr2 = numpy.asarray([[0,1,0],[1,0,1],[0,1,0]]) 630 | >>> arr2 631 | array([[0, 1, 0], 632 | [1, 0, 1], 633 | [0, 1, 0]]) 634 | 635 | comparing `arr1` to `arr2` we get 636 | 637 | >>> ravd(arr1, arr2) 638 | -0.2 639 | 640 | and reversing the inputs the directivness of the metric becomes evident 641 | 642 | >>> ravd(arr2, arr1) 643 | 0.25 644 | 645 | It is important to keep in mind that a perfect score of `0` does not mean that the 646 | binary objects fit exactely, as only the volumes are compared: 647 | 648 | >>> arr1 = numpy.asarray([1,0,0]) 649 | >>> arr2 = numpy.asarray([0,0,1]) 650 | >>> ravd(arr1, arr2) 651 | 0.0 652 | 653 | """ 654 | result = numpy.atleast_1d(result.astype(numpy.bool)) 655 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 656 | 657 | vol1 = numpy.count_nonzero(result) 658 | vol2 = numpy.count_nonzero(reference) 659 | 660 | if 0 == vol2: 661 | raise RuntimeError('The second supplied array does not contain any binary object.') 662 | 663 | return (vol1 - vol2) / float(vol2) 664 | 665 | 666 | def volume_correlation(results, references): 667 | r""" 668 | Volume correlation. 669 | 670 | Computes the linear correlation in binary object volume between the 671 | contents of the successive binary images supplied. Measured through 672 | the Pearson product-moment correlation coefficient. 673 | 674 | Parameters 675 | ---------- 676 | results : sequence of array_like 677 | Ordered list of input data containing objects. Each array_like will be 678 | converted into binary: background where 0, object everywhere else. 679 | references : sequence of array_like 680 | Ordered list of input data containing objects. Each array_like will be 681 | converted into binary: background where 0, object everywhere else. 682 | The order must be the same as for ``results``. 683 | 684 | Returns 685 | ------- 686 | r : float 687 | The correlation coefficient between -1 and 1. 688 | p : float 689 | The two-side p value. 690 | 691 | """ 692 | results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool)) 693 | references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool)) 694 | 695 | results_volumes = [numpy.count_nonzero(r) for r in results] 696 | references_volumes = [numpy.count_nonzero(r) for r in references] 697 | 698 | return pearsonr(results_volumes, references_volumes) # returns (Pearson' 699 | 700 | 701 | def volume_change_correlation(results, references): 702 | r""" 703 | Volume change correlation. 704 | 705 | Computes the linear correlation of change in binary object volume between 706 | the contents of the successive binary images supplied. Measured through 707 | the Pearson product-moment correlation coefficient. 708 | 709 | Parameters 710 | ---------- 711 | results : sequence of array_like 712 | Ordered list of input data containing objects. Each array_like will be 713 | converted into binary: background where 0, object everywhere else. 714 | references : sequence of array_like 715 | Ordered list of input data containing objects. Each array_like will be 716 | converted into binary: background where 0, object everywhere else. 717 | The order must be the same as for ``results``. 718 | 719 | Returns 720 | ------- 721 | r : float 722 | The correlation coefficient between -1 and 1. 723 | p : float 724 | The two-side p value. 725 | 726 | """ 727 | results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool)) 728 | references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool)) 729 | 730 | results_volumes = numpy.asarray([numpy.count_nonzero(r) for r in results]) 731 | references_volumes = numpy.asarray([numpy.count_nonzero(r) for r in references]) 732 | 733 | results_volumes_changes = results_volumes[1:] - results_volumes[:-1] 734 | references_volumes_changes = references_volumes[1:] - references_volumes[:-1] 735 | 736 | return pearsonr(results_volumes_changes, 737 | references_volumes_changes) # returns (Pearson's correlation coefficient, 2-tailed p-value) 738 | 739 | 740 | def obj_assd(result, reference, voxelspacing=None, connectivity=1): 741 | """ 742 | Average symmetric surface distance. 743 | 744 | Computes the average symmetric surface distance (ASSD) between the binary objects in 745 | two images. 746 | 747 | Parameters 748 | ---------- 749 | result : array_like 750 | Input data containing objects. Can be any type but will be converted 751 | into binary: background where 0, object everywhere else. 752 | reference : array_like 753 | Input data containing objects. Can be any type but will be converted 754 | into binary: background where 0, object everywhere else. 755 | voxelspacing : float or sequence of floats, optional 756 | The voxelspacing in a distance unit i.e. spacing of elements 757 | along each dimension. If a sequence, must be of length equal to 758 | the input rank; if a single number, this is used for all axes. If 759 | not specified, a grid spacing of unity is implied. 760 | connectivity : int 761 | The neighbourhood/connectivity considered when determining what accounts 762 | for a distinct binary object as well as when determining the surface 763 | of the binary objects. This value is passed to 764 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 765 | The decision on the connectivity is important, as it can influence the results 766 | strongly. If in doubt, leave it as it is. 767 | 768 | Returns 769 | ------- 770 | assd : float 771 | The average symmetric surface distance between all mutually existing distinct 772 | binary object(s) in ``result`` and ``reference``. The distance unit is the same as for 773 | the spacing of elements along each dimension, which is usually given in mm. 774 | 775 | See also 776 | -------- 777 | :func:`obj_asd` 778 | 779 | Notes 780 | ----- 781 | This is a real metric, obtained by calling and averaging 782 | 783 | >>> obj_asd(result, reference) 784 | 785 | and 786 | 787 | >>> obj_asd(reference, result) 788 | 789 | The binary images can therefore be supplied in any order. 790 | """ 791 | assd = numpy.mean((obj_asd(result, reference, voxelspacing, connectivity), 792 | obj_asd(reference, result, voxelspacing, connectivity))) 793 | return assd 794 | 795 | 796 | def obj_asd(result, reference, voxelspacing=None, connectivity=1): 797 | """ 798 | Average surface distance between objects. 799 | 800 | First correspondences between distinct binary objects in reference and result are 801 | established. Then the average surface distance is only computed between corresponding 802 | objects. Correspondence is defined as unique and at least one voxel overlap. 803 | 804 | Parameters 805 | ---------- 806 | result : array_like 807 | Input data containing objects. Can be any type but will be converted 808 | into binary: background where 0, object everywhere else. 809 | reference : array_like 810 | Input data containing objects. Can be any type but will be converted 811 | into binary: background where 0, object everywhere else. 812 | voxelspacing : float or sequence of floats, optional 813 | The voxelspacing in a distance unit i.e. spacing of elements 814 | along each dimension. If a sequence, must be of length equal to 815 | the input rank; if a single number, this is used for all axes. If 816 | not specified, a grid spacing of unity is implied. 817 | connectivity : int 818 | The neighbourhood/connectivity considered when determining what accounts 819 | for a distinct binary object as well as when determining the surface 820 | of the binary objects. This value is passed to 821 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 822 | The decision on the connectivity is important, as it can influence the results 823 | strongly. If in doubt, leave it as it is. 824 | 825 | Returns 826 | ------- 827 | asd : float 828 | The average surface distance between all mutually existing distinct binary 829 | object(s) in ``result`` and ``reference``. The distance unit is the same as for the 830 | spacing of elements along each dimension, which is usually given in mm. 831 | 832 | See also 833 | -------- 834 | :func:`obj_assd` 835 | :func:`obj_tpr` 836 | :func:`obj_fpr` 837 | 838 | Notes 839 | ----- 840 | This is not a real metric, as it is directed. See `obj_assd` for a real metric of this. 841 | 842 | For the understanding of this metric, both the notions of connectedness and surface 843 | distance are essential. Please see :func:`obj_tpr` and :func:`obj_fpr` for more 844 | information on the first and :func:`asd` on the second. 845 | 846 | Examples 847 | -------- 848 | >>> arr1 = numpy.asarray([[1,1,1],[1,1,1],[1,1,1]]) 849 | >>> arr2 = numpy.asarray([[0,1,0],[0,1,0],[0,1,0]]) 850 | >>> arr1 851 | array([[1, 1, 1], 852 | [1, 1, 1], 853 | [1, 1, 1]]) 854 | >>> arr2 855 | array([[0, 1, 0], 856 | [0, 1, 0], 857 | [0, 1, 0]]) 858 | >>> obj_asd(arr1, arr2) 859 | 1.5 860 | >>> obj_asd(arr2, arr1) 861 | 0.333333333333 862 | 863 | With the `voxelspacing` parameter, the distances between the voxels can be set for 864 | each dimension separately: 865 | 866 | >>> obj_asd(arr1, arr2, voxelspacing=(1,2)) 867 | 1.5 868 | >>> obj_asd(arr2, arr1, voxelspacing=(1,2)) 869 | 0.333333333333 870 | 871 | More examples depicting the notion of object connectedness: 872 | 873 | >>> arr1 = numpy.asarray([[1,0,1],[1,0,0],[0,0,0]]) 874 | >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]]) 875 | >>> arr1 876 | array([[1, 0, 1], 877 | [1, 0, 0], 878 | [0, 0, 0]]) 879 | >>> arr2 880 | array([[1, 0, 1], 881 | [1, 0, 0], 882 | [0, 0, 1]]) 883 | >>> obj_asd(arr1, arr2) 884 | 0.0 885 | >>> obj_asd(arr2, arr1) 886 | 0.0 887 | 888 | >>> arr1 = numpy.asarray([[1,0,1],[1,0,1],[0,0,1]]) 889 | >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]]) 890 | >>> arr1 891 | array([[1, 0, 1], 892 | [1, 0, 1], 893 | [0, 0, 1]]) 894 | >>> arr2 895 | array([[1, 0, 1], 896 | [1, 0, 0], 897 | [0, 0, 1]]) 898 | >>> obj_asd(arr1, arr2) 899 | 0.6 900 | >>> obj_asd(arr2, arr1) 901 | 0.0 902 | 903 | Influence of `connectivity` parameter can be seen in the following example, where 904 | with the (default) connectivity of `1` the first array is considered to contain two 905 | objects, while with an increase connectivity of `2`, just one large object is 906 | detected. 907 | 908 | >>> arr1 = numpy.asarray([[1,0,0],[0,1,1],[0,1,1]]) 909 | >>> arr2 = numpy.asarray([[1,0,0],[0,0,0],[0,0,0]]) 910 | >>> arr1 911 | array([[1, 0, 0], 912 | [0, 1, 1], 913 | [0, 1, 1]]) 914 | >>> arr2 915 | array([[1, 0, 0], 916 | [0, 0, 0], 917 | [0, 0, 0]]) 918 | >>> obj_asd(arr1, arr2) 919 | 0.0 920 | >>> obj_asd(arr1, arr2, connectivity=2) 921 | 1.742955328 922 | 923 | Note that the connectivity also influence the notion of what is considered an object 924 | surface voxels. 925 | """ 926 | sds = list() 927 | labelmap1, labelmap2, _a, _b, mapping = __distinct_binary_object_correspondences(result, reference, connectivity) 928 | slicers1 = find_objects(labelmap1) 929 | slicers2 = find_objects(labelmap2) 930 | for lid2, lid1 in list(mapping.items()): 931 | window = __combine_windows(slicers1[lid1 - 1], slicers2[lid2 - 1]) 932 | object1 = labelmap1[window] == lid1 933 | object2 = labelmap2[window] == lid2 934 | sds.extend(__surface_distances(object1, object2, voxelspacing, connectivity)) 935 | asd = numpy.mean(sds) 936 | return asd 937 | 938 | 939 | def obj_fpr(result, reference, connectivity=1): 940 | """ 941 | The false positive rate of distinct binary object detection. 942 | 943 | The false positive rates gives a percentage measure of how many distinct binary 944 | objects in the second array do not exists in the first array. A partial overlap 945 | (of minimum one voxel) is here considered sufficient. 946 | 947 | In cases where two distinct binary object in the second array overlap with a single 948 | distinct object in the first array, only one is considered to have been detected 949 | successfully and the other is added to the count of false positives. 950 | 951 | Parameters 952 | ---------- 953 | result : array_like 954 | Input data containing objects. Can be any type but will be converted 955 | into binary: background where 0, object everywhere else. 956 | reference : array_like 957 | Input data containing objects. Can be any type but will be converted 958 | into binary: background where 0, object everywhere else. 959 | connectivity : int 960 | The neighbourhood/connectivity considered when determining what accounts 961 | for a distinct binary object. This value is passed to 962 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 963 | The decision on the connectivity is important, as it can influence the results 964 | strongly. If in doubt, leave it as it is. 965 | 966 | Returns 967 | ------- 968 | tpr : float 969 | A percentage measure of how many distinct binary objects in ``results`` have no 970 | corresponding binary object in ``reference``. It has the range :math:`[0, 1]`, where a :math:`0` 971 | denotes an ideal score. 972 | 973 | Raises 974 | ------ 975 | RuntimeError 976 | If the second array is empty. 977 | 978 | See also 979 | -------- 980 | :func:`obj_tpr` 981 | 982 | Notes 983 | ----- 984 | This is not a real metric, as it is directed. Whatever array is considered as 985 | reference should be passed second. A perfect score of :math:`0` tells that there are no 986 | distinct binary objects in the second array that do not exists also in the reference 987 | array, but does not reveal anything about objects in the reference array also 988 | existing in the second array (use :func:`obj_tpr` for this). 989 | 990 | Examples 991 | -------- 992 | >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]]) 993 | >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]]) 994 | >>> arr2 995 | array([[1, 0, 0], 996 | [1, 0, 1], 997 | [0, 0, 1]]) 998 | >>> arr1 999 | array([[0, 0, 1], 1000 | [1, 0, 1], 1001 | [0, 0, 1]]) 1002 | >>> obj_fpr(arr1, arr2) 1003 | 0.0 1004 | >>> obj_fpr(arr2, arr1) 1005 | 0.0 1006 | 1007 | Example of directedness: 1008 | 1009 | >>> arr2 = numpy.asarray([1,0,1,0,1]) 1010 | >>> arr1 = numpy.asarray([1,0,1,0,0]) 1011 | >>> obj_fpr(arr1, arr2) 1012 | 0.0 1013 | >>> obj_fpr(arr2, arr1) 1014 | 0.3333333333333333 1015 | 1016 | Examples of multiple overlap treatment: 1017 | 1018 | >>> arr2 = numpy.asarray([1,0,1,0,1,1,1]) 1019 | >>> arr1 = numpy.asarray([1,1,1,0,1,0,1]) 1020 | >>> obj_fpr(arr1, arr2) 1021 | 0.3333333333333333 1022 | >>> obj_fpr(arr2, arr1) 1023 | 0.3333333333333333 1024 | 1025 | >>> arr2 = numpy.asarray([1,0,1,1,1,0,1]) 1026 | >>> arr1 = numpy.asarray([1,1,1,0,1,1,1]) 1027 | >>> obj_fpr(arr1, arr2) 1028 | 0.0 1029 | >>> obj_fpr(arr2, arr1) 1030 | 0.3333333333333333 1031 | 1032 | >>> arr2 = numpy.asarray([[1,0,1,0,0], 1033 | [1,0,0,0,0], 1034 | [1,0,1,1,1], 1035 | [0,0,0,0,0], 1036 | [1,0,1,0,0]]) 1037 | >>> arr1 = numpy.asarray([[1,1,1,0,0], 1038 | [0,0,0,0,0], 1039 | [1,1,1,0,1], 1040 | [0,0,0,0,0], 1041 | [1,1,1,0,0]]) 1042 | >>> obj_fpr(arr1, arr2) 1043 | 0.0 1044 | >>> obj_fpr(arr2, arr1) 1045 | 0.2 1046 | """ 1047 | _, _, _, n_obj_reference, mapping = __distinct_binary_object_correspondences(reference, result, connectivity) 1048 | return (n_obj_reference - len(mapping)) / float(n_obj_reference) 1049 | 1050 | 1051 | def obj_tpr(result, reference, connectivity=1): 1052 | """ 1053 | The true positive rate of distinct binary object detection. 1054 | 1055 | The true positive rates gives a percentage measure of how many distinct binary 1056 | objects in the first array also exists in the second array. A partial overlap 1057 | (of minimum one voxel) is here considered sufficient. 1058 | 1059 | In cases where two distinct binary object in the first array overlaps with a single 1060 | distinct object in the second array, only one is considered to have been detected 1061 | successfully. 1062 | 1063 | Parameters 1064 | ---------- 1065 | result : array_like 1066 | Input data containing objects. Can be any type but will be converted 1067 | into binary: background where 0, object everywhere else. 1068 | reference : array_like 1069 | Input data containing objects. Can be any type but will be converted 1070 | into binary: background where 0, object everywhere else. 1071 | connectivity : int 1072 | The neighbourhood/connectivity considered when determining what accounts 1073 | for a distinct binary object. This value is passed to 1074 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 1075 | The decision on the connectivity is important, as it can influence the results 1076 | strongly. If in doubt, leave it as it is. 1077 | 1078 | Returns 1079 | ------- 1080 | tpr : float 1081 | A percentage measure of how many distinct binary objects in ``result`` also exists 1082 | in ``reference``. It has the range :math:`[0, 1]`, where a :math:`1` denotes an ideal score. 1083 | 1084 | Raises 1085 | ------ 1086 | RuntimeError 1087 | If the reference object is empty. 1088 | 1089 | See also 1090 | -------- 1091 | :func:`obj_fpr` 1092 | 1093 | Notes 1094 | ----- 1095 | This is not a real metric, as it is directed. Whatever array is considered as 1096 | reference should be passed second. A perfect score of :math:`1` tells that all distinct 1097 | binary objects in the reference array also exist in the result array, but does not 1098 | reveal anything about additional binary objects in the result array 1099 | (use :func:`obj_fpr` for this). 1100 | 1101 | Examples 1102 | -------- 1103 | >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]]) 1104 | >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]]) 1105 | >>> arr2 1106 | array([[1, 0, 0], 1107 | [1, 0, 1], 1108 | [0, 0, 1]]) 1109 | >>> arr1 1110 | array([[0, 0, 1], 1111 | [1, 0, 1], 1112 | [0, 0, 1]]) 1113 | >>> obj_tpr(arr1, arr2) 1114 | 1.0 1115 | >>> obj_tpr(arr2, arr1) 1116 | 1.0 1117 | 1118 | Example of directedness: 1119 | 1120 | >>> arr2 = numpy.asarray([1,0,1,0,1]) 1121 | >>> arr1 = numpy.asarray([1,0,1,0,0]) 1122 | >>> obj_tpr(arr1, arr2) 1123 | 0.6666666666666666 1124 | >>> obj_tpr(arr2, arr1) 1125 | 1.0 1126 | 1127 | Examples of multiple overlap treatment: 1128 | 1129 | >>> arr2 = numpy.asarray([1,0,1,0,1,1,1]) 1130 | >>> arr1 = numpy.asarray([1,1,1,0,1,0,1]) 1131 | >>> obj_tpr(arr1, arr2) 1132 | 0.6666666666666666 1133 | >>> obj_tpr(arr2, arr1) 1134 | 0.6666666666666666 1135 | 1136 | >>> arr2 = numpy.asarray([1,0,1,1,1,0,1]) 1137 | >>> arr1 = numpy.asarray([1,1,1,0,1,1,1]) 1138 | >>> obj_tpr(arr1, arr2) 1139 | 0.6666666666666666 1140 | >>> obj_tpr(arr2, arr1) 1141 | 1.0 1142 | 1143 | >>> arr2 = numpy.asarray([[1,0,1,0,0], 1144 | [1,0,0,0,0], 1145 | [1,0,1,1,1], 1146 | [0,0,0,0,0], 1147 | [1,0,1,0,0]]) 1148 | >>> arr1 = numpy.asarray([[1,1,1,0,0], 1149 | [0,0,0,0,0], 1150 | [1,1,1,0,1], 1151 | [0,0,0,0,0], 1152 | [1,1,1,0,0]]) 1153 | >>> obj_tpr(arr1, arr2) 1154 | 0.8 1155 | >>> obj_tpr(arr2, arr1) 1156 | 1.0 1157 | """ 1158 | _, _, n_obj_result, _, mapping = __distinct_binary_object_correspondences(reference, result, connectivity) 1159 | return len(mapping) / float(n_obj_result) 1160 | 1161 | 1162 | def __distinct_binary_object_correspondences(reference, result, connectivity=1): 1163 | """ 1164 | Determines all distinct (where connectivity is defined by the connectivity parameter 1165 | passed to scipy's `generate_binary_structure`) binary objects in both of the input 1166 | parameters and returns a 1to1 mapping from the labelled objects in reference to the 1167 | corresponding (whereas a one-voxel overlap suffices for correspondence) objects in 1168 | result. 1169 | 1170 | All stems from the problem, that the relationship is non-surjective many-to-many. 1171 | 1172 | @return (labelmap1, labelmap2, n_lables1, n_labels2, labelmapping2to1) 1173 | """ 1174 | result = numpy.atleast_1d(result.astype(numpy.bool)) 1175 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 1176 | 1177 | # binary structure 1178 | footprint = generate_binary_structure(result.ndim, connectivity) 1179 | 1180 | # label distinct binary objects 1181 | labelmap1, n_obj_result = label(result, footprint) 1182 | labelmap2, n_obj_reference = label(reference, footprint) 1183 | 1184 | # find all overlaps from labelmap2 to labelmap1; collect one-to-one relationships and store all one-two-many for later processing 1185 | slicers = find_objects(labelmap2) # get windows of labelled objects 1186 | mapping = dict() # mappings from labels in labelmap2 to corresponding object labels in labelmap1 1187 | used_labels = set() # set to collect all already used labels from labelmap2 1188 | one_to_many = list() # list to collect all one-to-many mappings 1189 | for l1id, slicer in enumerate(slicers): # iterate over object in labelmap2 and their windows 1190 | l1id += 1 # labelled objects have ids sarting from 1 1191 | bobj = (l1id) == labelmap2[slicer] # find binary object corresponding to the label1 id in the segmentation 1192 | l2ids = numpy.unique(labelmap1[slicer][ 1193 | bobj]) # extract all unique object identifiers at the corresponding positions in the reference (i.e. the mapping) 1194 | l2ids = l2ids[0 != l2ids] # remove background identifiers (=0) 1195 | if 1 == len( 1196 | l2ids): # one-to-one mapping: if target label not already used, add to final list of object-to-object mappings and mark target label as used 1197 | l2id = l2ids[0] 1198 | if not l2id in used_labels: 1199 | mapping[l1id] = l2id 1200 | used_labels.add(l2id) 1201 | elif 1 < len(l2ids): # one-to-many mapping: store relationship for later processing 1202 | one_to_many.append((l1id, set(l2ids))) 1203 | 1204 | # process one-to-many mappings, always choosing the one with the least labelmap2 correspondences first 1205 | while True: 1206 | one_to_many = [(l1id, l2ids - used_labels) for l1id, l2ids in 1207 | one_to_many] # remove already used ids from all sets 1208 | one_to_many = [x for x in one_to_many if x[1]] # remove empty sets 1209 | one_to_many = sorted(one_to_many, key=lambda x: len(x[1])) # sort by set length 1210 | if 0 == len(one_to_many): 1211 | break 1212 | l2id = one_to_many[0][1].pop() # select an arbitrary target label id from the shortest set 1213 | mapping[one_to_many[0][0]] = l2id # add to one-to-one mappings 1214 | used_labels.add(l2id) # mark target label as used 1215 | one_to_many = one_to_many[1:] # delete the processed set from all sets 1216 | 1217 | return labelmap1, labelmap2, n_obj_result, n_obj_reference, mapping 1218 | 1219 | 1220 | def __surface_distances(result, reference, voxelspacing=None, connectivity=1): 1221 | """ 1222 | The distances between the surface voxel of binary objects in result and their 1223 | nearest partner surface voxel of a binary object in reference. 1224 | """ 1225 | result = numpy.atleast_1d(result.astype(numpy.bool)) 1226 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 1227 | if voxelspacing is not None: 1228 | voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim) 1229 | voxelspacing = numpy.asarray(voxelspacing, dtype=numpy.float64) 1230 | if not voxelspacing.flags.contiguous: 1231 | voxelspacing = voxelspacing.copy() 1232 | 1233 | # binary structure 1234 | footprint = generate_binary_structure(result.ndim, connectivity) 1235 | 1236 | # test for emptiness 1237 | if 0 == numpy.count_nonzero(result): 1238 | raise RuntimeError('The first supplied array does not contain any binary object.') 1239 | if 0 == numpy.count_nonzero(reference): 1240 | raise RuntimeError('The second supplied array does not contain any binary object.') 1241 | 1242 | # extract only 1-pixel border line of objects 1243 | result_border = result ^ binary_erosion(result, structure=footprint, iterations=1) 1244 | reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1) 1245 | 1246 | # compute average surface distance 1247 | # Note: scipys distance transform is calculated only inside the borders of the 1248 | # foreground objects, therefore the input has to be reversed 1249 | dt = distance_transform_edt(~reference_border, sampling=voxelspacing) 1250 | sds = dt[result_border] 1251 | 1252 | return sds 1253 | 1254 | 1255 | def __combine_windows(w1, w2): 1256 | """ 1257 | Joins two windows (defined by tuple of slices) such that their maximum 1258 | combined extend is covered by the new returned window. 1259 | """ 1260 | res = [] 1261 | for s1, s2 in zip(w1, w2): 1262 | res.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop))) 1263 | return tuple(res) --------------------------------------------------------------------------------