├── bowel_coarseseg
├── tools
│ ├── __init__.py
│ └── loss.py
├── datasets
│ ├── __init__.py
│ └── BowelDatasetCoarseSeg.py
├── crop_ROI.py
├── loc_model.py
├── preprocessing.py
├── infer.py
└── train.py
├── bowel_fineseg
├── tools
│ ├── __init__.py
│ ├── loss.py
│ └── utils.py
├── datasets
│ ├── __init__.py
│ └── BowelDatasetFineSeg.py
├── arch
│ ├── data.png
│ ├── example.png
│ ├── pipeline.png
│ ├── seg_demo.png
│ └── segmentors.png
├── create_boundary.py
├── create_skeleflux.py
├── infer.py
├── seg_model.py
└── train.py
└── README.md
/bowel_coarseseg/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bowel_fineseg/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bowel_coarseseg/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bowel_fineseg/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bowel_fineseg/arch/data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/data.png
--------------------------------------------------------------------------------
/bowel_fineseg/arch/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/example.png
--------------------------------------------------------------------------------
/bowel_fineseg/arch/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/pipeline.png
--------------------------------------------------------------------------------
/bowel_fineseg/arch/seg_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/seg_demo.png
--------------------------------------------------------------------------------
/bowel_fineseg/arch/segmentors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cwangrun/BowelNet/HEAD/bowel_fineseg/arch/segmentors.png
--------------------------------------------------------------------------------
/bowel_fineseg/create_boundary.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pylab as plt
2 | import nibabel as nib
3 | from glob import glob
4 | from cc3d import connected_components
5 | import numpy as np
6 | import os, cv2
7 | import shutil
8 | import scipy
9 | from scipy import ndimage
10 | from skimage import morphology
11 | from scipy.ndimage import gaussian_filter
12 | from shutil import copy
13 |
14 |
15 |
16 | def mkdir(folder_path):
17 | if not os.path.exists(folder_path):
18 | os.makedirs(folder_path)
19 |
20 |
21 |
22 | files = glob('/mnt/c/chong/data/Bowel/crop_stage1_ROI_small/*/*/*/masks_crop.nii.gz')
23 |
24 |
25 |
26 | for iii, file in enumerate(files):
27 |
28 |
29 | # file = '/home/chong/Desktop/small bowel skeleton/masks_crop.nii.gz'
30 |
31 |
32 | mask_info = nib.load(file)
33 | mask_arr_ori = np.round(mask_info.get_fdata()).astype(np.uint8)
34 |
35 | mask_arr = mask_arr_ori.copy()
36 |
37 | erosion_bin_map = ndimage.binary_erosion(mask_arr.astype(int), structure=np.ones((3, 3, 1)))
38 | # erosion_bin_map = ndimage.binary_erosion(mask_arr.astype(int), structure=np.ones((3, 3, 3)))
39 |
40 | edge_map = mask_arr - erosion_bin_map.astype(int)
41 |
42 | edge_heatmap = np.zeros_like(mask_arr).astype(np.float32)
43 | for z in range(mask_arr.shape[2]):
44 | gray = edge_map[:, :, z].astype(np.float32).copy()
45 | edge_heatmap[:, :, z] = gaussian_filter(gray, sigma=1.0, truncate=2.0) # larger sigma, more blur
46 |
47 | save_path = file.replace('.nii.gz', '_edge_heatmap.nii.gz')
48 | nib.save(nib.Nifti1Image(edge_heatmap, header=mask_info.header, affine=mask_info.affine), save_path)
49 |
50 | save_path = file.replace('.nii.gz', '_edge.nii.gz')
51 | nib.save(nib.Nifti1Image(edge_map, header=mask_info.header, affine=mask_info.affine), save_path)
52 |
53 | break
54 |
--------------------------------------------------------------------------------
/bowel_fineseg/tools/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def dice_loss(prob, target, epsilon=1e-7):
6 | """Computes the Sørensen–Dice loss.
7 | Args:
8 | target: a tensor of shape [B, 1, D, H, W].
9 | prob: a tensor of shape [B, C, D, H, W]. Corresponds to
10 | the prob output of the model.
11 | epsilon: added to the denominator for numerical stability.
12 | Returns:
13 | dice: the Sørensen–Dice loss.
14 | """
15 | # target_ = torch.cat((1 - target, target), dim=1)
16 | num_classes = prob.shape[1]
17 | target = F.one_hot(target.squeeze(1).long(), num_classes).permute(0, 4, 1, 2, 3) # [B, c, D, H, W]
18 | target = target.type(prob.type())
19 |
20 | assert prob.size() == target.size(), "the size of predict and target must be equal."
21 |
22 | intersection = torch.sum(prob * target, dim=[0, 2, 3, 4])
23 | union = torch.sum(prob + target, dim=[0, 2, 3, 4])
24 | dice = (2. * intersection / (union + epsilon)).mean() # average over classes
25 |
26 | return 1 - dice
27 |
28 |
29 | def dice_loss_PL(prob, target, epsilon=1e-7):
30 | """Computes the Sørensen–Dice loss.
31 | Args:
32 | target: a tensor of shape [B, C, D, H, W].
33 | prob: a tensor of shape [B, C, D, H, W]. Corresponds to
34 | the prob output of the model.
35 | epsilon: added to the denominator for numerical stability.
36 | Returns:
37 | dice: the Sørensen–Dice loss.
38 | """
39 | assert prob.size() == target.size(), "the size of predict and target must be equal."
40 |
41 | intersection = torch.sum(prob * target, dim=[0, 2, 3, 4])
42 | union = torch.sum(prob + target, dim=[0, 2, 3, 4])
43 | dice = (2. * intersection / (union + epsilon)).mean() # average over classes
44 |
45 | return 1 - dice
46 |
47 |
48 | def dice_score_metric(input, target, epsilon=1e-7):
49 | result = torch.argmax(input, dim=1).unsqueeze(1)
50 | intersection = torch.sum(result * target)
51 | union = torch.sum(result) + torch.sum(target)
52 | dice = 2. * intersection / (union + epsilon)
53 | return dice
54 |
--------------------------------------------------------------------------------
/bowel_coarseseg/tools/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 |
7 | def dice_loss_PL(prob, target, epsilon=1e-7):
8 | """Computes the Sørensen–Dice loss.
9 | Args:
10 | target: a tensor of shape [B, C, D, H, W].
11 | prob: a tensor of shape [B, C, D, H, W]. Corresponds to
12 | the prob output of the model.
13 | epsilon: added to the denominator for numerical stability.
14 | Returns:
15 | dice: the Sørensen–Dice loss.
16 | """
17 | assert prob.size() == target.size(), "the size of predict and target must be equal."
18 |
19 | intersection = torch.sum(prob * target, dim=[0, 2, 3, 4])
20 | union = torch.sum(prob + target, dim=[0, 2, 3, 4])
21 | dice = (2. * intersection / (union + epsilon)).mean() # average over classes
22 |
23 | return 1 - dice
24 |
25 |
26 | def dice_similarity(output, target, smooth=1e-7):
27 | """Computes the Dice similarity"""
28 |
29 | output = output.float()
30 | target = target.float()
31 |
32 | seg_channel = output.view(output.size(0), -1) # (batch, D*H*W)
33 | target_channel = target.view(target.size(0), -1)
34 |
35 | intersection = (seg_channel * target_channel).sum(-1)
36 | union = (seg_channel + target_channel).sum(-1)
37 | dice = (2. * intersection) / (union + smooth)
38 |
39 | return torch.mean(dice)
40 |
41 |
42 | def dice_score_partial(output, target, data_type):
43 | """Computes the Dice scores, given foreground classes"""
44 |
45 | assert data_type in ['fully_labeled', 'smallbowel', 'colon_sigmoid']
46 |
47 | result = torch.argmax(output, dim=1, keepdim=True)
48 | if data_type == 'fully_labeled':
49 | valid_class = [1, 2, 3, 4, 5]
50 | if data_type == 'smallbowel':
51 | valid_class = [4]
52 | if data_type == 'colon_sigmoid':
53 | valid_class = [2, 3]
54 |
55 | total_dice = []
56 | for c in valid_class:
57 | target_c = (target == c).long()
58 | output_c = (result == c).long()
59 | dice_c = dice_similarity(output_c, target_c)
60 | total_dice.append(dice_c.item())
61 | return total_dice
62 |
63 |
64 |
--------------------------------------------------------------------------------
/bowel_fineseg/create_skeleflux.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pylab as plt
2 | import nibabel as nib
3 | from glob import glob
4 | from cc3d import connected_components
5 | import numpy as np
6 | import os, cv2
7 | import scipy
8 | from scipy import ndimage
9 | from skimage import morphology
10 | from scipy.ndimage import gaussian_filter
11 | from shutil import copy
12 | import scipy.ndimage as ndi
13 |
14 |
15 |
16 | def mkdir(folder_path):
17 | if not os.path.exists(folder_path):
18 | os.makedirs(folder_path)
19 |
20 |
21 |
22 | files = glob('/mnt/c/chong/data/Bowel/crop_stage1_ROI_small/*/*/*/masks_crop.nii.gz')
23 |
24 |
25 |
26 | for iii, file in enumerate(files):
27 |
28 |
29 | # file = '/home/chong/Desktop/small bowel skeleton/masks_crop.nii.gz'
30 |
31 |
32 | mask_info = nib.load(file)
33 | mask_arr_ori = np.round(mask_info.get_fdata()).astype(np.uint8)
34 |
35 | skeleton = mask_arr_ori.copy()
36 |
37 | ####################
38 | # filtering for smoothing, larger sigma, more blur
39 | skeleton = ndi.gaussian_filter(skeleton.astype(np.float), sigma=(1.5, 1.5, 1.5), order=0, truncate=2.0)
40 | skeleton = (skeleton > 0.5).astype(np.uint8)
41 | #####################
42 |
43 | skeleton = morphology.skeletonize_3d(skeleton).astype(np.uint8)
44 |
45 | save_path = file.replace('.nii.gz', '_skele.nii.gz')
46 | skeleton_dilate = ndimage.binary_dilation(skeleton, structure=np.ones((3, 3, 5)))
47 | nib.save(nib.Nifti1Image(skeleton_dilate, header=mask_info.header, affine=mask_info.affine), save_path)
48 |
49 |
50 | pseudo_one_mask = np.ones_like(mask_arr_ori).astype(np.uint8)
51 | pseudo_one_mask[skeleton == 1] = 0
52 | skeleton_dist, skeleton_index = ndi.distance_transform_edt(pseudo_one_mask, return_indices=True) # distance to the nearest zero point
53 |
54 | grid = np.indices(skeleton_dist.shape).astype(float)
55 | diff = grid - skeleton_index
56 | dist = np.sqrt(np.sum(diff ** 2, axis=0))
57 |
58 | direction_0 = np.divide(diff[0, ...], dist + 1e-7) # avoid divide zero
59 | direction_1 = np.divide(diff[1, ...], dist + 1e-7)
60 | direction_2 = np.divide(diff[2, ...], dist + 1e-7)
61 |
62 | skeleton_dist[mask_arr_ori != 1] = 0
63 | direction_0[mask_arr_ori != 1] = 0
64 | direction_1[mask_arr_ori != 1] = 0
65 | direction_2[mask_arr_ori != 1] = 0
66 |
67 | save_path = file.replace('.nii.gz', '_skele_dist.nii.gz')
68 | nib.save(nib.Nifti1Image(skeleton_dist, header=mask_info.header, affine=mask_info.affine), save_path)
69 |
70 | save_path = file.replace('.nii.gz', '_skele_fluxx.nii.gz')
71 | nib.save(nib.Nifti1Image(direction_0, header=mask_info.header, affine=mask_info.affine), save_path)
72 |
73 | save_path = file.replace('.nii.gz', '_skele_fluxy.nii.gz')
74 | nib.save(nib.Nifti1Image(direction_1, header=mask_info.header, affine=mask_info.affine), save_path)
75 |
76 | save_path = file.replace('.nii.gz', '_skele_fluxz.nii.gz')
77 | nib.save(nib.Nifti1Image(direction_2, header=mask_info.header, affine=mask_info.affine), save_path)
78 |
79 | np_save = np.sqrt(direction_0 ** 2 + direction_1 ** 2 + direction_2 ** 2)
80 | nib.save(nib.Nifti1Image(np_save, header=mask_info.header, affine=mask_info.affine), file.replace('.nii.gz', '_skele_fluxmag.nii.gz'))
81 |
82 |
83 | break
84 |
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BowelNet
2 |
3 |
4 | ### Pytorch code for the paper "BowelNet: Joint Semantic-Geometric Ensemble Learning for Bowel Segmentation from Both Partially and Fully Labeled CT Images" at IEEE TMI 2022.
5 | Email: chong.wang@adelaide.edu.au
6 |
7 |
8 |
9 |
10 |
11 | ## Introduction:
12 |
13 | The BowelNet is a two-stage coarse-to-fine framework for the sgmentation of entire bowel (i.e., duodenum, jejunum-ileum, colon, sigmoid, and rectum) from abdominal CT images. The first stage jointly localizes all types of the bowel, trained robustly on both partially and fully labeled samples (see examples below). The second stage finely segments each type of localized the bowels using geometric bowel representations and hybrid psuedo labels:
14 |
15 | [(1) Joint localzation of the five bowel parts using both partially- and fully-labeled images](https://github.com/runningcw/BowelNet/tree/master/bowel_coarseseg)
16 |
17 | [(2) Fine segmentation of each part using geometric (i.e., boundary and skeleton) guidance](https://github.com/runningcw/BowelNet/tree/master/bowel_fineseg)
18 |
19 |
20 | Examples of fully (a) and partially (b, c) labeled training data used in our work:
21 |
22 |
23 |

24 |
25 |
26 |
27 | ## Dataset:
28 |
29 | We utilize a large private abdominal CT dataset that includes both partially and fully labeled segmentation annotations. The dataset is structured as follows:
30 |
31 | ```
32 | BowelSegData
33 | ├── Fully_labeled_5C
34 | │ ├── abdomen
35 | │ │ ├── .nii.gz
36 | │ │ ...
37 | │ ├── male
38 | │ │ ├── .nii.gz
39 | │ │ ...
40 | │ └── female
41 | │ ├── .nii.gz
42 | │ ...
43 | ├── Colon_Sigmoid
44 | │ ├── abdomen
45 | │ │ ├── .nii.gz
46 | │ │ ...
47 | │ ├── male
48 | │ │ ├── .nii.gz
49 | │ │ ...
50 | │ └── female
51 | │ ├── .nii.gz
52 | │ ...
53 | └── Smallbowel
54 | ├── abdomen
55 | │ ├── .nii.gz
56 | │ ...
57 | ├── male
58 | │ ├── .nii.gz
59 | │ ...
60 | └── female
61 | ├── .nii.gz
62 | ...
63 | ```
64 |
65 | ## Data Preprocessing:
66 |
67 | [Preprocessing](https://github.com/runningcw/BowelNet/blob/master/bowel_coarseseg/preprocessing.py) includes cropping abdominal body region. We average all 2D CT slices of a volume to form a mean image and then apply a thresholding on it to obtain the abdominal body region (excluding CT bed).
68 |
69 |
70 | ## Demo segmentation:
71 |
72 | Our BowelNet demonstrates improved performance over prior approaches in the entire bowel segmentation.
73 |
74 | 
75 |
76 |
77 | ## Citation:
78 | ```
79 | @article{wang2022bowelnet,
80 | title={BowelNet: Joint Semantic-Geometric Ensemble Learning for Bowel Segmentation From Both Partially and Fully Labeled CT Images},
81 | author={Wang, Chong and Cui, Zhiming and Yang, Junwei and Han, Miaofei and Carneiro, Gustavo and Shen, Dinggang},
82 | journal={IEEE Transactions on Medical Imaging},
83 | volume={42},
84 | number={4},
85 | pages={1225--1236},
86 | year={2022},
87 | publisher={IEEE}
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/bowel_coarseseg/crop_ROI.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import numpy as np
3 | from glob import glob
4 | import os.path
5 | import nibabel as nib
6 | import scipy
7 | from cc3d import connected_components
8 |
9 |
10 | name_to_label = {'rectum': 1, 'sigmoid': 2, 'colon': 3, 'small': 4, 'duodenum': 5}
11 |
12 |
13 | def img2box(img):
14 | r = np.any(img, axis=(1, 2))
15 | c = np.any(img, axis=(0, 2))
16 | z = np.any(img, axis=(0, 1))
17 | rmin, rmax = np.where(r)[0][[0, -1]]
18 | cmin, cmax = np.where(c)[0][[0, -1]]
19 | zmin, zmax = np.where(z)[0][[0, -1]]
20 | return rmin, rmax, cmin, cmax, zmin, zmax
21 |
22 |
23 | def normalize_volume(img):
24 | img_array = (img - img.min()) / (img.max() - img.min())
25 | return img_array
26 |
27 |
28 | def crop_ROI_using_coarse_seg_mask(mask_ori, bowel_name):
29 | assert bowel_name in name_to_label.keys()
30 |
31 | pos_label = name_to_label[bowel_name]
32 | mask = mask_ori.copy()
33 | mask[mask != pos_label] = 0
34 | mask[mask == pos_label] = 1
35 |
36 | # use the largest connected region to eliminate small noisy points, or use the mask directly
37 | labels_out, N = connected_components(mask, connectivity=26, return_N=True)
38 | numPix = []
39 | for segid in range(1, N + 1):
40 | numPix.append([segid, (labels_out == segid).astype(np.int8).sum()])
41 | numPix = np.array(numPix)
42 |
43 | if len(numPix) != 0:
44 | max_connected_image = np.int8(labels_out == numPix[np.argmax(numPix[:, 1]), 0])
45 | min_x, max_x, min_y, max_y, min_z, max_z = img2box(max_connected_image)
46 | else:
47 | print('coarse stage does not detect the organ, will skip this case')
48 | return (None, None), (None, None), (None, None), None
49 |
50 | # extend
51 | x_extend, y_extend, z_extend = (30, 30, 30)
52 |
53 | max_x = min(max_x + x_extend, mask.shape[0])
54 | max_y = min(max_y + y_extend, mask.shape[1])
55 | max_z = min(max_z + z_extend, mask.shape[2])
56 | min_x = max(min_x - x_extend, 0)
57 | min_y = max(min_y - y_extend, 0)
58 | min_z = max(min_z - z_extend, 0)
59 |
60 | return (min_x, max_x), (min_y, max_y), (min_z, max_z), mask
61 |
62 |
63 | if __name__ == '__main__':
64 |
65 | imgs = glob('/mnt/c/chong/data/Bowel/crop_ori_res/*/*.nii.gz')
66 |
67 | bowel_name = 'rectum'
68 | # bowel_name = 'sigmoid'
69 | # bowel_name = 'colon'
70 | # bowel_name = 'small'
71 | # bowel_name = 'duodenum'
72 |
73 | for iii, img_file in enumerate(imgs):
74 |
75 | img_info = nib.load(img_file)
76 | img = img_info.get_fdata()
77 | img = np.clip(img, -1024, 1000)
78 | img = normalize_volume(img)
79 |
80 | coarse_mask_file = img_file.replace('image.nii.gz', 'masks_partial_C5.nii.gz')
81 | assert os.path.isfile(coarse_mask_file)
82 | coarse_mask = nib.load(coarse_mask_file).get_fdata()
83 | coarse_mask = scipy.ndimage.interpolation.zoom((coarse_mask).astype(np.float32),
84 | zoom=np.array(img.shape) / np.array(coarse_mask.shape),
85 | mode='nearest',
86 | order=0) # order = 0 nearest interpolation
87 | coarse_mask = np.round(coarse_mask).astype(np.int32)
88 |
89 | crop_x, crop_y, crop_z, mask_return = crop_ROI_using_coarse_seg_mask(coarse_mask, bowel_name)
90 |
91 | if mask_return is None:
92 | continue
93 |
94 | img = img[crop_x[0]:crop_x[1], crop_y[0]:crop_y[1], crop_z[0]:crop_z[1]]
95 | mask_return = mask_return[crop_x[0]:crop_x[1], crop_y[0]:crop_y[1], crop_z[0]:crop_z[1]]
96 |
97 | save_dir = os.path.dirname(img_file.replace('crop_ori_res', 'crop_ori_res_' + bowel_name))
98 | os.makedirs(save_dir, exist_ok=True)
99 | ############################################################################################
100 | crop_coords_patient = [crop_x[0], ',', crop_x[1], ',', crop_y[0], ',', crop_y[1], ',', crop_z[0], ',', crop_z[1]]
101 | with open(save_dir + '/crop_info_' + bowel_name + '.txt', 'w') as file:
102 | file.writelines(map(lambda x: str(x), crop_coords_patient))
103 | file.close()
104 | nib.save(nib.Nifti1Image(img, header=img_info.header, affine=img_info.affine), os.path.join(save_dir, 'image_crop.nii.gz'))
105 | nib.save(nib.Nifti1Image(mask_return, header=img_info.header, affine=img_info.affine), os.path.join(save_dir, 'masks_partial_C5_crop.nii.gz'))
106 | ############################################################################################
107 |
108 | # print(iii, crop_x[1] - crop_x[0], crop_y[1] - crop_y[0], crop_z[1] - crop_z[0], img_file)
109 |
110 |
--------------------------------------------------------------------------------
/bowel_coarseseg/loc_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def passthrough(x, **kwargs):
8 | return x
9 |
10 |
11 | def ELUCons(elu, nchan):
12 | if elu:
13 | return nn.ELU(inplace=True)
14 | else:
15 | return nn.PReLU(nchan)
16 |
17 | # normalization between sub-volumes is necessary for good performance
18 | class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
19 | def _check_input_dim(self, input):
20 | if input.dim() != 5:
21 | raise ValueError('expected 5D input (got {}D input)'
22 | .format(input.dim()))
23 | # super(ContBatchNorm3d, self)._check_input_dim(input)
24 |
25 | def forward(self, input):
26 | self._check_input_dim(input)
27 | return F.batch_norm(
28 | input, self.running_mean, self.running_var, self.weight, self.bias,
29 | True, self.momentum, self.eps)
30 |
31 |
32 | class LUConv(nn.Module):
33 | def __init__(self, nchan, elu):
34 | super(LUConv, self).__init__()
35 | self.relu1 = ELUCons(elu, nchan)
36 | self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
37 | self.bn1 = ContBatchNorm3d(nchan)
38 |
39 | def forward(self, x):
40 | out = self.relu1(self.bn1(self.conv1(x)))
41 | return out
42 |
43 |
44 | def _make_nConv(nchan, depth, elu):
45 | layers = []
46 | for _ in range(depth):
47 | layers.append(LUConv(nchan, elu))
48 | return nn.Sequential(*layers)
49 |
50 |
51 | class InputTransition(nn.Module):
52 | def __init__(self, inChans, outChans, elu):
53 | super(InputTransition, self).__init__()
54 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=5, padding=2)
55 | self.bn1 = ContBatchNorm3d(outChans)
56 | self.relu1 = ELUCons(elu, outChans)
57 |
58 | def forward(self, x):
59 | out = self.bn1(self.conv1(x))
60 | x16 = torch.cat((x, x, x, x, x, x, x, x, x, x, x, x), 1)
61 | out = self.relu1(torch.add(out, x16))
62 | return out
63 |
64 |
65 | class DownTransition(nn.Module):
66 | def __init__(self, inChans, nConvs, elu, dropout=False):
67 | super(DownTransition, self).__init__()
68 | outChans = 2*inChans
69 | self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
70 | self.bn1 = ContBatchNorm3d(outChans)
71 | self.do1 = passthrough
72 | self.relu1 = ELUCons(elu, outChans)
73 | self.relu2 = ELUCons(elu, outChans)
74 | if dropout:
75 | self.do1 = nn.Dropout3d(p=0.2)
76 | self.ops = _make_nConv(outChans, nConvs, elu)
77 |
78 | def forward(self, x):
79 | down = self.relu1(self.bn1(self.down_conv(x)))
80 | out = self.do1(down)
81 | out = self.ops(out)
82 | out = self.relu2(torch.add(out, down))
83 | return out
84 |
85 |
86 | class UpTransition(nn.Module):
87 | def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
88 | super(UpTransition, self).__init__()
89 | self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
90 | self.bn1 = ContBatchNorm3d(outChans // 2)
91 | self.do1 = passthrough
92 | self.do2 = nn.Dropout3d(p=0.2)
93 | self.relu1 = ELUCons(elu, outChans // 2)
94 | self.relu2 = ELUCons(elu, outChans)
95 | if dropout:
96 | self.do1 = nn.Dropout3d(p=0.2)
97 | self.ops = _make_nConv(outChans, nConvs, elu)
98 |
99 | def forward(self, x, skipx):
100 | out = self.do1(x)
101 | skipxdo = self.do2(skipx)
102 | out = self.relu1(self.bn1(self.up_conv(out)))
103 | xcat = torch.cat((out, skipxdo), 1)
104 | out = self.ops(xcat)
105 | out = self.relu2(torch.add(out, xcat))
106 | return out
107 |
108 |
109 | class OutputTransition(nn.Module):
110 | def __init__(self, inChans, outChans, elu):
111 | super(OutputTransition, self).__init__()
112 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=5, padding=2)
113 | self.bn1 = ContBatchNorm3d(outChans)
114 | self.relu1 = ELUCons(elu, outChans)
115 | self.conv2 = nn.Conv3d(outChans, outChans, kernel_size=1)
116 |
117 | def forward(self, x):
118 | out = self.relu1(self.bn1(self.conv1(x)))
119 | out = self.conv2(out)
120 | return out
121 |
122 |
123 | class BowelLocNet(nn.Module):
124 |
125 | def __init__(self, elu=True):
126 | super(BowelLocNet, self).__init__()
127 | num_baseC = 12 # 16
128 | self.in_tr = InputTransition(1, 1*num_baseC, elu)
129 | self.down_tr32 = DownTransition(1*num_baseC, 1, elu)
130 | self.down_tr64 = DownTransition(2*num_baseC, 2, elu)
131 | self.down_tr128 = DownTransition(4*num_baseC, 2, elu, dropout=True)
132 | self.down_tr256 = DownTransition(8*num_baseC, 2, elu, dropout=True)
133 | self.up_tr256 = UpTransition(16*num_baseC, 16*num_baseC, 2, elu, dropout=True)
134 | self.up_tr128 = UpTransition(16*num_baseC, 8*num_baseC, 2, elu, dropout=True)
135 | self.up_tr64 = UpTransition(8*num_baseC, 4*num_baseC, 1, elu)
136 | self.up_tr32 = UpTransition(4*num_baseC, 2*num_baseC, 1, elu)
137 | self.out_tr = OutputTransition(2*num_baseC, 6, elu)
138 |
139 | def forward(self, x):
140 | out16 = self.in_tr(x)
141 | out32 = self.down_tr32(out16)
142 | out64 = self.down_tr64(out32)
143 | out128 = self.down_tr128(out64)
144 | out256 = self.down_tr256(out128)
145 |
146 | out = self.up_tr256(out256, out128)
147 | out = self.up_tr128(out, out64)
148 | out = self.up_tr64(out, out32)
149 | out = self.up_tr32(out, out16)
150 | out = self.out_tr(out)
151 | return F.softmax(out, dim=1)
--------------------------------------------------------------------------------
/bowel_coarseseg/preprocessing.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | from matplotlib import pylab as plt
3 | import nibabel as nib
4 | import SimpleITK as sitk
5 | from glob import glob
6 | from cc3d import connected_components
7 | import numpy as np
8 | import os, cv2
9 | from scipy import ndimage
10 | from skimage import morphology
11 | from skimage.morphology import convex_hull_image, disk, binary_closing
12 |
13 |
14 | def mkdir(folder_path):
15 | if not os.path.exists(folder_path):
16 | os.makedirs(folder_path)
17 |
18 |
19 | def resample_img(image_file, out_spacing=[0.97, 0.97, 1.5], is_label=False):
20 |
21 | itk_image = sitk.ReadImage(image_file)
22 |
23 | original_spacing = itk_image.GetSpacing()
24 | original_size = itk_image.GetSize()
25 |
26 | out_size = [
27 | int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
28 | int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
29 | int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]
30 |
31 | resample = sitk.ResampleImageFilter()
32 | resample.SetOutputSpacing(out_spacing)
33 | resample.SetSize(out_size)
34 | resample.SetOutputDirection(itk_image.GetDirection())
35 | resample.SetOutputOrigin(itk_image.GetOrigin())
36 | resample.SetTransform(sitk.Transform())
37 | resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())
38 |
39 | if is_label:
40 | resample.SetInterpolator(sitk.sitkNearestNeighbor)
41 | else:
42 | resample.SetInterpolator(sitk.sitkBSpline)
43 |
44 | itk_new = resample.Execute(itk_image)
45 |
46 | return sitk.GetArrayFromImage(itk_new).transpose(2, 1, 0)
47 |
48 |
49 | def select_largest_region(img_bin):
50 |
51 | # N is the number of connected components
52 | labels_out, N = connected_components(img_bin, connectivity=26, return_N=True)
53 | num_labels = []
54 | for segid in range(1, N + 1):
55 | extracted_image = labels_out * (labels_out == segid)
56 | num_labels.append(np.array([segid, extracted_image.sum()]))
57 |
58 | num_labels = np.array(num_labels)
59 | topk = np.argsort(num_labels, axis=0)[::-1]
60 | top1 = num_labels[topk[0][1]][0]
61 |
62 | largest_mask = np.zeros(img_bin.shape, dtype=np.uint8)
63 | largest_mask[labels_out == top1] = 255
64 |
65 | return largest_mask
66 |
67 |
68 | def crop_body(image):
69 |
70 | image = ((image - image.min()) / (image.max() - image.min())) * 255
71 | image = image.astype("uint8")
72 | blur = cv2.GaussianBlur(image, (5, 5), 0)
73 | thresh, mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
74 | mask_body = select_largest_region(mask) # remove CT bed
75 |
76 | # plt.imshow(image)
77 | # plt.show()
78 | # plt.imshow(mask)
79 | # plt.show()
80 | # plt.imshow(mask_body)
81 | # plt.show()
82 |
83 | x_range, y_range = np.where(mask_body != 0)
84 | x_min, y_min = x_range.min(), y_range.min()
85 | x_max, y_max = x_range.max(), y_range.max()
86 |
87 | return x_min, x_max, y_min, y_max, mask_body
88 |
89 |
90 | files = glob('/mnt/c/chong/data/Bowel/alldata/small_bowel/*/*/image.nii.gz') # resampled data
91 |
92 | for iii, image_file in enumerate(files):
93 |
94 | label_file = image_file.replace('image.nii.gz', 'masks.nii.gz')
95 |
96 | image_info = nib.load(image_file)
97 | image_arr = image_info.get_fdata()
98 |
99 | label_info = nib.load(label_file)
100 | label_arr = np.round(label_info.get_fdata()).astype(int)
101 |
102 | x, y, z = image_arr.shape
103 |
104 | image = image_arr[:, :, :].copy()
105 | image = np.clip(image, -1024, 1000) # remove abnormal intensity
106 |
107 | image_mean = np.mean(image, axis=2)
108 | body_x_min, body_x_max, body_y_min, body_y_max, mask_body = crop_body(image_mean)
109 |
110 | xs = body_x_min - 0
111 | xe = body_x_max + 0
112 | ys = body_y_min - 20
113 | ye = body_y_max + 0
114 |
115 | xs = 0 if xs < 0 else xs
116 | xe = x if xe > x else xe
117 |
118 | ys = 0 if ys < 0 else ys
119 | ye = y if ye > y else ye
120 |
121 | image_body = image_arr[xs:xe, ys:ye, :]
122 | masks_body = label_arr[xs:xe, ys:ye, :]
123 |
124 | print(iii, 'crop_shape', image_body.shape, 'ori_shape', image_arr.shape, image_file)
125 |
126 | # save preprocessed data
127 | #######################################################################################################
128 | if not os.path.exists(os.path.dirname(image_file.replace('/alldata', '/crop_preprocessed'))):
129 | os.makedirs(os.path.dirname(image_file.replace('/alldata', '/crop_preprocessed')))
130 |
131 | mask_save_path = os.path.dirname(os.path.dirname(image_file)).replace('/alldata', '/crop_preprocessed')
132 | cv2.imwrite(mask_save_path + '/' + image_file.split('/')[-2] + '.jpg', mask_body.transpose())
133 |
134 | nib.save(nib.Nifti1Image(image_body, header=image_info.header, affine=image_info.affine), image_file.replace('/alldata', '/crop_preprocessed'))
135 | nib.save(nib.Nifti1Image(masks_body, header=label_info.header, affine=label_info.affine), label_file.replace('/alldata', '/crop_preprocessed'))
136 | #######################################################################################################
137 |
138 | # save downsampled data
139 | #######################################################################################################
140 | image_downsampled = image_body[::2, ::2, ::2]
141 | masks_downsampled = masks_body[::2, ::2, ::2]
142 |
143 | if not os.path.exists(os.path.dirname(image_file.replace('/alldata', '/crop_downsample'))):
144 | os.makedirs(os.path.dirname(image_file.replace('/alldata', '/crop_downsample')))
145 |
146 | nib.save(nib.Nifti1Image(image_downsampled, header=image_info.header, affine=image_info.affine), image_file.replace('/alldata', '/crop_downsample'))
147 | nib.save(nib.Nifti1Image(masks_downsampled, header=label_info.header, affine=label_info.affine), label_file.replace('/alldata', '/crop_downsample'))
148 | #######################################################################################################
149 |
150 | # break
151 |
--------------------------------------------------------------------------------
/bowel_coarseseg/infer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from __future__ import division
3 | import argparse
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | from scipy.ndimage.filters import gaussian_filter
9 | from datasets.BowelDatasetCoarseSeg import BowelCoarseSeg
10 | from tools.loss import *
11 | import os, math
12 | import numpy as np
13 | import nibabel as nib
14 | import loc_model
15 |
16 |
17 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # "1, 2"
18 | print("GPU ID:", os.environ['CUDA_VISIBLE_DEVICES'])
19 |
20 |
21 | name_to_label = {'rectum': 1, 'sigmoid': 2, 'colon': 3, 'small': 4, 'duodenum': 5}
22 |
23 |
24 | def main():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--ngpu', type=int, default=1)
27 | parser.add_argument('--batch_size', type=int, default=8) # 8 (2 GPU)
28 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
29 | parser.add_argument('--nEpochs', type=int, default=1500, help='total training epoch')
30 |
31 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
32 | parser.add_argument('--weight_decay', default=1e-8, type=float, metavar='W', help='weight decay (default: 1e-8)')
33 | parser.add_argument('--eval_interval', default=5, type=int, help='evaluate interval on validation set')
34 | parser.add_argument('--save_dir', type=str, default=None)
35 | parser.add_argument('--seed', type=int, default=1)
36 | parser.add_argument('--deterministic', type=bool, default=True)
37 | args = parser.parse_args()
38 |
39 |
40 |
41 | print("build Bowel Localisation Network")
42 | model = loc_model.BowelLocNet(elu=False)
43 |
44 | model_path = "./exp/BowelLocNet.20230208_1536/partial_5C_dict_1300.pth"
45 | print('load checkpoint:', model_path)
46 | model.load_state_dict(torch.load(model_path))
47 |
48 | model = model.cuda()
49 | model = nn.parallel.DataParallel(model)
50 | model.eval()
51 |
52 |
53 | print("loading fully_labeled dataset")
54 | fully_labeled_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Fully_labeled_5C'
55 | testSet_fully_labeled = BowelCoarseSeg(fully_labeled_dir, mode="test", transform=False, dataset_name="fully_labeled", save_dir=None)
56 |
57 | print("loading small dataset")
58 | smallbowel_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Smallbowel'
59 | testSet_small = BowelCoarseSeg(smallbowel_dir, mode="test", transform=False, dataset_name="smallbowel", save_dir=None)
60 |
61 | print("loading colon_sigmoid dataset")
62 | colon_sigmoid_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Colon_Sigmoid/colon_sigmoid'
63 | testSet_colon_sigmoid = BowelCoarseSeg(colon_sigmoid_dir, mode="test", transform=False, dataset_name="colon_sigmoid", save_dir=None)
64 |
65 |
66 | patch_size = (64, 128, 128)
67 | overlap = (32, 64, 64)
68 |
69 |
70 | n_test_samples = 0.
71 | dice_score = 0.0
72 | with torch.no_grad():
73 | # for data_name in testSet_fully_labeled.imgs:
74 | for data_name in testSet_small.imgs:
75 | # for data_name in testSet_colon_sigmoid.imgs:
76 |
77 | n_test_samples = n_test_samples + 1
78 |
79 | img_info = nib.load(data_name)
80 | img = img_info.get_fdata()
81 | img = np.clip(img, -1024, 1000)
82 | img = normalize_volume(img)
83 |
84 | label = nib.load(data_name.replace('image.nii.gz', 'masks.nii.gz')).get_fdata()
85 | label = np.round(label)
86 |
87 | img_arr, label_arr, zs_pad, ze_pad = padding_z(img, label, min_z=100)
88 | img_arr, label_arr, xs_pad, xe_pad = padding_x(img_arr, label_arr, min_x=160)
89 | img_arr, label_arr, ys_pad, ye_pad = padding_y(img_arr, label_arr, min_y=160)
90 |
91 | img_arr = img_arr.transpose((2, 0, 1))
92 | label_arr = label_arr.transpose((2, 0, 1))
93 |
94 | data = torch.from_numpy(img_arr[np.newaxis, np.newaxis, :].astype(np.float32))
95 | target = torch.from_numpy(label_arr[np.newaxis, np.newaxis, :].astype(np.int32))
96 |
97 | data, target = Variable(data.cuda()), Variable(target.cuda())
98 |
99 | b, _, z, x, y = data.shape
100 |
101 | zs = list(range(0, z, patch_size[0] - overlap[0]))
102 | xs = list(range(0, x, patch_size[1] - overlap[1]))
103 | ys = list(range(0, y, patch_size[2] - overlap[2]))
104 |
105 | gaussian_map = torch.from_numpy(_get_gaussian(patch_size, sigma_scale=1. / 4)).cuda()
106 | output_all = torch.zeros((b, 6, z, x, y)).cuda()
107 | for zzz in zs:
108 | for xxx in xs:
109 | for yyy in ys:
110 | if xxx + patch_size[1] > x:
111 | xxx = x - patch_size[1]
112 | if yyy + patch_size[2] > y:
113 | yyy = y - patch_size[2]
114 | if zzz + patch_size[0] > z:
115 | zzz = z - patch_size[0]
116 | candidate_patch = data[:, :, zzz:zzz+patch_size[0], xxx:xxx+patch_size[1], yyy:yyy+patch_size[2]]
117 | output = model(candidate_patch) # b, 6, z, x, y prob output
118 | output_all[:, :, zzz:zzz+patch_size[0], xxx:xxx+patch_size[1], yyy:yyy+patch_size[2]] += output * gaussian_map
119 | # remove padded slice
120 | output_all = output_all[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad]
121 | target = target[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad]
122 | result = torch.argmax(output_all, dim=1, keepdim=True)
123 |
124 |
125 | # pos_cls = [1, 2, 3, 4, 5] # fully labeled dataset
126 | pos_cls = [4] # small bowel dataset
127 | # pos_cls = [2, 3] # colon_sigmoid dataset
128 |
129 |
130 | dice_c = []
131 | for cls in pos_cls:
132 | dice_c.append(dice_similarity((result == cls).contiguous(), (target == cls).contiguous()).item())
133 | dice_c = np.array(dice_c)
134 | dice_score += dice_c
135 |
136 | print('Name: {}, Dice: {}'.format(data_name, dice_c))
137 |
138 | # np_save = result.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.int32)
139 | # nib.save(nib.Nifti1Image(np_save, header=img_info.header, affine=img_info.affine), data_name.replace('image.nii.gz', 'masks_partial_C5.nii.gz'))
140 |
141 | dice_score /= n_test_samples
142 |
143 | print('\nMean test: Dice: {}\n'.format(dice_score))
144 |
145 |
146 | def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
147 | tmp = np.zeros(patch_size)
148 | center_coords = [i // 2 for i in patch_size]
149 | sigmas = [i * sigma_scale for i in patch_size]
150 | tmp[tuple(center_coords)] = 1
151 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
152 | gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
153 | gaussian_importance_map = gaussian_importance_map.astype(np.float32)
154 |
155 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
156 | gaussian_importance_map[gaussian_importance_map == 0] = np.min(
157 | gaussian_importance_map[gaussian_importance_map != 0])
158 |
159 | return gaussian_importance_map
160 |
161 |
162 | def normalize_volume(img):
163 | img_array = (img - img.min()) / (img.max() - img.min())
164 | return img_array
165 |
166 |
167 | def padding_z(img, label, min_z):
168 | x, y, z = img.shape
169 | if z >= min_z:
170 | return img, label, 0, z
171 | else:
172 | num_pad = min_z - z
173 | num_top_pad = num_pad // 2
174 | top_pad = np.zeros((x, y, num_top_pad), dtype=np.float64)
175 | bottom_pad = np.zeros((x, y, num_pad - num_top_pad), dtype=np.float64)
176 | img = np.concatenate((bottom_pad, img, top_pad), axis=2)
177 | label = np.concatenate((bottom_pad, label, top_pad), axis=2)
178 | return img, label, bottom_pad.shape[2], img.shape[2] - top_pad.shape[2]
179 |
180 |
181 | def padding_x(img, label, min_x):
182 | x, y, z = img.shape
183 | if x >= min_x:
184 | return img, label, 0, x
185 | else:
186 | num_pad = min_x - x
187 | num_top_pad = num_pad // 2
188 | top_pad = np.zeros((num_top_pad, y, z), dtype=np.float64)
189 | bottom_pad = np.zeros((num_pad - num_top_pad, y, z), dtype=np.float64)
190 | img = np.concatenate((bottom_pad, img, top_pad), axis=0)
191 | label = np.concatenate((bottom_pad, label, top_pad), axis=0)
192 | return img, label, bottom_pad.shape[0], img.shape[0] - top_pad.shape[0]
193 |
194 |
195 | def padding_y(img, label, min_y):
196 | x, y, z = img.shape
197 | if y >= min_y:
198 | return img, label, 0, y
199 | else:
200 | num_pad = min_y - y
201 | num_top_pad = num_pad // 2
202 | top_pad = np.zeros((x, num_top_pad, z), dtype=np.float64)
203 | bottom_pad = np.zeros((x, num_pad - num_top_pad, z), dtype=np.float64)
204 | img = np.concatenate((bottom_pad, img, top_pad), axis=1)
205 | label = np.concatenate((bottom_pad, label, top_pad), axis=1)
206 | return img, label, bottom_pad.shape[1], img.shape[1] - top_pad.shape[1]
207 |
208 |
209 | if __name__ == '__main__':
210 | main()
211 |
--------------------------------------------------------------------------------
/bowel_fineseg/infer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from __future__ import division
3 | import argparse
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import numpy as np
9 | from torch.utils.data import DataLoader
10 | from scipy.ndimage.filters import gaussian_filter
11 | from tools.loss import *
12 | from tools import utils
13 | import os, math
14 | import nibabel as nib
15 | import seg_model
16 | from datasets.BowelDatasetFineSeg import BowelFineSeg
17 |
18 |
19 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' # "1, 2"
20 | print("GPU ID:", os.environ['CUDA_VISIBLE_DEVICES'])
21 |
22 |
23 | name_to_label = {'rectum': 1, 'sigmoid': 2, 'colon': 3, 'small': 4, 'duodenum': 5}
24 |
25 |
26 | def main():
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--ngpu', type=int, default=1)
29 | parser.add_argument('--batch_size', type=int, default=4) # 4 (single GPU)
30 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
31 | parser.add_argument('--nEpochs_base', type=int, default=501, help='total epoch number for base segmentor')
32 | parser.add_argument('--nEpochs_meta', type=int, default=201, help='total epoch number for meta segmentor')
33 |
34 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
35 | parser.add_argument('--weight-decay', '--wd', default=1e-8, type=float, metavar='W', help='weight decay')
36 | parser.add_argument('--eval_interval', default=1, type=int, help='evaluate interval on validation set')
37 | parser.add_argument('--temp', default=0.7, type=float, help='temperature for meta segmentor')
38 | parser.add_argument('--save_dir', type=str, default=None)
39 | parser.add_argument('--seed', type=int, default=1)
40 | parser.add_argument('--deterministic', type=bool, default=True)
41 | args = parser.parse_args()
42 |
43 |
44 | print("build BowelNet")
45 | model = seg_model.BowelNet(elu=False)
46 |
47 | model_path = "./BowelNet.20230207_1744/rectum_meta_195.pth"
48 | print('load checkpoint:', model_path)
49 | model.load_state_dict(torch.load(model_path))
50 |
51 | model = model.cuda()
52 | model = nn.parallel.DataParallel(model)
53 | model.eval()
54 |
55 |
56 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_small/'
57 | # bowel_name = 'small'
58 |
59 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_colon/'
60 | # bowel_name = 'colon'
61 |
62 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_sigmoid/'
63 | # bowel_name = 'sigmoid'
64 |
65 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_duodenum/'
66 | # bowel_name = 'duodenum'
67 |
68 | data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_rectum/'
69 | bowel_name = 'rectum'
70 |
71 |
72 | print("loading test set")
73 | testSet = BowelFineSeg(root=data_dir, mode="test", transform=False, bowel_name=bowel_name, save_dir=None)
74 |
75 | if bowel_name == 'small':
76 | patch_size = (64, 192, 192)
77 | overlap = (32, 96, 96)
78 | if bowel_name == 'colon':
79 | patch_size = (64, 192, 192)
80 | overlap = (32, 96, 96)
81 | if bowel_name == 'sigmoid':
82 | patch_size = (64, 160, 160)
83 | overlap = (32, 80, 80)
84 | if bowel_name == 'duodenum':
85 | patch_size = (64, 160, 160)
86 | overlap = (32, 80, 80)
87 | if bowel_name == 'rectum':
88 | patch_size = (64, 96, 96)
89 | overlap = (32, 48, 48)
90 |
91 | n_test_samples = 0.
92 | dice_score_meta = 0.0
93 | dice_score_edge = 0.0
94 | dice_score_skele = 0.0
95 | assd_score_meta = 0.0
96 | assd_score_edge = 0.0
97 | assd_score_skele = 0.0
98 | with torch.no_grad():
99 | for iii, data_name in enumerate(testSet.imgs):
100 |
101 | n_test_samples = n_test_samples + 1
102 |
103 | img_info = nib.load(data_name)
104 | img = img_info.get_fdata()
105 | # img = np.clip(img, -1024, 1000)
106 | # img = normalize_volume(img)
107 |
108 | label = nib.load(data_name.replace('image_crop.nii.gz', 'masks_crop.nii.gz')).get_fdata()
109 | label = np.round(label)
110 |
111 | img_arr, label_arr, zs_pad, ze_pad = padding_z(img, label, min_z=80)
112 | img_arr, label_arr, xs_pad, xe_pad = padding_x(img_arr, label_arr, min_x=120)
113 | img_arr, label_arr, ys_pad, ye_pad = padding_y(img_arr, label_arr, min_y=120)
114 |
115 | img_arr = img_arr.transpose((2, 0, 1))
116 | label_arr = label_arr.transpose((2, 0, 1))
117 |
118 | data = torch.from_numpy(img_arr[np.newaxis, np.newaxis, :].astype(np.float32))
119 | target = torch.from_numpy(label_arr[np.newaxis, np.newaxis, :].astype(np.int32))
120 | data, target = Variable(data.cuda()), Variable(target.cuda())
121 |
122 | b, _, z, x, y = data.shape
123 |
124 | zs = list(range(0, z, patch_size[0] - overlap[0]))
125 | xs = list(range(0, x, patch_size[1] - overlap[1]))
126 | ys = list(range(0, y, patch_size[2] - overlap[2]))
127 |
128 | gaussian_map = torch.from_numpy(_get_gaussian(patch_size, sigma_scale=1. / 4)).cuda()
129 | out_meta_all = torch.zeros((b, 2, z, x, y)).cuda()
130 | out_edge_all = torch.zeros((b, 2, z, x, y)).cuda()
131 | out_skele_all = torch.zeros((b, 2, z, x, y)).cuda()
132 | for zzz in zs:
133 | for xxx in xs:
134 | for yyy in ys:
135 | if xxx + patch_size[1] > x:
136 | xxx = x - patch_size[1]
137 | if yyy + patch_size[2] > y:
138 | yyy = y - patch_size[2]
139 | if zzz + patch_size[0] > z:
140 | zzz = z - patch_size[0]
141 | candidate_patch = data[:, :, zzz:zzz + patch_size[0], xxx:xxx + patch_size[1], yyy:yyy + patch_size[2]]
142 | out_meta, out_edge, out_skele = model(candidate_patch, 'meta') # b, 2, z, x, y
143 | prob_meta = F.softmax(out_meta, dim=1)
144 | prob_edge = F.softmax(out_edge, dim=1)
145 | prob_skele = F.softmax(out_skele, dim=1)
146 | out_meta_all[:, :, zzz:zzz + patch_size[0], xxx:xxx + patch_size[1], yyy:yyy + patch_size[2]] += prob_meta * gaussian_map
147 | out_edge_all[:, :, zzz:zzz + patch_size[0], xxx:xxx + patch_size[1], yyy:yyy + patch_size[2]] += prob_edge * gaussian_map
148 | out_skele_all[:, :, zzz:zzz + patch_size[0], xxx:xxx + patch_size[1], yyy:yyy + patch_size[2]] += prob_skele * gaussian_map
149 | # remove padded slice
150 | out_meta_all = out_meta_all[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad]
151 | out_edge_all = out_edge_all[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad]
152 | out_skele_all = out_skele_all[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad]
153 | target = target[:, :, zs_pad:ze_pad, xs_pad:xe_pad, ys_pad:ye_pad]
154 |
155 | dice_meta = dice_score_metric(out_meta_all, target)
156 | dice_edge = dice_score_metric(out_edge_all, target)
157 | dice_skele = dice_score_metric(out_skele_all, target)
158 | dice_score_meta += dice_meta
159 | dice_score_edge += dice_edge
160 | dice_score_skele += dice_skele
161 |
162 | assd_meta = utils.assd(torch.argmax(out_meta_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy(), target.squeeze().permute(1, 2, 0).cpu().numpy()) # x, y, z
163 | assd_edge = utils.assd(torch.argmax(out_edge_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy(), target.squeeze().permute(1, 2, 0).cpu().numpy())
164 | assd_skele = utils.assd(torch.argmax(out_skele_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy(), target.squeeze().permute(1, 2, 0).cpu().numpy())
165 | assd_score_meta += assd_meta
166 | assd_score_edge += assd_edge
167 | assd_score_skele += assd_skele
168 |
169 | print('Name: {}, Dice_meta:{:.4f}, Dice_edge:{:.4f}, Dice_skele:{:.4f}, ASSD_meta:{:.4f}, ASSD_edge:{:.4f}, ASSD_skele:{:.4f}'.
170 | format(data_name, dice_meta, dice_edge, dice_skele, assd_meta, assd_edge, assd_skele))
171 |
172 | # np_save = torch.argmax(out_meta_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.int32)
173 | # np_save[np_save == 1] = name_to_label[bowel_name]
174 | # nib.save(nib.Nifti1Image(np_save, header=img_info.header, affine=img_info.affine), data_name.replace('image_crop.nii.gz', 'masks_crop_meta.nii.gz'))
175 | #
176 | # np_save = torch.argmax(out_edge_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.int32)
177 | # np_save[np_save == 1] = name_to_label[bowel_name]
178 | # nib.save(nib.Nifti1Image(np_save, header=img_info.header, affine=img_info.affine), data_name.replace('image_crop.nii.gz', 'masks_crop_edge.nii.gz'))
179 | #
180 | # np_save = torch.argmax(out_skele_all, dim=1).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.int32)
181 | # np_save[np_save == 1] = name_to_label[bowel_name]
182 | # nib.save(nib.Nifti1Image(np_save, header=img_info.header, affine=img_info.affine), data_name.replace('image_crop.nii.gz', 'masks_crop_skele.nii.gz'))
183 |
184 | # break
185 |
186 | dice_score_meta /= n_test_samples
187 | dice_score_edge /= n_test_samples
188 | dice_score_skele /= n_test_samples
189 | assd_score_meta /= n_test_samples
190 | assd_score_edge /= n_test_samples
191 | assd_score_skele /= n_test_samples
192 |
193 | print('\nMean test: Dice_meta:{:.4f}, Dice_edge:{:.4f}, Dice_skele:{:.4f}, '
194 | 'ASSD_meta: {:.4f}, ASSD_edge: {:.4f}, ASSD_skele: {:.4f}\n'.
195 | format(dice_score_meta, dice_score_edge, dice_score_skele,
196 | assd_score_meta, assd_score_edge, assd_score_skele))
197 |
198 |
199 | def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
200 | tmp = np.zeros(patch_size)
201 | center_coords = [i // 2 for i in patch_size]
202 | sigmas = [i * sigma_scale for i in patch_size]
203 | tmp[tuple(center_coords)] = 1
204 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
205 | gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
206 | gaussian_importance_map = gaussian_importance_map.astype(np.float32)
207 |
208 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
209 | gaussian_importance_map[gaussian_importance_map == 0] = np.min(
210 | gaussian_importance_map[gaussian_importance_map != 0])
211 |
212 | return gaussian_importance_map
213 |
214 |
215 | def normalize_volume(img):
216 | img_array = (img - img.min()) / (img.max() - img.min())
217 | return img_array
218 |
219 |
220 | def padding_z(img, label, min_z):
221 | x, y, z = img.shape
222 | if z >= min_z:
223 | return img, label, 0, z
224 | else:
225 | num_pad = min_z - z
226 | num_top_pad = num_pad // 2
227 | top_pad = np.zeros((x, y, num_top_pad), dtype=np.float64)
228 | bottom_pad = np.zeros((x, y, num_pad - num_top_pad), dtype=np.float64)
229 | img = np.concatenate((bottom_pad, img, top_pad), axis=2)
230 | label = np.concatenate((bottom_pad, label, top_pad), axis=2)
231 | return img, label, bottom_pad.shape[2], img.shape[2] - top_pad.shape[2]
232 |
233 |
234 | def padding_x(img, label, min_x):
235 | x, y, z = img.shape
236 | if x >= min_x:
237 | return img, label, 0, x
238 | else:
239 | num_pad = min_x - x
240 | num_top_pad = num_pad // 2
241 | top_pad = np.zeros((num_top_pad, y, z), dtype=np.float64)
242 | bottom_pad = np.zeros((num_pad - num_top_pad, y, z), dtype=np.float64)
243 | img = np.concatenate((bottom_pad, img, top_pad), axis=0)
244 | label = np.concatenate((bottom_pad, label, top_pad), axis=0)
245 | return img, label, bottom_pad.shape[0], img.shape[0] - top_pad.shape[0]
246 |
247 |
248 | def padding_y(img, label, min_y):
249 | x, y, z = img.shape
250 | if y >= min_y:
251 | return img, label, 0, y
252 | else:
253 | num_pad = min_y - y
254 | num_top_pad = num_pad // 2
255 | top_pad = np.zeros((x, num_top_pad, z), dtype=np.float64)
256 | bottom_pad = np.zeros((x, num_pad - num_top_pad, z), dtype=np.float64)
257 | img = np.concatenate((bottom_pad, img, top_pad), axis=1)
258 | label = np.concatenate((bottom_pad, label, top_pad), axis=1)
259 | return img, label, bottom_pad.shape[1], img.shape[1] - top_pad.shape[1]
260 |
261 |
262 | if __name__ == '__main__':
263 | main()
264 |
--------------------------------------------------------------------------------
/bowel_fineseg/seg_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def passthrough(x, **kwargs):
8 | return x
9 |
10 |
11 | def ELUCons(elu, nchan):
12 | if elu:
13 | return nn.ELU(inplace=True)
14 | else:
15 | return nn.PReLU(nchan)
16 |
17 |
18 | # normalization between sub-volumes is necessary
19 | # for good performance
20 | class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
21 | def _check_input_dim(self, input):
22 | if input.dim() != 5:
23 | raise ValueError('expected 5D input (got {}D input)'
24 | .format(input.dim()))
25 | # super(ContBatchNorm3d, self)._check_input_dim(input)
26 |
27 | def forward(self, input):
28 | self._check_input_dim(input)
29 | return F.batch_norm(
30 | input, self.running_mean, self.running_var, self.weight, self.bias,
31 | True, self.momentum, self.eps)
32 |
33 |
34 | class LUConv(nn.Module):
35 | def __init__(self, nchan, elu):
36 | super(LUConv, self).__init__()
37 | self.relu1 = ELUCons(elu, nchan)
38 | self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
39 | self.bn1 = ContBatchNorm3d(nchan)
40 |
41 | def forward(self, x):
42 | out = self.relu1(self.bn1(self.conv1(x)))
43 | return out
44 |
45 |
46 | def _make_nConv(nchan, depth, elu):
47 | layers = []
48 | for _ in range(depth):
49 | layers.append(LUConv(nchan, elu))
50 | return nn.Sequential(*layers)
51 |
52 |
53 | class InputTransition(nn.Module):
54 | def __init__(self, inChans, outChans, elu):
55 | super(InputTransition, self).__init__()
56 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=5, padding=2)
57 | self.bn1 = ContBatchNorm3d(outChans)
58 | self.relu1 = ELUCons(elu, outChans)
59 |
60 | def forward(self, x):
61 | out = self.relu1(self.bn1(self.conv1(x)))
62 | return out
63 |
64 |
65 | class DownTransition(nn.Module):
66 | def __init__(self, inChans, nConvs, elu, dropout=False):
67 | super(DownTransition, self).__init__()
68 | outChans = 2*inChans
69 | self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
70 | self.bn1 = ContBatchNorm3d(outChans)
71 | self.do1 = passthrough
72 | self.relu1 = ELUCons(elu, outChans)
73 | self.relu2 = ELUCons(elu, outChans)
74 | if dropout:
75 | self.do1 = nn.Dropout3d(p=0.2) # 0.5
76 | self.ops = _make_nConv(outChans, nConvs, elu)
77 |
78 | def forward(self, x):
79 | down = self.relu1(self.bn1(self.down_conv(x)))
80 | out = self.do1(down)
81 | out = self.ops(out)
82 | out = self.relu2(torch.add(out, down))
83 | return out
84 |
85 |
86 | class UpTransition(nn.Module):
87 | def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
88 | super(UpTransition, self).__init__()
89 | self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
90 | self.bn1 = ContBatchNorm3d(outChans // 2)
91 | self.do1 = passthrough
92 | self.do2 = nn.Dropout3d(p=0.2) # 0.5
93 | self.relu1 = ELUCons(elu, outChans // 2)
94 | self.relu2 = ELUCons(elu, outChans)
95 | if dropout:
96 | self.do1 = nn.Dropout3d(p=0.2) # 0.5
97 | self.ops = _make_nConv(outChans, nConvs, elu)
98 |
99 | def forward(self, x, skipx):
100 | out = self.do1(x)
101 | skipxdo = self.do2(skipx)
102 | out = self.relu1(self.bn1(self.up_conv(out)))
103 | xcat = torch.cat((out, skipxdo), 1)
104 | out = self.ops(xcat)
105 | out = self.relu2(torch.add(out, xcat))
106 | return out
107 |
108 |
109 | class OutputTransition(nn.Module):
110 | def __init__(self, inChans, outChans, elu):
111 | super(OutputTransition, self).__init__()
112 | self.conv1 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
113 | self.bn1 = ContBatchNorm3d(16)
114 | self.relu1 = ELUCons(elu, 16)
115 | self.conv2 = nn.Conv3d(16, 16, kernel_size=1)
116 |
117 | self.conv3 = nn.Conv3d(inChans, 16, kernel_size=1)
118 | self.relu3 = ELUCons(elu, 16)
119 |
120 | self.conv5 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
121 | self.bn5 = ContBatchNorm3d(16)
122 | self.relu5 = ELUCons(elu, 16)
123 |
124 | self.conv6 = nn.Conv3d(16, outChans, kernel_size=1)
125 |
126 | self.dropout = nn.Dropout3d(p=0.2)
127 |
128 | def forward(self, x):
129 |
130 | out_2 = self.conv2(self.relu1(self.bn1(self.conv1(x))))
131 | out_2 = self.dropout(out_2)
132 |
133 | out_3 = self.relu3(self.conv3(out_2))
134 |
135 | out_6 = self.conv6(self.relu5(self.bn5(self.conv5(out_3))))
136 |
137 | return out_6
138 |
139 |
140 |
141 | class OutputTransition_Edge(nn.Module):
142 | def __init__(self, inChans, outChans, elu):
143 | super(OutputTransition_Edge, self).__init__()
144 | self.conv1 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
145 | self.bn1 = ContBatchNorm3d(16)
146 | self.relu1 = ELUCons(elu, 16)
147 | self.conv2 = nn.Conv3d(16, 16, kernel_size=1)
148 |
149 | self.conv3 = nn.Conv3d(inChans, 16, kernel_size=1)
150 | self.relu3 = ELUCons(elu, 16)
151 |
152 | self.conv5 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
153 | self.bn5 = ContBatchNorm3d(16)
154 | self.relu5 = ELUCons(elu, 16)
155 |
156 | self.conv6 = nn.Conv3d(16, outChans, kernel_size=1)
157 |
158 | self.dropout = nn.Dropout3d(p=0.2)
159 |
160 | def forward(self, x):
161 |
162 | out_2 = self.conv2(self.relu1(self.bn1(self.conv1(x))))
163 | out_2 = self.dropout(out_2) # added
164 |
165 | out_3 = self.relu3(self.conv3(out_2))
166 |
167 | out_6 = self.conv6(self.relu5(self.bn5(self.conv5(out_3))))
168 |
169 | return out_6
170 |
171 |
172 |
173 | class OutputTransition_Skele(nn.Module):
174 | def __init__(self, inChans, outChans, elu):
175 | super(OutputTransition_Skele, self).__init__()
176 | self.conv1 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
177 | self.bn1 = ContBatchNorm3d(16)
178 | self.relu1 = ELUCons(elu, 16)
179 | self.conv2 = nn.Conv3d(16, 16, kernel_size=1)
180 |
181 | self.conv3 = nn.Conv3d(inChans, 16, kernel_size=1)
182 |
183 | self.conv5 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
184 | self.bn5 = ContBatchNorm3d(16)
185 |
186 | self.conv6 = nn.Conv3d(16, outChans, kernel_size=1)
187 |
188 | self.dropout = nn.Dropout3d(p=0.2)
189 |
190 | def forward(self, x):
191 |
192 | out_2 = self.conv2(self.relu1(self.bn1(self.conv1(x))))
193 | out_2 = self.dropout(out_2) # added
194 |
195 | out_3 = self.conv3(out_2)
196 |
197 | out_6 = self.conv6(self.bn5(self.conv5(out_3)))
198 |
199 | return out_6
200 |
201 |
202 | class MetaTransition_In(nn.Module):
203 | def __init__(self, inChans, outChans, elu):
204 | super(MetaTransition_In, self).__init__()
205 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=3, padding=1)
206 | self.bn1 = ContBatchNorm3d(outChans)
207 | self.relu1 = ELUCons(elu, outChans)
208 |
209 | self.conv2 = nn.Conv3d(outChans, outChans // 2, kernel_size=3, padding=1)
210 | self.bn2 = ContBatchNorm3d(outChans // 2)
211 | self.relu2 = ELUCons(elu, outChans // 2)
212 |
213 | self.conv3 = nn.Conv3d(outChans//2, outChans // 2, kernel_size=1)
214 |
215 | self.dropout = nn.Dropout3d(p=0.2)
216 |
217 | def forward(self, x):
218 |
219 | out_1 = self.relu1(self.bn1(self.conv1(x)))
220 | out_2 = self.dropout(self.relu2(self.bn2(self.conv2(out_1))))
221 | out_3 = self.conv3(out_2)
222 |
223 | return out_3
224 |
225 |
226 | class MetaTransition_Fused(nn.Module):
227 | def __init__(self, inChans, outChans, elu):
228 | super(MetaTransition_Fused, self).__init__()
229 | self.conv1 = nn.Conv3d(inChans, inChans, kernel_size=3, padding=1)
230 | self.bn1 = ContBatchNorm3d(inChans)
231 | self.relu1 = ELUCons(elu, inChans)
232 |
233 | self.conv2 = nn.Conv3d(inChans, inChans // 2, kernel_size=3, padding=1)
234 | self.bn2 = ContBatchNorm3d(inChans // 2)
235 | self.relu2 = ELUCons(elu, inChans // 2)
236 |
237 | self.conv3 = nn.Conv3d(inChans // 2, outChans, kernel_size=1)
238 |
239 | self.dropout = nn.Dropout3d(p=0.2)
240 |
241 | def forward(self, x):
242 |
243 | out_1 = self.relu1(self.bn1(self.conv1(x)))
244 | out_1 = self.dropout(out_1)
245 | out_2 = self.relu2(self.bn2(self.conv2(out_1)))
246 | out_3 = self.conv3(out_2)
247 |
248 | return out_3
249 |
250 |
251 | class BowelNet(nn.Module):
252 | def __init__(self, elu=True):
253 | super(BowelNet, self).__init__()
254 |
255 | num_baseC = 4 # 8
256 |
257 | self.in_tr = InputTransition(1, 2 * num_baseC, elu)
258 |
259 | # skeleton segmentor
260 | #########################################################################
261 | self.down_tr32_skele = DownTransition(2 * num_baseC, 1, elu)
262 | self.down_tr64_skele = DownTransition(4 * num_baseC, 2, elu)
263 | self.down_tr128_skele = DownTransition(8 * num_baseC, 2, elu, dropout=True)
264 | self.down_tr256_skele = DownTransition(16 * num_baseC, 2, elu, dropout=True)
265 | self.up_tr256_skele = UpTransition(32 * num_baseC, 32 * num_baseC, 2, elu, dropout=True)
266 | self.up_tr128_skele = UpTransition(32 * num_baseC, 16 * num_baseC, 2, elu, dropout=True)
267 | self.up_tr64_skele = UpTransition(16 * num_baseC, 8 * num_baseC, 1, elu)
268 | self.up_tr32_skele = UpTransition(8 * num_baseC, 4 * num_baseC, 1, elu)
269 | self.out_tr_skele = OutputTransition_Skele(4 * num_baseC, 3, elu) # no relu for negative skeleton flux value
270 | self.out_tr_skele_mask = OutputTransition(4 * num_baseC, 2, elu)
271 |
272 | # boundary segmentor
273 | #########################################################################
274 | self.down_tr32_edge = DownTransition(2 * num_baseC, 1, elu)
275 | self.down_tr64_edge = DownTransition(4 * num_baseC, 2, elu)
276 | self.down_tr128_edge = DownTransition(8 * num_baseC, 2, elu, dropout=True)
277 | self.down_tr256_edge = DownTransition(16 * num_baseC, 2, elu, dropout=True)
278 | self.up_tr256_edge = UpTransition(32 * num_baseC, 32 * num_baseC, 2, elu, dropout=True)
279 | self.up_tr128_edge = UpTransition(32 * num_baseC, 16 * num_baseC, 2, elu, dropout=True)
280 | self.up_tr64_edge = UpTransition(16 * num_baseC, 8 * num_baseC, 1, elu)
281 | self.up_tr32_edge = UpTransition(8 * num_baseC, 4 * num_baseC, 1, elu)
282 |
283 | self.out_tr_edge = OutputTransition_Edge(4 * num_baseC, 1, elu)
284 | self.out_tr_edge_mask = OutputTransition(4 * num_baseC, 2, elu)
285 |
286 | # meta segmentor
287 | #########################################################################
288 | self.in_seg_meta = MetaTransition_In(1, 4 * num_baseC, elu)
289 | self.in_img_meta = MetaTransition_In(1, 4 * num_baseC, elu)
290 | self.fused_mask_meta = MetaTransition_Fused(4 * num_baseC, 2, elu)
291 |
292 |
293 | def forward(self, x, train_segmentor):
294 |
295 | assert train_segmentor in ['base', 'meta']
296 |
297 | out16 = self.in_tr(x)
298 |
299 | out32_skele = self.down_tr32_skele(out16)
300 | out64_skele = self.down_tr64_skele(out32_skele)
301 | out128_skele = self.down_tr128_skele(out64_skele)
302 | out256_skele = self.down_tr256_skele(out128_skele)
303 | out_skele = self.up_tr256_skele(out256_skele, out128_skele)
304 | out_skele = self.up_tr128_skele(out_skele, out64_skele)
305 | out_skele = self.up_tr64_skele(out_skele, out32_skele)
306 | out_skele = self.up_tr32_skele(out_skele, out16) # same resolution as input
307 | out_skele_seg = self.out_tr_skele(out_skele)
308 | out_skele_mask_seg = self.out_tr_skele_mask(out_skele)
309 |
310 | out32_edge = self.down_tr32_edge(out16)
311 | out64_edge = self.down_tr64_edge(out32_edge)
312 | out128_edge = self.down_tr128_edge(out64_edge)
313 | out256_edge = self.down_tr256_edge(out128_edge)
314 | out_edge = self.up_tr256_edge(out256_edge, out128_edge)
315 | out_edge = self.up_tr128_edge(out_edge, out64_edge)
316 | out_edge = self.up_tr64_edge(out_edge, out32_edge)
317 | out_edge = self.up_tr32_edge(out_edge, out16) # same resolution as input
318 | out_edge_seg = self.out_tr_edge(out_edge)
319 | out_edge_mask_seg = self.out_tr_edge_mask(out_edge)
320 |
321 | if train_segmentor == 'base':
322 | return out_skele_seg, out_skele_mask_seg, out_edge_seg, out_edge_mask_seg
323 | else:
324 | # in_seg_avg = 1.0*torch.argmax(0.5 * (F.softmax(out_edge_mask_seg, dim=1) + F.softmax(out_skele_mask_seg, dim=1)), dim=1).unsqueeze(1)
325 | in_seg_avg = 0.5 * (torch.argmax(out_edge_mask_seg, dim=1).unsqueeze(1) + torch.argmax(out_skele_mask_seg, dim=1).unsqueeze(1))
326 | seg_feat = self.in_seg_meta(in_seg_avg)
327 | img_feat = self.in_img_meta(x)
328 | comb_feat = torch.cat((img_feat, seg_feat), dim=1)
329 | out = self.fused_mask_meta(comb_feat)
330 | return out, out_edge_mask_seg, out_skele_mask_seg
--------------------------------------------------------------------------------
/bowel_coarseseg/datasets/BowelDatasetCoarseSeg.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import numpy as np
3 | import torch
4 | import torch.utils.data as data
5 | from glob import glob
6 | import os
7 | import os.path
8 | import nibabel as nib
9 | import random
10 | import cv2
11 |
12 |
13 | name_to_label = {1: 'rectum', 2: 'sigmoid', 3: 'colon', 4: 'small', 5: 'duodenum'}
14 |
15 |
16 | def split_dataset(dir, current_test, test_fraction, dataset_name, save_dir):
17 |
18 | test_split = []
19 | train_split = []
20 |
21 | sub_folder = ['Fully_labeled_5C', 'Colon_Sigmoid', 'Smallbowel']
22 |
23 | if dataset_name == 'fully_labeled':
24 | all_volumes = sorted(glob(os.path.join(dir, "*/*/image.nii.gz")))
25 | test_num = int(len(all_volumes) * test_fraction)
26 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
27 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
28 | test_split.extend(test_volumes)
29 | train_split.extend(train_volumes)
30 |
31 | if dataset_name == 'smallbowel':
32 | all_volumes = sorted(glob(os.path.join(dir, "*/*/image.nii.gz")))
33 | test_num = int(len(all_volumes) * test_fraction)
34 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
35 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
36 | test_split.extend(test_volumes)
37 | train_split.extend(train_volumes)
38 |
39 | if dataset_name == 'colon_sigmoid':
40 | all_volumes = sorted(glob(os.path.join(dir, "*/image.nii.gz")))
41 | test_num = int(len(all_volumes) * test_fraction)
42 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
43 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
44 | test_split.extend(test_volumes)
45 | train_split.extend(train_volumes)
46 |
47 | if save_dir is not None:
48 | with open(os.path.join(save_dir, dataset_name + '_train_' + str(current_test) + ".txt"), 'w') as f:
49 | for i in train_split:
50 | f.write(i + '\n')
51 |
52 | with open(os.path.join(save_dir, dataset_name + '_test_' + str(current_test) + ".txt"), 'w') as f:
53 | for i in test_split:
54 | f.write(i + '\n')
55 |
56 | return train_split, test_split
57 |
58 |
59 | def load_image_and_label(img_file, transform=False):
60 |
61 | # img_file = '/mnt/c/chong/data/Bowel/crop_downsample/Colon_Sigmoid/colon_sigmoid/0005_male/image.nii.gz'
62 |
63 | img = nib.load(img_file).get_fdata()
64 | img = np.clip(img, -1024, 1000) # remove abnormal intensity
65 | img = normalize_volume(img)
66 |
67 | label = nib.load(img_file.replace('image.nii.gz', 'masks.nii.gz')).get_fdata()
68 | label = np.round(label)
69 |
70 | if transform:
71 | op = random.choice(['ori', 'rotate', 'crop'])
72 | if op == 'rotate':
73 | img, label = rotate(img, label, degree=random.uniform(-10, 10))
74 | if op == 'crop':
75 | img, label = crop_resize(img, label, shift_size_x=10, shift_size_y=10)
76 |
77 | # zero padding
78 | img_arr, label_arr = padding_z(img, label, min_z=100)
79 | img_arr, label_arr = padding_x(img_arr, label_arr, min_x=160)
80 | img_arr, label_arr = padding_y(img_arr, label_arr, min_y=160)
81 |
82 | return img_arr.transpose((2, 0, 1)), label_arr.transpose((2, 0, 1))
83 |
84 |
85 | def padding_z(img, label, min_z):
86 | x, y, z = img.shape
87 | if z >= min_z:
88 | return img, label
89 | else:
90 | num_pad = min_z - z
91 | num_top_pad = num_pad // 2
92 | top_pad = np.zeros((x, y, num_top_pad), dtype=np.float64)
93 | bottom_pad = np.zeros((x, y, num_pad - num_top_pad), dtype=np.float64)
94 | img = np.concatenate((bottom_pad, img, top_pad), axis=2)
95 | label = np.concatenate((bottom_pad, label, top_pad), axis=2)
96 | return img, label
97 |
98 |
99 | def padding_x(img, label, min_x):
100 | x, y, z = img.shape
101 | if x >= min_x:
102 | return img, label
103 | else:
104 | num_pad = min_x - x
105 | num_top_pad = num_pad // 2
106 | top_pad = np.zeros((num_top_pad, y, z), dtype=np.float64)
107 | bottom_pad = np.zeros((num_pad - num_top_pad, y, z), dtype=np.float64)
108 | img = np.concatenate((bottom_pad, img, top_pad), axis=0)
109 | label = np.concatenate((bottom_pad, label, top_pad), axis=0)
110 | return img, label
111 |
112 |
113 | def padding_y(img, label, min_y):
114 | x, y, z = img.shape
115 | if y >= min_y:
116 | return img, label
117 | else:
118 | num_pad = min_y - y
119 | num_top_pad = num_pad // 2
120 | top_pad = np.zeros((x, num_top_pad, z), dtype=np.float64)
121 | bottom_pad = np.zeros((x, num_pad - num_top_pad, z), dtype=np.float64)
122 | img = np.concatenate((bottom_pad, img, top_pad), axis=1)
123 | label = np.concatenate((bottom_pad, label, top_pad), axis=1)
124 | return img, label
125 |
126 |
127 | def rotate(img_ori, label_ori, degree):
128 | height, width, depth = img_ori.shape
129 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1)
130 |
131 | imgRotation = np.zeros_like(img_ori)
132 | labelRotation = np.zeros_like(label_ori)
133 | for z in range(depth):
134 | imgRotation[:, :, z] = cv2.warpAffine(img_ori[:, :, z], matRotation, (width, height), flags=cv2.INTER_NEAREST, borderValue=0)
135 | temp = label_ori[:, :, z]
136 | unique_labels = np.unique(temp)
137 | result = np.zeros_like(temp, temp.dtype)
138 | for i, c in enumerate(unique_labels):
139 | res_new = cv2.warpAffine((temp == c).astype(float), matRotation, (width, height), flags=cv2.INTER_NEAREST)
140 | result[res_new > 0.5] = c
141 | labelRotation[:, :, z] = result
142 | return imgRotation, labelRotation
143 |
144 |
145 | def crop_resize(img_ori, label_ori, shift_size_x, shift_size_y):
146 | H, W, C = img_ori.shape
147 | x_small = np.random.randint(0, shift_size_x)
148 | x_large = np.random.randint(H - shift_size_x, H)
149 | y_small = np.random.randint(0, shift_size_y)
150 | y_large = np.random.randint(W - shift_size_y, W)
151 |
152 | imgCropresize = np.zeros_like(img_ori)
153 | labelCropresize = np.zeros_like(label_ori)
154 | for z in range(C):
155 | imgCropresize[:, :, z] = cv2.resize(img_ori[x_small:x_large, y_small:y_large, z], (W, H), interpolation=cv2.INTER_NEAREST)
156 | temp = label_ori[x_small:x_large, y_small:y_large, z]
157 | unique_labels = np.unique(temp)
158 | result = np.zeros((H, W), label_ori.dtype)
159 | for i, c in enumerate(unique_labels):
160 | res_new = cv2.resize((temp == c).astype(float), (W, H), interpolation=cv2.INTER_NEAREST)
161 | result[res_new > 0.5] = c
162 | labelCropresize[:, :, z] = result
163 | return imgCropresize, labelCropresize
164 |
165 |
166 | def normalize_volume(img):
167 | img_array = (img - img.min()) / (img.max() - img.min())
168 | return img_array
169 |
170 |
171 | class BowelCoarseSeg(data.Dataset):
172 | def __init__(self, root='', transform=None, mode="train", test_fraction=0.2, dataset_name='', save_dir=''):
173 |
174 | assert dataset_name in ["fully_labeled", "smallbowel", "colon_sigmoid"]
175 |
176 | current_test = 5
177 | train_split, test_split = split_dataset(root, current_test, test_fraction, dataset_name, save_dir)
178 |
179 | if mode == "infer" or mode == "test":
180 | self.imgs = test_split
181 | else:
182 | self.imgs = train_split
183 |
184 | self.mode = mode
185 | self.root = root
186 | self.patch_size = (64, 128, 128) # z, x, y
187 | self.transform = transform
188 |
189 | def __len__(self):
190 | return len(self.imgs)
191 |
192 | def __getitem__(self, index):
193 | patch_size = self.patch_size
194 | image_name = self.imgs[index]
195 |
196 | # image_name = '/mnt/c/chong/data/Bowel/crop_downsample/Smallbowel/abdomen/260719/image.nii.gz'
197 | # image_name = '/mnt/c/chong/data/Bowel/crop_downsample/Fully_labeled_5C/rectum_sigmoid_colon_small_duodenum/20210310-092019-464/image.nii.gz'
198 | # image_name = '/mnt/c/chong/data/Bowel/crop_downsample/Colon_Sigmoid/colon_sigmoid/0013_female/image.nii.gz'
199 |
200 | image, label = load_image_and_label(image_name, transform=self.transform)
201 |
202 | positive_labels = sorted(np.unique(label)[1:].astype(int))
203 | for pos in positive_labels:
204 | assert pos in [1, 2, 3, 4, 5], image_name + ", wrong class label!!!"
205 |
206 | # name_to_label = {1: 'rectum', 2: 'sigmoid', 3: 'colon', 4: 'small', 5: 'duodenum'}
207 |
208 | z, x, y = image.shape
209 | # fully labeled and Colon_Sigmoid dataset
210 | if 'rectum' in image_name or 'duodenum' in image_name or 'sigmoid' in image_name:
211 |
212 | if np.random.choice([False, False, False, True]):
213 | force_fg = True
214 | else:
215 | force_fg = False
216 |
217 | if force_fg:
218 |
219 | pos_labels = sorted(np.unique(label)[1:].astype(int))
220 | label_will_be_chosen = pos_labels + list(set(pos_labels) - set([3, 4])) # double chance to sample the 3 bowels above
221 | label_chosen = random.choice(label_will_be_chosen)
222 |
223 | z_range, x_range, y_range = np.where(label == label_chosen)
224 | z_min, x_min, y_min = z_range.min(), x_range.min(), y_range.min()
225 | z_max, x_max, y_max = z_range.max(), x_range.max(), y_range.max()
226 | center_z = random.randint(z_min, z_max)
227 | center_x = random.randint(x_min, x_max)
228 | center_y = random.randint(y_min, y_max)
229 |
230 | zs = center_z - patch_size[0] // 2 # start
231 | ze = center_z + patch_size[0] // 2 # end
232 | xs = center_x - patch_size[1] // 2
233 | xe = center_x + patch_size[1] // 2
234 | ys = center_y - patch_size[2] // 2
235 | ye = center_y + patch_size[2] // 2
236 |
237 | if zs < 0:
238 | zs = 0
239 | ze = zs + patch_size[0]
240 | if ze > z:
241 | ze = z
242 | zs = ze - patch_size[0]
243 |
244 | if xs < 0:
245 | xs = 0
246 | xe = xs + patch_size[1]
247 | if xe > x:
248 | xe = x
249 | xs = xe - patch_size[1]
250 |
251 | if ys < 0:
252 | ys = 0
253 | ye = ys + patch_size[2]
254 | if ye > y:
255 | ye = y
256 | ys = ye - patch_size[2]
257 | else:
258 | zs = random.randint(0, z - patch_size[0])
259 | ze = zs + self.patch_size[0]
260 | xs = random.randint(0, x - patch_size[1])
261 | xe = xs + self.patch_size[1]
262 | ys = random.randint(0, y - patch_size[2])
263 | ye = ys + self.patch_size[2]
264 |
265 | else: # small bowel dataset
266 |
267 | if np.random.choice([False, False, False, True]):
268 | force_fg = True
269 | else:
270 | force_fg = False
271 |
272 | if force_fg:
273 | z_range, x_range, y_range = np.where(label != 0)
274 | z_min, x_min, y_min = z_range.min(), x_range.min(), y_range.min()
275 | z_max, x_max, y_max = z_range.max(), x_range.max(), y_range.max()
276 | center_z = random.randint(z_min, z_max)
277 | center_x = random.randint(x_min, x_max)
278 | center_y = random.randint(y_min, y_max)
279 |
280 | zs = center_z - patch_size[0] // 2
281 | ze = center_z + patch_size[0] // 2
282 | xs = center_x - patch_size[1] // 2
283 | xe = center_x + patch_size[1] // 2
284 | ys = center_y - patch_size[2] // 2
285 | ye = center_y + patch_size[2] // 2
286 |
287 | if zs < 0:
288 | zs = 0
289 | ze = zs + patch_size[0]
290 | if ze > z:
291 | ze = z
292 | zs = ze - patch_size[0]
293 |
294 | if xs < 0:
295 | xs = 0
296 | xe = xs + patch_size[1]
297 | if xe > x:
298 | xe = x
299 | xs = xe - patch_size[1]
300 |
301 | if ys < 0:
302 | ys = 0
303 | ye = ys + patch_size[2]
304 | if ye > y:
305 | ye = y
306 | ys = ye - patch_size[2]
307 | else:
308 | zs = random.randint(0, z - patch_size[0])
309 | ze = zs + self.patch_size[0]
310 | xs = random.randint(0, x - patch_size[1])
311 | xe = xs + self.patch_size[1]
312 | ys = random.randint(0, y - patch_size[2])
313 | ye = ys + self.patch_size[2]
314 |
315 | data = torch.from_numpy(image[zs:ze, xs:xe, ys:ye][np.newaxis, :].astype(np.float32))
316 | mask = torch.from_numpy(label[zs:ze, xs:xe, ys:ye][np.newaxis, :].astype(np.int32))
317 |
318 | return data, mask, image_name
319 |
320 |
--------------------------------------------------------------------------------
/bowel_fineseg/datasets/BowelDatasetFineSeg.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import numpy as np
3 | import torch
4 | import torch.utils.data as data
5 | from glob import glob
6 | import os
7 | import os.path
8 | import nibabel as nib
9 | import random
10 | import cv2
11 |
12 |
13 | name_to_label = {'rectum': 1, 'sigmoid': 2, 'colon': 3, 'small': 4, 'duodenum': 5}
14 |
15 |
16 | def split_dataset(dir, current_test, test_fraction, bowel_name, save_dir):
17 |
18 | test_split = []
19 | train_split = []
20 |
21 | sub_folder = ['Fully_labeled_5C', 'Colon_Sigmoid', 'Smallbowel']
22 |
23 | assert bowel_name in name_to_label.keys()
24 |
25 | if bowel_name == 'rectum':
26 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz")))
27 | test_num = int(len(all_volumes) * test_fraction)
28 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
29 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
30 | test_split.extend(test_volumes)
31 | train_split.extend(train_volumes)
32 |
33 | if bowel_name == 'sigmoid':
34 | all_volumes = sorted(glob(os.path.join(dir, "Colon_Sigmoid/*/*/image_crop.nii.gz")))
35 | test_num = int(len(all_volumes) * test_fraction)
36 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
37 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
38 | test_split.extend(test_volumes)
39 | train_split.extend(train_volumes)
40 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz")))
41 | test_num = int(len(all_volumes) * test_fraction)
42 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
43 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
44 | test_split.extend(test_volumes)
45 | train_split.extend(train_volumes)
46 |
47 | if bowel_name == 'colon':
48 | all_volumes = sorted(glob(os.path.join(dir, "Colon_Sigmoid/*/*/image_crop.nii.gz")))
49 | test_num = int(len(all_volumes) * test_fraction)
50 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
51 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
52 | test_split.extend(test_volumes)
53 | train_split.extend(train_volumes)
54 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz")))
55 | test_num = int(len(all_volumes) * test_fraction)
56 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
57 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
58 | test_split.extend(test_volumes)
59 | train_split.extend(train_volumes)
60 |
61 | if bowel_name == 'small':
62 | all_volumes = sorted(glob(os.path.join(dir, "Smallbowel/*/*/image_crop.nii.gz")))
63 | test_num = int(len(all_volumes) * test_fraction)
64 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
65 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
66 | test_split.extend(test_volumes)
67 | train_split.extend(train_volumes)
68 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz")))
69 | test_num = int(len(all_volumes) * test_fraction)
70 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
71 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
72 | test_split.extend(test_volumes)
73 | train_split.extend(train_volumes)
74 |
75 | if bowel_name == 'duodenum':
76 | all_volumes = sorted(glob(os.path.join(dir, "Fully_labeled_5C/*/*/image_crop.nii.gz")))
77 | test_num = int(len(all_volumes) * test_fraction)
78 | test_volumes = all_volumes[test_num * (current_test - 1): test_num * current_test]
79 | train_volumes = sorted(list(set(all_volumes) - set(test_volumes)))
80 | test_split.extend(test_volumes)
81 | train_split.extend(train_volumes)
82 |
83 | if save_dir is not None:
84 | with open(os.path.join(save_dir, bowel_name + '_train_' + str(current_test) + ".txt"), 'w') as f:
85 | for i in train_split:
86 | f.write(i + '\n')
87 |
88 | with open(os.path.join(save_dir, bowel_name + '_test_' + str(current_test) + ".txt"), 'w') as f:
89 | for i in test_split:
90 | f.write(i + '\n')
91 |
92 | return train_split, test_split
93 |
94 |
95 | def load_image_and_label(img_file, transform=False):
96 |
97 | img_info = nib.load(img_file)
98 | img = img_info.get_fdata()
99 | # img = np.clip(img, -1024, 1000) # remove abnormal intensity
100 | # img = normalize_volume(img)
101 |
102 | label = nib.load(img_file.replace('image_crop.nii.gz', 'masks_crop.nii.gz')).get_fdata()
103 | label = np.round(label)
104 |
105 | skele_x = nib.load(img_file.replace('image_crop.nii.gz', 'skele_crop_fluxx.nii.gz')).get_fdata() # skele_crop_fluxx.nii.gz
106 | skele_y = nib.load(img_file.replace('image_crop.nii.gz', 'skele_crop_fluxy.nii.gz')).get_fdata() # skele_crop_fluxy.nii.gz
107 | skele_z = nib.load(img_file.replace('image_crop.nii.gz', 'skele_crop_fluxz.nii.gz')).get_fdata() # skele_crop_fluxz.nii.gz
108 | edge = nib.load(img_file.replace('image_crop.nii.gz', 'edge_reg_crop.nii.gz')).get_fdata() # edge_heatmap.nii.gz
109 |
110 | if transform:
111 | op = random.choice(['ori', 'rotate', 'crop'])
112 | if op == 'rotate':
113 | img, label, edge, skele_x, skele_y, skele_z = rotate(img, label, edge, skele_x, skele_y, skele_z, degree=random.randint(-10, 10))
114 | if op == 'crop':
115 | img, label, edge, skele_x, skele_y, skele_z = crop_resize(img, label, edge, skele_x, skele_y, skele_z, shift_size_x=10, shift_size_y=10)
116 |
117 | # zero padding in case that cropped ROI is smaller than patch size
118 | img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr = padding_z(img, label, edge, skele_x, skele_y, skele_z, min_z=80)
119 | img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr = padding_x(img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr, min_x=120)
120 | img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr = padding_y(img_arr, label_arr, edge_arr, skele_x_arr, skele_y_arr, skele_z_arr, min_y=120)
121 |
122 | return img_arr.transpose((2, 0, 1)), label_arr.transpose((2, 0, 1)), edge_arr.transpose((2, 0, 1)), skele_x_arr.transpose((2, 0, 1)), skele_y_arr.transpose((2, 0, 1)), skele_z_arr.transpose((2, 0, 1))
123 |
124 |
125 | def padding_z(img, label, edge, skele_x, skele_y, skele_z, min_z):
126 | x, y, z = img.shape
127 | if z >= min_z:
128 | return img, label, edge, skele_x, skele_y, skele_z
129 | else:
130 | num_pad = min_z - z
131 | num_top_pad = num_pad // 2
132 | top_pad = np.zeros((x, y, num_top_pad), dtype=np.float64)
133 | bottom_pad = np.zeros((x, y, num_pad - num_top_pad), dtype=np.float64)
134 | img = np.concatenate((bottom_pad, img, top_pad), axis=2)
135 | label = np.concatenate((bottom_pad, label, top_pad), axis=2)
136 | edge = np.concatenate((bottom_pad, edge, top_pad), axis=2)
137 | skele_x = np.concatenate((bottom_pad, skele_x, top_pad), axis=2)
138 | skele_y = np.concatenate((bottom_pad, skele_y, top_pad), axis=2)
139 | skele_z = np.concatenate((bottom_pad, skele_z, top_pad), axis=2)
140 | return img, label, edge, skele_x, skele_y, skele_z
141 |
142 |
143 | def padding_x(img, label, edge, skele_x, skele_y, skele_z, min_x):
144 | x, y, z = img.shape
145 | if x >= min_x:
146 | return img, label, edge, skele_x, skele_y, skele_z
147 | else:
148 | num_pad = min_x - x
149 | num_top_pad = num_pad // 2
150 | top_pad = np.zeros((num_top_pad, y, z), dtype=np.float64)
151 | bottom_pad = np.zeros((num_pad - num_top_pad, y, z), dtype=np.float64)
152 | img = np.concatenate((bottom_pad, img, top_pad), axis=0)
153 | label = np.concatenate((bottom_pad, label, top_pad), axis=0)
154 | edge = np.concatenate((bottom_pad, edge, top_pad), axis=0)
155 | skele_x = np.concatenate((bottom_pad, skele_x, top_pad), axis=0)
156 | skele_y = np.concatenate((bottom_pad, skele_y, top_pad), axis=0)
157 | skele_z = np.concatenate((bottom_pad, skele_z, top_pad), axis=0)
158 | return img, label, edge, skele_x, skele_y, skele_z
159 |
160 |
161 | def padding_y(img, label, edge, skele_x, skele_y, skele_z, min_y):
162 | x, y, z = img.shape
163 | if y >= min_y:
164 | return img, label, edge, skele_x, skele_y, skele_z
165 | else:
166 | num_pad = min_y - y
167 | num_top_pad = num_pad // 2
168 | top_pad = np.zeros((x, num_top_pad, z), dtype=np.float64)
169 | bottom_pad = np.zeros((x, num_pad - num_top_pad, z), dtype=np.float64)
170 | img = np.concatenate((bottom_pad, img, top_pad), axis=1)
171 | label = np.concatenate((bottom_pad, label, top_pad), axis=1)
172 | edge = np.concatenate((bottom_pad, edge, top_pad), axis=1)
173 | skele_x = np.concatenate((bottom_pad, skele_x, top_pad), axis=1)
174 | skele_y = np.concatenate((bottom_pad, skele_y, top_pad), axis=1)
175 | skele_z = np.concatenate((bottom_pad, skele_z, top_pad), axis=1)
176 | return img, label, edge, skele_x, skele_y, skele_z
177 |
178 |
179 | def rotate(img_ori, label_ori, edge_ori, skele_x_ori, skele_y_ori, skele_z_ori, degree):
180 | height, width, depth = img_ori.shape
181 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1)
182 |
183 | imgRotation = np.zeros_like(img_ori)
184 | labelRotation = np.zeros_like(label_ori)
185 | skele_x_Rotation = np.zeros_like(skele_x_ori)
186 | skele_y_Rotation = np.zeros_like(skele_y_ori)
187 | skele_z_Rotation = np.zeros_like(skele_z_ori)
188 | edgeRotation = np.zeros_like(edge_ori)
189 | for z in range(depth):
190 | imgRotation[:, :, z] = cv2.warpAffine(img_ori[:, :, z], matRotation, (width, height), borderValue=0)
191 | labelRotation[:, :, z] = cv2.warpAffine(label_ori[:, :, z], matRotation, (width, height))
192 | edgeRotation[:, :, z] = cv2.warpAffine(edge_ori[:, :, z], matRotation, (width, height))
193 | skele_x_Rotation[:, :, z] = cv2.warpAffine(skele_x_ori[:, :, z], matRotation, (width, height))
194 | skele_y_Rotation[:, :, z] = cv2.warpAffine(skele_y_ori[:, :, z], matRotation, (width, height))
195 | skele_z_Rotation[:, :, z] = cv2.warpAffine(skele_z_ori[:, :, z], matRotation, (width, height))
196 | labelRotation = np.round(labelRotation)
197 | return imgRotation, labelRotation, edgeRotation, skele_x_Rotation, skele_y_Rotation, skele_z_Rotation
198 |
199 |
200 | def crop_resize(img_ori, label_ori, edge_ori, skele_x_ori, skele_y_ori, skele_z_ori, shift_size_x, shift_size_y):
201 | H, W, C = img_ori.shape
202 | x_small = np.random.randint(0, shift_size_x)
203 | x_large = np.random.randint(H - shift_size_x, H)
204 | y_small = np.random.randint(0, shift_size_y)
205 | y_large = np.random.randint(W - shift_size_y, W)
206 |
207 | imgCropresize = np.zeros_like(img_ori)
208 | labelCropresize = np.zeros_like(label_ori)
209 | skele_x_Cropresize = np.zeros_like(skele_x_ori)
210 | skele_y_Cropresize = np.zeros_like(skele_y_ori)
211 | skele_z_Cropresize = np.zeros_like(skele_z_ori)
212 | edgeCropresize = np.zeros_like(edge_ori)
213 | for z in range(C):
214 | imgCropresize[:, :, z] = cv2.resize(img_ori[x_small:x_large, y_small:y_large, z], (W, H))
215 | labelCropresize[:, :, z] = cv2.resize(label_ori[x_small:x_large, y_small:y_large, z], (W, H))
216 | edgeCropresize[:, :, z] = cv2.resize(edge_ori[x_small:x_large, y_small:y_large, z], (W, H))
217 | skele_x_Cropresize[:, :, z] = cv2.resize(skele_x_ori[x_small:x_large, y_small:y_large, z], (W, H))
218 | skele_y_Cropresize[:, :, z] = cv2.resize(skele_y_ori[x_small:x_large, y_small:y_large, z], (W, H))
219 | skele_z_Cropresize[:, :, z] = cv2.resize(skele_z_ori[x_small:x_large, y_small:y_large, z], (W, H))
220 | labelCropresize = np.round(labelCropresize)
221 | return imgCropresize, labelCropresize, edgeCropresize, skele_x_Cropresize, skele_y_Cropresize, skele_z_Cropresize
222 |
223 |
224 | def normalize_volume(img):
225 | img_array = (img - img.min()) / (img.max() - img.min())
226 | return img_array
227 |
228 |
229 | class BowelFineSeg(data.Dataset):
230 | def __init__(self, root='', transform=None, mode="train", test_fraction=0.2, bowel_name='', save_dir=''):
231 |
232 | assert bowel_name in name_to_label.keys()
233 |
234 | current_test = 5
235 | train_split, test_split = split_dataset(root, current_test, test_fraction, bowel_name, save_dir)
236 |
237 | if mode == "infer" or mode == "test":
238 | self.imgs = test_split
239 | else:
240 | self.imgs = train_split
241 |
242 | self.bowel_name = bowel_name
243 | self.mode = mode
244 | self.root = root
245 | self.patch_size = (64, 192, 192)
246 |
247 | if bowel_name == 'colon':
248 | self.patch_size = (64, 192, 192)
249 |
250 | if bowel_name == 'sigmoid':
251 | self.patch_size = (64, 160, 160)
252 |
253 | if bowel_name == 'duodenum':
254 | self.patch_size = (64, 160, 160)
255 |
256 | if bowel_name == 'rectum':
257 | self.patch_size = (64, 96, 96)
258 |
259 | self.transform = transform
260 |
261 | def __len__(self):
262 | return len(self.imgs)
263 |
264 | def __getitem__(self, index):
265 |
266 | patch_size = self.patch_size
267 | img_name = self.imgs[index]
268 |
269 | img, label, edge, skele_x, skele_y, skele_z = load_image_and_label(img_name, transform=self.transform)
270 | z, x, y = img.shape
271 |
272 | if np.random.choice([False, False, False, True]):
273 | force_fg = True
274 | else:
275 | force_fg = False
276 |
277 | if force_fg:
278 | # sample a foreground region
279 | z_range, x_range, y_range = np.where(label != 0)
280 | z_min, x_min, y_min = z_range.min(), x_range.min(), y_range.min()
281 | z_max, x_max, y_max = z_range.max(), x_range.max(), y_range.max()
282 | center_z = random.randint(z_min, z_max)
283 | center_x = random.randint(x_min, x_max)
284 | center_y = random.randint(y_min, y_max)
285 |
286 | zs = center_z - patch_size[0] // 2 # start
287 | ze = center_z + patch_size[0] // 2 # end
288 | xs = center_x - patch_size[1] // 2
289 | xe = center_x + patch_size[1] // 2
290 | ys = center_y - patch_size[2] // 2
291 | ye = center_y + patch_size[2] // 2
292 |
293 | if zs < 0:
294 | zs = 0
295 | ze = zs + patch_size[0]
296 | if ze > z:
297 | ze = z
298 | zs = ze - patch_size[0]
299 |
300 | if xs < 0:
301 | xs = 0
302 | xe = xs + patch_size[1]
303 | if xe > x:
304 | xe = x
305 | xs = xe - patch_size[1]
306 |
307 | if ys < 0:
308 | ys = 0
309 | ye = ys + patch_size[2]
310 | if ye > y:
311 | ye = y
312 | ys = ye - patch_size[2]
313 | else:
314 | zs = random.randint(0, z - patch_size[0])
315 | ze = zs + self.patch_size[0]
316 | xs = random.randint(0, x - patch_size[1])
317 | xe = xs + self.patch_size[1]
318 | ys = random.randint(0, y - patch_size[2])
319 | ye = ys + self.patch_size[2]
320 |
321 | img_p = img[zs:ze, xs:xe, ys:ye][np.newaxis, :]
322 | label_p = label[zs:ze, xs:xe, ys:ye][np.newaxis, :]
323 | edge_p = edge[zs:ze, xs:xe, ys:ye][np.newaxis, :]
324 | skele_x_p = skele_x[zs:ze, xs:xe, ys:ye][np.newaxis, :]
325 | skele_y_p = skele_y[zs:ze, xs:xe, ys:ye][np.newaxis, :]
326 | skele_z_p = skele_z[zs:ze, xs:xe, ys:ye][np.newaxis, :]
327 |
328 | img_p = torch.from_numpy(img_p.astype(np.float32))
329 | label_p = torch.from_numpy(label_p.astype(np.int32))
330 | edge_p = torch.from_numpy(edge_p.astype(np.float32))
331 | skele_x_p = torch.from_numpy(skele_x_p.astype(np.float32))
332 | skele_y_p = torch.from_numpy(skele_y_p.astype(np.float32))
333 | skele_z_p = torch.from_numpy(skele_z_p.astype(np.float32))
334 |
335 | return img_p, label_p, edge_p, skele_x_p, skele_y_p, skele_z_p
336 |
337 |
--------------------------------------------------------------------------------
/bowel_coarseseg/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from __future__ import division
3 | import time
4 | import argparse
5 | import torch
6 | import random
7 | import numpy as np
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | from torch.autograd import Variable
12 | from torch.utils.data import DataLoader
13 | from datasets.BowelDatasetCoarseSeg import BowelCoarseSeg
14 | from tools.loss import *
15 | import os, math
16 | import shutil
17 | import loc_model
18 | import wandb
19 |
20 |
21 | os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' # "1, 2"
22 | print("GPU ID:", os.environ['CUDA_VISIBLE_DEVICES'])
23 |
24 |
25 | def weights_init(m):
26 | classname = m.__class__.__name__
27 | if classname.find('Conv3d') != -1:
28 | # nn.init.kaiming_normal_(m.weight)
29 | nn.init.xavier_normal_(m.weight, gain=0.02)
30 | m.bias.data.zero_()
31 |
32 |
33 | def datestr():
34 | now = time.gmtime()
35 | return '{}{:02}{:02}_{:02}{:02}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min)
36 |
37 |
38 | def save_checkpoint(state, is_best, path, prefix, filename='checkpoint.pth.tar'):
39 | prefix_save = os.path.join(path, prefix)
40 | name = prefix_save + '_' + filename
41 | torch.save(state, name)
42 | if is_best:
43 | shutil.copyfile(name, prefix_save + '_model_best.pth.tar')
44 |
45 |
46 | def main():
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--ngpu', type=int, default=1)
49 | parser.add_argument('--batch_size', type=int, default=8) # 8 (2 GPU)
50 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
51 | parser.add_argument('--nEpochs', type=int, default=1500, help='total training epoch')
52 |
53 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
54 | parser.add_argument('--weight_decay', default=1e-8, type=float, metavar='W', help='weight decay (default: 1e-8)')
55 | parser.add_argument('--eval_interval', default=5, type=int, help='evaluate interval on validation set')
56 | parser.add_argument('--save_dir', type=str, default=None)
57 | parser.add_argument('--seed', type=int, default=1)
58 | parser.add_argument('--deterministic', type=bool, default=True)
59 | args = parser.parse_args()
60 |
61 |
62 | args.save_dir = 'exp/BowelLocNet.{}'.format(datestr())
63 | if os.path.exists(args.save_dir):
64 | shutil.rmtree(args.save_dir)
65 | os.makedirs(args.save_dir, exist_ok=True)
66 |
67 |
68 |
69 | shutil.copy(src=os.path.join(os.getcwd(), 'train.py'), dst=args.save_dir)
70 | shutil.copy(src=os.path.join(os.getcwd(), 'infer.py'), dst=args.save_dir)
71 | shutil.copy(src=os.path.join(os.getcwd(), 'crop_ROI.py'), dst=args.save_dir)
72 | shutil.copy(src=os.path.join(os.getcwd(), 'loc_model.py'), dst=args.save_dir)
73 | shutil.copy(src=os.path.join(os.getcwd(), 'tools/loss.py'), dst=args.save_dir)
74 | shutil.copy(src=os.path.join(os.getcwd(), 'datasets/BowelDatasetCoarseSeg.py'), dst=args.save_dir)
75 |
76 |
77 | # if args.seed is not None:
78 | # random.seed(args.seed)
79 | # np.random.seed(args.seed)
80 | # torch.manual_seed(args.seed)
81 | # torch.cuda.manual_seed(args.seed)
82 | # torch.backends.cudnn.deterministic = args.deterministic
83 |
84 |
85 | print("build Bowel Localisation Network")
86 | model = loc_model.BowelLocNet(elu=False)
87 | model.apply(weights_init)
88 |
89 |
90 | # model.load_state_dict(torch.load("./exp/BowelLocNet.20230208_1536/partial_5C_dict_1300.pth"))
91 |
92 |
93 | model = model.cuda()
94 | model = nn.parallel.DataParallel(model)
95 |
96 | print(' + Number of params: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
97 |
98 |
99 |
100 | # WandB – Initialize a new run
101 | wandb.init(project='BowelNet', mode='disabled') # mode='disabled'
102 | wandb.run.name = 'BowelLoc_' + wandb.run.id
103 |
104 |
105 |
106 | print("loading fully_labeled dataset")
107 | fully_labeled_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Fully_labeled_5C'
108 | trainSet_fully_labeled = BowelCoarseSeg(fully_labeled_dir, mode="train", transform=True, dataset_name="fully_labeled", save_dir=args.save_dir)
109 | trainLoader_fully_labeled = DataLoader(trainSet_fully_labeled, batch_size=args.batch_size, shuffle=True, num_workers=6, pin_memory=False)
110 | testSet_fully_labeled = BowelCoarseSeg(fully_labeled_dir, mode="test", transform=False, dataset_name="fully_labeled", save_dir=args.save_dir)
111 | testLoader_fully_labeled = DataLoader(testSet_fully_labeled, batch_size=args.batch_size, shuffle=False, num_workers=6, pin_memory=False)
112 |
113 | print("loading smallbowel dataset")
114 | smallbowel_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Smallbowel'
115 | trainSet_small = BowelCoarseSeg(smallbowel_dir, mode="train", transform=True, dataset_name="smallbowel", save_dir=args.save_dir)
116 | trainLoader_small = DataLoader(trainSet_small, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)
117 | testSet_small = BowelCoarseSeg(smallbowel_dir, mode="test", transform=False, dataset_name="smallbowel", save_dir=args.save_dir)
118 | testLoader_small = DataLoader(testSet_small, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False)
119 |
120 | print("loading colon_sigmoid dataset")
121 | colon_sigmoid_dir = '/mnt/c/chong/data/Bowel/crop_downsample/Colon_Sigmoid/colon_sigmoid'
122 | trainSet_colon_sigmoid = BowelCoarseSeg(colon_sigmoid_dir, mode="train", transform=True, dataset_name="colon_sigmoid", save_dir=args.save_dir)
123 | trainLoader_colon_sigmoid = DataLoader(trainSet_colon_sigmoid, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)
124 | testSet_colon_sigmoid = BowelCoarseSeg(colon_sigmoid_dir, mode="test", transform=False, dataset_name="colon_sigmoid", save_dir=args.save_dir)
125 | testLoader_colon_sigmoid = DataLoader(testSet_colon_sigmoid, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False)
126 |
127 |
128 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
129 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.85) # 0.90
130 |
131 |
132 | best_dice = 0.0
133 | trainF = open(os.path.join(args.save_dir, 'train.csv'), 'w')
134 | testF = open(os.path.join(args.save_dir, 'test.csv'), 'w')
135 | for epoch in range(0, args.nEpochs + 1):
136 | if epoch < 0: # 200
137 | train_loc(epoch, model, trainLoader_fully_labeled, optimizer, trainF, data_type='fully_labeled')
138 | # train_loc(epoch, model, trainLoader_small, optimizer, trainF, data_type='smallbowel')
139 | # train_loc(epoch, model, trainLoader_colon_sigmoid, optimizer, trainF, data_type='colon_sigmoid')
140 | else:
141 | if epoch % 3 == 0:
142 | train_loc(epoch, model, trainLoader_fully_labeled, optimizer, trainF, data_type='fully_labeled')
143 | elif epoch % 3 == 1:
144 | train_loc(epoch, model, trainLoader_small, optimizer, trainF, data_type='smallbowel')
145 | else:
146 | train_loc(epoch, model, trainLoader_colon_sigmoid, optimizer, trainF, data_type='colon_sigmoid')
147 |
148 | scheduler.step()
149 | wandb.log({
150 | "LR": optimizer.param_groups[0]['lr'],
151 | })
152 |
153 | if epoch % args.eval_interval == 0: # 5
154 | dice_fully_labeled = test_loc(epoch, model, testLoader_fully_labeled, testF, data_type='fully_labeled')
155 | dice_small = test_loc(epoch, model, testLoader_small, testF, data_type='smallbowel')
156 | dice_colon_sigmoid = test_loc(epoch, model, testLoader_colon_sigmoid, testF, data_type='colon_sigmoid')
157 | dice_mean = (np.mean(dice_fully_labeled) + np.mean(dice_small) + np.mean(dice_colon_sigmoid)) / 3.0
158 |
159 | wandb.log({
160 | "Test Mean Dice Score": dice_mean,
161 | })
162 |
163 | is_best = False
164 | if dice_mean > best_dice:
165 | is_best = True
166 | best_dice = dice_mean
167 | save_checkpoint({'epoch': epoch, 'state_dict': model.module.state_dict(), 'best_acc': best_dice},
168 | is_best, args.save_dir, "partial_5C")
169 | torch.save(model.module.state_dict(), args.save_dir + '/partial_5C_dict_' + str(epoch) + '.pth')
170 |
171 | trainF.close()
172 | testF.close()
173 |
174 |
175 | def train_loc(epoch, model, trainLoader, optimizer, trainF, data_type):
176 |
177 | assert data_type in ['fully_labeled', 'smallbowel', 'colon_sigmoid']
178 |
179 | model.train()
180 | nProcessed = 0
181 | nTrain = len(trainLoader)
182 |
183 | for batch_idx, (data, target, image_name) in enumerate(trainLoader):
184 |
185 | data, target = Variable(data.cuda()), Variable(target.cuda())
186 |
187 | optimizer.zero_grad()
188 |
189 | out = model(data)
190 | out = torch.clamp(out, min=1e-10, max=1) # prevent overflow
191 |
192 | # {1: 'rectum', 2: 'sigmoid', 3: 'colon', 4: 'small', 5: 'duodenum'}
193 | if data_type == 'fully_labeled':
194 | output_p = out
195 | target_ce = target.long()
196 | target_dice = torch.cat((target_ce == 0, target_ce == 1, target_ce == 2,
197 | target_ce == 3, target_ce == 4, target_ce == 5), dim=1).long()
198 | loss_entropy = torch.tensor(0.0).cuda()
199 |
200 | if data_type == 'smallbowel':
201 | output_p1 = out[:, 4, :, :, :] # smallbowel
202 | output_p0 = out[:, 0, :, :, :] + \
203 | out[:, 1, :, :, :] + \
204 | out[:, 2, :, :, :] + \
205 | out[:, 3, :, :, :] + \
206 | out[:, 5, :, :, :]
207 | output_p = torch.stack((output_p0, output_p1), dim=1)
208 | target_ce = (target == 4).long()
209 | target_dice = torch.cat((target_ce == 0, target_ce == 1), dim=1).long()
210 | prob_neg = torch.stack((out[:, 0],
211 | out[:, 1],
212 | out[:, 2],
213 | out[:, 3],
214 | out[:, 5]), dim=1)
215 | entropy = (-prob_neg * torch.log(prob_neg)).sum(dim=1)
216 | neg_mask = (target_ce == 0).squeeze(1)
217 | loss_entropy = (entropy * neg_mask).sum() / ((neg_mask).sum() + 1e-7)
218 |
219 | if data_type == 'colon_sigmoid':
220 | output_p1 = out[:, 2, :, :, :] # sigmoid
221 | output_p2 = out[:, 3, :, :, :] # colon
222 | output_p0 = out[:, 0, :, :, :] + \
223 | out[:, 1, :, :, :] + \
224 | out[:, 4, :, :, :] + \
225 | out[:, 5, :, :, :]
226 | output_p = torch.stack((output_p0, output_p1, output_p2), dim=1)
227 | target_ce = torch.zeros(target.shape).long().cuda()
228 | target_ce[target == 2] = 1
229 | target_ce[target == 3] = 2
230 | target_dice = torch.cat((target_ce == 0, target_ce == 1, target_ce == 2), dim=1).long()
231 | prob_neg = torch.stack((out[:, 0],
232 | out[:, 1],
233 | out[:, 4],
234 | out[:, 5]), dim=1)
235 | entropy = (-prob_neg * torch.log(prob_neg)).sum(dim=1)
236 | neg_mask = (target_ce == 0).squeeze(1)
237 | loss_entropy = (entropy * neg_mask).sum() / ((neg_mask).sum() + 1e-7)
238 |
239 | loss_ce = F.nll_loss(output_p.log(), target_ce.squeeze(1))
240 | loss_dice = dice_loss_PL(output_p, target_dice)
241 | loss = loss_dice + loss_ce + loss_entropy
242 |
243 | loss.backward()
244 | optimizer.step()
245 |
246 | dice_score = dice_score_partial(out, target, data_type)
247 |
248 | pred = torch.argmax(output_p, dim=1, keepdim=True)
249 | correct = pred.eq(target_ce.data).cpu().sum()
250 | acc = correct / target_ce.numel()
251 |
252 | correct_pos = torch.logical_and((pred != 0), (target_ce != 0)).sum().item()
253 | sen_pos = round(correct_pos / ((target_ce != 0).sum().item() + 0.0001), 3)
254 |
255 | correct_neg = torch.logical_and((pred == 0), (target_ce == 0)).sum().item()
256 | sen_neg = round(correct_neg / ((target_ce == 0).sum().item() + 0.0001), 3)
257 |
258 | nProcessed += len(data)
259 | partialEpoch = epoch + batch_idx / nTrain
260 | print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tData_type: {}\tLoss: {:.4f}\tLoss_Dice: {:.4f}\tLoss_CE: {:.4f} '
261 | '\tLoss_Ent: {:.4f}'
262 | '\tAcc: {:.3f}\tSen_pos: {:.3f}\tSen_neg: {:.3f}\tDice: {}'.format(
263 | partialEpoch, nProcessed, len(trainLoader.dataset), 100. * batch_idx / nTrain, data_type,
264 | loss.data, loss_dice.data, loss_ce.data, loss_entropy.data, acc, sen_pos, sen_neg, dice_score))
265 |
266 | wandb.log({
267 | "Train Loss": loss,
268 | "Train CE Loss": loss_ce,
269 | "Train Dice Loss": loss_dice,
270 | "Train Entropy Loss": loss_entropy,
271 | })
272 |
273 | trainF.write(
274 | '{},{},{},{},{},{},{},{},{}, {}\n'.format(partialEpoch, data_type, loss.data, loss_dice.data, loss_ce.data, loss_entropy.data, acc, sen_pos, sen_neg, dice_score))
275 | trainF.flush()
276 |
277 |
278 | def test_loc(epoch, model, testLoader, testF, data_type):
279 |
280 | assert data_type in ['fully_labeled', 'smallbowel', 'colon_sigmoid']
281 |
282 | model.eval()
283 | test_loss = 0
284 | dice_score = 0
285 | acc_all = 0
286 | sen_pos_all = 0
287 | sen_neg_all = 0
288 | with torch.no_grad():
289 | for data, target, image_name in testLoader:
290 | data, target = Variable(data.cuda()), Variable(target.cuda())
291 |
292 | out = model(data)
293 |
294 | # {1: 'rectum', 2: 'sigmoid', 3: 'colon', 4: 'small', 5: 'duodenum'}
295 |
296 | if data_type == 'fully_labeled':
297 | output_p = out
298 | target_ce = target.long()
299 | target_dice = torch.cat((target_ce == 0, target_ce == 1, target_ce == 2,
300 | target_ce == 3, target_ce == 4, target_ce == 5), dim=1).long()
301 |
302 | if data_type == 'smallbowel':
303 | output_p1 = out[:, 4, :, :, :] # smallbowel
304 | output_p0 = out[:, 0, :, :, :] + \
305 | out[:, 1, :, :, :] + \
306 | out[:, 2, :, :, :] + \
307 | out[:, 3, :, :, :] + \
308 | out[:, 5, :, :, :]
309 | output_p = torch.stack((output_p0, output_p1), dim=1)
310 | target_ce = (target == 4).long()
311 | target_dice = torch.cat((target_ce == 0, target_ce == 1), dim=1).long()
312 |
313 | if data_type == 'colon_sigmoid':
314 | output_p1 = out[:, 2, :, :, :] # sigmoid
315 | output_p2 = out[:, 3, :, :, :] # colon
316 | output_p0 = out[:, 0, :, :, :] + \
317 | out[:, 1, :, :, :] + \
318 | out[:, 4, :, :, :] + \
319 | out[:, 5, :, :, :]
320 | output_p = torch.stack((output_p0, output_p1, output_p2), dim=1)
321 | target_ce = torch.zeros(target.shape).long().cuda()
322 | target_ce[target == 2] = 1
323 | target_ce[target == 3] = 2
324 | target_dice = torch.cat((target_ce == 0, target_ce == 1, target_ce == 2), dim=1).long()
325 |
326 |
327 | loss_dice = dice_loss_PL(output_p, target_dice)
328 | test_loss += loss_dice
329 |
330 | dice_score += np.array(dice_score_partial(out, target, data_type))
331 |
332 | pred = torch.argmax(output_p, dim=1, keepdim=True)
333 | correct = pred.eq(target_ce.data).cpu().sum()
334 | acc_all += correct / target_ce.numel()
335 |
336 | correct_pos = torch.logical_and((pred != 0), (target_ce != 0)).sum().item()
337 | sen_pos = round(correct_pos / ((target_ce != 0).sum().item() + 0.0001), 3)
338 | sen_pos_all += sen_pos
339 |
340 | correct_neg = torch.logical_and((pred == 0), (target_ce == 0)).sum().item()
341 | sen_neg = round(correct_neg / ((target_ce == 0).sum().item() + 0.0001), 3)
342 | sen_neg_all += sen_neg
343 |
344 | test_loss /= len(testLoader)
345 | dice_score /= len(testLoader)
346 | acc_all /= len(testLoader)
347 | sen_pos_all /= len(testLoader)
348 | sen_neg_all /= len(testLoader)
349 | print('\nTest online: Data_type: {}, Dice loss: {:.4f}, Acc:{:.3f}, Sen_pos:{:.3f}, Sen_neg:{:.3f}, Dice: {}\n'.
350 | format(data_type, test_loss, acc_all, sen_pos_all, sen_neg_all, dice_score))
351 |
352 | testF.write('{},{},{},{},{},{},{}\n'.format(epoch, data_type, test_loss, acc_all, sen_pos_all, sen_neg_all, dice_score))
353 | testF.flush()
354 |
355 | wandb.log({
356 | "Test Loss": test_loss,
357 | "Test Dice Score": dice_score.mean(),
358 | })
359 |
360 | return dice_score
361 |
362 |
363 | if __name__ == '__main__':
364 | main()
365 |
--------------------------------------------------------------------------------
/bowel_fineseg/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from __future__ import division
3 | import time
4 | import argparse
5 | import torch
6 | import random
7 | import numpy as np
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | from torch.autograd import Variable
12 | import torchvision.transforms as transforms
13 | from torch.utils.data import DataLoader
14 | from datasets.BowelDatasetFineSeg import BowelFineSeg
15 | from tools.loss import *
16 | import os, math
17 | import shutil
18 | import seg_model
19 |
20 | import wandb
21 |
22 |
23 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' # "0, 1"
24 | print("GPU ID:", os.environ['CUDA_VISIBLE_DEVICES'])
25 |
26 |
27 | def weights_init(m):
28 | classname = m.__class__.__name__
29 | if classname.find('Conv3d') != -1:
30 | # nn.init.kaiming_normal_(m.weight)
31 | nn.init.xavier_normal_(m.weight, gain=0.02)
32 | m.bias.data.zero_()
33 |
34 |
35 | def datestr():
36 | now = time.gmtime()
37 | return '{}{:02}{:02}_{:02}{:02}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min)
38 |
39 |
40 | def save_checkpoint(state, is_best, path, prefix, filename='checkpoint.pth.tar'):
41 | prefix_save = os.path.join(path, prefix)
42 | name = prefix_save + '_' + filename
43 | torch.save(state, name)
44 | if is_best:
45 | shutil.copyfile(name, prefix_save + '_model_best.pth.tar')
46 |
47 |
48 | def main():
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument('--ngpu', type=int, default=1)
51 | parser.add_argument('--batch_size', type=int, default=4) # 4 (single GPU)
52 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
53 | parser.add_argument('--nEpochs_base', type=int, default=501, help='total epoch number for base segmentor')
54 | parser.add_argument('--nEpochs_meta', type=int, default=301, help='total epoch number for meta segmentor')
55 |
56 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
57 | parser.add_argument('--weight_decay', default=1e-8, type=float, metavar='W', help='weight decay (default: 1e-8)')
58 | parser.add_argument('--eval_interval', default=5, type=int, help='evaluate interval on validation set')
59 | parser.add_argument('--temp', default=0.7, type=float, help='temperature for meta segmentor')
60 | parser.add_argument('--save_dir', type=str, default=None)
61 | parser.add_argument('--seed', type=int, default=1)
62 | parser.add_argument('--deterministic', type=bool, default=True)
63 | args = parser.parse_args()
64 |
65 |
66 | args.save_dir = 'exp/BowelNet.{}'.format(datestr())
67 | if os.path.exists(args.save_dir):
68 | shutil.rmtree(args.save_dir)
69 | os.makedirs(args.save_dir, exist_ok=True)
70 |
71 | shutil.copy(src=os.path.join(os.getcwd(), 'train.py'), dst=args.save_dir)
72 | shutil.copy(src=os.path.join(os.getcwd(), 'infer.py'), dst=args.save_dir)
73 | shutil.copy(src=os.path.join(os.getcwd(), 'seg_model.py'), dst=args.save_dir)
74 | shutil.copy(src=os.path.join(os.getcwd(), 'tools/loss.py'), dst=args.save_dir)
75 | shutil.copy(src=os.path.join(os.getcwd(), 'tools/utils.py'), dst=args.save_dir)
76 | shutil.copy(src=os.path.join(os.getcwd(), 'datasets/BowelDatasetFineSeg.py'), dst=args.save_dir)
77 |
78 |
79 | # if args.seed is not None:
80 | # random.seed(args.seed)
81 | # np.random.seed(args.seed)
82 | # torch.manual_seed(args.seed)
83 | # torch.cuda.manual_seed(args.seed)
84 | # torch.backends.cudnn.deterministic = args.deterministic
85 |
86 |
87 | print("build BowelNet")
88 | model = seg_model.BowelNet(elu=False)
89 | model.apply(weights_init)
90 |
91 |
92 |
93 | # model.load_state_dict(torch.load("./BowelNet.20230207_1744/rectum_meta_195.pth"))
94 |
95 |
96 |
97 |
98 | model = model.cuda()
99 | model = nn.parallel.DataParallel(model)
100 |
101 |
102 | print(' + Number of params: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
103 |
104 |
105 |
106 |
107 |
108 |
109 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_small/'
110 | # bowel_name = 'small'
111 |
112 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_colon/'
113 | # bowel_name = 'colon'
114 |
115 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_sigmoid/'
116 | # bowel_name = 'sigmoid'
117 |
118 | # data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_duodenum/'
119 | # bowel_name = 'duodenum'
120 |
121 | data_dir = '/mnt/c/chong/data/Bowel/crop_stage1_ROI_rectum/'
122 | bowel_name = 'rectum'
123 |
124 |
125 |
126 | # WandB – Initialize a new run
127 | wandb.init(project='BowelNet', mode='disabled') # mode='disabled'
128 | wandb.run.name = bowel_name + '_' + wandb.run.id
129 |
130 |
131 |
132 | print("loading training set")
133 | trainSet = BowelFineSeg(root=data_dir, mode="train", transform=True, bowel_name=bowel_name, save_dir=args.save_dir)
134 | trainLoader = DataLoader(trainSet, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=False) # 8
135 |
136 | print("loading test set")
137 | testSet = BowelFineSeg(root=data_dir, mode="test", transform=False, bowel_name=bowel_name, save_dir=args.save_dir)
138 | testLoader = DataLoader(testSet, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=False) # 8
139 |
140 |
141 |
142 | # base segmentor training
143 | ##########################################################################
144 | params_base = [param for name, param in model.named_parameters() if 'meta' not in name]
145 | optimizer_base = optim.Adam(params_base, lr=args.lr, weight_decay=args.weight_decay)
146 | base_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_base, step_size=50, gamma=0.85)
147 | best_dice = 0.0
148 | trainF_base = open(os.path.join(args.save_dir, 'train_base.csv'), 'w')
149 | testF_base = open(os.path.join(args.save_dir, 'test_base.csv'), 'w')
150 | for epoch in range(1, args.nEpochs_base + 1):
151 | train_base(epoch, model, trainLoader, optimizer_base, trainF_base, train_segmentor='base')
152 | # base_scheduler.step()
153 | wandb.log({
154 | "Base LR": optimizer_base.param_groups[0]['lr'],
155 | })
156 | if epoch % args.eval_interval == 0: # 5
157 | dice = test_base(epoch, model, testLoader, testF_base, train_segmentor='base')
158 | is_best = False
159 | if dice > best_dice:
160 | is_best = True
161 | best_dice = dice
162 | save_checkpoint({'epoch': epoch, 'state_dict': model.module.state_dict(), 'best_acc': best_dice},
163 | is_best, args.save_dir, bowel_name + '_base')
164 | torch.save(model.module.state_dict(), args.save_dir + '/' + bowel_name + '_base_' + str(epoch) + '.pth')
165 | trainF_base.close()
166 | testF_base.close()
167 |
168 | # best_base_model_path = os.path.join(args.save_dir, bowel_name + '_base' + '_model_best.pth.tar')
169 | # model.load_state_dict(torch.load(best_base_model_path)['state_dict'])
170 |
171 | # meta segmentor training
172 | ##########################################################################
173 | params_meta = [param for name, param in model.named_parameters() if 'meta' in name]
174 | optimizer_meta = optim.Adam(params_meta, lr=args.lr, weight_decay=args.weight_decay)
175 | meta_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_meta, step_size=30, gamma=0.85)
176 | best_dice = 0.0
177 | trainF_meta = open(os.path.join(args.save_dir, 'train_meta.csv'), 'w')
178 | testF_meta = open(os.path.join(args.save_dir, 'test_meta.csv'), 'w')
179 | for epoch in range(1, args.nEpochs_meta + 1):
180 | train_meta(epoch, model, trainLoader, optimizer_meta, trainF_meta, train_segmentor='meta', temperature=args.temp)
181 | # meta_scheduler.step()
182 | wandb.log({
183 | "Meta LR": optimizer_meta.param_groups[0]['lr'],
184 | })
185 | if epoch % args.eval_interval == 0: # 5
186 | dice = test_meta(epoch, model, testLoader, testF_meta, train_segmentor='meta')
187 | is_best = False
188 | if dice > best_dice:
189 | is_best = True
190 | best_dice = dice
191 | save_checkpoint({'epoch': epoch, 'state_dict': model.module.state_dict(), 'best_acc': best_dice},
192 | is_best, args.save_dir, bowel_name + '_meta')
193 | torch.save(model.module.state_dict(), args.save_dir + '/' + bowel_name + '_meta_' + str(epoch) + '.pth')
194 | trainF_meta.close()
195 | testF_meta.close()
196 |
197 |
198 |
199 | def train_base(epoch, model, trainLoader, optimizer, trainF, train_segmentor='base'):
200 | model.train()
201 | nProcessed = 0
202 | nTrain = len(trainLoader.dataset)
203 | for batch_idx, (data, target, edge, skele_x, skele_y, skele_z) in enumerate(trainLoader):
204 | data, target, edge, skele_x, skele_y, skele_z = Variable(data.cuda()), Variable(target.cuda()), \
205 | Variable(edge.cuda()), Variable(skele_x.cuda()), \
206 | Variable(skele_y.cuda()), Variable(skele_z.cuda())
207 | optimizer.zero_grad()
208 | out_flux_reg, out_skele_mask_seg, out_edge_reg, out_edge_mask_seg = model(data, train_segmentor)
209 |
210 | ##### boundary segmentor loss
211 | loss_ce_edge_mask = F.cross_entropy(out_edge_mask_seg, target.squeeze(1).long())
212 | loss_dice_edge_mask = dice_loss(F.softmax(out_edge_mask_seg, dim=1), target)
213 |
214 | edge_temp_mask = (edge != 0).float()
215 | if edge_temp_mask.sum() == 0:
216 | weight_matrix = torch.ones_like(edge_temp_mask) * 1.0
217 | else:
218 | pos_sum = (edge_temp_mask != 0).sum()
219 | neg_sum = (edge_temp_mask == 0).sum()
220 | pos_matrix = (neg_sum / pos_sum * 0.5) * edge_temp_mask
221 | neg_matrix = 1.0 * (1 - edge_temp_mask)
222 | weight_matrix = pos_matrix + neg_matrix
223 | loss_edge = (weight_matrix * torch.square(out_edge_reg - edge)).mean()
224 | # print('edge re-weight:', weight_matrix.max().item(), weight_matrix.min().item())
225 |
226 | ##### skeleton segmentor loss
227 | loss_ce_skele_mask = F.cross_entropy(out_skele_mask_seg, target.squeeze(1).long())
228 | loss_dice_skele_mask = dice_loss(F.softmax(out_skele_mask_seg, dim=1), target) # notice
229 |
230 | skele_xyz = torch.cat((skele_x, skele_y, skele_z), dim=1)
231 | loss_flux = F.mse_loss(out_flux_reg, skele_xyz)
232 | skele_square = skele_xyz[:, 0, ...] ** 2 + skele_xyz[:, 1, ...] ** 2 + skele_xyz[:, 2, ...] ** 2
233 | out_flux_square = out_flux_reg[:, 0, ...] ** 2 + out_flux_reg[:, 1, ...] ** 2 + out_flux_reg[:, 2, ...] ** 2
234 | loss_flux_square = F.mse_loss(out_flux_square, skele_square)
235 | loss_flux = loss_flux + loss_flux_square
236 |
237 |
238 | loss = (loss_dice_skele_mask + loss_ce_skele_mask) + \
239 | (loss_dice_edge_mask + loss_ce_edge_mask) + \
240 | 0.8 * loss_flux + \
241 | 0.8 * loss_edge
242 |
243 | loss.backward()
244 | optimizer.step()
245 |
246 | dice_loss_edge = dice_score_metric(out_edge_mask_seg, target)
247 | dice_loss_skele = dice_score_metric(out_skele_mask_seg, target)
248 | pred = torch.argmax(out_edge_mask_seg, dim=1).unsqueeze(1)
249 | correct = pred.eq(target.data).cpu().sum()
250 | acc = correct / target.numel()
251 |
252 | correct_pos = torch.logical_and((pred == 1), (target == 1)).sum().item()
253 | sen_pos = round(correct_pos / ((target == 1).sum().item() + 0.0001), 3)
254 |
255 | correct_neg = torch.logical_and((pred == 0), (target == 0)).sum().item()
256 | sen_neg = round(correct_neg / ((target == 0).sum().item() + 0.0001), 3)
257 |
258 | nProcessed += len(data)
259 | partialEpoch = epoch + batch_idx / len(trainLoader) - 1
260 | print('Base Train, Epoch: {:.2f} [{}/{} ({:.0f}%)]\t'
261 | 'Loss: {:.5f}\t'
262 | 'L_Dice_skele: {:.5f}\tL_CE_skele: {:.5f}\t L_flux_reg: {:.5f}\t'
263 | 'L_Dice_edge: {:.5f}\tL_CE_edge: {:.5f}\t L_edge_seg: {:.5f}\t'
264 | 'Acc: {:.3f}\tSen_pos: {:.3f}\tSen_neg: {:.3f}\tDice: {:.5f}\tDice_skele: {:.5f}'.format(
265 | partialEpoch, nProcessed, nTrain, 100. * batch_idx / len(trainLoader),
266 | loss.data,
267 | loss_dice_skele_mask.data, loss_ce_skele_mask.data, loss_flux.data,
268 | loss_dice_edge_mask.data, loss_ce_edge_mask.data, loss_edge.data,
269 | acc, sen_pos, sen_neg, dice_loss_edge, dice_loss_skele))
270 |
271 | trainF.write('{},{},{},{},{},{},{},{},{}\n'.format(partialEpoch, loss.data, loss_dice_skele_mask.data,
272 | loss_dice_edge_mask.data, acc, sen_pos, sen_neg,
273 | dice_loss_edge, dice_loss_skele))
274 | trainF.flush()
275 |
276 | wandb.log({
277 | "Train Dice Skele Mask Loss": loss_dice_skele_mask.item(),
278 | "Train Dice Edge Mask Loss": loss_dice_edge_mask.item(),
279 | "Train Skele Loss": loss_flux.item(),
280 | "Train Edge Loss": loss_edge.item(),
281 | })
282 |
283 |
284 | def test_base(epoch, model, testLoader, testF, train_segmentor='base'):
285 | model.eval()
286 | test_loss = 0
287 | edge_error = 0
288 | skele_error = 0
289 | dice_score_edge = 0
290 | dice_score_skele = 0
291 | acc_all = 0
292 | sen_pos_all = 0
293 | sen_neg_all = 0
294 | with torch.no_grad():
295 | for data, target, edge, skele_x, skele_y, skele_z in testLoader:
296 | data, target, edge, skele_x, skele_y, skele_z = Variable(data.cuda()), Variable(target.cuda()), Variable(
297 | edge.cuda()), Variable(skele_x.cuda()), Variable(skele_y.cuda()), Variable(skele_z.cuda())
298 |
299 | out_flux_reg, out_skele_mask_seg, out_edge_reg, out_edge_mask_seg = model(data, train_segmentor)
300 |
301 | loss_dice_skele = dice_loss(F.softmax(out_skele_mask_seg, dim=1), target).data
302 | loss_dice_edge = dice_loss(F.softmax(out_edge_mask_seg, dim=1), target).data
303 | test_loss += loss_dice_skele + loss_dice_edge
304 |
305 | loss_edge = F.mse_loss(out_edge_reg, edge)
306 | edge_error += loss_edge
307 |
308 | skele_xyz = torch.cat((skele_x, skele_y, skele_z), dim=1)
309 | loss_flux = F.mse_loss(out_flux_reg, skele_xyz)
310 | skele_error += loss_flux
311 |
312 | dice_score_edge += dice_score_metric(out_edge_mask_seg, target)
313 | dice_score_skele += dice_score_metric(out_skele_mask_seg, target)
314 |
315 | pred = torch.argmax(out_edge_mask_seg, dim=1).unsqueeze(1)
316 | correct = pred.eq(target.data).cpu().sum()
317 | acc_all += correct / target.numel()
318 |
319 | correct_pos = torch.logical_and((pred == 1), (target == 1)).sum().item()
320 | sen_pos = round(correct_pos / ((target == 1).sum().item() + 0.0001), 3)
321 | sen_pos_all += sen_pos
322 |
323 | correct_neg = torch.logical_and((pred == 0), (target == 0)).sum().item()
324 | sen_neg = round(correct_neg / ((target == 0).sum().item() + 0.0001), 3)
325 | sen_neg_all += sen_neg
326 |
327 | test_loss /= len(testLoader)
328 | dice_score_edge /= len(testLoader)
329 | dice_score_skele /= len(testLoader)
330 | edge_error /= len(testLoader)
331 | skele_error /= len(testLoader)
332 |
333 | acc_all /= len(testLoader)
334 | sen_pos_all /= len(testLoader)
335 | sen_neg_all /= len(testLoader)
336 |
337 | print('\nTest online: Dice loss: {:.6f}, edge_error: {:.6f}, flux_error: {:.6f}, '
338 | '\tAcc:{:.3f}, Sen_pos:{:.3f}, Sen_neg:{:.3f}, '
339 | '\tDice_edge: {:.6f}, Dice_skele: {:.6f}\n'.format(
340 | test_loss, edge_error, skele_error,
341 | acc_all, sen_pos_all, sen_neg_all,
342 | dice_score_edge, dice_score_skele))
343 |
344 | testF.write(
345 | '{},{},{},{},{},{},{},{},{}\n'.format(epoch, test_loss, edge_error, skele_error, acc_all, sen_pos_all,
346 | sen_neg_all, dice_score_edge, dice_score_skele))
347 | testF.flush()
348 |
349 | wandb.log({
350 | "Test Loss": test_loss,
351 | "Test Skele Error": skele_error,
352 | "Test Edge Error": edge_error,
353 | "Test Skele Dice Score": dice_score_skele,
354 | "Test Edge Dice Score": dice_score_edge,
355 | })
356 |
357 | return (dice_score_skele + dice_score_edge) * 0.5
358 |
359 |
360 | def train_meta(epoch, model, trainLoader, optimizer, trainF, train_segmentor='meta', temperature=0.7):
361 | model.train()
362 |
363 | #################################################
364 | for name, para in model.module.named_parameters():
365 | if 'meta' in name:
366 | para.requires_grad = True
367 | else:
368 | para.requires_grad = False
369 | #################################################
370 |
371 | for batch_idx, (data, target, edge, skele_x, skele_y, skele_z) in enumerate(trainLoader):
372 | data, target, edge, skele_x, skele_y, skele_z = Variable(data.cuda()), Variable(target.cuda()), \
373 | Variable(edge.cuda()), Variable(skele_x.cuda()), \
374 | Variable(skele_y.cuda()), Variable(skele_z.cuda())
375 | optimizer.zero_grad()
376 |
377 | out_meta_mask_seg, out_edge, out_skele = model(data, train_segmentor)
378 |
379 | prob_edge = F.softmax(out_edge / temperature, dim=1)
380 | prob_skele = F.softmax(out_skele / temperature, dim=1)
381 |
382 | alpha = np.random.uniform(0.0, 1.0, 1)[0]
383 | soft_pseudo_target = alpha * prob_edge + (1 - alpha) * prob_skele
384 | loss_dice = dice_loss_PL(F.softmax(out_meta_mask_seg / temperature, dim=1), soft_pseudo_target)
385 | loss_ce = torch.mean(-(soft_pseudo_target * F.log_softmax(out_meta_mask_seg / temperature, dim=1)).sum(dim=1))
386 | loss_kl = F.kl_div(F.log_softmax(out_meta_mask_seg / temperature, dim=1), soft_pseudo_target, reduction='none').sum(dim=1).mean()
387 |
388 | loss = loss_ce + loss_dice
389 |
390 | loss.backward()
391 | optimizer.step()
392 |
393 | dice_score_meta = dice_score_metric(out_meta_mask_seg, target)
394 |
395 | pred = torch.argmax(out_meta_mask_seg, dim=1).unsqueeze(1)
396 | correct = pred.eq(target.data).cpu().sum()
397 | acc = correct / target.numel()
398 |
399 | correct_pos = torch.logical_and((pred == 1), (target == 1)).sum().item()
400 | sen_pos = round(correct_pos / ((target == 1).sum().item() + 0.0001), 3)
401 |
402 | correct_neg = torch.logical_and((pred == 0), (target == 0)).sum().item()
403 | sen_neg = round(correct_neg / ((target == 0).sum().item() + 0.0001), 3)
404 |
405 | partialEpoch = epoch + batch_idx / len(trainLoader) - 1
406 |
407 | print('Meta Train, Epoch: {:.2f}\tloss: {:.5f}\tLoss_dice: {:.5f}\tLoss_ce: {:.5f}\tAcc: {:.5f}\tSen: {:.5f}\tSpe: {:.5f}\tDice_score: {:.5f}'.
408 | format(partialEpoch, loss.item(), loss_dice.item(), loss_ce.item(), acc, sen_pos, sen_neg, dice_score_meta))
409 |
410 | trainF.write('{},{},{},{},{},{},{},{}\n'.
411 | format(partialEpoch, loss.item(), loss_dice.item(), loss_ce.item(), acc, sen_pos, sen_neg, dice_score_meta))
412 |
413 | wandb.log({
414 | "Train Meta Mask Loss": loss.item(),
415 | "Train Meta CE Loss": loss_ce.item(),
416 | "Train Meta Dice Loss": loss_dice.item(),
417 | })
418 |
419 | trainF.flush()
420 |
421 |
422 | def test_meta(epoch, model, testLoader, testF, train_segmentor='meta'):
423 | model.eval()
424 | test_loss = 0
425 | dice_score_meta = 0
426 | acc_all = 0
427 | sen_pos_all = 0
428 | sen_neg_all = 0
429 | with torch.no_grad():
430 | for data, target, edge, skele_x, skele_y, skele_z in testLoader:
431 | data, target, edge, skele_x, skele_y, skele_z = Variable(data.cuda()), Variable(target.cuda()), Variable(
432 | edge.cuda()), Variable(skele_x.cuda()), Variable(skele_y.cuda()), Variable(skele_z.cuda())
433 |
434 | out_meta_mask_seg, prob_edge, prob_skele = model(data, train_segmentor)
435 |
436 | loss_dice_meta = dice_loss(F.softmax(out_meta_mask_seg, dim=1), target).data
437 | test_loss += loss_dice_meta
438 |
439 | dice_score_meta += dice_score_metric(out_meta_mask_seg, target)
440 |
441 | pred = torch.argmax(out_meta_mask_seg, dim=1).unsqueeze(1)
442 | correct = pred.eq(target.data).cpu().sum()
443 | acc_all += correct / target.numel()
444 |
445 | correct_pos = torch.logical_and((pred == 1), (target == 1)).sum().item()
446 | sen_pos = round(correct_pos / ((target == 1).sum().item() + 0.0001), 3)
447 | sen_pos_all += sen_pos
448 |
449 | correct_neg = torch.logical_and((pred == 0), (target == 0)).sum().item()
450 | sen_neg = round(correct_neg / ((target == 0).sum().item() + 0.0001), 3)
451 | sen_neg_all += sen_neg
452 |
453 | test_loss /= len(testLoader)
454 | dice_score_meta /= len(testLoader)
455 |
456 | acc_all /= len(testLoader)
457 | sen_pos_all /= len(testLoader)
458 | sen_neg_all /= len(testLoader)
459 | print('\nTest online: Dice loss: {:.5f}, Acc:{:.3f}, Sen:{:.3f}, Spe:{:.3f}, Dice_meta: {:.5f}\n'.format(
460 | test_loss, acc_all, sen_pos_all, sen_neg_all, dice_score_meta))
461 |
462 | testF.write(
463 | '{},{},{},{},{},{}\n'.format(epoch, test_loss, acc_all, sen_pos_all, sen_neg_all, dice_score_meta))
464 |
465 |
466 | wandb.log({
467 | "Test Meta Loss": test_loss,
468 | "Test Meta Dice Score": dice_score_meta,
469 | })
470 |
471 |
472 | testF.flush()
473 |
474 | return dice_score_meta
475 |
476 |
477 | if __name__ == '__main__':
478 | main()
479 |
--------------------------------------------------------------------------------
/bowel_fineseg/tools/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2013 Oskar Maier
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | #
16 | # author Oskar Maier
17 | # version r0.1.1
18 | # since 2014-03-13
19 | # status Release
20 |
21 | # build-in modules
22 |
23 | # third-party modules
24 | import numpy
25 | from scipy.ndimage import _ni_support
26 | from scipy.ndimage.morphology import distance_transform_edt, binary_erosion, \
27 | generate_binary_structure
28 | from scipy.ndimage.measurements import label, find_objects
29 | from scipy.stats import pearsonr
30 |
31 |
32 | # own modules
33 |
34 | # code
35 | def dc(result, reference):
36 | r"""
37 | Dice coefficient
38 |
39 | Computes the Dice coefficient (also known as Sorensen index) between the binary
40 | objects in two images.
41 |
42 | The metric is defined as
43 |
44 | .. math::
45 |
46 | DC=\frac{2|A\cap B|}{|A|+|B|}
47 |
48 | , where :math:`A` is the first and :math:`B` the second set of samples (here: binary objects).
49 |
50 | Parameters
51 | ----------
52 | result : array_like
53 | Input data containing objects. Can be any type but will be converted
54 | into binary: background where 0, object everywhere else.
55 | reference : array_like
56 | Input data containing objects. Can be any type but will be converted
57 | into binary: background where 0, object everywhere else.
58 |
59 | Returns
60 | -------
61 | dc : float
62 | The Dice coefficient between the object(s) in ```result``` and the
63 | object(s) in ```reference```. It ranges from 0 (no overlap) to 1 (perfect overlap).
64 |
65 | Notes
66 | -----
67 | This is a real metric. The binary images can therefore be supplied in any order.
68 | """
69 | result = numpy.atleast_1d(result.astype(numpy.bool))
70 | reference = numpy.atleast_1d(reference.astype(numpy.bool))
71 |
72 | intersection = numpy.count_nonzero(result & reference)
73 |
74 | size_i1 = numpy.count_nonzero(result)
75 | size_i2 = numpy.count_nonzero(reference)
76 |
77 | try:
78 | dc = 2. * intersection / float(size_i1 + size_i2)
79 | except ZeroDivisionError:
80 | dc = 0.0
81 |
82 | return dc
83 |
84 |
85 | def jc(result, reference):
86 | """
87 | Jaccard coefficient
88 |
89 | Computes the Jaccard coefficient between the binary objects in two images.
90 |
91 | Parameters
92 | ----------
93 | result: array_like
94 | Input data containing objects. Can be any type but will be converted
95 | into binary: background where 0, object everywhere else.
96 | reference: array_like
97 | Input data containing objects. Can be any type but will be converted
98 | into binary: background where 0, object everywhere else.
99 |
100 | Returns
101 | -------
102 | jc: float
103 | The Jaccard coefficient between the object(s) in `result` and the
104 | object(s) in `reference`. It ranges from 0 (no overlap) to 1 (perfect overlap).
105 |
106 | Notes
107 | -----
108 | This is a real metric. The binary images can therefore be supplied in any order.
109 | """
110 | result = numpy.atleast_1d(result.astype(numpy.bool))
111 | reference = numpy.atleast_1d(reference.astype(numpy.bool))
112 |
113 | intersection = numpy.count_nonzero(result & reference)
114 | union = numpy.count_nonzero(result | reference)
115 |
116 | jc = float(intersection) / float(union)
117 |
118 | return jc
119 |
120 |
121 | def precision(result, reference):
122 | """
123 | Precison.
124 |
125 | Parameters
126 | ----------
127 | result : array_like
128 | Input data containing objects. Can be any type but will be converted
129 | into binary: background where 0, object everywhere else.
130 | reference : array_like
131 | Input data containing objects. Can be any type but will be converted
132 | into binary: background where 0, object everywhere else.
133 |
134 | Returns
135 | -------
136 | precision : float
137 | The precision between two binary datasets, here mostly binary objects in images,
138 | which is defined as the fraction of retrieved instances that are relevant. The
139 | precision is not symmetric.
140 |
141 | See also
142 | --------
143 | :func:`recall`
144 |
145 | Notes
146 | -----
147 | Not symmetric. The inverse of the precision is :func:`recall`.
148 | High precision means that an algorithm returned substantially more relevant results than irrelevant.
149 |
150 | References
151 | ----------
152 | .. [1] http://en.wikipedia.org/wiki/Precision_and_recall
153 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
154 | """
155 | result = numpy.atleast_1d(result.astype(numpy.bool))
156 | reference = numpy.atleast_1d(reference.astype(numpy.bool))
157 |
158 | tp = numpy.count_nonzero(result & reference)
159 | fp = numpy.count_nonzero(result & ~reference)
160 |
161 | try:
162 | precision = tp / float(tp + fp)
163 | except ZeroDivisionError:
164 | precision = 0.0
165 |
166 | return precision
167 |
168 |
169 | def recall(result, reference):
170 | """
171 | Recall.
172 |
173 | Parameters
174 | ----------
175 | result : array_like
176 | Input data containing objects. Can be any type but will be converted
177 | into binary: background where 0, object everywhere else.
178 | reference : array_like
179 | Input data containing objects. Can be any type but will be converted
180 | into binary: background where 0, object everywhere else.
181 |
182 | Returns
183 | -------
184 | recall : float
185 | The recall between two binary datasets, here mostly binary objects in images,
186 | which is defined as the fraction of relevant instances that are retrieved. The
187 | recall is not symmetric.
188 |
189 | See also
190 | --------
191 | :func:`precision`
192 |
193 | Notes
194 | -----
195 | Not symmetric. The inverse of the recall is :func:`precision`.
196 | High recall means that an algorithm returned most of the relevant results.
197 |
198 | References
199 | ----------
200 | .. [1] http://en.wikipedia.org/wiki/Precision_and_recall
201 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
202 | """
203 | result = numpy.atleast_1d(result.astype(numpy.bool))
204 | reference = numpy.atleast_1d(reference.astype(numpy.bool))
205 |
206 | tp = numpy.count_nonzero(result & reference)
207 | fn = numpy.count_nonzero(~result & reference)
208 |
209 | try:
210 | recall = tp / float(tp + fn)
211 | except ZeroDivisionError:
212 | recall = 0.0
213 |
214 | return recall
215 |
216 |
217 | def sensitivity(result, reference):
218 | """
219 | Sensitivity.
220 | Same as :func:`recall`, see there for a detailed description.
221 |
222 | See also
223 | --------
224 | :func:`specificity`
225 | """
226 | return recall(result, reference)
227 |
228 |
229 | def specificity(result, reference):
230 | """
231 | Specificity.
232 |
233 | Parameters
234 | ----------
235 | result : array_like
236 | Input data containing objects. Can be any type but will be converted
237 | into binary: background where 0, object everywhere else.
238 | reference : array_like
239 | Input data containing objects. Can be any type but will be converted
240 | into binary: background where 0, object everywhere else.
241 |
242 | Returns
243 | -------
244 | specificity : float
245 | The specificity between two binary datasets, here mostly binary objects in images,
246 | which denotes the fraction of correctly returned negatives. The
247 | specificity is not symmetric.
248 |
249 | See also
250 | --------
251 | :func:`sensitivity`
252 |
253 | Notes
254 | -----
255 | Not symmetric. The completment of the specificity is :func:`sensitivity`.
256 | High recall means that an algorithm returned most of the irrelevant results.
257 |
258 | References
259 | ----------
260 | .. [1] https://en.wikipedia.org/wiki/Sensitivity_and_specificity
261 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
262 | """
263 | result = numpy.atleast_1d(result.astype(numpy.bool))
264 | reference = numpy.atleast_1d(reference.astype(numpy.bool))
265 |
266 | tn = numpy.count_nonzero(~result & ~reference)
267 | fp = numpy.count_nonzero(result & ~reference)
268 |
269 | try:
270 | specificity = tn / float(tn + fp)
271 | except ZeroDivisionError:
272 | specificity = 0.0
273 |
274 | return specificity
275 |
276 |
277 | def true_negative_rate(result, reference):
278 | """
279 | True negative rate.
280 | Same as :func:`specificity`, see there for a detailed description.
281 |
282 | See also
283 | --------
284 | :func:`true_positive_rate`
285 | :func:`positive_predictive_value`
286 | """
287 | return specificity(result, reference)
288 |
289 |
290 | def true_positive_rate(result, reference):
291 | """
292 | True positive rate.
293 | Same as :func:`recall` and :func:`sensitivity`, see there for a detailed description.
294 |
295 | See also
296 | --------
297 | :func:`positive_predictive_value`
298 | :func:`true_negative_rate`
299 | """
300 | return recall(result, reference)
301 |
302 |
303 | def positive_predictive_value(result, reference):
304 | """
305 | Positive predictive value.
306 | Same as :func:`precision`, see there for a detailed description.
307 |
308 | See also
309 | --------
310 | :func:`true_positive_rate`
311 | :func:`true_negative_rate`
312 | """
313 | return precision(result, reference)
314 |
315 |
316 | def hd(result, reference, voxelspacing=None, connectivity=1):
317 | """
318 | Hausdorff Distance.
319 |
320 | Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two
321 | images. It is defined as the maximum surface distance between the objects.
322 |
323 | Parameters
324 | ----------
325 | result : array_like
326 | Input data containing objects. Can be any type but will be converted
327 | into binary: background where 0, object everywhere else.
328 | reference : array_like
329 | Input data containing objects. Can be any type but will be converted
330 | into binary: background where 0, object everywhere else.
331 | voxelspacing : float or sequence of floats, optional
332 | The voxelspacing in a distance unit i.e. spacing of elements
333 | along each dimension. If a sequence, must be of length equal to
334 | the input rank; if a single number, this is used for all axes. If
335 | not specified, a grid spacing of unity is implied.
336 | connectivity : int
337 | The neighbourhood/connectivity considered when determining the surface
338 | of the binary objects. This value is passed to
339 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
340 | Note that the connectivity influences the result in the case of the Hausdorff distance.
341 |
342 | Returns
343 | -------
344 | hd : float
345 | The symmetric Hausdorff Distance between the object(s) in ```result``` and the
346 | object(s) in ```reference```. The distance unit is the same as for the spacing of
347 | elements along each dimension, which is usually given in mm.
348 |
349 | See also
350 | --------
351 | :func:`assd`
352 | :func:`asd`
353 |
354 | Notes
355 | -----
356 | This is a real metric. The binary images can therefore be supplied in any order.
357 | """
358 | hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max()
359 | hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max()
360 | hd = max(hd1, hd2)
361 | return hd
362 |
363 |
364 | def hd95(result, reference, voxelspacing=None, connectivity=1):
365 | """
366 | 95th percentile of the Hausdorff Distance.
367 |
368 | Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two
369 | images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is
370 | commonly used in Biomedical Segmentation challenges.
371 |
372 | Parameters
373 | ----------
374 | result : array_like
375 | Input data containing objects. Can be any type but will be converted
376 | into binary: background where 0, object everywhere else.
377 | reference : array_like
378 | Input data containing objects. Can be any type but will be converted
379 | into binary: background where 0, object everywhere else.
380 | voxelspacing : float or sequence of floats, optional
381 | The voxelspacing in a distance unit i.e. spacing of elements
382 | along each dimension. If a sequence, must be of length equal to
383 | the input rank; if a single number, this is used for all axes. If
384 | not specified, a grid spacing of unity is implied.
385 | connectivity : int
386 | The neighbourhood/connectivity considered when determining the surface
387 | of the binary objects. This value is passed to
388 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
389 | Note that the connectivity influences the result in the case of the Hausdorff distance.
390 |
391 | Returns
392 | -------
393 | hd : float
394 | The symmetric Hausdorff Distance between the object(s) in ```result``` and the
395 | object(s) in ```reference```. The distance unit is the same as for the spacing of
396 | elements along each dimension, which is usually given in mm.
397 |
398 | See also
399 | --------
400 | :func:`hd`
401 |
402 | Notes
403 | -----
404 | This is a real metric. The binary images can therefore be supplied in any order.
405 | """
406 | hd1 = __surface_distances(result, reference, voxelspacing, connectivity)
407 | hd2 = __surface_distances(reference, result, voxelspacing, connectivity)
408 | hd95 = numpy.percentile(numpy.hstack((hd1, hd2)), 95)
409 | return hd95
410 |
411 |
412 | def assd(result, reference, voxelspacing=None, connectivity=1):
413 | """
414 | Average symmetric surface distance.
415 |
416 | Computes the average symmetric surface distance (ASD) between the binary objects in
417 | two images.
418 |
419 | Parameters
420 | ----------
421 | result : array_like
422 | Input data containing objects. Can be any type but will be converted
423 | into binary: background where 0, object everywhere else.
424 | reference : array_like
425 | Input data containing objects. Can be any type but will be converted
426 | into binary: background where 0, object everywhere else.
427 | voxelspacing : float or sequence of floats, optional
428 | The voxelspacing in a distance unit i.e. spacing of elements
429 | along each dimension. If a sequence, must be of length equal to
430 | the input rank; if a single number, this is used for all axes. If
431 | not specified, a grid spacing of unity is implied.
432 | connectivity : int
433 | The neighbourhood/connectivity considered when determining the surface
434 | of the binary objects. This value is passed to
435 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
436 | The decision on the connectivity is important, as it can influence the results
437 | strongly. If in doubt, leave it as it is.
438 |
439 | Returns
440 | -------
441 | assd : float
442 | The average symmetric surface distance between the object(s) in ``result`` and the
443 | object(s) in ``reference``. The distance unit is the same as for the spacing of
444 | elements along each dimension, which is usually given in mm.
445 |
446 | See also
447 | --------
448 | :func:`asd`
449 | :func:`hd`
450 |
451 | Notes
452 | -----
453 | This is a real metric, obtained by calling and averaging
454 |
455 | >>> asd(result, reference)
456 |
457 | and
458 |
459 | >>> asd(reference, result)
460 |
461 | The binary images can therefore be supplied in any order.
462 | """
463 | assd = numpy.mean(
464 | (asd(result, reference, voxelspacing, connectivity), asd(reference, result, voxelspacing, connectivity)))
465 | return assd
466 |
467 |
468 | def asd(result, reference, voxelspacing=None, connectivity=1):
469 | """
470 | Average surface distance metric.
471 |
472 | Computes the average surface distance (ASD) between the binary objects in two images.
473 |
474 | Parameters
475 | ----------
476 | result : array_like
477 | Input data containing objects. Can be any type but will be converted
478 | into binary: background where 0, object everywhere else.
479 | reference : array_like
480 | Input data containing objects. Can be any type but will be converted
481 | into binary: background where 0, object everywhere else.
482 | voxelspacing : float or sequence of floats, optional
483 | The voxelspacing in a distance unit i.e. spacing of elements
484 | along each dimension. If a sequence, must be of length equal to
485 | the input rank; if a single number, this is used for all axes. If
486 | not specified, a grid spacing of unity is implied.
487 | connectivity : int
488 | The neighbourhood/connectivity considered when determining the surface
489 | of the binary objects. This value is passed to
490 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
491 | The decision on the connectivity is important, as it can influence the results
492 | strongly. If in doubt, leave it as it is.
493 |
494 | Returns
495 | -------
496 | asd : float
497 | The average surface distance between the object(s) in ``result`` and the
498 | object(s) in ``reference``. The distance unit is the same as for the spacing
499 | of elements along each dimension, which is usually given in mm.
500 |
501 | See also
502 | --------
503 | :func:`assd`
504 | :func:`hd`
505 |
506 |
507 | Notes
508 | -----
509 | This is not a real metric, as it is directed. See `assd` for a real metric of this.
510 |
511 | The method is implemented making use of distance images and simple binary morphology
512 | to achieve high computational speed.
513 |
514 | Examples
515 | --------
516 | The `connectivity` determines what pixels/voxels are considered the surface of a
517 | binary object. Take the following binary image showing a cross
518 |
519 | >>> from scipy.ndimage.morphology import generate_binary_structure
520 | >>> cross = generate_binary_structure(2, 1)
521 | array([[0, 1, 0],
522 | [1, 1, 1],
523 | [0, 1, 0]])
524 |
525 | With `connectivity` set to `1` a 4-neighbourhood is considered when determining the
526 | object surface, resulting in the surface
527 |
528 | .. code-block:: python
529 |
530 | array([[0, 1, 0],
531 | [1, 0, 1],
532 | [0, 1, 0]])
533 |
534 | Changing `connectivity` to `2`, a 8-neighbourhood is considered and we get:
535 |
536 | .. code-block:: python
537 |
538 | array([[0, 1, 0],
539 | [1, 1, 1],
540 | [0, 1, 0]])
541 |
542 | , as a diagonal connection does no longer qualifies as valid object surface.
543 |
544 | This influences the results `asd` returns. Imagine we want to compute the surface
545 | distance of our cross to a cube-like object:
546 |
547 | >>> cube = generate_binary_structure(2, 1)
548 | array([[1, 1, 1],
549 | [1, 1, 1],
550 | [1, 1, 1]])
551 |
552 | , which surface is, independent of the `connectivity` value set, always
553 |
554 | .. code-block:: python
555 |
556 | array([[1, 1, 1],
557 | [1, 0, 1],
558 | [1, 1, 1]])
559 |
560 | Using a `connectivity` of `1` we get
561 |
562 | >>> asd(cross, cube, connectivity=1)
563 | 0.0
564 |
565 | while a value of `2` returns us
566 |
567 | >>> asd(cross, cube, connectivity=2)
568 | 0.20000000000000001
569 |
570 | due to the center of the cross being considered surface as well.
571 |
572 | """
573 | sds = __surface_distances(result, reference, voxelspacing, connectivity)
574 | asd = sds.mean()
575 | return asd
576 |
577 |
578 | def ravd(result, reference):
579 | """
580 | Relative absolute volume difference.
581 |
582 | Compute the relative absolute volume difference between the (joined) binary objects
583 | in the two images.
584 |
585 | Parameters
586 | ----------
587 | result : array_like
588 | Input data containing objects. Can be any type but will be converted
589 | into binary: background where 0, object everywhere else.
590 | reference : array_like
591 | Input data containing objects. Can be any type but will be converted
592 | into binary: background where 0, object everywhere else.
593 |
594 | Returns
595 | -------
596 | ravd : float
597 | The relative absolute volume difference between the object(s) in ``result``
598 | and the object(s) in ``reference``. This is a percentage value in the range
599 | :math:`[-1.0, +inf]` for which a :math:`0` denotes an ideal score.
600 |
601 | Raises
602 | ------
603 | RuntimeError
604 | If the reference object is empty.
605 |
606 | See also
607 | --------
608 | :func:`dc`
609 | :func:`precision`
610 | :func:`recall`
611 |
612 | Notes
613 | -----
614 | This is not a real metric, as it is directed. Negative values denote a smaller
615 | and positive values a larger volume than the reference.
616 | This implementation does not check, whether the two supplied arrays are of the same
617 | size.
618 |
619 | Examples
620 | --------
621 | Considering the following inputs
622 |
623 | >>> import numpy
624 | >>> arr1 = numpy.asarray([[0,1,0],[1,1,1],[0,1,0]])
625 | >>> arr1
626 | array([[0, 1, 0],
627 | [1, 1, 1],
628 | [0, 1, 0]])
629 | >>> arr2 = numpy.asarray([[0,1,0],[1,0,1],[0,1,0]])
630 | >>> arr2
631 | array([[0, 1, 0],
632 | [1, 0, 1],
633 | [0, 1, 0]])
634 |
635 | comparing `arr1` to `arr2` we get
636 |
637 | >>> ravd(arr1, arr2)
638 | -0.2
639 |
640 | and reversing the inputs the directivness of the metric becomes evident
641 |
642 | >>> ravd(arr2, arr1)
643 | 0.25
644 |
645 | It is important to keep in mind that a perfect score of `0` does not mean that the
646 | binary objects fit exactely, as only the volumes are compared:
647 |
648 | >>> arr1 = numpy.asarray([1,0,0])
649 | >>> arr2 = numpy.asarray([0,0,1])
650 | >>> ravd(arr1, arr2)
651 | 0.0
652 |
653 | """
654 | result = numpy.atleast_1d(result.astype(numpy.bool))
655 | reference = numpy.atleast_1d(reference.astype(numpy.bool))
656 |
657 | vol1 = numpy.count_nonzero(result)
658 | vol2 = numpy.count_nonzero(reference)
659 |
660 | if 0 == vol2:
661 | raise RuntimeError('The second supplied array does not contain any binary object.')
662 |
663 | return (vol1 - vol2) / float(vol2)
664 |
665 |
666 | def volume_correlation(results, references):
667 | r"""
668 | Volume correlation.
669 |
670 | Computes the linear correlation in binary object volume between the
671 | contents of the successive binary images supplied. Measured through
672 | the Pearson product-moment correlation coefficient.
673 |
674 | Parameters
675 | ----------
676 | results : sequence of array_like
677 | Ordered list of input data containing objects. Each array_like will be
678 | converted into binary: background where 0, object everywhere else.
679 | references : sequence of array_like
680 | Ordered list of input data containing objects. Each array_like will be
681 | converted into binary: background where 0, object everywhere else.
682 | The order must be the same as for ``results``.
683 |
684 | Returns
685 | -------
686 | r : float
687 | The correlation coefficient between -1 and 1.
688 | p : float
689 | The two-side p value.
690 |
691 | """
692 | results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool))
693 | references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool))
694 |
695 | results_volumes = [numpy.count_nonzero(r) for r in results]
696 | references_volumes = [numpy.count_nonzero(r) for r in references]
697 |
698 | return pearsonr(results_volumes, references_volumes) # returns (Pearson'
699 |
700 |
701 | def volume_change_correlation(results, references):
702 | r"""
703 | Volume change correlation.
704 |
705 | Computes the linear correlation of change in binary object volume between
706 | the contents of the successive binary images supplied. Measured through
707 | the Pearson product-moment correlation coefficient.
708 |
709 | Parameters
710 | ----------
711 | results : sequence of array_like
712 | Ordered list of input data containing objects. Each array_like will be
713 | converted into binary: background where 0, object everywhere else.
714 | references : sequence of array_like
715 | Ordered list of input data containing objects. Each array_like will be
716 | converted into binary: background where 0, object everywhere else.
717 | The order must be the same as for ``results``.
718 |
719 | Returns
720 | -------
721 | r : float
722 | The correlation coefficient between -1 and 1.
723 | p : float
724 | The two-side p value.
725 |
726 | """
727 | results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool))
728 | references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool))
729 |
730 | results_volumes = numpy.asarray([numpy.count_nonzero(r) for r in results])
731 | references_volumes = numpy.asarray([numpy.count_nonzero(r) for r in references])
732 |
733 | results_volumes_changes = results_volumes[1:] - results_volumes[:-1]
734 | references_volumes_changes = references_volumes[1:] - references_volumes[:-1]
735 |
736 | return pearsonr(results_volumes_changes,
737 | references_volumes_changes) # returns (Pearson's correlation coefficient, 2-tailed p-value)
738 |
739 |
740 | def obj_assd(result, reference, voxelspacing=None, connectivity=1):
741 | """
742 | Average symmetric surface distance.
743 |
744 | Computes the average symmetric surface distance (ASSD) between the binary objects in
745 | two images.
746 |
747 | Parameters
748 | ----------
749 | result : array_like
750 | Input data containing objects. Can be any type but will be converted
751 | into binary: background where 0, object everywhere else.
752 | reference : array_like
753 | Input data containing objects. Can be any type but will be converted
754 | into binary: background where 0, object everywhere else.
755 | voxelspacing : float or sequence of floats, optional
756 | The voxelspacing in a distance unit i.e. spacing of elements
757 | along each dimension. If a sequence, must be of length equal to
758 | the input rank; if a single number, this is used for all axes. If
759 | not specified, a grid spacing of unity is implied.
760 | connectivity : int
761 | The neighbourhood/connectivity considered when determining what accounts
762 | for a distinct binary object as well as when determining the surface
763 | of the binary objects. This value is passed to
764 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
765 | The decision on the connectivity is important, as it can influence the results
766 | strongly. If in doubt, leave it as it is.
767 |
768 | Returns
769 | -------
770 | assd : float
771 | The average symmetric surface distance between all mutually existing distinct
772 | binary object(s) in ``result`` and ``reference``. The distance unit is the same as for
773 | the spacing of elements along each dimension, which is usually given in mm.
774 |
775 | See also
776 | --------
777 | :func:`obj_asd`
778 |
779 | Notes
780 | -----
781 | This is a real metric, obtained by calling and averaging
782 |
783 | >>> obj_asd(result, reference)
784 |
785 | and
786 |
787 | >>> obj_asd(reference, result)
788 |
789 | The binary images can therefore be supplied in any order.
790 | """
791 | assd = numpy.mean((obj_asd(result, reference, voxelspacing, connectivity),
792 | obj_asd(reference, result, voxelspacing, connectivity)))
793 | return assd
794 |
795 |
796 | def obj_asd(result, reference, voxelspacing=None, connectivity=1):
797 | """
798 | Average surface distance between objects.
799 |
800 | First correspondences between distinct binary objects in reference and result are
801 | established. Then the average surface distance is only computed between corresponding
802 | objects. Correspondence is defined as unique and at least one voxel overlap.
803 |
804 | Parameters
805 | ----------
806 | result : array_like
807 | Input data containing objects. Can be any type but will be converted
808 | into binary: background where 0, object everywhere else.
809 | reference : array_like
810 | Input data containing objects. Can be any type but will be converted
811 | into binary: background where 0, object everywhere else.
812 | voxelspacing : float or sequence of floats, optional
813 | The voxelspacing in a distance unit i.e. spacing of elements
814 | along each dimension. If a sequence, must be of length equal to
815 | the input rank; if a single number, this is used for all axes. If
816 | not specified, a grid spacing of unity is implied.
817 | connectivity : int
818 | The neighbourhood/connectivity considered when determining what accounts
819 | for a distinct binary object as well as when determining the surface
820 | of the binary objects. This value is passed to
821 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
822 | The decision on the connectivity is important, as it can influence the results
823 | strongly. If in doubt, leave it as it is.
824 |
825 | Returns
826 | -------
827 | asd : float
828 | The average surface distance between all mutually existing distinct binary
829 | object(s) in ``result`` and ``reference``. The distance unit is the same as for the
830 | spacing of elements along each dimension, which is usually given in mm.
831 |
832 | See also
833 | --------
834 | :func:`obj_assd`
835 | :func:`obj_tpr`
836 | :func:`obj_fpr`
837 |
838 | Notes
839 | -----
840 | This is not a real metric, as it is directed. See `obj_assd` for a real metric of this.
841 |
842 | For the understanding of this metric, both the notions of connectedness and surface
843 | distance are essential. Please see :func:`obj_tpr` and :func:`obj_fpr` for more
844 | information on the first and :func:`asd` on the second.
845 |
846 | Examples
847 | --------
848 | >>> arr1 = numpy.asarray([[1,1,1],[1,1,1],[1,1,1]])
849 | >>> arr2 = numpy.asarray([[0,1,0],[0,1,0],[0,1,0]])
850 | >>> arr1
851 | array([[1, 1, 1],
852 | [1, 1, 1],
853 | [1, 1, 1]])
854 | >>> arr2
855 | array([[0, 1, 0],
856 | [0, 1, 0],
857 | [0, 1, 0]])
858 | >>> obj_asd(arr1, arr2)
859 | 1.5
860 | >>> obj_asd(arr2, arr1)
861 | 0.333333333333
862 |
863 | With the `voxelspacing` parameter, the distances between the voxels can be set for
864 | each dimension separately:
865 |
866 | >>> obj_asd(arr1, arr2, voxelspacing=(1,2))
867 | 1.5
868 | >>> obj_asd(arr2, arr1, voxelspacing=(1,2))
869 | 0.333333333333
870 |
871 | More examples depicting the notion of object connectedness:
872 |
873 | >>> arr1 = numpy.asarray([[1,0,1],[1,0,0],[0,0,0]])
874 | >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]])
875 | >>> arr1
876 | array([[1, 0, 1],
877 | [1, 0, 0],
878 | [0, 0, 0]])
879 | >>> arr2
880 | array([[1, 0, 1],
881 | [1, 0, 0],
882 | [0, 0, 1]])
883 | >>> obj_asd(arr1, arr2)
884 | 0.0
885 | >>> obj_asd(arr2, arr1)
886 | 0.0
887 |
888 | >>> arr1 = numpy.asarray([[1,0,1],[1,0,1],[0,0,1]])
889 | >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]])
890 | >>> arr1
891 | array([[1, 0, 1],
892 | [1, 0, 1],
893 | [0, 0, 1]])
894 | >>> arr2
895 | array([[1, 0, 1],
896 | [1, 0, 0],
897 | [0, 0, 1]])
898 | >>> obj_asd(arr1, arr2)
899 | 0.6
900 | >>> obj_asd(arr2, arr1)
901 | 0.0
902 |
903 | Influence of `connectivity` parameter can be seen in the following example, where
904 | with the (default) connectivity of `1` the first array is considered to contain two
905 | objects, while with an increase connectivity of `2`, just one large object is
906 | detected.
907 |
908 | >>> arr1 = numpy.asarray([[1,0,0],[0,1,1],[0,1,1]])
909 | >>> arr2 = numpy.asarray([[1,0,0],[0,0,0],[0,0,0]])
910 | >>> arr1
911 | array([[1, 0, 0],
912 | [0, 1, 1],
913 | [0, 1, 1]])
914 | >>> arr2
915 | array([[1, 0, 0],
916 | [0, 0, 0],
917 | [0, 0, 0]])
918 | >>> obj_asd(arr1, arr2)
919 | 0.0
920 | >>> obj_asd(arr1, arr2, connectivity=2)
921 | 1.742955328
922 |
923 | Note that the connectivity also influence the notion of what is considered an object
924 | surface voxels.
925 | """
926 | sds = list()
927 | labelmap1, labelmap2, _a, _b, mapping = __distinct_binary_object_correspondences(result, reference, connectivity)
928 | slicers1 = find_objects(labelmap1)
929 | slicers2 = find_objects(labelmap2)
930 | for lid2, lid1 in list(mapping.items()):
931 | window = __combine_windows(slicers1[lid1 - 1], slicers2[lid2 - 1])
932 | object1 = labelmap1[window] == lid1
933 | object2 = labelmap2[window] == lid2
934 | sds.extend(__surface_distances(object1, object2, voxelspacing, connectivity))
935 | asd = numpy.mean(sds)
936 | return asd
937 |
938 |
939 | def obj_fpr(result, reference, connectivity=1):
940 | """
941 | The false positive rate of distinct binary object detection.
942 |
943 | The false positive rates gives a percentage measure of how many distinct binary
944 | objects in the second array do not exists in the first array. A partial overlap
945 | (of minimum one voxel) is here considered sufficient.
946 |
947 | In cases where two distinct binary object in the second array overlap with a single
948 | distinct object in the first array, only one is considered to have been detected
949 | successfully and the other is added to the count of false positives.
950 |
951 | Parameters
952 | ----------
953 | result : array_like
954 | Input data containing objects. Can be any type but will be converted
955 | into binary: background where 0, object everywhere else.
956 | reference : array_like
957 | Input data containing objects. Can be any type but will be converted
958 | into binary: background where 0, object everywhere else.
959 | connectivity : int
960 | The neighbourhood/connectivity considered when determining what accounts
961 | for a distinct binary object. This value is passed to
962 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
963 | The decision on the connectivity is important, as it can influence the results
964 | strongly. If in doubt, leave it as it is.
965 |
966 | Returns
967 | -------
968 | tpr : float
969 | A percentage measure of how many distinct binary objects in ``results`` have no
970 | corresponding binary object in ``reference``. It has the range :math:`[0, 1]`, where a :math:`0`
971 | denotes an ideal score.
972 |
973 | Raises
974 | ------
975 | RuntimeError
976 | If the second array is empty.
977 |
978 | See also
979 | --------
980 | :func:`obj_tpr`
981 |
982 | Notes
983 | -----
984 | This is not a real metric, as it is directed. Whatever array is considered as
985 | reference should be passed second. A perfect score of :math:`0` tells that there are no
986 | distinct binary objects in the second array that do not exists also in the reference
987 | array, but does not reveal anything about objects in the reference array also
988 | existing in the second array (use :func:`obj_tpr` for this).
989 |
990 | Examples
991 | --------
992 | >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]])
993 | >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]])
994 | >>> arr2
995 | array([[1, 0, 0],
996 | [1, 0, 1],
997 | [0, 0, 1]])
998 | >>> arr1
999 | array([[0, 0, 1],
1000 | [1, 0, 1],
1001 | [0, 0, 1]])
1002 | >>> obj_fpr(arr1, arr2)
1003 | 0.0
1004 | >>> obj_fpr(arr2, arr1)
1005 | 0.0
1006 |
1007 | Example of directedness:
1008 |
1009 | >>> arr2 = numpy.asarray([1,0,1,0,1])
1010 | >>> arr1 = numpy.asarray([1,0,1,0,0])
1011 | >>> obj_fpr(arr1, arr2)
1012 | 0.0
1013 | >>> obj_fpr(arr2, arr1)
1014 | 0.3333333333333333
1015 |
1016 | Examples of multiple overlap treatment:
1017 |
1018 | >>> arr2 = numpy.asarray([1,0,1,0,1,1,1])
1019 | >>> arr1 = numpy.asarray([1,1,1,0,1,0,1])
1020 | >>> obj_fpr(arr1, arr2)
1021 | 0.3333333333333333
1022 | >>> obj_fpr(arr2, arr1)
1023 | 0.3333333333333333
1024 |
1025 | >>> arr2 = numpy.asarray([1,0,1,1,1,0,1])
1026 | >>> arr1 = numpy.asarray([1,1,1,0,1,1,1])
1027 | >>> obj_fpr(arr1, arr2)
1028 | 0.0
1029 | >>> obj_fpr(arr2, arr1)
1030 | 0.3333333333333333
1031 |
1032 | >>> arr2 = numpy.asarray([[1,0,1,0,0],
1033 | [1,0,0,0,0],
1034 | [1,0,1,1,1],
1035 | [0,0,0,0,0],
1036 | [1,0,1,0,0]])
1037 | >>> arr1 = numpy.asarray([[1,1,1,0,0],
1038 | [0,0,0,0,0],
1039 | [1,1,1,0,1],
1040 | [0,0,0,0,0],
1041 | [1,1,1,0,0]])
1042 | >>> obj_fpr(arr1, arr2)
1043 | 0.0
1044 | >>> obj_fpr(arr2, arr1)
1045 | 0.2
1046 | """
1047 | _, _, _, n_obj_reference, mapping = __distinct_binary_object_correspondences(reference, result, connectivity)
1048 | return (n_obj_reference - len(mapping)) / float(n_obj_reference)
1049 |
1050 |
1051 | def obj_tpr(result, reference, connectivity=1):
1052 | """
1053 | The true positive rate of distinct binary object detection.
1054 |
1055 | The true positive rates gives a percentage measure of how many distinct binary
1056 | objects in the first array also exists in the second array. A partial overlap
1057 | (of minimum one voxel) is here considered sufficient.
1058 |
1059 | In cases where two distinct binary object in the first array overlaps with a single
1060 | distinct object in the second array, only one is considered to have been detected
1061 | successfully.
1062 |
1063 | Parameters
1064 | ----------
1065 | result : array_like
1066 | Input data containing objects. Can be any type but will be converted
1067 | into binary: background where 0, object everywhere else.
1068 | reference : array_like
1069 | Input data containing objects. Can be any type but will be converted
1070 | into binary: background where 0, object everywhere else.
1071 | connectivity : int
1072 | The neighbourhood/connectivity considered when determining what accounts
1073 | for a distinct binary object. This value is passed to
1074 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
1075 | The decision on the connectivity is important, as it can influence the results
1076 | strongly. If in doubt, leave it as it is.
1077 |
1078 | Returns
1079 | -------
1080 | tpr : float
1081 | A percentage measure of how many distinct binary objects in ``result`` also exists
1082 | in ``reference``. It has the range :math:`[0, 1]`, where a :math:`1` denotes an ideal score.
1083 |
1084 | Raises
1085 | ------
1086 | RuntimeError
1087 | If the reference object is empty.
1088 |
1089 | See also
1090 | --------
1091 | :func:`obj_fpr`
1092 |
1093 | Notes
1094 | -----
1095 | This is not a real metric, as it is directed. Whatever array is considered as
1096 | reference should be passed second. A perfect score of :math:`1` tells that all distinct
1097 | binary objects in the reference array also exist in the result array, but does not
1098 | reveal anything about additional binary objects in the result array
1099 | (use :func:`obj_fpr` for this).
1100 |
1101 | Examples
1102 | --------
1103 | >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]])
1104 | >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]])
1105 | >>> arr2
1106 | array([[1, 0, 0],
1107 | [1, 0, 1],
1108 | [0, 0, 1]])
1109 | >>> arr1
1110 | array([[0, 0, 1],
1111 | [1, 0, 1],
1112 | [0, 0, 1]])
1113 | >>> obj_tpr(arr1, arr2)
1114 | 1.0
1115 | >>> obj_tpr(arr2, arr1)
1116 | 1.0
1117 |
1118 | Example of directedness:
1119 |
1120 | >>> arr2 = numpy.asarray([1,0,1,0,1])
1121 | >>> arr1 = numpy.asarray([1,0,1,0,0])
1122 | >>> obj_tpr(arr1, arr2)
1123 | 0.6666666666666666
1124 | >>> obj_tpr(arr2, arr1)
1125 | 1.0
1126 |
1127 | Examples of multiple overlap treatment:
1128 |
1129 | >>> arr2 = numpy.asarray([1,0,1,0,1,1,1])
1130 | >>> arr1 = numpy.asarray([1,1,1,0,1,0,1])
1131 | >>> obj_tpr(arr1, arr2)
1132 | 0.6666666666666666
1133 | >>> obj_tpr(arr2, arr1)
1134 | 0.6666666666666666
1135 |
1136 | >>> arr2 = numpy.asarray([1,0,1,1,1,0,1])
1137 | >>> arr1 = numpy.asarray([1,1,1,0,1,1,1])
1138 | >>> obj_tpr(arr1, arr2)
1139 | 0.6666666666666666
1140 | >>> obj_tpr(arr2, arr1)
1141 | 1.0
1142 |
1143 | >>> arr2 = numpy.asarray([[1,0,1,0,0],
1144 | [1,0,0,0,0],
1145 | [1,0,1,1,1],
1146 | [0,0,0,0,0],
1147 | [1,0,1,0,0]])
1148 | >>> arr1 = numpy.asarray([[1,1,1,0,0],
1149 | [0,0,0,0,0],
1150 | [1,1,1,0,1],
1151 | [0,0,0,0,0],
1152 | [1,1,1,0,0]])
1153 | >>> obj_tpr(arr1, arr2)
1154 | 0.8
1155 | >>> obj_tpr(arr2, arr1)
1156 | 1.0
1157 | """
1158 | _, _, n_obj_result, _, mapping = __distinct_binary_object_correspondences(reference, result, connectivity)
1159 | return len(mapping) / float(n_obj_result)
1160 |
1161 |
1162 | def __distinct_binary_object_correspondences(reference, result, connectivity=1):
1163 | """
1164 | Determines all distinct (where connectivity is defined by the connectivity parameter
1165 | passed to scipy's `generate_binary_structure`) binary objects in both of the input
1166 | parameters and returns a 1to1 mapping from the labelled objects in reference to the
1167 | corresponding (whereas a one-voxel overlap suffices for correspondence) objects in
1168 | result.
1169 |
1170 | All stems from the problem, that the relationship is non-surjective many-to-many.
1171 |
1172 | @return (labelmap1, labelmap2, n_lables1, n_labels2, labelmapping2to1)
1173 | """
1174 | result = numpy.atleast_1d(result.astype(numpy.bool))
1175 | reference = numpy.atleast_1d(reference.astype(numpy.bool))
1176 |
1177 | # binary structure
1178 | footprint = generate_binary_structure(result.ndim, connectivity)
1179 |
1180 | # label distinct binary objects
1181 | labelmap1, n_obj_result = label(result, footprint)
1182 | labelmap2, n_obj_reference = label(reference, footprint)
1183 |
1184 | # find all overlaps from labelmap2 to labelmap1; collect one-to-one relationships and store all one-two-many for later processing
1185 | slicers = find_objects(labelmap2) # get windows of labelled objects
1186 | mapping = dict() # mappings from labels in labelmap2 to corresponding object labels in labelmap1
1187 | used_labels = set() # set to collect all already used labels from labelmap2
1188 | one_to_many = list() # list to collect all one-to-many mappings
1189 | for l1id, slicer in enumerate(slicers): # iterate over object in labelmap2 and their windows
1190 | l1id += 1 # labelled objects have ids sarting from 1
1191 | bobj = (l1id) == labelmap2[slicer] # find binary object corresponding to the label1 id in the segmentation
1192 | l2ids = numpy.unique(labelmap1[slicer][
1193 | bobj]) # extract all unique object identifiers at the corresponding positions in the reference (i.e. the mapping)
1194 | l2ids = l2ids[0 != l2ids] # remove background identifiers (=0)
1195 | if 1 == len(
1196 | l2ids): # one-to-one mapping: if target label not already used, add to final list of object-to-object mappings and mark target label as used
1197 | l2id = l2ids[0]
1198 | if not l2id in used_labels:
1199 | mapping[l1id] = l2id
1200 | used_labels.add(l2id)
1201 | elif 1 < len(l2ids): # one-to-many mapping: store relationship for later processing
1202 | one_to_many.append((l1id, set(l2ids)))
1203 |
1204 | # process one-to-many mappings, always choosing the one with the least labelmap2 correspondences first
1205 | while True:
1206 | one_to_many = [(l1id, l2ids - used_labels) for l1id, l2ids in
1207 | one_to_many] # remove already used ids from all sets
1208 | one_to_many = [x for x in one_to_many if x[1]] # remove empty sets
1209 | one_to_many = sorted(one_to_many, key=lambda x: len(x[1])) # sort by set length
1210 | if 0 == len(one_to_many):
1211 | break
1212 | l2id = one_to_many[0][1].pop() # select an arbitrary target label id from the shortest set
1213 | mapping[one_to_many[0][0]] = l2id # add to one-to-one mappings
1214 | used_labels.add(l2id) # mark target label as used
1215 | one_to_many = one_to_many[1:] # delete the processed set from all sets
1216 |
1217 | return labelmap1, labelmap2, n_obj_result, n_obj_reference, mapping
1218 |
1219 |
1220 | def __surface_distances(result, reference, voxelspacing=None, connectivity=1):
1221 | """
1222 | The distances between the surface voxel of binary objects in result and their
1223 | nearest partner surface voxel of a binary object in reference.
1224 | """
1225 | result = numpy.atleast_1d(result.astype(numpy.bool))
1226 | reference = numpy.atleast_1d(reference.astype(numpy.bool))
1227 | if voxelspacing is not None:
1228 | voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)
1229 | voxelspacing = numpy.asarray(voxelspacing, dtype=numpy.float64)
1230 | if not voxelspacing.flags.contiguous:
1231 | voxelspacing = voxelspacing.copy()
1232 |
1233 | # binary structure
1234 | footprint = generate_binary_structure(result.ndim, connectivity)
1235 |
1236 | # test for emptiness
1237 | if 0 == numpy.count_nonzero(result):
1238 | raise RuntimeError('The first supplied array does not contain any binary object.')
1239 | if 0 == numpy.count_nonzero(reference):
1240 | raise RuntimeError('The second supplied array does not contain any binary object.')
1241 |
1242 | # extract only 1-pixel border line of objects
1243 | result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)
1244 | reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)
1245 |
1246 | # compute average surface distance
1247 | # Note: scipys distance transform is calculated only inside the borders of the
1248 | # foreground objects, therefore the input has to be reversed
1249 | dt = distance_transform_edt(~reference_border, sampling=voxelspacing)
1250 | sds = dt[result_border]
1251 |
1252 | return sds
1253 |
1254 |
1255 | def __combine_windows(w1, w2):
1256 | """
1257 | Joins two windows (defined by tuple of slices) such that their maximum
1258 | combined extend is covered by the new returned window.
1259 | """
1260 | res = []
1261 | for s1, s2 in zip(w1, w2):
1262 | res.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop)))
1263 | return tuple(res)
--------------------------------------------------------------------------------