├── README.md ├── data ├── image_drive.png ├── image_drive_added_noise.png ├── image_vessap.npy ├── skeleton_drive.png ├── skeleton_drive_added_noise.png ├── skeleton_drive_added_noise_20steps.png └── skeleton_vessap.npy ├── demo.py ├── graphical_abstract.png ├── requirements.txt └── skeletonize.py /README.md: -------------------------------------------------------------------------------- 1 | # A skeletonization algorithm for gradient-based optimization 2 | 3 | Accompanying code for the paper *A skeletonization algorithm for gradient-based optimization* presented at the International Conference on Computer Vision 2023 [1]. 4 | 5 | ## Introduction 6 | 7 | The skeleton of a digital image is a compact representation of its topology, geometry, and scale. It has utility in many computer vision applications, such as image description, segmentation, and registration. However, skeletonization has only seen limited use in contemporary deep learning solutions. Most existing skeletonization algorithms are not differentiable, making it impossible to integrate them with gradient-based optimization. Compatible algorithms based on morphological operations and neural networks have been proposed, but their results often deviate from the geometry and topology of the true medial axis. Our work introduces the first three-dimensional skeletonization algorithm that is both compatible with gradient-based optimization and preserves an object's topology (see figure) [1]. 8 | 9 | ![](./graphical_abstract.png) 10 | 11 | At its core, our method is an iterative boundary-peeling algorithm, which repeatedly removes simple points until only the skeleton remains. Simple points are identified using one of two strategies: one solution relies on the calculation of the Euler characteristic [2], and the other one is based on a set of Boolean rules that evaluate a point's 26-neighborhood [3]. Additionally, we adopt a scheme to safely delete multiple simple points at once, enabling the parallelization of our algorithm and introduce a strategy to apply our algorithm to non-binary inputs by employing the reparametrization trick and a straight-through estimator. The resulting method is exclusively based on matrix additions and multiplications, convolutional operations, basic non-linear functions, and sampling from a uniform probability distribution, allowing it to be easily implemented in PyTorch or any major deep learning library. 12 | 13 | 14 | ## Getting started 15 | 16 | The full skeletonization algorithm is contained in `skeletonize.py`. It is implemented as a [PyTorch Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) and inherets all its functionalities including its compatibility with PyTorch's automatic differentation engine and the ability to run on graphic processing units. As sucht, the skeletonization algorithm can be easily integrated in neural network architectures and loss functions similar to a convolutional or pooling layer. `demo.py` demonstrates the functionality on a two- and three-dimensional input from the DRIVE [4] and VesSAP [5] dataset, respectively. 17 | 18 | 19 | ## Citation 20 | 21 | To cite this work, please use the following BibTeX entry: 22 | 23 | ``` 24 | @inproceedings{Menten2023skeletonization, 25 | title="A skeletonization algorithm for gradient-based optimization", 26 | author="Menten, Martin J. and Paetzold, Johannes C. and Zimmer, Veronika A. and Shit, Suprosanna and Ezhov, Ivan and Holland, Robbie and Probst, Monika and Schnabel, Julia A. and Rueckert, Daniel", 27 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 28 | month={October}, 29 | year={2023} 30 | } 31 | ``` 32 | 33 | ## References 34 | 35 | [1] Martin J. Menten et al. A skeletonization algorithm for gradient-based optimization. Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023. 36 | 37 | [2] Steven Lobregt et al. Three-dimensional skeletonization: principle and algorithm. IEEE Transactions on pattern analysis and machine intelligence, 2(1):75–77, 1980 38 | 39 | [3] Gilles Bertrand. A boolean characterization of three-dimensional simple points. Pattern recognition letters, 17(2):115-124, 1996 40 | 41 | [4] https://drive.grand-challenge.org 42 | 43 | [5] https://www.discotechnologies.org/VesSAP -------------------------------------------------------------------------------- /data/image_drive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmenten/skeletonization-for-gradient-based-optimization/43c3d3e419598d42d57c0fd39107e0dc26f0982b/data/image_drive.png -------------------------------------------------------------------------------- /data/image_drive_added_noise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmenten/skeletonization-for-gradient-based-optimization/43c3d3e419598d42d57c0fd39107e0dc26f0982b/data/image_drive_added_noise.png -------------------------------------------------------------------------------- /data/image_vessap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmenten/skeletonization-for-gradient-based-optimization/43c3d3e419598d42d57c0fd39107e0dc26f0982b/data/image_vessap.npy -------------------------------------------------------------------------------- /data/skeleton_drive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmenten/skeletonization-for-gradient-based-optimization/43c3d3e419598d42d57c0fd39107e0dc26f0982b/data/skeleton_drive.png -------------------------------------------------------------------------------- /data/skeleton_drive_added_noise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmenten/skeletonization-for-gradient-based-optimization/43c3d3e419598d42d57c0fd39107e0dc26f0982b/data/skeleton_drive_added_noise.png -------------------------------------------------------------------------------- /data/skeleton_drive_added_noise_20steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmenten/skeletonization-for-gradient-based-optimization/43c3d3e419598d42d57c0fd39107e0dc26f0982b/data/skeleton_drive_added_noise_20steps.png -------------------------------------------------------------------------------- /data/skeleton_vessap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmenten/skeletonization-for-gradient-based-optimization/43c3d3e419598d42d57c0fd39107e0dc26f0982b/data/skeleton_vessap.npy -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | 5 | from skeletonize import Skeletonize 6 | 7 | 8 | # Two-dimensional example from the DRIVE dataset 9 | img = imageio.imread('data/image_drive.png') / 255. 10 | img = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) 11 | 12 | skeletonization_module = Skeletonize(probabilistic=False, simple_point_detection='Boolean') 13 | skeleton = skeletonization_module(img) 14 | 15 | skeleton = skeleton.numpy().squeeze() * 255 16 | imageio.imwrite('data/skeleton_drive.png', skeleton.astype(np.uint8)) 17 | 18 | 19 | # Same example with added uniform noise to demonstrate skeletonization of a non-binary input 20 | img = imageio.imread('data/image_drive_added_noise.png') / 255. 21 | img = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) 22 | 23 | skeletonization_module = Skeletonize(probabilistic=True, beta=0.33, tau=1.0, simple_point_detection='Boolean') 24 | skeleton = skeletonization_module(img) 25 | 26 | skeleton = skeleton.numpy().squeeze() * 255 27 | imageio.imwrite('data/skeleton_drive_added_noise.png', skeleton.astype(np.uint8)) 28 | 29 | 30 | # Application of the skeletonization module multiple times (as done commonly in gradient-based optimization) 31 | # so that the output converges towards the true skeleton 32 | img = imageio.imread('data/image_drive_added_noise.png') / 255. 33 | img = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) 34 | 35 | skeletonization_module = Skeletonize(probabilistic=True, beta=0.33, tau=1.0, simple_point_detection='Boolean') 36 | skeleton_stack = np.zeros_like(img.squeeze()) 37 | for step in range(20): 38 | skeleton_stack = skeleton_stack + skeletonization_module(img).numpy().squeeze() 39 | skeleton = (skeleton_stack / 20).round() 40 | 41 | skeleton = skeleton * 255 42 | imageio.imwrite('data/skeleton_drive_added_noise_20steps.png', skeleton.astype(np.uint8)) 43 | 44 | 45 | # Three-dimensional example from the VESSAP dataset 46 | img = np.load('data/image_vessap.npy') 47 | img = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) 48 | 49 | skeletonization_module = Skeletonize(probabilistic=False, simple_point_detection='Boolean', num_iter=10) 50 | skeleton = skeletonization_module(img) 51 | 52 | skeleton = skeleton.numpy().squeeze() 53 | np.save('data/skeleton_vessap.npy', skeleton) 54 | -------------------------------------------------------------------------------- /graphical_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmenten/skeletonization-for-gradient-based-optimization/43c3d3e419598d42d57c0fd39107e0dc26f0982b/graphical_abstract.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.9.0 2 | numpy==1.23.4 3 | torch==1.9.0 4 | -------------------------------------------------------------------------------- /skeletonize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Skeletonize(torch.nn.Module): 7 | """ 8 | Class based on PyTorch's Module class to skeletonize two- or three-dimensional input images 9 | while being fully compatible with PyTorch's autograd automatic differention engine as proposed in [1]. 10 | 11 | Attributes: 12 | propabilistic: a Boolean that indicates whether the input image should be binarized using 13 | the reparametrization trick and straight-through estimator. 14 | It should always be set to True if non-binary inputs are being provided. 15 | beta: scale of added logistic noise during the reparametrization trick. If too small, there will not be any learning via 16 | gradient-based optimization; if too large, the learning is very slow. 17 | tau: Boltzmann temperature for reparametrization trick. 18 | simple_point_detection: decides whether simple points should be identified using Boolean characterization of their 26-neighborhood (Boolean) [2] 19 | or by checking whether the Euler characteristic changes under their deletion (EulerCharacteristic) [3]. 20 | num_iter: number of iterations that each include one end-point check, eight checks for simple points and eight subsequent deletions. 21 | The number of iterations should be tuned to the type of input image. 22 | 23 | [1] Martin J. Menten et al. A skeletonization algorithm for gradient-based optimization. 24 | Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023. 25 | [2] Gilles Bertrand. A boolean characterization of three- dimensional simple points. 26 | Pattern recognition letters, 17(2):115-124, 1996. 27 | [3] Steven Lobregt et al. Three-dimensional skeletonization:principle and algorithm. 28 | IEEE Transactions on pattern analysis and machine intelligence, 2(1):75-77, 1980. 29 | """ 30 | 31 | def __init__(self, probabilistic=True, beta=0.33, tau=1.0, simple_point_detection='Boolean', num_iter=5): 32 | 33 | super(Skeletonize, self).__init__() 34 | 35 | self.probabilistic = probabilistic 36 | self.tau = tau 37 | self.beta = beta 38 | 39 | self.num_iter = num_iter 40 | self.endpoint_check = self._single_neighbor_check 41 | if simple_point_detection == 'Boolean': 42 | self.simple_check = self._boolean_simple_check 43 | elif simple_point_detection == 'EulerCharacteristic': 44 | self.simple_check = self._euler_characteristic_simple_check 45 | else: 46 | raise Exception() 47 | 48 | 49 | def forward(self, img): 50 | 51 | img = self._prepare_input(img) 52 | 53 | if self.probabilistic: 54 | img = self._stochastic_discretization(img) 55 | 56 | for current_iter in range(self.num_iter): 57 | 58 | # At each iteration create a new map of the end-points 59 | is_endpoint = self.endpoint_check(img) 60 | 61 | # Sub-iterate through eight different subfields 62 | x_offsets = [0, 1, 0, 1, 0, 1, 0, 1] 63 | y_offsets = [0, 0, 1, 1, 0, 0, 1, 1] 64 | z_offsets = [0, 0, 0, 0, 1, 1, 1, 1] 65 | 66 | for x_offset, y_offset, z_offset in zip(x_offsets, y_offsets, z_offsets): 67 | 68 | # At each sub-iteration detect all simple points and delete all simple points that are not end-points 69 | is_simple = self.simple_check(img[:, :, x_offset:, y_offset:, z_offset:]) 70 | deletion_candidates = is_simple * (1 - is_endpoint[:, :, x_offset::2, y_offset::2, z_offset::2]) 71 | img[:, :, x_offset::2, y_offset::2, z_offset::2] = torch.min(img[:, :, x_offset::2, y_offset::2, z_offset::2].clone(), 1 - deletion_candidates) 72 | 73 | img = self._prepare_output(img) 74 | 75 | return img 76 | 77 | 78 | 79 | def _prepare_input(self, img): 80 | """ 81 | Function to check that the input image is compatible with the subsequent calculations. 82 | Only two- and three-dimensional images with values between 0 and 1 are supported. 83 | If the input image is two-dimensional then it is converted into a three-dimensional one for further processing. 84 | """ 85 | 86 | if img.dim() == 5: 87 | self.expanded_dims = False 88 | elif img.dim() == 4: 89 | self.expanded_dims = True 90 | img = img.unsqueeze(2) 91 | else: 92 | raise Exception("Only two-or three-dimensional images (tensor dimensionality of 4 or 5) are supported as input.") 93 | 94 | if img.shape[2] == 2 or img.shape[3] == 2 or img.shape[4] == 2 or img.shape[3] == 1 or img.shape[4] == 1: 95 | raise Exception() 96 | 97 | if img.min() < 0.0 or img.max() > 1.0: 98 | raise Exception("Image values must lie between 0 and 1.") 99 | 100 | img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0) 101 | 102 | return img 103 | 104 | 105 | def _stochastic_discretization(self, img): 106 | """ 107 | Function to binarize the image so that it can be processed by our skeletonization method. 108 | In order to remain compatible with backpropagation we utilize the reparameterization trick and a straight-through estimator. 109 | """ 110 | 111 | alpha = (img + 1e-8) / (1.0 - img + 1e-8) 112 | 113 | uniform_noise = torch.rand_like(img) 114 | uniform_noise = torch.empty_like(img).uniform_(1e-8, 1 - 1e-8) 115 | logistic_noise = (torch.log(uniform_noise) - torch.log(1 - uniform_noise)) 116 | 117 | img = torch.sigmoid((torch.log(alpha) + logistic_noise * self.beta) / self.tau) 118 | img = (img.detach() > 0.5).float() - img.detach() + img 119 | 120 | return img 121 | 122 | 123 | def _single_neighbor_check(self, img): 124 | """ 125 | Function that characterizes points as endpoints if they have a single neighbor or no neighbor at all. 126 | """ 127 | 128 | img = F.pad(img, (1, 1, 1, 1, 1, 1)) 129 | 130 | # Check that number of ones in twentysix-neighborhood is exactly 0 or 1 131 | K = torch.tensor([[[1.0, 1.0, 1.0], 132 | [1.0, 1.0, 1.0], 133 | [1.0, 1.0, 1.0]], 134 | [[1.0, 1.0, 1.0], 135 | [1.0, 0.0, 1.0], 136 | [1.0, 1.0, 1.0]], 137 | [[1.0, 1.0, 1.0], 138 | [1.0, 1.0, 1.0], 139 | [1.0, 1.0, 1.0]]], device=img.device).view(1, 1, 3, 3, 3) 140 | 141 | num_twentysix_neighbors = F.conv3d(img, K) 142 | condition1 = F.hardtanh(-(num_twentysix_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors 143 | 144 | return condition1 145 | 146 | 147 | def _boolean_simple_check(self, img): 148 | """ 149 | Function that identifies simple points using Boolean conditions introduced by Bertrand et al. [1]. 150 | Each Boolean conditions can be assessed via convolutions with a limited number of pre-defined kernels. 151 | It total, four conditions are checked. If any one is fulfilled, the point is deemed simple. 152 | 153 | [1] Gilles Bertrand. A boolean characterization of three- dimensional simple points. 154 | Pattern recognition letters, 17(2):115-124, 1996. 155 | """ 156 | 157 | img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0) 158 | 159 | # Condition 1: number of zeros in the six-neighborhood is exactly 1 160 | K_N6 = torch.tensor([[[0.0, 0.0, 0.0], 161 | [0.0, 1.0, 0.0], 162 | [0.0, 0.0, 0.0]], 163 | [[0.0, 1.0, 0.0], 164 | [1.0, 0.0, 1.0], 165 | [0.0, 1.0, 0.0]], 166 | [[0.0, 0.0, 0.0], 167 | [0.0, 1.0, 0.0], 168 | [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 169 | 170 | num_six_neighbors = F.conv3d(1 - img, K_N6, stride=2) 171 | 172 | subcondition1a = F.hardtanh(num_six_neighbors, min_val=0, max_val=1) # 1 or more neighbors 173 | subcondition1b = F.hardtanh(-(num_six_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neighbors 174 | 175 | condition1 = subcondition1a * subcondition1b 176 | 177 | 178 | # Condition 2: number of ones in twentysix-neighborhood is exactly 1 179 | K_N26 = torch.tensor([[[1.0, 1.0, 1.0], 180 | [1.0, 1.0, 1.0], 181 | [1.0, 1.0, 1.0]], 182 | [[1.0, 1.0, 1.0], 183 | [1.0, 0.0, 1.0], 184 | [1.0, 1.0, 1.0]], 185 | [[1.0, 1.0, 1.0], 186 | [1.0, 1.0, 1.0], 187 | [1.0, 1.0, 1.0]]], device=img.device).view(1, 1, 3, 3, 3) 188 | 189 | num_twentysix_neighbors = F.conv3d(img, K_N26, stride=2) 190 | 191 | subcondition2a = F.hardtanh(num_twentysix_neighbors, min_val=0, max_val=1) # 1 or more neighbors 192 | subcondition2b = F.hardtanh(-(num_twentysix_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors 193 | 194 | condition2 = subcondition2a * subcondition2b 195 | 196 | 197 | # Condition 3: Number of ones in eighteen-neigborhood exactly 1... 198 | K_N18 = torch.tensor([[[0.0, 1.0, 0.0], 199 | [1.0, 1.0, 1.0], 200 | [0.0, 1.0, 0.0]], 201 | [[1.0, 1.0, 1.0], 202 | [1.0, 0.0, 1.0], 203 | [1.0, 1.0, 1.0]], 204 | [[0.0, 1.0, 0.0], 205 | [1.0, 1.0, 1.0], 206 | [0.0, 1.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 207 | 208 | num_eighteen_neighbors = F.conv3d(img, K_N18, stride=2) 209 | 210 | subcondition3a = F.hardtanh(num_eighteen_neighbors, min_val=0, max_val=1) # 1 or more neighbors 211 | subcondition3b = F.hardtanh(-(num_eighteen_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors 212 | 213 | # ... and cell configration B26 does not exist 214 | K_B26 = torch.tensor([[[1.0, -1.0, 0.0], 215 | [-1.0, -1.0, 0.0], 216 | [0.0, 0.0, 0.0]], 217 | [[-1.0, -1.0, 0.0], 218 | [-1.0, 0.0, 0.0], 219 | [0.0, 0.0, 0.0]], 220 | [[0.0, 0.0, 0.0], 221 | [0.0, 0.0, 0.0], 222 | [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 223 | 224 | B26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B26, stride=2) - 6) 225 | B26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2]), stride=2) - 6) 226 | B26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3]), stride=2) - 6) 227 | B26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[4]), stride=2) - 6) 228 | B26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3]), stride=2) - 6) 229 | B26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 4]), stride=2) - 6) 230 | B26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3, 4]), stride=2) - 6) 231 | B26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3, 4]), stride=2) - 6) 232 | num_B26_cells = B26_1_present + B26_2_present + B26_3_present + B26_4_present + B26_5_present + B26_6_present + B26_7_present + B26_8_present 233 | 234 | subcondition3c = F.hardtanh(-(num_B26_cells - 1), min_val=0, max_val=1) 235 | 236 | condition3 = subcondition3a * subcondition3b * subcondition3c 237 | 238 | 239 | # Condition 4: cell configuration A6 does not exist... 240 | K_A6 = torch.tensor([[[0.0, 1.0, 0.0], 241 | [1.0, -1.0, 1.0], 242 | [0.0, 1.0, 0.0]], 243 | [[0.0, 0.0, 0.0], 244 | [0.0, 0.0, 0.0], 245 | [0.0, 0.0, 0.0]], 246 | [[0.0, 0.0, 0.0], 247 | [0.0, 0.0, 0.0], 248 | [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 249 | 250 | A6_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A6, stride=2) - 4) 251 | A6_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A6, dims=[2, 3]), stride=2) - 4) 252 | A6_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A6, dims=[2, 4]), stride=2) - 4) 253 | A6_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A6, dims=[2]), stride=2) - 4) 254 | A6_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.flip(K_A6, dims=[2]), dims=[2, 3]), stride=2) - 4) 255 | A6_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.flip(K_A6, dims=[2]), dims=[2, 4]), stride=2) - 4) 256 | num_A6_cells = A6_1_present + A6_2_present + A6_3_present + A6_4_present + A6_5_present + A6_6_present 257 | 258 | subcondition4a = F.hardtanh(-(num_A6_cells - 1), min_val=0, max_val=1) 259 | 260 | # ... and cell configuration B26 does not exist... 261 | K_B26 = torch.tensor([[[1.0, -1.0, 0.0], 262 | [-1.0, -1.0, 0.0], 263 | [0.0, 0.0, 0.0]], 264 | [[-1.0, -1.0, 0.0], 265 | [-1.0, 0.0, 0.0], 266 | [0.0, 0.0, 0.0]], 267 | [[0.0, 0.0, 0.0], 268 | [0.0, 0.0, 0.0], 269 | [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 270 | 271 | B26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B26, stride=2) - 6) 272 | B26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2]), stride=2) - 6) 273 | B26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3]), stride=2) - 6) 274 | B26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[4]), stride=2) - 6) 275 | B26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3]), stride=2) - 6) 276 | B26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 4]), stride=2) - 6) 277 | B26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3, 4]), stride=2) - 6) 278 | B26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3, 4]), stride=2) - 6) 279 | num_B26_cells = B26_1_present + B26_2_present + B26_3_present + B26_4_present + B26_5_present + B26_6_present + B26_7_present + B26_8_present 280 | 281 | subcondition4b = F.hardtanh(-(num_B26_cells - 1), min_val=0, max_val=1) 282 | 283 | # ... and cell configuration B18 does not exist... 284 | K_B18 = torch.tensor([[[0.0, 1.0, 0.0], 285 | [-1.0, -1.0, -1.0], 286 | [0.0, 0.0, 0.0]], 287 | [[-1.0, -1.0, -1.0], 288 | [-1.0, 0.0, -1.0], 289 | [0.0, 0.0, 0.0]], 290 | [[0.0, 0.0, 0.0], 291 | [0.0, 0.0, 0.0], 292 | [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 293 | 294 | B18_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B18, stride=2) - 8) 295 | B18_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4]), stride=2) - 8) 296 | B18_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4], k=2), stride=2) - 8) 297 | B18_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4], k=3), stride=2) - 8) 298 | B18_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[3, 4]), stride=2) - 8) 299 | B18_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4]), stride=2) - 8) 300 | B18_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4], k=2), stride=2) - 8) 301 | B18_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4], k=3), stride=2) - 8) 302 | B18_9_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[3, 4], k=2), stride=2) - 8) 303 | B18_10_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4]), stride=2) - 8) 304 | B18_11_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4], k=2), stride=2) - 8) 305 | B18_12_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4], k=3), stride=2) - 8) 306 | num_B18_cells = B18_1_present + B18_2_present + B18_3_present + B18_4_present + B18_5_present + B18_6_present + B18_7_present + B18_8_present + B18_9_present + B18_10_present + B18_11_present + B18_12_present 307 | 308 | subcondition4c = F.hardtanh(-(num_B18_cells - 1), min_val=0, max_val=1) 309 | 310 | # ... and the number of zeros in the six-neighborhood minus the number of A18 cell configurations plus the number of A26 cell configurations is exactly one 311 | K_N6 = torch.tensor([[[0.0, 0.0, 0.0], 312 | [0.0, 1.0, 0.0], 313 | [0.0, 0.0, 0.0]], 314 | [[0.0, 1.0, 0.0], 315 | [1.0, 0.0, 1.0], 316 | [0.0, 1.0, 0.0]], 317 | [[0.0, 0.0, 0.0], 318 | [0.0, 1.0, 0.0], 319 | [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 320 | 321 | num_six_neighbors = F.conv3d(1-img, K_N6, stride=2) 322 | 323 | K_A18 = torch.tensor([[[0.0, -1.0, 0.0], 324 | [0.0, -1.0, 0.0], 325 | [0.0, 0.0, 0.0]], 326 | [[0.0, -1.0, 0.0], 327 | [0.0, 0.0, 0.0], 328 | [0.0, 0.0, 0.0]], 329 | [[0.0, 0.0, 0.0], 330 | [0.0, 0.0, 0.0], 331 | [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 332 | 333 | A18_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A18, stride=2) - 2) 334 | A18_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4]), stride=2) - 2) 335 | A18_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4], k=2), stride=2) - 2) 336 | A18_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4], k=3), stride=2) - 2) 337 | A18_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[3, 4]), stride=2) - 2) 338 | A18_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4]), stride=2) - 2) 339 | A18_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4], k=2), stride=2) - 2) 340 | A18_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4], k=3), stride=2) - 2) 341 | A18_9_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[3, 4], k=2), stride=2) - 2) 342 | A18_10_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4]), stride=2) - 2) 343 | A18_11_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4], k=2), stride=2) - 2) 344 | A18_12_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4], k=3), stride=2) - 2) 345 | num_A18_cells = A18_1_present + A18_2_present + A18_3_present + A18_4_present + A18_5_present + A18_6_present + A18_7_present + A18_8_present + A18_9_present + A18_10_present + A18_11_present + A18_12_present 346 | 347 | K_A26 = torch.tensor([[[-1.0, -1.0, 0.0], 348 | [-1.0, -1.0, 0.0], 349 | [0.0, 0.0, 0.0]], 350 | [[-1.0, -1.0, 0.0], 351 | [-1.0, 0.0, 0.0], 352 | [0.0, 0.0, 0.0]], 353 | [[0.0, 0.0, 0.0], 354 | [0.0, 0.0, 0.0], 355 | [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3) 356 | 357 | A26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A26, stride=2) - 6) 358 | A26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2]), stride=2) - 6) 359 | A26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[3]), stride=2) - 6) 360 | A26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[4]), stride=2) - 6) 361 | A26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 3]), stride=2) - 6) 362 | A26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 4]), stride=2) - 6) 363 | A26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[3, 4]), stride=2) - 6) 364 | A26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 3, 4]), stride=2) - 6) 365 | num_A26_cells = A26_1_present + A26_2_present + A26_3_present + A26_4_present + A26_5_present + A26_6_present + A26_7_present + A26_8_present 366 | 367 | subcondition4d = F.hardtanh(num_six_neighbors - num_A18_cells + num_A26_cells, min_val=0, max_val=1) # 1 or more configurations 368 | subcondition4e = F.hardtanh(-(num_six_neighbors - num_A18_cells + num_A26_cells - 2), min_val=0, max_val=1) # 1 or fewer configurations 369 | 370 | condition4 = subcondition4a * subcondition4b * subcondition4c * subcondition4d * subcondition4e 371 | 372 | # If any of the four conditions is fulfilled the point is simple 373 | combined = torch.cat([condition1, condition2, condition3, condition4], dim=1) 374 | is_simple = torch.amax(combined, dim=1, keepdim=True) 375 | 376 | return is_simple 377 | 378 | 379 | # Specifically designed to be used with the eight-subfield iterative scheme from above. 380 | def _euler_characteristic_simple_check(self, img): 381 | """ 382 | Function that identifies simple points by assessing whether the Euler characteristic changes when deleting it [1]. 383 | In order to calculate the Euler characteristic, the amount of vertices, edges, faces and octants are counted using convolutions with pre-defined kernels. 384 | The function is meant to be used in combination with the subfield-based iterative scheme employed in the forward function. 385 | 386 | [1] Steven Lobregt et al. Three-dimensional skeletonization:principle and algorithm. 387 | IEEE Transactions on pattern analysis and machine intelligence, 2(1):75-77, 1980. 388 | """ 389 | 390 | img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0) 391 | 392 | # Create masked version of the image where the center of 26-neighborhoods is changed to zero 393 | mask = torch.ones_like(img) 394 | mask[:, :, 1::2, 1::2, 1::2] = 0 395 | masked_img = img.clone() * mask 396 | 397 | # Count vertices 398 | vertices = F.relu(-(2.0 * img - 1.0)) 399 | num_vertices = F.avg_pool3d(vertices, (3, 3, 3), stride=2) * 27 400 | 401 | masked_vertices = F.relu(-(2.0 * masked_img - 1.0)) 402 | num_masked_vertices = F.avg_pool3d(masked_vertices, (3, 3, 3), stride=2) * 27 403 | 404 | # Count edges 405 | K_ud_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 2, 1, 1) 406 | K_ns_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 1, 2, 1) 407 | K_we_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 1, 1, 2) 408 | 409 | ud_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ud_edge)) 410 | num_ud_edges = F.avg_pool3d(ud_edges, (2, 3, 3), stride=2) * 18 411 | ns_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ns_edge)) 412 | num_ns_edges = F.avg_pool3d(ns_edges, (3, 2, 3), stride=2) * 18 413 | we_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_we_edge)) 414 | num_we_edges = F.avg_pool3d(we_edges, (3, 3, 2), stride=2) * 18 415 | num_edges = num_ud_edges + num_ns_edges + num_we_edges 416 | 417 | masked_ud_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ud_edge)) 418 | num_masked_ud_edges = F.avg_pool3d(masked_ud_edges, (2, 3, 3), stride=2) * 18 419 | masked_ns_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ns_edge)) 420 | num_masked_ns_edges = F.avg_pool3d(masked_ns_edges, (3, 2, 3), stride=2) * 18 421 | masked_we_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_we_edge)) 422 | num_masked_we_edges = F.avg_pool3d(masked_we_edges, (3, 3, 2), stride=2) * 18 423 | num_masked_edges = num_masked_ud_edges + num_masked_ns_edges + num_masked_we_edges 424 | 425 | # Count faces 426 | K_ud_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 1, 2, 2) 427 | K_ns_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 2, 1, 2) 428 | K_we_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 2, 2, 1) 429 | 430 | ud_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ud_face) - 0.5) * 2 431 | num_ud_faces = F.avg_pool3d(ud_faces, (3, 2, 2), stride=2) * 12 432 | ns_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ns_face) - 0.5) * 2 433 | num_ns_faces = F.avg_pool3d(ns_faces, (2, 3, 2), stride=2) * 12 434 | we_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_we_face) - 0.5) * 2 435 | num_we_faces = F.avg_pool3d(we_faces, (2, 2, 3), stride=2) * 12 436 | num_faces = num_ud_faces + num_ns_faces + num_we_faces 437 | 438 | masked_ud_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ud_face) - 0.5) * 2 439 | num_masked_ud_faces = F.avg_pool3d(masked_ud_faces, (3, 2, 2), stride=2) * 12 440 | masked_ns_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ns_face) - 0.5) * 2 441 | num_masked_ns_faces = F.avg_pool3d(masked_ns_faces, (2, 3, 2), stride=2) * 12 442 | masked_we_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_we_face) - 0.5) * 2 443 | num_masked_we_faces = F.avg_pool3d(masked_we_faces, (2, 2, 3), stride=2) * 12 444 | num_masked_faces = num_masked_ud_faces + num_masked_ns_faces + num_masked_we_faces 445 | 446 | # Count octants 447 | K_octants = torch.tensor([[[0.125, 0.125], [0.125, 0.125]], [[0.125, 0.125], [0.125, 0.125]]], device=img.device).view(1, 1, 2, 2, 2) 448 | 449 | octants = F.relu(F.conv3d(-(2.0 * img - 1.0), K_octants) - 0.75) * 4 450 | num_octants = F.avg_pool3d(octants, (2, 2, 2), stride=2) * 8 451 | 452 | masked_octants = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_octants) - 0.75) * 4 453 | num_masked_octants = F.avg_pool3d(masked_octants, (2, 2, 2), stride=2) * 8 454 | 455 | # Combined number of vertices, edges, faces and octants to calculate the euler characteristic 456 | euler_characteristic = num_vertices - num_edges + num_faces - num_octants 457 | masked_euler_characteristic = num_masked_vertices - num_masked_edges + num_masked_faces - num_masked_octants 458 | 459 | # If the Euler characteristic is unchanged after switching a point from 1 to 0 this indicates that the point is simple 460 | euler_change = F.hardtanh(torch.abs(masked_euler_characteristic - euler_characteristic), min_val=0, max_val=1) 461 | is_simple = 1 - euler_change 462 | is_simple = (is_simple.detach() > 0.5).float() - is_simple.detach() + is_simple 463 | 464 | return is_simple 465 | 466 | 467 | def _prepare_output(self, img): 468 | """ 469 | Function that removes the padding and dimensions added by _prepare_input function. 470 | """ 471 | 472 | img = img[:, :, 1:-1, 1:-1, 1:-1] 473 | 474 | if self.expanded_dims: 475 | img = torch.squeeze(img, dim=2) 476 | 477 | return img 478 | --------------------------------------------------------------------------------