├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── l1skeleton_py.iml ├── misc.xml ├── modules.xml ├── other.xml └── vcs.xml ├── LICENSE ├── README.md ├── main.py ├── renderer.py └── skeleton ├── center.py ├── center_type.py ├── debug.py ├── fit └── ellipse.py ├── params.py ├── recentering.py ├── skeletonization.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ply 2 | *.npy -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/l1skeleton_py.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 SmartPolarBear 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # l1skeleton_py 2 | 3 | A python implementation for the paper [L1-Medial Skeleton of Point Cloud](https://vcc.tech/research/2013/L1skeleton) 4 | mostly based on [MarcSchotman/skeletons-from-poincloud](https://github.com/MarcSchotman/skeletons-from-poincloud). 5 | 6 | All features including an experimental recentering feature are implemented. 7 | 8 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | 6 | from skeleton.skeletonization import skeletonize 7 | 8 | if __name__ == "__main__": 9 | points = np.load("data/default_original.npy") 10 | # points = np.load("data/simple_tree.npy") 11 | 12 | # pcd = o3d.io.read_point_cloud("data/4_Mimosa.ply", format='ply') 13 | # points = np.asarray(pcd.points) 14 | 15 | # mimosa: dh=8 16 | myCenters = skeletonize(points, n_centers=1000, downsampling_rate=1, dh=2.0, recenter_knn=200) 17 | 18 | if len(points) > 5000: 19 | random_indices = random.sample(range(0, len(points)), 5000) 20 | points = points[random_indices, :] 21 | 22 | original = o3d.geometry.PointCloud() 23 | original.points = o3d.utility.Vector3dVector(points) 24 | original.colors = o3d.utility.Vector3dVector([[0, 0.9, 0] for p in points]) 25 | 26 | all_centers = o3d.geometry.PointCloud() 27 | cts = myCenters.get_all_centers(exclude=[]) 28 | all_centers.points = o3d.utility.Vector3dVector(cts) 29 | all_centers.normals = o3d.utility.Vector3dVector([c.normal_vector() for c in myCenters.myCenters if c.label != 4]) 30 | all_centers.colors = o3d.utility.Vector3dVector([[0.0, 0.0, 0.9] for p in cts]) 31 | 32 | skeleton = o3d.geometry.PointCloud() 33 | cts = myCenters.get_skeleton_points() 34 | skeleton.points = o3d.utility.Vector3dVector(cts) 35 | skeleton.colors = o3d.utility.Vector3dVector([[0.9, 0.0, 0.0] for p in cts]) 36 | 37 | # o3d.visualization.draw_geometries([original, all_centers, skeleton], point_show_normal=True) 38 | o3d.visualization.draw_geometries([original, skeleton]) 39 | # o3d.visualization.draw_geometries([original, all_centers]) 40 | -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartPolarBear/l1skeleton_py/d242cf649cb4907636564655760c7312643f2efd/renderer.py -------------------------------------------------------------------------------- /skeleton/center.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from scipy.spatial import distance 5 | import time 6 | 7 | from skeleton.params import get_term1, get_sigma, get_term2 8 | from skeleton.utils import unit_vector, plane_dist, get_local_points, get_local_points_fast 9 | 10 | from skeleton.recentering import recenter_around 11 | import open3d as o3d 12 | 13 | from skeleton.center_type import CenterType 14 | 15 | from typing import Iterable, Final 16 | 17 | 18 | class Center: 19 | def __init__(self, center, h, index, sigma=0.5): 20 | self.center = center 21 | self.h = h 22 | self.label = CenterType.NON_BRANCH 23 | self.index = index 24 | self.connections = [] 25 | self.bridge_connections = None 26 | self.closest_neighbours = np.array([]) 27 | self.head_tail = False 28 | self.branch_number = None 29 | self.eigen_vectors = np.zeros((3, 3)) 30 | self.sigma = sigma 31 | 32 | def set_non_branch(self): 33 | if self.label != CenterType.BRANCH and self.label != CenterType.REMOVED: 34 | self.set_label(CenterType.NON_BRANCH) 35 | self.connections = [] 36 | self.bridge_connections = None 37 | self.head_tail = False 38 | self.branch_number = None 39 | 40 | def set_as_bridge_point(self, key, connection): 41 | if self.label != CenterType.REMOVED: 42 | self.set_non_branch() 43 | self.set_label(CenterType.BRIDGE) 44 | self.bridge_connections = connection 45 | self.branch_number = key 46 | 47 | def set_as_branch_point(self, key): 48 | self.connections = [] 49 | self.bridge_connections = None 50 | self.head_tail = False 51 | self.branch_number = None 52 | self.label = CenterType.BRANCH 53 | self.branch_number = key 54 | 55 | def set_eigen_vectors(self, eigen_vectors): 56 | if self.label == CenterType.NON_BRANCH: 57 | self.eigen_vectors = eigen_vectors 58 | 59 | def set_sigma(self, sigma): 60 | if self.label != CenterType.BRANCH: 61 | self.sigma = sigma 62 | 63 | def set_closest_neighbours(self, closest_neighbours): 64 | self.closest_neighbours = closest_neighbours 65 | 66 | def set_label(self, label: CenterType): 67 | if self.label != CenterType.REMOVED: 68 | self.label = label 69 | 70 | def set_center(self, center): 71 | if self.label != CenterType.BRANCH: 72 | self.center = center 73 | 74 | def set_h(self, h): 75 | if self.label != CenterType.BRANCH: 76 | self.h = h 77 | 78 | def normal_vector(self): 79 | return self.eigen_vectors[:, 0] 80 | 81 | 82 | class Centers: 83 | 84 | def set_my_non_branch_centers(self): 85 | 86 | my_non_branch_centers = [] 87 | 88 | for center in self.myCenters: 89 | if center.label == CenterType.NON_BRANCH or center.label == CenterType.BRIDGE: 90 | my_non_branch_centers.append(center) 91 | self.my_non_branch_centers = my_non_branch_centers 92 | 93 | def get_nearest_neighbours(self): 94 | 95 | distances = distance.squareform(distance.pdist(self.centers)) 96 | self.closest = np.argsort(distances, axis=1) 97 | 98 | for center in self.myCenters: 99 | # center.set_closest_neighbours(self.closest[center.index,1:]) 100 | closest = self.closest[center.index, :].copy() 101 | sorted_local_distances = distances[center.index, closest] ** 2 102 | 103 | # Returns zero if ALL values are within the range 104 | in_neighboorhood = np.argmax(sorted_local_distances >= center.h ** 2) 105 | if in_neighboorhood == 0: 106 | in_neighboorhood = -1 107 | 108 | center.set_closest_neighbours(closest[1:in_neighboorhood]) 109 | 110 | def __init__(self, points, center_count=2000, smoothing_k=5): 111 | self.smoothing_k = smoothing_k 112 | self.points = points 113 | 114 | self.pcd = o3d.geometry.PointCloud() 115 | self.pcd.points = o3d.utility.Vector3dVector(self.points) 116 | self.pcd.paint_uniform_color([0.5, 0.5, 0.5]) 117 | 118 | self.kdt = o3d.geometry.KDTreeFlann(self.pcd) 119 | 120 | self.h = self.h0 = self._compute_h0() 121 | 122 | centers = self._sample_centers(count=center_count) 123 | 124 | # Making sure centers are never the same as the actual points which can lead to bad things 125 | self.centers = centers + 10 ** -20 126 | 127 | self.myCenters = [] 128 | self.my_non_branch_centers = [] 129 | index = 0 130 | for center in centers: 131 | self.myCenters.append(Center(center, self.h0, index)) 132 | index += 1 133 | self.skeleton = {} 134 | self.closest = [] 135 | self.sigmas = np.array([None] * len(centers)) 136 | self.eigen_vectors = [None] * len(centers) 137 | self.branch_points = [None] * len(centers) 138 | self.non_branch_points = [None] * len(centers) 139 | self.get_nearest_neighbours() 140 | self.set_my_non_branch_centers() 141 | self.Nremoved = 0 142 | 143 | # From the official code 144 | self.search_distance = .4 145 | self.too_close_threshold = 0.01 146 | self.allowed_branch_length = 5 147 | 148 | self._initialize_sigmas() 149 | 150 | def _initialize_sigmas(self): 151 | for idx, myCenter in enumerate(self.myCenters): 152 | # Get the closest 50 centers to do calculations with 153 | centers_indices = myCenter.closest_neighbours 154 | 155 | # local centers and local sigmas for sigma smoothing 156 | centers_in = np.array(self.centers[centers_indices]) 157 | sigmas_in = np.array([self.myCenters[i].sigma for i in centers_indices]) 158 | 159 | sigma, vecs = get_sigma(myCenter.center, centers_in, sigmas_in, self.h, k=-1) 160 | 161 | myCenter.set_eigen_vectors(vecs) 162 | myCenter.set_sigma(sigma) 163 | 164 | print("Set initial sigmas.") 165 | 166 | def _compute_h0(self): 167 | 168 | bbox = self.pcd.get_oriented_bounding_box(robust=False) 169 | 170 | # bbox_pcd = o3d.geometry.PointCloud() 171 | # bbox_pcd.points = o3d.utility.Vector3dVector( 172 | # np.asarray(self.pcd.get_axis_aligned_bounding_box().get_box_points())) 173 | # bbox_pcd.colors = o3d.utility.Vector3dVector([[0, 0, 0.9]] * 8) 174 | # o3d.visualization.draw_geometries([self.pcd, bbox_pcd]) 175 | 176 | points = np.asarray(bbox.get_box_points()) 177 | print("Bounding box:", points) 178 | dists = distance.squareform(distance.pdist(points)) 179 | diag = np.max(dists) 180 | return (2.0 * diag) / (len(self.points) ** (1.0 / 3.0)) 181 | 182 | def _sample_centers(self, count): 183 | assert self.h0 184 | 185 | # gradually decrease voxel size to get the closest voxel count to the given count 186 | voxel_size = self.h0 187 | 188 | center_pcd = self.pcd.voxel_down_sample(voxel_size) 189 | while len(center_pcd.points) < count: 190 | voxel_size /= 2.0 191 | center_pcd = self.pcd.voxel_down_sample(voxel_size) 192 | 193 | print("Voxel down sampling to {}".format(len(center_pcd.points))) 194 | centers = np.asarray(center_pcd.points) 195 | 196 | # centers = np.asarray(self.pcd.points) 197 | 198 | if len(centers) > count: 199 | random_centers = random.sample(range(0, len(centers)), count) 200 | centers = centers[random_centers, :] 201 | 202 | return centers 203 | 204 | def get_h0(self): 205 | return self.h0 206 | 207 | def get_h(self): 208 | return self.h 209 | 210 | def get_all_centers(self, copy: bool = False, exclude=None) -> Iterable[np.ndarray]: 211 | if exclude is None: 212 | exclude = [CenterType.REMOVED] 213 | 214 | if copy: 215 | return [c.center.copy() for c in self.myCenters if c.label not in exclude] 216 | else: 217 | return [c.center for c in self.myCenters if c.label not in exclude] 218 | 219 | def get_skeleton_points(self, copy: bool = False) -> Iterable[np.ndarray]: 220 | ret = [] 221 | for key in self.skeleton: 222 | if copy: 223 | ret += [self.myCenters[k].center.copy() for k in self.skeleton[key]['branch'] if 224 | self.myCenters[k].label != CenterType.REMOVED] 225 | else: 226 | ret += [self.myCenters[k].center for k in self.skeleton[key]['branch'] if 227 | self.myCenters[k].label != CenterType.REMOVED] 228 | return ret 229 | 230 | def recenter(self, downsampling_rate: float = 0.5, knn: int = 200) -> None: 231 | # down sample each branch 232 | for key in self.skeleton: 233 | # branch = self.skeleton[key] 234 | # branch_list = branch['branch'] 235 | self.skeleton[key]['branch'] = random.sample(self.skeleton[key]['branch'], 236 | k=int(downsampling_rate * len(self.skeleton[key]['branch']))) 237 | 238 | THRESHOLD: Final = self.h0 / 16.0 239 | 240 | enough = 0 241 | zero_normals = 0 242 | 243 | for key in self.skeleton: 244 | # skl_pcd = o3d.geometry.PointCloud() 245 | # skl_pcd.points = o3d.utility.Vector3dVector( 246 | # [self.myCenters[i].center for i in self.skeleton[key]['branch']]) 247 | # 248 | # sum_pcd: o3d.geometry.PointCloud = self.pcd + skl_pcd 249 | # sum_pcd.estimate_normals() 250 | 251 | for idx, i in enumerate(self.skeleton[key]['branch']): 252 | p = self.myCenters[i] 253 | 254 | n = p.normal_vector() 255 | # n = sum_pcd.normals[idx + len(self.pcd.points)] 256 | 257 | if np.allclose(n, np.zeros_like(n)): 258 | self.myCenters[i].set_label(CenterType.REMOVED) 259 | zero_normals += 1 260 | continue 261 | 262 | k, idx, _ = self.kdt.search_knn_vector_3d(p.center, knn=knn) 263 | pts = self.points[list(idx[1:])] 264 | # neighbors = pts 265 | 266 | dists = [plane_dist(q, p.center, n) for q in pts] 267 | 268 | neighbors = pts[dists <= THRESHOLD] 269 | 270 | enough += 1 271 | 272 | self.myCenters[i] = recenter_around(p, neighbors, max_dist_move=self.h0) 273 | 274 | print("E/0N", enough, zero_normals) 275 | 276 | def remove_centers(self, indices): 277 | """ 278 | Removes a center completely 279 | """ 280 | if not isinstance(indices, list): 281 | indices = list([indices]) 282 | 283 | for index in sorted(indices, reverse=True): 284 | center = self.myCenters[index] 285 | center.set_label(CenterType.REMOVED) 286 | self.centers[center.index] = [9999, 9999, 9999] 287 | 288 | self.set_my_non_branch_centers() 289 | self.Nremoved += len(indices) 290 | 291 | def get_non_branch_points(self): 292 | 293 | non_branch_points = [] 294 | for center in self.myCenters: 295 | if center.label != CenterType.BRANCH and center.label != CenterType.REMOVED: 296 | non_branch_points.append(center.index) 297 | 298 | return non_branch_points 299 | 300 | def get_bridge_points(self): 301 | 302 | bridge_points = [] 303 | for key in self.skeleton: 304 | head = self.skeleton[key]['head_bridge_connection'] 305 | tail = self.skeleton[key]['tail_bridge_connection'] 306 | 307 | if head[0] and head[1] is not None: 308 | if not head[1] in bridge_points: 309 | bridge_points.append(head[1]) 310 | if tail[0] and tail[1] is not None: 311 | if not tail[1] in bridge_points: 312 | bridge_points.append(tail[1]) 313 | 314 | return bridge_points 315 | 316 | def update_sigmas(self): 317 | 318 | k = 5 319 | 320 | new_sigmas = [] 321 | 322 | for center in self.my_non_branch_centers: 323 | index = center.index 324 | 325 | indices = np.array(self.closest[index, :k]).astype(int) 326 | 327 | sigma_nearest_k_neighbours = self.sigmas[indices] 328 | 329 | mean_sigma = np.mean(sigma_nearest_k_neighbours) 330 | 331 | new_sigmas.append(mean_sigma) 332 | 333 | index = 0 334 | for center in self.my_non_branch_centers: 335 | center.set_sigma(new_sigmas[index]) 336 | 337 | self.sigmas[center.index] = new_sigmas[index] 338 | index += 1 339 | 340 | def update_properties(self): 341 | 342 | self.set_my_non_branch_centers() 343 | 344 | for center in self.myCenters: 345 | index = center.index 346 | self.centers[index] = center.center 347 | self.eigen_vectors[index] = center.eigen_vectors 348 | self.sigmas[index] = center.sigma 349 | 350 | self.get_nearest_neighbours() 351 | self.update_sigmas() 352 | 353 | def update_labels_connections(self): 354 | """ 355 | Update all the labels of all the centers 356 | 1) goes through all the branches and checks if the head has a bridge connection or a branch connection 357 | - If bridge connection this is still the head/tail of the branch 358 | - If it has a branch connection it is simply connected to another branch --> It is no head/tail anymore 359 | 2) Checks if bridges are still bridges 360 | 3) Sets all other points to simple non_branch_points 361 | """ 362 | 363 | updated_centers = [] 364 | for key in self.skeleton: 365 | 366 | branch = self.skeleton[key] 367 | 368 | head = self.myCenters[branch['branch'][0]] 369 | tail = self.myCenters[branch['branch'][-1]] 370 | 371 | # This is either a None value (for not having found a bridge point / connected branch) 372 | # or this is an integer index 373 | head_connection = branch['head_bridge_connection'][1] 374 | tail_connection = branch['tail_bridge_connection'][1] 375 | 376 | if head_connection is not None: 377 | 378 | head_connection = self.myCenters[head_connection] 379 | 380 | if branch['head_bridge_connection'][0]: 381 | head_connection.set_as_bridge_point(key, head.index) 382 | head.head_tail = True 383 | else: 384 | head_connection.set_as_branch_point(key) 385 | head.head_tail = False 386 | 387 | head.set_as_branch_point(key) 388 | head.connections = [head_connection.index, branch['branch'][1]] 389 | 390 | updated_centers.append(head_connection.index) 391 | updated_centers.append(head.index) 392 | else: 393 | head.set_as_branch_point(key) 394 | head.head_tail = True 395 | 396 | if tail_connection is not None: 397 | 398 | tail_connection = self.myCenters[tail_connection] 399 | 400 | if branch['tail_bridge_connection'][0]: 401 | tail.head_tail = True 402 | tail_connection.set_as_bridge_point(key, tail.index) 403 | else: 404 | tail.head_tail = False 405 | tail_connection.set_as_branch_point(key) 406 | 407 | tail.set_as_branch_point(key) 408 | tail.connections = [tail_connection.index, branch['branch'][-2]] 409 | updated_centers.append(tail_connection.index) 410 | updated_centers.append(tail.index) 411 | else: 412 | tail.set_as_branch_point(key) 413 | tail.head_tail = True 414 | 415 | # 1) Go through the branch list and set each center t branch_point and set the head_tail value appropriately 416 | # 2) Set the connections 417 | index = 1 418 | for center in branch['branch'][1:-1]: # [1:-1] to remove head and tail 419 | center = self.myCenters[center] 420 | 421 | center.set_as_branch_point(key) 422 | 423 | center.connections.append(branch['branch'][index - 1]) 424 | center.connections.append(branch['branch'][index + 1]) 425 | center.head_tail = False 426 | 427 | updated_centers.append(center.index) 428 | index += 1 429 | 430 | for center in self.myCenters: 431 | 432 | if center.index in updated_centers: 433 | continue 434 | center.set_non_branch() 435 | 436 | for key in self.skeleton: 437 | branch = self.skeleton[key] 438 | 439 | for index in branch['branch']: 440 | if branch['branch'].count(index) > 1: 441 | del branch['branch'][branch['branch'].index(index)] 442 | # print("ERROR: branch {} has multiple counts of index {}...".format(branch['branch'], index)) 443 | break 444 | 445 | def contract(self, h, density_weights, mu=0.35): 446 | """ 447 | Updates the centers by the algorithm suggested in "L1-medial skeleton of Point Cloud 2010" 448 | 449 | INPUT: 450 | - Centers 451 | - points belonging to centers 452 | - local neighbourhood h0 453 | - mu factor for force between centers (preventing them from clustering) 454 | OUTPUT: 455 | - New centers 456 | - Sigmas (indicator for the strength of dominant direction) 457 | - The eigenvectors of the points belonging to the centers 458 | """ 459 | self.h = h 460 | 461 | error_center = 0 462 | 463 | too_small_sigma = 0; 464 | 465 | N = 0 466 | 467 | # local_indices = get_local_points(self.kdt, centers=self.centers, h=h) 468 | local_indices = get_local_points_fast(self.points, centers=self.centers, h=h) 469 | 470 | # center_pcd = o3d.geometry.PointCloud() 471 | # center_pcd.points = o3d.utility.Vector3dVector([c.center for c in self.myCenters]) 472 | # 473 | # sum_pcd: o3d.geometry.PointCloud = self.pcd + center_pcd 474 | # sum_pcd.estimate_normals() 475 | # base_idx = len(self.pcd.points) 476 | 477 | for idx, myCenter in enumerate(self.myCenters): 478 | 479 | # Get the closest 50 centers to do calculations with 480 | centers_indices = myCenter.closest_neighbours 481 | 482 | # local centers and local sigmas for sigma smoothing 483 | centers_in = np.array(self.centers[centers_indices]) 484 | sigmas_in = np.array([self.myCenters[i].sigma for i in centers_indices]) 485 | 486 | my_local_indices = local_indices[myCenter.index] 487 | local_points = self.points[my_local_indices] 488 | 489 | # Check if we have enough points and centers 490 | shape = local_points.shape 491 | if len(shape) == 1: 492 | continue 493 | elif shape[0] > 2 and len(centers_in) > 1: 494 | 495 | density_weights_points = density_weights[my_local_indices] 496 | 497 | term1 = get_term1(myCenter.center, local_points, h, density_weights_points) 498 | term2 = get_term2(myCenter.center, centers_in, h) 499 | 500 | if term1.any() and term2.any(): 501 | sigma, vecs = get_sigma(myCenter.center, centers_in, sigmas_in, h, k=self.smoothing_k) 502 | 503 | # sigma = np.clip(sigma, 0 ,1.) 504 | 505 | # DIFFERS FROM PAPER 506 | # mu = mu_length/sigma_length * (sigma - min_sigma) 507 | # if mu < 0: 508 | # continue 509 | 510 | # mu_average +=mu 511 | 512 | new_center = term1 + mu * sigma * term2 513 | 514 | error_center += np.linalg.norm(myCenter.center - new_center) 515 | N += 1 516 | 517 | # Update this center object 518 | myCenter.set_center(new_center) 519 | myCenter.set_eigen_vectors(vecs) 520 | myCenter.set_sigma(sigma) 521 | myCenter.set_h(h) 522 | 523 | if N == 0: 524 | return 0 525 | 526 | return error_center / N 527 | 528 | def bridge_2_branch(self, bridge_point, requesting_branch_number): 529 | """ 530 | Change a bridge to a branch. 531 | 532 | 1) finds a branch with this bridge_point 533 | 2) changes the boolean indicating bridge/branch to False 534 | 3) Changes the head/tail label of the head/tail of this branch 535 | 4) When the whole skeleton is checked it changes the bridge_label to branch_label 536 | """ 537 | 538 | for key in self.skeleton: 539 | head_bridge_connection = self.skeleton[key]['head_bridge_connection'] 540 | tail_bridge_connection = self.skeleton[key]['tail_bridge_connection'] 541 | 542 | # 1) 543 | if bridge_point == head_bridge_connection[1]: 544 | # 2) 545 | self.skeleton[key]['head_bridge_connection'][0] = False 546 | # 3) 547 | head = self.skeleton[key]['branch'][0] 548 | self.myCenters[head].head_tail = False 549 | 550 | if bridge_point == tail_bridge_connection[1]: 551 | self.skeleton[key]['tail_bridge_connection'][0] = False 552 | tail = self.skeleton[key]['branch'][-1] 553 | self.myCenters[tail].head_tail = False 554 | 555 | # 4) 556 | self.myCenters[bridge_point].set_as_branch_point(requesting_branch_number) 557 | 558 | def find_bridge_point(self, index, connection_vector): 559 | """ 560 | Finds the bridging points of a branch 561 | These briding points are used to couple different branches at places where we have conjunctions 562 | INPUT:v 563 | - Index of the tail/head of the branch 564 | - the vector connecting this head/tail point to the branch 565 | OUTPUT: 566 | - If bridge_point found: 567 | index of bridge_point 568 | else: 569 | none 570 | ACTIONS: 571 | 1) find points in the neighboorhood of this point 572 | 2) Check if they are non_branching_points (i.e. not already in a branch) 573 | 3) Are they 'close'? We defined close as 5*(distance_to_closest_neighbour) 574 | 5) Angle of line end_of_branch to point and connection_vector < 90? 575 | 6) return branch_point_index 576 | """ 577 | 578 | myCenter = self.myCenters[index] 579 | 580 | success = False 581 | bridge_point = None 582 | for neighbour in myCenter.closest_neighbours: 583 | 584 | neighbour = self.myCenters[neighbour] 585 | 586 | if neighbour.label == CenterType.BRANCH or neighbour.label == CenterType.REMOVED: 587 | continue 588 | 589 | # If current neighbour is too far away we break 590 | if sum((neighbour.center - myCenter.center) ** 2) > self.h ** 2: 591 | break 592 | 593 | branch_2_bridge_u = unit_vector(neighbour.center - myCenter.center) 594 | connection_vector_u = unit_vector(connection_vector) 595 | 596 | cos_theta = np.dot(branch_2_bridge_u, connection_vector_u) 597 | # cos_theta >0 --> theta < 90 degrees 598 | if cos_theta > 0: 599 | bridge_point = neighbour.index 600 | success = True 601 | break 602 | 603 | return bridge_point, success 604 | 605 | def connect_bridge_points_in_h(self): 606 | # Connects bridge points which are within the same neighboorhood 607 | for center in self.myCenters: 608 | if center.label != CenterType.BRIDGE: 609 | continue 610 | # Check the local neighbourhood for any other bridge_points 611 | for neighbour in center.closest_neighbours: 612 | 613 | neighbour = self.myCenters[neighbour] 614 | 615 | # Is it a bridge point? 616 | if neighbour.label != CenterType.BRIDGE: 617 | continue 618 | 619 | # Is it still in the local neighbourhood? 620 | if sum((neighbour.center - center.center) ** 2) > (2 * center.h) ** 2: 621 | break 622 | 623 | # If here we have two bridge points in 1 local nneighboorhood: 624 | # So we merge them: 625 | branch1 = center.branch_number 626 | branch2 = neighbour.branch_number 627 | 628 | # Check if we are connected to the head or tail of the branch 629 | if self.skeleton[branch1]['head_bridge_connection'][1] == center.index: 630 | index_branch1_connection = 0 631 | elif self.skeleton[branch1]['tail_bridge_connection'][1] == center.index: 632 | index_branch1_connection = -1 633 | else: 634 | print("fuck!") 635 | continue 636 | # else: 637 | # raise Exception( 638 | # "ERROR in 'merge_bridge_points': COULDNT FIND THE BRIDGE INDEX IN THE BRIDGE_CONNECTIONS OF THE SPECIFIED BRANCH") 639 | if self.skeleton[branch2]['head_bridge_connection'][1] == neighbour.index: 640 | index_branch2_connection = 0 641 | elif self.skeleton[branch2]['tail_bridge_connection'][1] == neighbour.index: 642 | index_branch2_connection = -1 643 | else: 644 | print("fuck!!") 645 | continue 646 | 647 | # else: 648 | # raise Exception( 649 | # "ERROR in 'merge_bridge_points': COULDNT FIND THE BRIDGE INDEX IN THE BRIDGE_CONNECTIONS OF THE SPECIFIED BRANCH") 650 | 651 | # Change the conenctions and boolenas accordingly: 652 | if index_branch1_connection == 0: 653 | # Add the bridge point to the branch 654 | self.skeleton[branch1]['branch'].insert(0, center.index) 655 | # Update the head_conenction such that it does not have any bridge connection anymore 656 | self.skeleton[branch1]['head_bridge_connection'][0] = False 657 | # And connect it to the otehr branch, i.e. the neighboor 658 | self.skeleton[branch1]['head_bridge_connection'][1] = neighbour.index 659 | else: 660 | self.skeleton[branch1]['branch'].extend([center.index]) 661 | self.skeleton[branch1]['tail_bridge_connection'][0] = False 662 | self.skeleton[branch1]['tail_bridge_connection'][1] = neighbour.index 663 | 664 | if index_branch2_connection == 0: 665 | self.skeleton[branch2]['branch'].insert(0, neighbour.index) 666 | self.skeleton[branch2]['head_bridge_connection'][0] = False 667 | self.skeleton[branch2]['head_bridge_connection'][1] = center.index 668 | else: 669 | self.skeleton[branch2]['branch'].extend([neighbour.index]) 670 | self.skeleton[branch2]['tail_bridge_connection'][0] = False 671 | self.skeleton[branch2]['tail_bridge_connection'][1] = center.index 672 | 673 | # Now they are branch points: 674 | center.set_as_branch_point(branch1) 675 | neighbour.set_as_branch_point(branch2) 676 | 677 | def connect_identical_bridge_points(self): 678 | """ 679 | Connects branches which are connected to an identical bridge point 680 | 1) Makes a list with the connection values of all the heads and tails. The value is None if it is connected to another branch 681 | 2) Finds a similar index 682 | 3) Connects these branches 683 | 4) Replaces the value by None in the list and start at (2) again 684 | """ 685 | 686 | # 1) 687 | bridge_points = [] 688 | for key in self.skeleton: 689 | branch = self.skeleton[key] 690 | 691 | bridges_of_branch = [] 692 | if branch['head_bridge_connection'][0]: 693 | bridges_of_branch.append(branch['head_bridge_connection'][1]) 694 | else: 695 | bridges_of_branch.append(None) 696 | 697 | if branch['tail_bridge_connection'][0]: 698 | bridges_of_branch.append(branch['tail_bridge_connection'][1]) 699 | else: 700 | bridges_of_branch.append(None) 701 | 702 | bridge_points.append(bridges_of_branch) 703 | 704 | bridge_points = np.array(bridge_points, dtype=object) 705 | success = True 706 | while success: 707 | success = False 708 | for points in bridge_points: 709 | bridge_head = points[0] 710 | bridge_tail = points[1] 711 | 712 | # If not None check how man y instances we have of this bridge point 713 | if bridge_head is not None: 714 | # 2) 715 | count_head = len(np.argwhere(bridge_points == bridge_head)) 716 | if count_head > 1: 717 | # 3) #If mroe then 1 we get all the indices (row, column wise) where the rows are branch numbers and the columns indicate if its at the head or tail 718 | indices = np.where(bridge_points == bridge_head) 719 | # We choose the first banch as the 'parent' it will adopt this bridge_point as branch point. All other branches with this bridge_point will simply loose it. 720 | branch1 = indices[0][0] 721 | # Set these values to False as after this we do not have a bridge point anymore 722 | self.skeleton[branch1]['head_bridge_connection'][0] = False 723 | self.skeleton[branch1]['head_bridge_connection'][1] = bridge_head 724 | # Sets all branches with this bridge_point to False as well 725 | self.bridge_2_branch(bridge_head, branch1) 726 | # 4) Set all the indices with this bridge_point to None and start over 727 | bridge_points[indices] = None 728 | 729 | success = True 730 | break 731 | 732 | if bridge_tail is not None: 733 | count_tail = len(np.argwhere(bridge_points == bridge_tail)) 734 | if count_tail > 1: 735 | indices = np.where(bridge_points == bridge_tail) 736 | branch1 = indices[0][0] # Becomes part of the branch 737 | self.skeleton[branch1]['tail_bridge_connection'][0] = False 738 | self.skeleton[branch1]['tail_bridge_connection'][1] = bridge_tail 739 | 740 | self.bridge_2_branch(bridge_tail, branch1) 741 | bridge_points[indices] = None 742 | 743 | success = True 744 | break 745 | 746 | def merge_bridge_points(self): 747 | """ 748 | 1) Connects bridge points which are within the same neighboorhood 749 | 2) Connects branches which are connected to an identical bridge point 750 | """ 751 | 752 | # 1) 753 | self.connect_bridge_points_in_h() 754 | # 2) 755 | self.connect_identical_bridge_points() 756 | 757 | def set_bridge_points(self, key, branch): 758 | """ 759 | First finds then sets bridge_points of this branch 760 | 1) checks if head/tail is connected to a branch 761 | 2) Checks if we can find a bridge point 762 | 3) If we find bridge, set the old bridge(if we had it) to non_branch_point and set new bridge label to bridge_point and update the branch 763 | """ 764 | 765 | # 1) 766 | if branch['head_bridge_connection'][0]: 767 | head = branch['branch'][0] 768 | head_1 = branch['branch'][1] 769 | head_bridge_vector = self.centers[head] - self.centers[head_1] 770 | # 2) 771 | bridge_point, success = self.find_bridge_point(head, head_bridge_vector) 772 | # 3) Update old bridge_point 773 | if success: 774 | old_bridge_point = branch['head_bridge_connection'][1] 775 | if old_bridge_point is not None: 776 | old_bridge_point = self.myCenters[old_bridge_point] 777 | old_bridge_point.set_non_branch() 778 | 779 | branch['head_bridge_connection'][1] = bridge_point 780 | self.myCenters[bridge_point].set_as_bridge_point(key, head) 781 | 782 | if branch['tail_bridge_connection'][0]: 783 | tail = branch['branch'][-1] 784 | tail_1 = branch['branch'][-2] 785 | tail_bridge_vector = self.centers[tail] - self.centers[tail_1] 786 | 787 | bridge_point, success = self.find_bridge_point(tail, tail_bridge_vector) 788 | 789 | if success: 790 | # Update old bridge_point 791 | old_bridge_point = branch['tail_bridge_connection'][1] 792 | if old_bridge_point is not None: 793 | old_bridge_point = self.myCenters[old_bridge_point] 794 | old_bridge_point.set_non_branch() 795 | 796 | branch['tail_bridge_connection'][1] = bridge_point 797 | self.myCenters[bridge_point].set_as_bridge_point(key, tail) 798 | 799 | self.skeleton[key] = branch 800 | 801 | def add_new_branch(self, branch_list): 802 | """ 803 | A branch: {'branch': [list of branch points], 'head connection':[Bool denoting if its a bridge/branch point True/False, index of conenction], tail_bridge_connection:[same stuff]} 804 | For each new branch a few checks: 805 | 1) were there bridge points? If so they need to be connected 806 | - If they are and the head / tail of the branch this mean the branch is connected to another branch 807 | 2) Finds the potential bridge points 808 | 3) sets the labels of the centers 809 | 4) adds the branch to the skeleton list of branches 810 | """ 811 | head_bridge_connection = [True, None] 812 | tail_bridge_connection = [True, None] 813 | 814 | key = len(self.skeleton) + 1 815 | 816 | # Check for bridge points: 817 | for index in branch_list: 818 | 819 | center = self.myCenters[index] 820 | # Do we have a bridge point? 821 | if center.label != CenterType.BRIDGE: 822 | continue 823 | 824 | # Our head is connected to a bridge point of another branch. 825 | # Thus, our head has NO bridge point, 826 | # and we need to change this in the branch from which this is the bridge_point 827 | if index == branch_list[0]: 828 | head_bridge_connection[0] = False 829 | head_bridge_connection[1] = center.bridge_connections 830 | 831 | # same stuff 832 | elif index == branch_list[-1]: 833 | tail_bridge_connection[0] = False 834 | tail_bridge_connection[1] = center.bridge_connections 835 | 836 | # Now make this bridge_point a branch 837 | self.bridge_2_branch(center.index, requesting_branch_number=key) 838 | 839 | branch = {'branch': branch_list, 'head_bridge_connection': head_bridge_connection, 840 | 'tail_bridge_connection': tail_bridge_connection} 841 | 842 | # Set labels 843 | for index in branch_list: 844 | 845 | self.myCenters[index].set_as_branch_point(key) 846 | 847 | if (index == branch_list[0] and head_bridge_connection[0]) or ( 848 | index == branch_list[-1] and tail_bridge_connection[0]): 849 | self.myCenters[index].head_tail = True 850 | else: 851 | self.myCenters[index].head_tail = False 852 | 853 | self.skeleton[key] = branch 854 | 855 | def update_branch(self, key, new_branch): 856 | """ 857 | Checks wheter the updated branch contains a bridge point of another branch. If so it updates the label of the bridge point and the branch head/tail connection values 858 | INPUTS: 859 | - key of the branch 860 | - the new branch 861 | """ 862 | 863 | # Go through the new_branch list 864 | for index in [new_branch['branch'][0], new_branch['branch'][-1]]: 865 | # Check if this point is a bridge_point not from this branch 866 | center = self.myCenters[index] 867 | # Set the label of this center to branch_point 868 | center.set_as_branch_point(key) 869 | 870 | # Set head/tail label 871 | if index == new_branch['branch'][0] and new_branch['head_bridge_connection'][0]: 872 | center.head_tail = True 873 | elif index == new_branch['branch'][-1] and new_branch['tail_bridge_connection'][0]: 874 | center.head_tail = True 875 | else: 876 | center.head_tail = False 877 | 878 | # Actually update branch 879 | self.skeleton[key] = new_branch 880 | 881 | def find_extension_point(self, center_index, vector): 882 | """ 883 | INPUT: 884 | - The neighbours 885 | - The center which this is about 886 | - the connection vector of this center to the skeleton 887 | OUTPUT: 888 | - Boolean indicating connection yes or 889 | - index of connection 890 | ACTIONS: 891 | 1) Check if neighbour is too far 892 | 2) Check if too close 893 | 3) check if in the right direcion 894 | 4) if checks 1,2,3 we check if we meet the requirement. Then we stop 895 | """ 896 | myCenter = self.myCenters[center_index] 897 | vector_u = unit_vector(vector) 898 | 899 | connection = False 900 | neighbour = None 901 | for n in myCenter.closest_neighbours: 902 | 903 | neighbour = self.myCenters[n] 904 | 905 | if neighbour.label == CenterType.BRANCH or neighbour.label == CenterType.REMOVED: 906 | continue 907 | 908 | # #Check if inside local neighbourhood 909 | r = neighbour.center - myCenter.center 910 | r2 = np.einsum('i,i->i', r, r) 911 | r2_sum = np.einsum('i->', r2) 912 | 913 | # 1) 914 | if r2_sum > self.search_distance ** 2: 915 | break 916 | # 2) 917 | elif r2_sum <= self.too_close_threshold ** 2: 918 | self.remove_centers(neighbour.index) 919 | continue 920 | 921 | # make unit vector: 922 | center_2_neighbour_u = unit_vector( 923 | neighbour.center - myCenter.center) # From front of skeleton TOWARDS the new direction 924 | 925 | # Check skeleton angle condition 926 | # cos(theta) = dot(u,v)/(norm(u)*norm(v)) <= -0.9 927 | cos_theta = np.dot(center_2_neighbour_u, vector_u) 928 | 929 | # 3) 930 | if cos_theta > 0: 931 | continue 932 | 933 | # 4) 934 | if cos_theta <= -0.9: 935 | connection = True 936 | break 937 | 938 | if neighbour is None: 939 | # FIXME 940 | return False, -1 941 | 942 | return connection, neighbour.index 943 | 944 | def try_extend_branch(self, branch, head_bridge_connection=True, tail_bridge_connection=True): 945 | """ 946 | Tries to extend this branch from the head and the tail onwards. 947 | 948 | INPUTS: 949 | - the branch as list of idnices 950 | - head/tail conenction boolean indicating if the head/tail is connected to a bridge point (T) or branch point (F) 951 | OUTPUTS: 952 | - The extended branch ( as far as possible) 953 | - Boolean indicating if a branch was etended in any way 954 | """ 955 | 956 | found_connection = False 957 | 958 | # head =! tail --> which mean skeleton is a full circle AND head is nto conencted to another branch 959 | if head_bridge_connection: 960 | # Get index of head connection 961 | head = branch[0] 962 | # Get vector conencted head to the rest of the skeleton 963 | head_bridge_connection_vector = self.centers[branch[1]] - self.centers[head] 964 | if not np.allclose(head_bridge_connection_vector, np.zeros_like(head_bridge_connection_vector)): 965 | # find a possible extensions of this connection 966 | connection, index = self.find_extension_point(head, head_bridge_connection_vector) 967 | # Inserts it 968 | if connection: 969 | if not connection == branch[-1]: 970 | branch.insert(0, index) 971 | found_connection = True 972 | 973 | if tail_bridge_connection: 974 | tail = branch[-1] 975 | tail_bridge_connection_vector = self.centers[branch[-2]] - self.centers[tail] 976 | 977 | if not np.allclose(tail_bridge_connection_vector, np.zeros_like(tail_bridge_connection_vector)): 978 | connection, index = self.find_extension_point(tail, tail_bridge_connection_vector) 979 | if connection: 980 | if not connection == branch[0]: 981 | branch.extend([index]) 982 | found_connection = True 983 | 984 | return branch, found_connection 985 | 986 | def try_to_make_new_branch(self, myCenter): 987 | """ 988 | Tries to form a new branch 989 | """ 990 | 991 | found_branch = False 992 | for neighbour in myCenter.closest_neighbours: 993 | neighbour_center = self.centers[neighbour] 994 | # Check if inside local neighbourhood 995 | if sum((neighbour_center - myCenter.center) ** 2) > (self.search_distance) ** 2: 996 | break 997 | 998 | center_2_neighbour_u = unit_vector(neighbour_center - myCenter.center) 999 | 1000 | # Check if this neighbour is in the direction of the dominant eigen_vector: 1001 | # So is the angle > 155 or < 25 1002 | 1003 | if abs(np.dot(myCenter.eigen_vectors[:, 0], center_2_neighbour_u)) < 0.9: 1004 | continue 1005 | 1006 | branch = [myCenter.index, neighbour] 1007 | found_connection = True 1008 | while found_connection: 1009 | branch, found_connection = self.try_extend_branch(branch) 1010 | 1011 | # We ackknowledge new branches if they exist of 5 or more centers: 1012 | branch_length = self.allowed_branch_length * int(self.h / self.h0) 1013 | if len(branch) > 5: 1014 | # If inside this branch are centers which were bridge points we will connect them up 1015 | self.add_new_branch(branch) 1016 | found_branch = True 1017 | break 1018 | 1019 | return found_branch 1020 | 1021 | def try_extend_skeleton(self): 1022 | """ 1023 | Tries to extend each already existing branch from the head and the tail onwards 1024 | - If other bridge points are encountered the branch will stop extending this way and connect to the branch of this aprticular bridge point 1025 | """ 1026 | 1027 | for key in self.skeleton: 1028 | 1029 | branch = self.skeleton[key] 1030 | branch_list = branch['branch'] 1031 | 1032 | success = True 1033 | had_one_success = False 1034 | # Tries to extend the head and tail by one center 1035 | while success: 1036 | 1037 | # Go through the new_branch list, skip the already known indices 1038 | branch_list, success = self.try_extend_branch(branch_list, branch['head_bridge_connection'][0], 1039 | branch['tail_bridge_connection'][0]) 1040 | 1041 | # If newly found tail/head are a bridge point from another branch we stop extending this way 1042 | if success: 1043 | new_head = self.myCenters[branch_list[0]] 1044 | new_tail = self.myCenters[branch_list[-1]] 1045 | 1046 | # If we encounter a bridge from a different branch we are connected to this branch and thus do not have a bridge point anymore and need to adjust that particular branch as well 1047 | if new_head.label == CenterType.BRIDGE and new_head.index != branch['head_bridge_connection'][1]: 1048 | # Update this branch head bridge connection 1049 | branch['head_bridge_connection'][0] = False 1050 | branch['head_bridge_connection'][1] = new_head.bridge_connections 1051 | # Update the branch from which this bridge_point originated 1052 | self.bridge_2_branch(new_head.index, key) 1053 | 1054 | elif new_tail.label == CenterType.BRIDGE and new_tail.index != branch['tail_bridge_connection'][1]: 1055 | 1056 | branch['tail_bridge_connection'][0] = False 1057 | branch['tail_bridge_connection'][1] = new_tail.bridge_connections 1058 | # Update the branch from which this bridge_point originated 1059 | self.bridge_2_branch(new_tail.index, key) 1060 | 1061 | # If we extended the branch we update it. 1062 | if success: 1063 | had_one_success = True 1064 | branch['branch'] = branch_list 1065 | self.update_branch(key, branch) 1066 | 1067 | self.set_bridge_points(key, branch) 1068 | self.merge_bridge_points() 1069 | self.clean_points_around_branch(branch) 1070 | 1071 | def find_connections(self): 1072 | """ 1073 | 1) Tries to extend the existing skeleton: 1074 | 2) Tries to find new skeletons 1075 | 3) Merges skeletons if possible 1076 | """ 1077 | 1078 | self.try_extend_skeleton() 1079 | 1080 | non_branch_points = np.array(self.get_non_branch_points()) 1081 | if non_branch_points.any(): 1082 | sigma_candidates = self.sigmas[non_branch_points] 1083 | 1084 | seed_points_to_check = non_branch_points[np.where(sigma_candidates > 0.9)] 1085 | 1086 | bridge_points = self.get_bridge_points() 1087 | 1088 | seed_points_to_check = list(bridge_points) + list(seed_points_to_check) 1089 | print("top 5 sigmas:", sorted(sigma_candidates, reverse=True)[:5]) 1090 | for seed_point_index in seed_points_to_check: 1091 | 1092 | myCenter = self.myCenters[seed_point_index] 1093 | 1094 | # old_skeleton = len(self.skeleton) 1095 | if self.try_to_make_new_branch(myCenter): 1096 | new_branch_number = len(self.skeleton) 1097 | new_branch = self.skeleton[new_branch_number] 1098 | self.set_bridge_points(new_branch_number, new_branch) 1099 | self.merge_bridge_points() 1100 | self.clean_points_around_branch(new_branch) 1101 | 1102 | self.update_labels_connections() 1103 | self.clean_points() 1104 | 1105 | def clean_points(self): 1106 | """ 1107 | Cleans points which 1108 | 1) have no points in their neighborhood 1109 | 2) Are bridge_points, with no non_branch_points around. Makes them part of the branch 1110 | 3) More than half your neighbours are branch_points 1111 | AFTER the first 2 branches are formed 1112 | """ 1113 | 1114 | if len(self.skeleton) > 1: 1115 | remove_centers = [] 1116 | for center in self.myCenters: 1117 | if center.label == CenterType.REMOVED or center.label == CenterType.BRANCH: 1118 | continue 1119 | 1120 | # 1) If no neighbours: 1121 | # if not center.closest_neighbours.any(): 1122 | if len(center.closest_neighbours) <= 5: 1123 | # If a bridge point we make it a branch 1124 | if center.label == CenterType.BRIDGE: 1125 | self.bridge_2_branch(center.index, center.branch_number) 1126 | elif center.label == CenterType.NON_BRANCH: 1127 | remove_centers.append(center.index) 1128 | # Skip the other checks 1129 | continue 1130 | # 2) 1131 | if center.label == CenterType.BRIDGE: 1132 | has_non_branch_point_neighbours = False 1133 | # Check if all close neighbours are branch_points: 1134 | for neighbour in center.closest_neighbours: 1135 | neighbour = self.myCenters[neighbour] 1136 | 1137 | # Check till we leave the neighborhood 1138 | if sum((neighbour.center - center.center) ** 2) > (2 * center.h) ** 2: 1139 | break 1140 | 1141 | if neighbour.label != CenterType.BRANCH and neighbour.label != CenterType.REMOVED: 1142 | has_non_branch_point_neighbours = True 1143 | break 1144 | if not has_non_branch_point_neighbours: 1145 | self.bridge_2_branch(center.index, center.branch_number) 1146 | 1147 | # 3) 1148 | if center.label == CenterType.NON_BRANCH: 1149 | N_branch_points = 0 1150 | N_non_branch_points = 0 1151 | # Check if all close neighbours are branch_points: 1152 | for neighbour in center.closest_neighbours: 1153 | neighbour = self.myCenters[neighbour] 1154 | 1155 | # Check till we leave the neighbourhood 1156 | if np.sum((neighbour.center - center.center) ** 2) > center.h ** 2: 1157 | break 1158 | 1159 | if neighbour.label == CenterType.BRANCH: 1160 | N_branch_points += 1 1161 | elif neighbour.label == CenterType.NON_BRANCH or neighbour.label == CenterType.BRIDGE: 1162 | N_non_branch_points += 1 1163 | 1164 | if N_branch_points > N_non_branch_points: 1165 | remove_centers.append(center.index) 1166 | 1167 | # Remove all the centers 1168 | if remove_centers: 1169 | self.remove_centers(remove_centers) 1170 | print("removed", len(remove_centers), 'points!') 1171 | 1172 | def clean_points_around_branch(self, branch): 1173 | """ 1174 | Removes points which: 1175 | 1) are within h of any point in the branch_list 1176 | 2) Excludes the head and tail IF they are connected to a bridge 1177 | 3) are non_branch_points 1178 | 1179 | """ 1180 | 1181 | remove_centers = [] 1182 | 1183 | for center in branch['branch']: 1184 | 1185 | # 2) If the head/tail is connected to a bridge do not remove any points 1186 | if center == branch['branch'][0] and branch['head_bridge_connection'][0]: 1187 | continue 1188 | elif center == branch['branch'][-1] and branch['tail_bridge_connection'][0]: 1189 | continue 1190 | 1191 | center = self.myCenters[center] 1192 | 1193 | for neighbour in center.closest_neighbours: 1194 | 1195 | neighbour = self.myCenters[neighbour] 1196 | 1197 | if sum((neighbour.center - center.center) ** 2) > self.too_close_threshold ** 2: 1198 | break 1199 | 1200 | if neighbour.label != CenterType.NON_BRANCH: 1201 | continue 1202 | 1203 | remove_centers.append(neighbour.index) 1204 | 1205 | if remove_centers: 1206 | self.remove_centers(remove_centers) 1207 | print("removed", len(remove_centers), 'points!') 1208 | -------------------------------------------------------------------------------- /skeleton/center_type.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | 3 | 4 | @unique 5 | class CenterType(Enum): 6 | NON_BRANCH = 1, 7 | BRANCH = 2, 8 | BRIDGE = 3, 9 | REMOVED = 4, 10 | -------------------------------------------------------------------------------- /skeleton/debug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import random 4 | import sys 5 | import time 6 | 7 | import skeleton.center as sct 8 | 9 | from skeleton.center_type import CenterType 10 | 11 | from skeleton.params import get_density_weights 12 | from skeleton.utils import get_local_points 13 | 14 | import open3d as o3d 15 | from tqdm import tqdm 16 | 17 | import timeit 18 | 19 | 20 | class SkeletonBeforeAfterVisualizer: 21 | def __init__(self, skl: sct.Centers, enable=True): 22 | self.skl = skl 23 | self.enable = enable 24 | 25 | def __enter__(self): 26 | if not self.enable: 27 | return 28 | 29 | self.before_pts = self.skl.get_skeleton_points(copy=True) 30 | 31 | def __exit__(self, exc_type, exc_val, exc_tb): 32 | if not self.enable: 33 | return 34 | 35 | self.after_cts = self.skl.get_skeleton_points(copy=True) 36 | self._visualize_result() 37 | 38 | def _visualize_result(self): 39 | before_pcd = o3d.geometry.PointCloud() 40 | before_pcd.points = o3d.utility.Vector3dVector(self.before_pts) 41 | before_pcd.colors = o3d.utility.Vector3dVector([[0, 0.9, 0] for p in self.before_pts]) 42 | 43 | after_pcd = o3d.geometry.PointCloud() 44 | after_pcd.points = o3d.utility.Vector3dVector([p for p in self.after_cts]) 45 | after_pcd.colors = o3d.utility.Vector3dVector([[0, 0, 0.9] for p in self.after_cts]) 46 | 47 | o3d.visualization.draw_geometries([before_pcd, after_pcd]) 48 | 49 | 50 | class CodeTimer: 51 | def __init__(self, desc=None): 52 | self.t_start = 0 53 | self.t_end = 0 54 | self.desc = desc 55 | 56 | def __enter__(self): 57 | self.t_start = timeit.default_timer() 58 | 59 | def __exit__(self, exc_type, exc_val, exc_tb): 60 | self.t_end = timeit.default_timer() 61 | desc = self.desc 62 | if desc is None: 63 | desc = "Time: " 64 | 65 | print(desc, self.t_end - self.t_start) 66 | -------------------------------------------------------------------------------- /skeleton/fit/ellipse.py: -------------------------------------------------------------------------------- 1 | from numpy.linalg import eig, inv, svd 2 | from math import atan2 3 | import numpy as np 4 | 5 | 6 | def __fit_ellipse(x, y): 7 | x, y = x[:, np.newaxis], y[:, np.newaxis] 8 | D = np.hstack((x * x, x * y, y * y, x, y, np.ones_like(x))) 9 | S, C = np.dot(D.T, D), np.zeros([6, 6]) 10 | C[0, 2], C[2, 0], C[1, 1] = 2, 2, -1 11 | U, s, V = svd(np.dot(inv(S), C)) 12 | a = U[:, 0] 13 | return a 14 | 15 | 16 | def ellipse_center(a): 17 | b, c, d, f, g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0] 18 | num = b * b - a * c 19 | x0 = (c * d - b * f) / num 20 | y0 = (a * f - b * d) / num 21 | return np.array([x0, y0]) 22 | 23 | 24 | def ellipse_axis_length(a): 25 | b, c, d, f, g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0] 26 | up = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g) 27 | down1 = (b * b - a * c) * ( 28 | (c - a) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a) 29 | ) 30 | down2 = (b * b - a * c) * ( 31 | (a - c) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a) 32 | ) 33 | res1 = np.sqrt(up / down1) 34 | res2 = np.sqrt(up / down2) 35 | return np.array([res1, res2]) 36 | 37 | 38 | def ellipse_angle_of_rotation(a): 39 | b, c, d, f, g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0] 40 | return atan2(2 * b, (a - c)) / 2 41 | 42 | 43 | def fit_ellipse(x, y): 44 | """@brief fit an ellipse to supplied data points: the 5 params 45 | returned are: 46 | M - major axis length 47 | m - minor axis length 48 | cx - ellipse centre (x coord.) 49 | cy - ellipse centre (y coord.) 50 | phi - rotation angle of ellipse bounding box 51 | @param x first coordinate of points to fit (array) 52 | @param y second coord. of points to fit (array) 53 | """ 54 | a = __fit_ellipse(x, y) 55 | centre = ellipse_center(a) 56 | phi = ellipse_angle_of_rotation(a) 57 | M, m = ellipse_axis_length(a) 58 | # assert that the major axix M > minor axis m 59 | if m > M: 60 | M, m = m, M 61 | # ensure the angle is betwen 0 and 2*pi 62 | phi -= 2 * np.pi * int(phi / (2 * np.pi)) 63 | return [M, m, centre[0], centre[1], phi] 64 | -------------------------------------------------------------------------------- /skeleton/params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | from typing import Final 5 | 6 | SMALL_THRESHOLD: Final[float] = 1e-20 7 | 8 | 9 | def get_density_weights(points, hd, for_center=False, center=None): 10 | """ 11 | INPUTS: 12 | x: 1x3 center we of interest, np.ndarray 13 | points: Nx3 array of all the points, np.ndarray 14 | h: size of local neighboorhood, float 15 | RETURNS: 16 | - np.array Nx1 of density weights assoiscated to each point 17 | """ 18 | if center is None: 19 | center = [0, 0, 0] 20 | 21 | density_weights = [] 22 | 23 | if for_center: 24 | r = points - center 25 | r2 = np.einsum('ij,ij->i', r, r) 26 | density_weights = 1 + np.einsum('i->', np.exp((-r2) / ((hd / 2) ** 2))) 27 | else: 28 | for point in points: 29 | r = point - points 30 | r2 = np.einsum('ij,ij->i', r, r) 31 | r2 = r2[r2 > SMALL_THRESHOLD] 32 | 33 | density_weight = 1 + np.einsum('i->', np.exp((-r2) / ((hd / 2.0) ** 2))) 34 | density_weights.append(density_weight) 35 | 36 | return np.array(density_weights) 37 | 38 | 39 | def get_term1(center: np.ndarray, points: np.ndarray, h: float, density_weights: np.ndarray): 40 | """ 41 | :param center: 1x3 center we of interest, np.ndarray 42 | :param points: Nx3 array of all the points, np.ndarray 43 | :param h: size of local neighborhood, float 44 | :param density_weights: 45 | :return: term1 of the equation as float 46 | """ 47 | 48 | r = points - center 49 | r2 = np.einsum('ij,ij->i', r, r) 50 | 51 | thetas = np.exp(-r2 / ((h / 2) ** 2)) 52 | 53 | r2[r2 <= SMALL_THRESHOLD] = 1 54 | alphas = thetas # / np.sqrt(r2) 55 | alphas /= density_weights 56 | 57 | denom = np.einsum('i->', alphas) 58 | if denom > SMALL_THRESHOLD: 59 | # term1 = np.sum((points.T*alphas).T, axis = 0)/denom 60 | term1 = np.einsum('j,jk->k', alphas, points) / denom 61 | else: 62 | term1 = np.array(False) 63 | 64 | return term1 65 | 66 | 67 | def get_term2(center: np.ndarray, centers: np.ndarray, h: float): 68 | """ 69 | :param center: 1x3 center we of interest, np.ndarray 70 | :param centers: Nx3 array of all the centers (excluding the current center), np.ndarray 71 | :param h: size of local neighborhood, float 72 | :return: term2 of the equation as float 73 | """ 74 | 75 | x = center - centers 76 | 77 | r2 = np.einsum('ij,ij->i', x, x) 78 | 79 | indexes = r2 > SMALL_THRESHOLD 80 | r2 = r2[indexes] 81 | x = x[indexes] 82 | 83 | thetas = np.exp((-r2) / ((h / 2) ** 2)) 84 | 85 | betas = thetas / np.sqrt(r2) 86 | 87 | denom = np.einsum('i->', betas) 88 | 89 | if denom > SMALL_THRESHOLD: 90 | num = np.einsum('j,jk->k', betas, x) 91 | term2 = num / denom 92 | else: 93 | term2 = np.array(False) 94 | 95 | return term2 96 | 97 | 98 | def get_sigma(center, centers, local_sigmas, h, k=5): 99 | # These are the weights 100 | r = centers - center 101 | r2 = np.einsum('ij,ij->i', r, r) 102 | 103 | indexes = r2 > SMALL_THRESHOLD 104 | r = r[indexes] 105 | r2 = r2[indexes] 106 | 107 | thetas = np.exp((-r2) / ((h / 2) ** 2)) 108 | 109 | cov = np.einsum('j,jk,jl->kl', thetas, r, r) 110 | 111 | # Get eigenvalues and eigenvectors 112 | values, vectors = np.linalg.eig(cov) 113 | 114 | if np.iscomplex(values).any(): 115 | values = np.real(values) 116 | 117 | vectors = np.real(vectors) 118 | vectors_norm = np.sqrt(np.einsum('ij,ij->i', vectors, vectors)) 119 | vectors = vectors / vectors_norm 120 | 121 | # argsort always works from low --> to high so taking the negative values will give us high --> low indices 122 | sorted_indices = np.argsort(-values) 123 | 124 | sigma = np.max(values) / np.sum(values) 125 | 126 | if k > 0: 127 | knn = np.argsort(r2) 128 | knn_sigmas = np.sum(local_sigmas[knn[1:k]]) 129 | sigma = (sigma + knn_sigmas) / k # smoothing sigma 130 | 131 | vectors_sorted = vectors[:, sorted_indices] 132 | 133 | if not np.isfinite(sigma): 134 | sigma = 0.9 135 | 136 | return sigma, vectors_sorted 137 | -------------------------------------------------------------------------------- /skeleton/recentering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import skeleton.utils as utils 4 | from skeleton.center_type import CenterType 5 | 6 | from skimage.measure import EllipseModel 7 | 8 | from skeleton.fit.ellipse import fit_ellipse 9 | 10 | import open3d as o3d 11 | 12 | 13 | def ellipse_center_skimage(projected): 14 | xy = projected[:, [0, 1]] 15 | 16 | ell = EllipseModel() 17 | if not ell.estimate(xy): 18 | return False, None 19 | 20 | xc, yc, _, _, _ = ell.params 21 | return True, np.array([xc, yc]) 22 | 23 | 24 | def ellipse_center_svd(projected): 25 | x = projected[:, 0] 26 | y = projected[:, 1] 27 | 28 | _, _, xc, yc, _ = fit_ellipse(x, y) 29 | return True, np.array([xc, yc]) 30 | 31 | 32 | def ellipse_center(projected, algorithm='svd'): 33 | if algorithm == 'svd': 34 | return ellipse_center_svd(projected) 35 | else: 36 | return ellipse_center_skimage(projected) 37 | 38 | 39 | def ellipse_center(projected): 40 | xy = projected[:, [0, 1]] 41 | 42 | ell = EllipseModel() 43 | if not ell.estimate(xy): 44 | return False, None 45 | 46 | xc, yc, _, _, _ = ell.params 47 | return True, np.array([xc, yc]) 48 | 49 | 50 | def visualize_result(projected, neighbors, p): 51 | prj = o3d.geometry.PointCloud() 52 | prj.points = o3d.utility.Vector3dVector(projected) 53 | prj.colors = o3d.utility.Vector3dVector([[0, 0.9, 0] for p in projected]) 54 | 55 | original = o3d.geometry.PointCloud() 56 | original.points = o3d.utility.Vector3dVector([p for p in neighbors]) 57 | original.colors = o3d.utility.Vector3dVector([[0, 0, 0.9] for p in neighbors]) 58 | 59 | cloud = o3d.geometry.PointCloud() 60 | cts = [p] 61 | cloud.points = o3d.utility.Vector3dVector(cts) 62 | cloud.colors = o3d.utility.Vector3dVector([[0.9, 0.0, 0.0] for _ in cts]) 63 | 64 | o3d.visualization.draw_geometries([prj, original, cloud]) 65 | 66 | 67 | def recenter_around(center, neighbors, max_dist_move): 68 | normal = center.normal_vector() 69 | # normal = utils.unit_vector(normal) 70 | 71 | if np.allclose(normal, np.zeros_like(normal)): 72 | return center 73 | 74 | if not np.isfinite(normal).all(): 75 | return center 76 | 77 | p = center.center.copy() 78 | 79 | projected = np.array( 80 | [utils.project_one_point(q, p, normal) for q in neighbors if np.isfinite(q).all()]) 81 | 82 | # visualize_result(projected, neighbors, p) 83 | 84 | success, cp = ellipse_center(projected) 85 | if not success: 86 | # center.set_label(CenterType.REMOVED) 87 | return center 88 | 89 | nxy = normal[[0, 1]] 90 | diff = p[[0, 1]] - cp 91 | pz = -np.dot(diff, nxy) / normal[2] + p[2] 92 | 93 | cp = np.append(cp, pz) 94 | 95 | move = cp - center.center 96 | l_move = np.linalg.norm(move) 97 | if l_move > max_dist_move: 98 | return center 99 | 100 | center.center = cp 101 | return center 102 | -------------------------------------------------------------------------------- /skeleton/skeletonization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import random 4 | import sys 5 | import time 6 | 7 | import skeleton.center as sct 8 | 9 | from skeleton.center_type import CenterType 10 | 11 | from skeleton.params import get_density_weights 12 | from skeleton.utils import get_local_points 13 | 14 | from skeleton.debug import SkeletonBeforeAfterVisualizer 15 | from typing import Final 16 | 17 | 18 | def skeletonize(points, n_centers=1000, 19 | max_iterations=50, 20 | dh=2.0, 21 | sigma_smoothing_k=5, 22 | error_tolerance=1e-5, 23 | downsampling_rate=0.5, 24 | try_make_skeleton=True, 25 | recenter_knn=200, 26 | max_points=None): 27 | assert len(points) > n_centers 28 | assert len(points) > recenter_knn 29 | 30 | if max_points is not None and len(points) > max_points: 31 | print("Down sampling the original point cloud") 32 | random_indices = random.sample(range(0, len(points)), max_points) 33 | points = points[random_indices, :] 34 | 35 | # random.seed(int(time.time())) 36 | random.seed(3074) 37 | 38 | skl_centers = sct.Centers(points=points, center_count=n_centers, smoothing_k=sigma_smoothing_k) 39 | 40 | # for i in range(len(skl_centers.myCenters)): 41 | # skl_centers.myCenters[i].set_label(CenterType.BRANCH) 42 | # return skl_centers 43 | 44 | h = h0 = skl_centers.get_h0() 45 | print("h0:", h0) 46 | 47 | hd: Final[float] = h0 / 2 48 | density_weights = get_density_weights(points, hd) 49 | 50 | print("Max iterations: {}, Number points: {}, Number centers: {}".format(max_iterations, len(points), 51 | len(skl_centers.centers))) 52 | 53 | last_non_branch = len(skl_centers.centers) 54 | non_change_iters = 0 55 | for i in range(max_iterations): 56 | bridge_points = len([1 for c in skl_centers.myCenters if c.label == CenterType.BRIDGE]) 57 | non_branch_points = len([1 for c in skl_centers.myCenters if c.label == CenterType.NON_BRANCH]) 58 | 59 | print("\n\nIteration:{}, h:{}, bridge_points:{}\n\n".format(i, round(h, 3), bridge_points)) 60 | 61 | last_error = 0 62 | for j in range(50): # magic number. do contracting at most 30 times 63 | error = skl_centers.contract(h, density_weights) 64 | skl_centers.update_properties() 65 | 66 | if np.abs(error - last_error) < error_tolerance: 67 | break 68 | 69 | last_error = error 70 | 71 | if try_make_skeleton: 72 | skl_centers.find_connections() 73 | 74 | print("Non-branch:", non_branch_points) 75 | 76 | if non_branch_points == last_non_branch: 77 | non_change_iters += 1 78 | elif non_branch_points < last_non_branch: 79 | non_change_iters = 0 80 | 81 | if non_change_iters >= 5: 82 | print("Cannot make more branch points") 83 | break 84 | 85 | if non_branch_points == 0: 86 | print("Found whole skeleton!") 87 | break 88 | 89 | last_non_branch = non_branch_points 90 | 91 | h = h + h0 / dh 92 | 93 | with SkeletonBeforeAfterVisualizer(skl_centers, enable=True): 94 | if recenter_knn > 0: 95 | skl_centers.recenter(downsampling_rate=downsampling_rate, knn=recenter_knn) 96 | 97 | return skl_centers 98 | -------------------------------------------------------------------------------- /skeleton/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | from skspatial.objects import Plane, Point 5 | 6 | 7 | def unit_vector(vector): 8 | return vector / np.linalg.norm(vector) 9 | 10 | 11 | def get_local_points_fast(points, centers, h, max_local_points=50000): 12 | # Get local_points points around this center point 13 | local_indices = [] 14 | for center in centers: 15 | x, y, z = center 16 | 17 | # 1) first get the square around the center 18 | where_square = ((points[:, 0] >= (x - h)) & (points[:, 0] <= (x + h)) & (points[:, 1] >= (y - h)) & 19 | (points[:, 1] <= (y + h)) & (points[:, 2] >= (z - h)) & (points[:, 2] <= (z + h))) 20 | 21 | square = points[where_square] 22 | indices_square = np.where(where_square == True)[0] 23 | 24 | # Get points which comply to x^2, y^2, z^2 <= r^2 25 | square_squared = np.square(square - center) 26 | where_sphere = np.sum(square_squared, axis=1) <= h ** 2 27 | local_sphere_indices = indices_square[where_sphere] 28 | 29 | local_indices.append(local_sphere_indices) 30 | 31 | if len(local_indices) > max_local_points: 32 | return random.sample(local_indices, max_local_points) 33 | 34 | return local_indices 35 | 36 | 37 | def get_local_points(kdt, centers, h, max_local_points=50000): 38 | # Get local_points points around this center point 39 | local_indices = [] 40 | for center in centers: 41 | k, idx, _ = kdt.search_radius_vector_3d(center, radius=h) 42 | 43 | indices = list(idx[1:]) 44 | if len(indices) > max_local_points: 45 | return random.sample(indices, max_local_points) 46 | 47 | local_indices.append(indices) 48 | 49 | if len(local_indices) > max_local_points: 50 | return random.sample(local_indices, max_local_points) 51 | return local_indices 52 | 53 | 54 | def project_one_point(q, p, n): 55 | """ 56 | :param q: a point 57 | :param p: the point on the plane 58 | :param n: the normal vector of the plane 59 | :return: the projected point 60 | """ 61 | plane = Plane(point=p, normal=n) 62 | pt = Point(q) 63 | return plane.project_point(pt) 64 | 65 | 66 | def plane_dist(q, p, n): 67 | """ 68 | :param q: a point 69 | :param p: a point on plane 70 | :param n: the normal vector of the plane 71 | :return: 72 | """ 73 | pq = q - p 74 | dot = np.dot(pq, n) 75 | ret = dot / np.linalg.norm(n) 76 | if not np.isfinite([ret]).all(): 77 | return 0x7fffffff 78 | return ret 79 | --------------------------------------------------------------------------------