├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── l1skeleton_py.iml
├── misc.xml
├── modules.xml
├── other.xml
└── vcs.xml
├── LICENSE
├── README.md
├── main.py
├── renderer.py
└── skeleton
├── center.py
├── center_type.py
├── debug.py
├── fit
└── ellipse.py
├── params.py
├── recentering.py
├── skeletonization.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.ply
2 | *.npy
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/l1skeleton_py.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022 SmartPolarBear
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # l1skeleton_py
2 |
3 | A python implementation for the paper [L1-Medial Skeleton of Point Cloud](https://vcc.tech/research/2013/L1skeleton)
4 | mostly based on [MarcSchotman/skeletons-from-poincloud](https://github.com/MarcSchotman/skeletons-from-poincloud).
5 |
6 | All features including an experimental recentering feature are implemented.
7 |
8 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import open3d as o3d
5 |
6 | from skeleton.skeletonization import skeletonize
7 |
8 | if __name__ == "__main__":
9 | points = np.load("data/default_original.npy")
10 | # points = np.load("data/simple_tree.npy")
11 |
12 | # pcd = o3d.io.read_point_cloud("data/4_Mimosa.ply", format='ply')
13 | # points = np.asarray(pcd.points)
14 |
15 | # mimosa: dh=8
16 | myCenters = skeletonize(points, n_centers=1000, downsampling_rate=1, dh=2.0, recenter_knn=200)
17 |
18 | if len(points) > 5000:
19 | random_indices = random.sample(range(0, len(points)), 5000)
20 | points = points[random_indices, :]
21 |
22 | original = o3d.geometry.PointCloud()
23 | original.points = o3d.utility.Vector3dVector(points)
24 | original.colors = o3d.utility.Vector3dVector([[0, 0.9, 0] for p in points])
25 |
26 | all_centers = o3d.geometry.PointCloud()
27 | cts = myCenters.get_all_centers(exclude=[])
28 | all_centers.points = o3d.utility.Vector3dVector(cts)
29 | all_centers.normals = o3d.utility.Vector3dVector([c.normal_vector() for c in myCenters.myCenters if c.label != 4])
30 | all_centers.colors = o3d.utility.Vector3dVector([[0.0, 0.0, 0.9] for p in cts])
31 |
32 | skeleton = o3d.geometry.PointCloud()
33 | cts = myCenters.get_skeleton_points()
34 | skeleton.points = o3d.utility.Vector3dVector(cts)
35 | skeleton.colors = o3d.utility.Vector3dVector([[0.9, 0.0, 0.0] for p in cts])
36 |
37 | # o3d.visualization.draw_geometries([original, all_centers, skeleton], point_show_normal=True)
38 | o3d.visualization.draw_geometries([original, skeleton])
39 | # o3d.visualization.draw_geometries([original, all_centers])
40 |
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SmartPolarBear/l1skeleton_py/d242cf649cb4907636564655760c7312643f2efd/renderer.py
--------------------------------------------------------------------------------
/skeleton/center.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from scipy.spatial import distance
5 | import time
6 |
7 | from skeleton.params import get_term1, get_sigma, get_term2
8 | from skeleton.utils import unit_vector, plane_dist, get_local_points, get_local_points_fast
9 |
10 | from skeleton.recentering import recenter_around
11 | import open3d as o3d
12 |
13 | from skeleton.center_type import CenterType
14 |
15 | from typing import Iterable, Final
16 |
17 |
18 | class Center:
19 | def __init__(self, center, h, index, sigma=0.5):
20 | self.center = center
21 | self.h = h
22 | self.label = CenterType.NON_BRANCH
23 | self.index = index
24 | self.connections = []
25 | self.bridge_connections = None
26 | self.closest_neighbours = np.array([])
27 | self.head_tail = False
28 | self.branch_number = None
29 | self.eigen_vectors = np.zeros((3, 3))
30 | self.sigma = sigma
31 |
32 | def set_non_branch(self):
33 | if self.label != CenterType.BRANCH and self.label != CenterType.REMOVED:
34 | self.set_label(CenterType.NON_BRANCH)
35 | self.connections = []
36 | self.bridge_connections = None
37 | self.head_tail = False
38 | self.branch_number = None
39 |
40 | def set_as_bridge_point(self, key, connection):
41 | if self.label != CenterType.REMOVED:
42 | self.set_non_branch()
43 | self.set_label(CenterType.BRIDGE)
44 | self.bridge_connections = connection
45 | self.branch_number = key
46 |
47 | def set_as_branch_point(self, key):
48 | self.connections = []
49 | self.bridge_connections = None
50 | self.head_tail = False
51 | self.branch_number = None
52 | self.label = CenterType.BRANCH
53 | self.branch_number = key
54 |
55 | def set_eigen_vectors(self, eigen_vectors):
56 | if self.label == CenterType.NON_BRANCH:
57 | self.eigen_vectors = eigen_vectors
58 |
59 | def set_sigma(self, sigma):
60 | if self.label != CenterType.BRANCH:
61 | self.sigma = sigma
62 |
63 | def set_closest_neighbours(self, closest_neighbours):
64 | self.closest_neighbours = closest_neighbours
65 |
66 | def set_label(self, label: CenterType):
67 | if self.label != CenterType.REMOVED:
68 | self.label = label
69 |
70 | def set_center(self, center):
71 | if self.label != CenterType.BRANCH:
72 | self.center = center
73 |
74 | def set_h(self, h):
75 | if self.label != CenterType.BRANCH:
76 | self.h = h
77 |
78 | def normal_vector(self):
79 | return self.eigen_vectors[:, 0]
80 |
81 |
82 | class Centers:
83 |
84 | def set_my_non_branch_centers(self):
85 |
86 | my_non_branch_centers = []
87 |
88 | for center in self.myCenters:
89 | if center.label == CenterType.NON_BRANCH or center.label == CenterType.BRIDGE:
90 | my_non_branch_centers.append(center)
91 | self.my_non_branch_centers = my_non_branch_centers
92 |
93 | def get_nearest_neighbours(self):
94 |
95 | distances = distance.squareform(distance.pdist(self.centers))
96 | self.closest = np.argsort(distances, axis=1)
97 |
98 | for center in self.myCenters:
99 | # center.set_closest_neighbours(self.closest[center.index,1:])
100 | closest = self.closest[center.index, :].copy()
101 | sorted_local_distances = distances[center.index, closest] ** 2
102 |
103 | # Returns zero if ALL values are within the range
104 | in_neighboorhood = np.argmax(sorted_local_distances >= center.h ** 2)
105 | if in_neighboorhood == 0:
106 | in_neighboorhood = -1
107 |
108 | center.set_closest_neighbours(closest[1:in_neighboorhood])
109 |
110 | def __init__(self, points, center_count=2000, smoothing_k=5):
111 | self.smoothing_k = smoothing_k
112 | self.points = points
113 |
114 | self.pcd = o3d.geometry.PointCloud()
115 | self.pcd.points = o3d.utility.Vector3dVector(self.points)
116 | self.pcd.paint_uniform_color([0.5, 0.5, 0.5])
117 |
118 | self.kdt = o3d.geometry.KDTreeFlann(self.pcd)
119 |
120 | self.h = self.h0 = self._compute_h0()
121 |
122 | centers = self._sample_centers(count=center_count)
123 |
124 | # Making sure centers are never the same as the actual points which can lead to bad things
125 | self.centers = centers + 10 ** -20
126 |
127 | self.myCenters = []
128 | self.my_non_branch_centers = []
129 | index = 0
130 | for center in centers:
131 | self.myCenters.append(Center(center, self.h0, index))
132 | index += 1
133 | self.skeleton = {}
134 | self.closest = []
135 | self.sigmas = np.array([None] * len(centers))
136 | self.eigen_vectors = [None] * len(centers)
137 | self.branch_points = [None] * len(centers)
138 | self.non_branch_points = [None] * len(centers)
139 | self.get_nearest_neighbours()
140 | self.set_my_non_branch_centers()
141 | self.Nremoved = 0
142 |
143 | # From the official code
144 | self.search_distance = .4
145 | self.too_close_threshold = 0.01
146 | self.allowed_branch_length = 5
147 |
148 | self._initialize_sigmas()
149 |
150 | def _initialize_sigmas(self):
151 | for idx, myCenter in enumerate(self.myCenters):
152 | # Get the closest 50 centers to do calculations with
153 | centers_indices = myCenter.closest_neighbours
154 |
155 | # local centers and local sigmas for sigma smoothing
156 | centers_in = np.array(self.centers[centers_indices])
157 | sigmas_in = np.array([self.myCenters[i].sigma for i in centers_indices])
158 |
159 | sigma, vecs = get_sigma(myCenter.center, centers_in, sigmas_in, self.h, k=-1)
160 |
161 | myCenter.set_eigen_vectors(vecs)
162 | myCenter.set_sigma(sigma)
163 |
164 | print("Set initial sigmas.")
165 |
166 | def _compute_h0(self):
167 |
168 | bbox = self.pcd.get_oriented_bounding_box(robust=False)
169 |
170 | # bbox_pcd = o3d.geometry.PointCloud()
171 | # bbox_pcd.points = o3d.utility.Vector3dVector(
172 | # np.asarray(self.pcd.get_axis_aligned_bounding_box().get_box_points()))
173 | # bbox_pcd.colors = o3d.utility.Vector3dVector([[0, 0, 0.9]] * 8)
174 | # o3d.visualization.draw_geometries([self.pcd, bbox_pcd])
175 |
176 | points = np.asarray(bbox.get_box_points())
177 | print("Bounding box:", points)
178 | dists = distance.squareform(distance.pdist(points))
179 | diag = np.max(dists)
180 | return (2.0 * diag) / (len(self.points) ** (1.0 / 3.0))
181 |
182 | def _sample_centers(self, count):
183 | assert self.h0
184 |
185 | # gradually decrease voxel size to get the closest voxel count to the given count
186 | voxel_size = self.h0
187 |
188 | center_pcd = self.pcd.voxel_down_sample(voxel_size)
189 | while len(center_pcd.points) < count:
190 | voxel_size /= 2.0
191 | center_pcd = self.pcd.voxel_down_sample(voxel_size)
192 |
193 | print("Voxel down sampling to {}".format(len(center_pcd.points)))
194 | centers = np.asarray(center_pcd.points)
195 |
196 | # centers = np.asarray(self.pcd.points)
197 |
198 | if len(centers) > count:
199 | random_centers = random.sample(range(0, len(centers)), count)
200 | centers = centers[random_centers, :]
201 |
202 | return centers
203 |
204 | def get_h0(self):
205 | return self.h0
206 |
207 | def get_h(self):
208 | return self.h
209 |
210 | def get_all_centers(self, copy: bool = False, exclude=None) -> Iterable[np.ndarray]:
211 | if exclude is None:
212 | exclude = [CenterType.REMOVED]
213 |
214 | if copy:
215 | return [c.center.copy() for c in self.myCenters if c.label not in exclude]
216 | else:
217 | return [c.center for c in self.myCenters if c.label not in exclude]
218 |
219 | def get_skeleton_points(self, copy: bool = False) -> Iterable[np.ndarray]:
220 | ret = []
221 | for key in self.skeleton:
222 | if copy:
223 | ret += [self.myCenters[k].center.copy() for k in self.skeleton[key]['branch'] if
224 | self.myCenters[k].label != CenterType.REMOVED]
225 | else:
226 | ret += [self.myCenters[k].center for k in self.skeleton[key]['branch'] if
227 | self.myCenters[k].label != CenterType.REMOVED]
228 | return ret
229 |
230 | def recenter(self, downsampling_rate: float = 0.5, knn: int = 200) -> None:
231 | # down sample each branch
232 | for key in self.skeleton:
233 | # branch = self.skeleton[key]
234 | # branch_list = branch['branch']
235 | self.skeleton[key]['branch'] = random.sample(self.skeleton[key]['branch'],
236 | k=int(downsampling_rate * len(self.skeleton[key]['branch'])))
237 |
238 | THRESHOLD: Final = self.h0 / 16.0
239 |
240 | enough = 0
241 | zero_normals = 0
242 |
243 | for key in self.skeleton:
244 | # skl_pcd = o3d.geometry.PointCloud()
245 | # skl_pcd.points = o3d.utility.Vector3dVector(
246 | # [self.myCenters[i].center for i in self.skeleton[key]['branch']])
247 | #
248 | # sum_pcd: o3d.geometry.PointCloud = self.pcd + skl_pcd
249 | # sum_pcd.estimate_normals()
250 |
251 | for idx, i in enumerate(self.skeleton[key]['branch']):
252 | p = self.myCenters[i]
253 |
254 | n = p.normal_vector()
255 | # n = sum_pcd.normals[idx + len(self.pcd.points)]
256 |
257 | if np.allclose(n, np.zeros_like(n)):
258 | self.myCenters[i].set_label(CenterType.REMOVED)
259 | zero_normals += 1
260 | continue
261 |
262 | k, idx, _ = self.kdt.search_knn_vector_3d(p.center, knn=knn)
263 | pts = self.points[list(idx[1:])]
264 | # neighbors = pts
265 |
266 | dists = [plane_dist(q, p.center, n) for q in pts]
267 |
268 | neighbors = pts[dists <= THRESHOLD]
269 |
270 | enough += 1
271 |
272 | self.myCenters[i] = recenter_around(p, neighbors, max_dist_move=self.h0)
273 |
274 | print("E/0N", enough, zero_normals)
275 |
276 | def remove_centers(self, indices):
277 | """
278 | Removes a center completely
279 | """
280 | if not isinstance(indices, list):
281 | indices = list([indices])
282 |
283 | for index in sorted(indices, reverse=True):
284 | center = self.myCenters[index]
285 | center.set_label(CenterType.REMOVED)
286 | self.centers[center.index] = [9999, 9999, 9999]
287 |
288 | self.set_my_non_branch_centers()
289 | self.Nremoved += len(indices)
290 |
291 | def get_non_branch_points(self):
292 |
293 | non_branch_points = []
294 | for center in self.myCenters:
295 | if center.label != CenterType.BRANCH and center.label != CenterType.REMOVED:
296 | non_branch_points.append(center.index)
297 |
298 | return non_branch_points
299 |
300 | def get_bridge_points(self):
301 |
302 | bridge_points = []
303 | for key in self.skeleton:
304 | head = self.skeleton[key]['head_bridge_connection']
305 | tail = self.skeleton[key]['tail_bridge_connection']
306 |
307 | if head[0] and head[1] is not None:
308 | if not head[1] in bridge_points:
309 | bridge_points.append(head[1])
310 | if tail[0] and tail[1] is not None:
311 | if not tail[1] in bridge_points:
312 | bridge_points.append(tail[1])
313 |
314 | return bridge_points
315 |
316 | def update_sigmas(self):
317 |
318 | k = 5
319 |
320 | new_sigmas = []
321 |
322 | for center in self.my_non_branch_centers:
323 | index = center.index
324 |
325 | indices = np.array(self.closest[index, :k]).astype(int)
326 |
327 | sigma_nearest_k_neighbours = self.sigmas[indices]
328 |
329 | mean_sigma = np.mean(sigma_nearest_k_neighbours)
330 |
331 | new_sigmas.append(mean_sigma)
332 |
333 | index = 0
334 | for center in self.my_non_branch_centers:
335 | center.set_sigma(new_sigmas[index])
336 |
337 | self.sigmas[center.index] = new_sigmas[index]
338 | index += 1
339 |
340 | def update_properties(self):
341 |
342 | self.set_my_non_branch_centers()
343 |
344 | for center in self.myCenters:
345 | index = center.index
346 | self.centers[index] = center.center
347 | self.eigen_vectors[index] = center.eigen_vectors
348 | self.sigmas[index] = center.sigma
349 |
350 | self.get_nearest_neighbours()
351 | self.update_sigmas()
352 |
353 | def update_labels_connections(self):
354 | """
355 | Update all the labels of all the centers
356 | 1) goes through all the branches and checks if the head has a bridge connection or a branch connection
357 | - If bridge connection this is still the head/tail of the branch
358 | - If it has a branch connection it is simply connected to another branch --> It is no head/tail anymore
359 | 2) Checks if bridges are still bridges
360 | 3) Sets all other points to simple non_branch_points
361 | """
362 |
363 | updated_centers = []
364 | for key in self.skeleton:
365 |
366 | branch = self.skeleton[key]
367 |
368 | head = self.myCenters[branch['branch'][0]]
369 | tail = self.myCenters[branch['branch'][-1]]
370 |
371 | # This is either a None value (for not having found a bridge point / connected branch)
372 | # or this is an integer index
373 | head_connection = branch['head_bridge_connection'][1]
374 | tail_connection = branch['tail_bridge_connection'][1]
375 |
376 | if head_connection is not None:
377 |
378 | head_connection = self.myCenters[head_connection]
379 |
380 | if branch['head_bridge_connection'][0]:
381 | head_connection.set_as_bridge_point(key, head.index)
382 | head.head_tail = True
383 | else:
384 | head_connection.set_as_branch_point(key)
385 | head.head_tail = False
386 |
387 | head.set_as_branch_point(key)
388 | head.connections = [head_connection.index, branch['branch'][1]]
389 |
390 | updated_centers.append(head_connection.index)
391 | updated_centers.append(head.index)
392 | else:
393 | head.set_as_branch_point(key)
394 | head.head_tail = True
395 |
396 | if tail_connection is not None:
397 |
398 | tail_connection = self.myCenters[tail_connection]
399 |
400 | if branch['tail_bridge_connection'][0]:
401 | tail.head_tail = True
402 | tail_connection.set_as_bridge_point(key, tail.index)
403 | else:
404 | tail.head_tail = False
405 | tail_connection.set_as_branch_point(key)
406 |
407 | tail.set_as_branch_point(key)
408 | tail.connections = [tail_connection.index, branch['branch'][-2]]
409 | updated_centers.append(tail_connection.index)
410 | updated_centers.append(tail.index)
411 | else:
412 | tail.set_as_branch_point(key)
413 | tail.head_tail = True
414 |
415 | # 1) Go through the branch list and set each center t branch_point and set the head_tail value appropriately
416 | # 2) Set the connections
417 | index = 1
418 | for center in branch['branch'][1:-1]: # [1:-1] to remove head and tail
419 | center = self.myCenters[center]
420 |
421 | center.set_as_branch_point(key)
422 |
423 | center.connections.append(branch['branch'][index - 1])
424 | center.connections.append(branch['branch'][index + 1])
425 | center.head_tail = False
426 |
427 | updated_centers.append(center.index)
428 | index += 1
429 |
430 | for center in self.myCenters:
431 |
432 | if center.index in updated_centers:
433 | continue
434 | center.set_non_branch()
435 |
436 | for key in self.skeleton:
437 | branch = self.skeleton[key]
438 |
439 | for index in branch['branch']:
440 | if branch['branch'].count(index) > 1:
441 | del branch['branch'][branch['branch'].index(index)]
442 | # print("ERROR: branch {} has multiple counts of index {}...".format(branch['branch'], index))
443 | break
444 |
445 | def contract(self, h, density_weights, mu=0.35):
446 | """
447 | Updates the centers by the algorithm suggested in "L1-medial skeleton of Point Cloud 2010"
448 |
449 | INPUT:
450 | - Centers
451 | - points belonging to centers
452 | - local neighbourhood h0
453 | - mu factor for force between centers (preventing them from clustering)
454 | OUTPUT:
455 | - New centers
456 | - Sigmas (indicator for the strength of dominant direction)
457 | - The eigenvectors of the points belonging to the centers
458 | """
459 | self.h = h
460 |
461 | error_center = 0
462 |
463 | too_small_sigma = 0;
464 |
465 | N = 0
466 |
467 | # local_indices = get_local_points(self.kdt, centers=self.centers, h=h)
468 | local_indices = get_local_points_fast(self.points, centers=self.centers, h=h)
469 |
470 | # center_pcd = o3d.geometry.PointCloud()
471 | # center_pcd.points = o3d.utility.Vector3dVector([c.center for c in self.myCenters])
472 | #
473 | # sum_pcd: o3d.geometry.PointCloud = self.pcd + center_pcd
474 | # sum_pcd.estimate_normals()
475 | # base_idx = len(self.pcd.points)
476 |
477 | for idx, myCenter in enumerate(self.myCenters):
478 |
479 | # Get the closest 50 centers to do calculations with
480 | centers_indices = myCenter.closest_neighbours
481 |
482 | # local centers and local sigmas for sigma smoothing
483 | centers_in = np.array(self.centers[centers_indices])
484 | sigmas_in = np.array([self.myCenters[i].sigma for i in centers_indices])
485 |
486 | my_local_indices = local_indices[myCenter.index]
487 | local_points = self.points[my_local_indices]
488 |
489 | # Check if we have enough points and centers
490 | shape = local_points.shape
491 | if len(shape) == 1:
492 | continue
493 | elif shape[0] > 2 and len(centers_in) > 1:
494 |
495 | density_weights_points = density_weights[my_local_indices]
496 |
497 | term1 = get_term1(myCenter.center, local_points, h, density_weights_points)
498 | term2 = get_term2(myCenter.center, centers_in, h)
499 |
500 | if term1.any() and term2.any():
501 | sigma, vecs = get_sigma(myCenter.center, centers_in, sigmas_in, h, k=self.smoothing_k)
502 |
503 | # sigma = np.clip(sigma, 0 ,1.)
504 |
505 | # DIFFERS FROM PAPER
506 | # mu = mu_length/sigma_length * (sigma - min_sigma)
507 | # if mu < 0:
508 | # continue
509 |
510 | # mu_average +=mu
511 |
512 | new_center = term1 + mu * sigma * term2
513 |
514 | error_center += np.linalg.norm(myCenter.center - new_center)
515 | N += 1
516 |
517 | # Update this center object
518 | myCenter.set_center(new_center)
519 | myCenter.set_eigen_vectors(vecs)
520 | myCenter.set_sigma(sigma)
521 | myCenter.set_h(h)
522 |
523 | if N == 0:
524 | return 0
525 |
526 | return error_center / N
527 |
528 | def bridge_2_branch(self, bridge_point, requesting_branch_number):
529 | """
530 | Change a bridge to a branch.
531 |
532 | 1) finds a branch with this bridge_point
533 | 2) changes the boolean indicating bridge/branch to False
534 | 3) Changes the head/tail label of the head/tail of this branch
535 | 4) When the whole skeleton is checked it changes the bridge_label to branch_label
536 | """
537 |
538 | for key in self.skeleton:
539 | head_bridge_connection = self.skeleton[key]['head_bridge_connection']
540 | tail_bridge_connection = self.skeleton[key]['tail_bridge_connection']
541 |
542 | # 1)
543 | if bridge_point == head_bridge_connection[1]:
544 | # 2)
545 | self.skeleton[key]['head_bridge_connection'][0] = False
546 | # 3)
547 | head = self.skeleton[key]['branch'][0]
548 | self.myCenters[head].head_tail = False
549 |
550 | if bridge_point == tail_bridge_connection[1]:
551 | self.skeleton[key]['tail_bridge_connection'][0] = False
552 | tail = self.skeleton[key]['branch'][-1]
553 | self.myCenters[tail].head_tail = False
554 |
555 | # 4)
556 | self.myCenters[bridge_point].set_as_branch_point(requesting_branch_number)
557 |
558 | def find_bridge_point(self, index, connection_vector):
559 | """
560 | Finds the bridging points of a branch
561 | These briding points are used to couple different branches at places where we have conjunctions
562 | INPUT:v
563 | - Index of the tail/head of the branch
564 | - the vector connecting this head/tail point to the branch
565 | OUTPUT:
566 | - If bridge_point found:
567 | index of bridge_point
568 | else:
569 | none
570 | ACTIONS:
571 | 1) find points in the neighboorhood of this point
572 | 2) Check if they are non_branching_points (i.e. not already in a branch)
573 | 3) Are they 'close'? We defined close as 5*(distance_to_closest_neighbour)
574 | 5) Angle of line end_of_branch to point and connection_vector < 90?
575 | 6) return branch_point_index
576 | """
577 |
578 | myCenter = self.myCenters[index]
579 |
580 | success = False
581 | bridge_point = None
582 | for neighbour in myCenter.closest_neighbours:
583 |
584 | neighbour = self.myCenters[neighbour]
585 |
586 | if neighbour.label == CenterType.BRANCH or neighbour.label == CenterType.REMOVED:
587 | continue
588 |
589 | # If current neighbour is too far away we break
590 | if sum((neighbour.center - myCenter.center) ** 2) > self.h ** 2:
591 | break
592 |
593 | branch_2_bridge_u = unit_vector(neighbour.center - myCenter.center)
594 | connection_vector_u = unit_vector(connection_vector)
595 |
596 | cos_theta = np.dot(branch_2_bridge_u, connection_vector_u)
597 | # cos_theta >0 --> theta < 90 degrees
598 | if cos_theta > 0:
599 | bridge_point = neighbour.index
600 | success = True
601 | break
602 |
603 | return bridge_point, success
604 |
605 | def connect_bridge_points_in_h(self):
606 | # Connects bridge points which are within the same neighboorhood
607 | for center in self.myCenters:
608 | if center.label != CenterType.BRIDGE:
609 | continue
610 | # Check the local neighbourhood for any other bridge_points
611 | for neighbour in center.closest_neighbours:
612 |
613 | neighbour = self.myCenters[neighbour]
614 |
615 | # Is it a bridge point?
616 | if neighbour.label != CenterType.BRIDGE:
617 | continue
618 |
619 | # Is it still in the local neighbourhood?
620 | if sum((neighbour.center - center.center) ** 2) > (2 * center.h) ** 2:
621 | break
622 |
623 | # If here we have two bridge points in 1 local nneighboorhood:
624 | # So we merge them:
625 | branch1 = center.branch_number
626 | branch2 = neighbour.branch_number
627 |
628 | # Check if we are connected to the head or tail of the branch
629 | if self.skeleton[branch1]['head_bridge_connection'][1] == center.index:
630 | index_branch1_connection = 0
631 | elif self.skeleton[branch1]['tail_bridge_connection'][1] == center.index:
632 | index_branch1_connection = -1
633 | else:
634 | print("fuck!")
635 | continue
636 | # else:
637 | # raise Exception(
638 | # "ERROR in 'merge_bridge_points': COULDNT FIND THE BRIDGE INDEX IN THE BRIDGE_CONNECTIONS OF THE SPECIFIED BRANCH")
639 | if self.skeleton[branch2]['head_bridge_connection'][1] == neighbour.index:
640 | index_branch2_connection = 0
641 | elif self.skeleton[branch2]['tail_bridge_connection'][1] == neighbour.index:
642 | index_branch2_connection = -1
643 | else:
644 | print("fuck!!")
645 | continue
646 |
647 | # else:
648 | # raise Exception(
649 | # "ERROR in 'merge_bridge_points': COULDNT FIND THE BRIDGE INDEX IN THE BRIDGE_CONNECTIONS OF THE SPECIFIED BRANCH")
650 |
651 | # Change the conenctions and boolenas accordingly:
652 | if index_branch1_connection == 0:
653 | # Add the bridge point to the branch
654 | self.skeleton[branch1]['branch'].insert(0, center.index)
655 | # Update the head_conenction such that it does not have any bridge connection anymore
656 | self.skeleton[branch1]['head_bridge_connection'][0] = False
657 | # And connect it to the otehr branch, i.e. the neighboor
658 | self.skeleton[branch1]['head_bridge_connection'][1] = neighbour.index
659 | else:
660 | self.skeleton[branch1]['branch'].extend([center.index])
661 | self.skeleton[branch1]['tail_bridge_connection'][0] = False
662 | self.skeleton[branch1]['tail_bridge_connection'][1] = neighbour.index
663 |
664 | if index_branch2_connection == 0:
665 | self.skeleton[branch2]['branch'].insert(0, neighbour.index)
666 | self.skeleton[branch2]['head_bridge_connection'][0] = False
667 | self.skeleton[branch2]['head_bridge_connection'][1] = center.index
668 | else:
669 | self.skeleton[branch2]['branch'].extend([neighbour.index])
670 | self.skeleton[branch2]['tail_bridge_connection'][0] = False
671 | self.skeleton[branch2]['tail_bridge_connection'][1] = center.index
672 |
673 | # Now they are branch points:
674 | center.set_as_branch_point(branch1)
675 | neighbour.set_as_branch_point(branch2)
676 |
677 | def connect_identical_bridge_points(self):
678 | """
679 | Connects branches which are connected to an identical bridge point
680 | 1) Makes a list with the connection values of all the heads and tails. The value is None if it is connected to another branch
681 | 2) Finds a similar index
682 | 3) Connects these branches
683 | 4) Replaces the value by None in the list and start at (2) again
684 | """
685 |
686 | # 1)
687 | bridge_points = []
688 | for key in self.skeleton:
689 | branch = self.skeleton[key]
690 |
691 | bridges_of_branch = []
692 | if branch['head_bridge_connection'][0]:
693 | bridges_of_branch.append(branch['head_bridge_connection'][1])
694 | else:
695 | bridges_of_branch.append(None)
696 |
697 | if branch['tail_bridge_connection'][0]:
698 | bridges_of_branch.append(branch['tail_bridge_connection'][1])
699 | else:
700 | bridges_of_branch.append(None)
701 |
702 | bridge_points.append(bridges_of_branch)
703 |
704 | bridge_points = np.array(bridge_points, dtype=object)
705 | success = True
706 | while success:
707 | success = False
708 | for points in bridge_points:
709 | bridge_head = points[0]
710 | bridge_tail = points[1]
711 |
712 | # If not None check how man y instances we have of this bridge point
713 | if bridge_head is not None:
714 | # 2)
715 | count_head = len(np.argwhere(bridge_points == bridge_head))
716 | if count_head > 1:
717 | # 3) #If mroe then 1 we get all the indices (row, column wise) where the rows are branch numbers and the columns indicate if its at the head or tail
718 | indices = np.where(bridge_points == bridge_head)
719 | # We choose the first banch as the 'parent' it will adopt this bridge_point as branch point. All other branches with this bridge_point will simply loose it.
720 | branch1 = indices[0][0]
721 | # Set these values to False as after this we do not have a bridge point anymore
722 | self.skeleton[branch1]['head_bridge_connection'][0] = False
723 | self.skeleton[branch1]['head_bridge_connection'][1] = bridge_head
724 | # Sets all branches with this bridge_point to False as well
725 | self.bridge_2_branch(bridge_head, branch1)
726 | # 4) Set all the indices with this bridge_point to None and start over
727 | bridge_points[indices] = None
728 |
729 | success = True
730 | break
731 |
732 | if bridge_tail is not None:
733 | count_tail = len(np.argwhere(bridge_points == bridge_tail))
734 | if count_tail > 1:
735 | indices = np.where(bridge_points == bridge_tail)
736 | branch1 = indices[0][0] # Becomes part of the branch
737 | self.skeleton[branch1]['tail_bridge_connection'][0] = False
738 | self.skeleton[branch1]['tail_bridge_connection'][1] = bridge_tail
739 |
740 | self.bridge_2_branch(bridge_tail, branch1)
741 | bridge_points[indices] = None
742 |
743 | success = True
744 | break
745 |
746 | def merge_bridge_points(self):
747 | """
748 | 1) Connects bridge points which are within the same neighboorhood
749 | 2) Connects branches which are connected to an identical bridge point
750 | """
751 |
752 | # 1)
753 | self.connect_bridge_points_in_h()
754 | # 2)
755 | self.connect_identical_bridge_points()
756 |
757 | def set_bridge_points(self, key, branch):
758 | """
759 | First finds then sets bridge_points of this branch
760 | 1) checks if head/tail is connected to a branch
761 | 2) Checks if we can find a bridge point
762 | 3) If we find bridge, set the old bridge(if we had it) to non_branch_point and set new bridge label to bridge_point and update the branch
763 | """
764 |
765 | # 1)
766 | if branch['head_bridge_connection'][0]:
767 | head = branch['branch'][0]
768 | head_1 = branch['branch'][1]
769 | head_bridge_vector = self.centers[head] - self.centers[head_1]
770 | # 2)
771 | bridge_point, success = self.find_bridge_point(head, head_bridge_vector)
772 | # 3) Update old bridge_point
773 | if success:
774 | old_bridge_point = branch['head_bridge_connection'][1]
775 | if old_bridge_point is not None:
776 | old_bridge_point = self.myCenters[old_bridge_point]
777 | old_bridge_point.set_non_branch()
778 |
779 | branch['head_bridge_connection'][1] = bridge_point
780 | self.myCenters[bridge_point].set_as_bridge_point(key, head)
781 |
782 | if branch['tail_bridge_connection'][0]:
783 | tail = branch['branch'][-1]
784 | tail_1 = branch['branch'][-2]
785 | tail_bridge_vector = self.centers[tail] - self.centers[tail_1]
786 |
787 | bridge_point, success = self.find_bridge_point(tail, tail_bridge_vector)
788 |
789 | if success:
790 | # Update old bridge_point
791 | old_bridge_point = branch['tail_bridge_connection'][1]
792 | if old_bridge_point is not None:
793 | old_bridge_point = self.myCenters[old_bridge_point]
794 | old_bridge_point.set_non_branch()
795 |
796 | branch['tail_bridge_connection'][1] = bridge_point
797 | self.myCenters[bridge_point].set_as_bridge_point(key, tail)
798 |
799 | self.skeleton[key] = branch
800 |
801 | def add_new_branch(self, branch_list):
802 | """
803 | A branch: {'branch': [list of branch points], 'head connection':[Bool denoting if its a bridge/branch point True/False, index of conenction], tail_bridge_connection:[same stuff]}
804 | For each new branch a few checks:
805 | 1) were there bridge points? If so they need to be connected
806 | - If they are and the head / tail of the branch this mean the branch is connected to another branch
807 | 2) Finds the potential bridge points
808 | 3) sets the labels of the centers
809 | 4) adds the branch to the skeleton list of branches
810 | """
811 | head_bridge_connection = [True, None]
812 | tail_bridge_connection = [True, None]
813 |
814 | key = len(self.skeleton) + 1
815 |
816 | # Check for bridge points:
817 | for index in branch_list:
818 |
819 | center = self.myCenters[index]
820 | # Do we have a bridge point?
821 | if center.label != CenterType.BRIDGE:
822 | continue
823 |
824 | # Our head is connected to a bridge point of another branch.
825 | # Thus, our head has NO bridge point,
826 | # and we need to change this in the branch from which this is the bridge_point
827 | if index == branch_list[0]:
828 | head_bridge_connection[0] = False
829 | head_bridge_connection[1] = center.bridge_connections
830 |
831 | # same stuff
832 | elif index == branch_list[-1]:
833 | tail_bridge_connection[0] = False
834 | tail_bridge_connection[1] = center.bridge_connections
835 |
836 | # Now make this bridge_point a branch
837 | self.bridge_2_branch(center.index, requesting_branch_number=key)
838 |
839 | branch = {'branch': branch_list, 'head_bridge_connection': head_bridge_connection,
840 | 'tail_bridge_connection': tail_bridge_connection}
841 |
842 | # Set labels
843 | for index in branch_list:
844 |
845 | self.myCenters[index].set_as_branch_point(key)
846 |
847 | if (index == branch_list[0] and head_bridge_connection[0]) or (
848 | index == branch_list[-1] and tail_bridge_connection[0]):
849 | self.myCenters[index].head_tail = True
850 | else:
851 | self.myCenters[index].head_tail = False
852 |
853 | self.skeleton[key] = branch
854 |
855 | def update_branch(self, key, new_branch):
856 | """
857 | Checks wheter the updated branch contains a bridge point of another branch. If so it updates the label of the bridge point and the branch head/tail connection values
858 | INPUTS:
859 | - key of the branch
860 | - the new branch
861 | """
862 |
863 | # Go through the new_branch list
864 | for index in [new_branch['branch'][0], new_branch['branch'][-1]]:
865 | # Check if this point is a bridge_point not from this branch
866 | center = self.myCenters[index]
867 | # Set the label of this center to branch_point
868 | center.set_as_branch_point(key)
869 |
870 | # Set head/tail label
871 | if index == new_branch['branch'][0] and new_branch['head_bridge_connection'][0]:
872 | center.head_tail = True
873 | elif index == new_branch['branch'][-1] and new_branch['tail_bridge_connection'][0]:
874 | center.head_tail = True
875 | else:
876 | center.head_tail = False
877 |
878 | # Actually update branch
879 | self.skeleton[key] = new_branch
880 |
881 | def find_extension_point(self, center_index, vector):
882 | """
883 | INPUT:
884 | - The neighbours
885 | - The center which this is about
886 | - the connection vector of this center to the skeleton
887 | OUTPUT:
888 | - Boolean indicating connection yes or
889 | - index of connection
890 | ACTIONS:
891 | 1) Check if neighbour is too far
892 | 2) Check if too close
893 | 3) check if in the right direcion
894 | 4) if checks 1,2,3 we check if we meet the requirement. Then we stop
895 | """
896 | myCenter = self.myCenters[center_index]
897 | vector_u = unit_vector(vector)
898 |
899 | connection = False
900 | neighbour = None
901 | for n in myCenter.closest_neighbours:
902 |
903 | neighbour = self.myCenters[n]
904 |
905 | if neighbour.label == CenterType.BRANCH or neighbour.label == CenterType.REMOVED:
906 | continue
907 |
908 | # #Check if inside local neighbourhood
909 | r = neighbour.center - myCenter.center
910 | r2 = np.einsum('i,i->i', r, r)
911 | r2_sum = np.einsum('i->', r2)
912 |
913 | # 1)
914 | if r2_sum > self.search_distance ** 2:
915 | break
916 | # 2)
917 | elif r2_sum <= self.too_close_threshold ** 2:
918 | self.remove_centers(neighbour.index)
919 | continue
920 |
921 | # make unit vector:
922 | center_2_neighbour_u = unit_vector(
923 | neighbour.center - myCenter.center) # From front of skeleton TOWARDS the new direction
924 |
925 | # Check skeleton angle condition
926 | # cos(theta) = dot(u,v)/(norm(u)*norm(v)) <= -0.9
927 | cos_theta = np.dot(center_2_neighbour_u, vector_u)
928 |
929 | # 3)
930 | if cos_theta > 0:
931 | continue
932 |
933 | # 4)
934 | if cos_theta <= -0.9:
935 | connection = True
936 | break
937 |
938 | if neighbour is None:
939 | # FIXME
940 | return False, -1
941 |
942 | return connection, neighbour.index
943 |
944 | def try_extend_branch(self, branch, head_bridge_connection=True, tail_bridge_connection=True):
945 | """
946 | Tries to extend this branch from the head and the tail onwards.
947 |
948 | INPUTS:
949 | - the branch as list of idnices
950 | - head/tail conenction boolean indicating if the head/tail is connected to a bridge point (T) or branch point (F)
951 | OUTPUTS:
952 | - The extended branch ( as far as possible)
953 | - Boolean indicating if a branch was etended in any way
954 | """
955 |
956 | found_connection = False
957 |
958 | # head =! tail --> which mean skeleton is a full circle AND head is nto conencted to another branch
959 | if head_bridge_connection:
960 | # Get index of head connection
961 | head = branch[0]
962 | # Get vector conencted head to the rest of the skeleton
963 | head_bridge_connection_vector = self.centers[branch[1]] - self.centers[head]
964 | if not np.allclose(head_bridge_connection_vector, np.zeros_like(head_bridge_connection_vector)):
965 | # find a possible extensions of this connection
966 | connection, index = self.find_extension_point(head, head_bridge_connection_vector)
967 | # Inserts it
968 | if connection:
969 | if not connection == branch[-1]:
970 | branch.insert(0, index)
971 | found_connection = True
972 |
973 | if tail_bridge_connection:
974 | tail = branch[-1]
975 | tail_bridge_connection_vector = self.centers[branch[-2]] - self.centers[tail]
976 |
977 | if not np.allclose(tail_bridge_connection_vector, np.zeros_like(tail_bridge_connection_vector)):
978 | connection, index = self.find_extension_point(tail, tail_bridge_connection_vector)
979 | if connection:
980 | if not connection == branch[0]:
981 | branch.extend([index])
982 | found_connection = True
983 |
984 | return branch, found_connection
985 |
986 | def try_to_make_new_branch(self, myCenter):
987 | """
988 | Tries to form a new branch
989 | """
990 |
991 | found_branch = False
992 | for neighbour in myCenter.closest_neighbours:
993 | neighbour_center = self.centers[neighbour]
994 | # Check if inside local neighbourhood
995 | if sum((neighbour_center - myCenter.center) ** 2) > (self.search_distance) ** 2:
996 | break
997 |
998 | center_2_neighbour_u = unit_vector(neighbour_center - myCenter.center)
999 |
1000 | # Check if this neighbour is in the direction of the dominant eigen_vector:
1001 | # So is the angle > 155 or < 25
1002 |
1003 | if abs(np.dot(myCenter.eigen_vectors[:, 0], center_2_neighbour_u)) < 0.9:
1004 | continue
1005 |
1006 | branch = [myCenter.index, neighbour]
1007 | found_connection = True
1008 | while found_connection:
1009 | branch, found_connection = self.try_extend_branch(branch)
1010 |
1011 | # We ackknowledge new branches if they exist of 5 or more centers:
1012 | branch_length = self.allowed_branch_length * int(self.h / self.h0)
1013 | if len(branch) > 5:
1014 | # If inside this branch are centers which were bridge points we will connect them up
1015 | self.add_new_branch(branch)
1016 | found_branch = True
1017 | break
1018 |
1019 | return found_branch
1020 |
1021 | def try_extend_skeleton(self):
1022 | """
1023 | Tries to extend each already existing branch from the head and the tail onwards
1024 | - If other bridge points are encountered the branch will stop extending this way and connect to the branch of this aprticular bridge point
1025 | """
1026 |
1027 | for key in self.skeleton:
1028 |
1029 | branch = self.skeleton[key]
1030 | branch_list = branch['branch']
1031 |
1032 | success = True
1033 | had_one_success = False
1034 | # Tries to extend the head and tail by one center
1035 | while success:
1036 |
1037 | # Go through the new_branch list, skip the already known indices
1038 | branch_list, success = self.try_extend_branch(branch_list, branch['head_bridge_connection'][0],
1039 | branch['tail_bridge_connection'][0])
1040 |
1041 | # If newly found tail/head are a bridge point from another branch we stop extending this way
1042 | if success:
1043 | new_head = self.myCenters[branch_list[0]]
1044 | new_tail = self.myCenters[branch_list[-1]]
1045 |
1046 | # If we encounter a bridge from a different branch we are connected to this branch and thus do not have a bridge point anymore and need to adjust that particular branch as well
1047 | if new_head.label == CenterType.BRIDGE and new_head.index != branch['head_bridge_connection'][1]:
1048 | # Update this branch head bridge connection
1049 | branch['head_bridge_connection'][0] = False
1050 | branch['head_bridge_connection'][1] = new_head.bridge_connections
1051 | # Update the branch from which this bridge_point originated
1052 | self.bridge_2_branch(new_head.index, key)
1053 |
1054 | elif new_tail.label == CenterType.BRIDGE and new_tail.index != branch['tail_bridge_connection'][1]:
1055 |
1056 | branch['tail_bridge_connection'][0] = False
1057 | branch['tail_bridge_connection'][1] = new_tail.bridge_connections
1058 | # Update the branch from which this bridge_point originated
1059 | self.bridge_2_branch(new_tail.index, key)
1060 |
1061 | # If we extended the branch we update it.
1062 | if success:
1063 | had_one_success = True
1064 | branch['branch'] = branch_list
1065 | self.update_branch(key, branch)
1066 |
1067 | self.set_bridge_points(key, branch)
1068 | self.merge_bridge_points()
1069 | self.clean_points_around_branch(branch)
1070 |
1071 | def find_connections(self):
1072 | """
1073 | 1) Tries to extend the existing skeleton:
1074 | 2) Tries to find new skeletons
1075 | 3) Merges skeletons if possible
1076 | """
1077 |
1078 | self.try_extend_skeleton()
1079 |
1080 | non_branch_points = np.array(self.get_non_branch_points())
1081 | if non_branch_points.any():
1082 | sigma_candidates = self.sigmas[non_branch_points]
1083 |
1084 | seed_points_to_check = non_branch_points[np.where(sigma_candidates > 0.9)]
1085 |
1086 | bridge_points = self.get_bridge_points()
1087 |
1088 | seed_points_to_check = list(bridge_points) + list(seed_points_to_check)
1089 | print("top 5 sigmas:", sorted(sigma_candidates, reverse=True)[:5])
1090 | for seed_point_index in seed_points_to_check:
1091 |
1092 | myCenter = self.myCenters[seed_point_index]
1093 |
1094 | # old_skeleton = len(self.skeleton)
1095 | if self.try_to_make_new_branch(myCenter):
1096 | new_branch_number = len(self.skeleton)
1097 | new_branch = self.skeleton[new_branch_number]
1098 | self.set_bridge_points(new_branch_number, new_branch)
1099 | self.merge_bridge_points()
1100 | self.clean_points_around_branch(new_branch)
1101 |
1102 | self.update_labels_connections()
1103 | self.clean_points()
1104 |
1105 | def clean_points(self):
1106 | """
1107 | Cleans points which
1108 | 1) have no points in their neighborhood
1109 | 2) Are bridge_points, with no non_branch_points around. Makes them part of the branch
1110 | 3) More than half your neighbours are branch_points
1111 | AFTER the first 2 branches are formed
1112 | """
1113 |
1114 | if len(self.skeleton) > 1:
1115 | remove_centers = []
1116 | for center in self.myCenters:
1117 | if center.label == CenterType.REMOVED or center.label == CenterType.BRANCH:
1118 | continue
1119 |
1120 | # 1) If no neighbours:
1121 | # if not center.closest_neighbours.any():
1122 | if len(center.closest_neighbours) <= 5:
1123 | # If a bridge point we make it a branch
1124 | if center.label == CenterType.BRIDGE:
1125 | self.bridge_2_branch(center.index, center.branch_number)
1126 | elif center.label == CenterType.NON_BRANCH:
1127 | remove_centers.append(center.index)
1128 | # Skip the other checks
1129 | continue
1130 | # 2)
1131 | if center.label == CenterType.BRIDGE:
1132 | has_non_branch_point_neighbours = False
1133 | # Check if all close neighbours are branch_points:
1134 | for neighbour in center.closest_neighbours:
1135 | neighbour = self.myCenters[neighbour]
1136 |
1137 | # Check till we leave the neighborhood
1138 | if sum((neighbour.center - center.center) ** 2) > (2 * center.h) ** 2:
1139 | break
1140 |
1141 | if neighbour.label != CenterType.BRANCH and neighbour.label != CenterType.REMOVED:
1142 | has_non_branch_point_neighbours = True
1143 | break
1144 | if not has_non_branch_point_neighbours:
1145 | self.bridge_2_branch(center.index, center.branch_number)
1146 |
1147 | # 3)
1148 | if center.label == CenterType.NON_BRANCH:
1149 | N_branch_points = 0
1150 | N_non_branch_points = 0
1151 | # Check if all close neighbours are branch_points:
1152 | for neighbour in center.closest_neighbours:
1153 | neighbour = self.myCenters[neighbour]
1154 |
1155 | # Check till we leave the neighbourhood
1156 | if np.sum((neighbour.center - center.center) ** 2) > center.h ** 2:
1157 | break
1158 |
1159 | if neighbour.label == CenterType.BRANCH:
1160 | N_branch_points += 1
1161 | elif neighbour.label == CenterType.NON_BRANCH or neighbour.label == CenterType.BRIDGE:
1162 | N_non_branch_points += 1
1163 |
1164 | if N_branch_points > N_non_branch_points:
1165 | remove_centers.append(center.index)
1166 |
1167 | # Remove all the centers
1168 | if remove_centers:
1169 | self.remove_centers(remove_centers)
1170 | print("removed", len(remove_centers), 'points!')
1171 |
1172 | def clean_points_around_branch(self, branch):
1173 | """
1174 | Removes points which:
1175 | 1) are within h of any point in the branch_list
1176 | 2) Excludes the head and tail IF they are connected to a bridge
1177 | 3) are non_branch_points
1178 |
1179 | """
1180 |
1181 | remove_centers = []
1182 |
1183 | for center in branch['branch']:
1184 |
1185 | # 2) If the head/tail is connected to a bridge do not remove any points
1186 | if center == branch['branch'][0] and branch['head_bridge_connection'][0]:
1187 | continue
1188 | elif center == branch['branch'][-1] and branch['tail_bridge_connection'][0]:
1189 | continue
1190 |
1191 | center = self.myCenters[center]
1192 |
1193 | for neighbour in center.closest_neighbours:
1194 |
1195 | neighbour = self.myCenters[neighbour]
1196 |
1197 | if sum((neighbour.center - center.center) ** 2) > self.too_close_threshold ** 2:
1198 | break
1199 |
1200 | if neighbour.label != CenterType.NON_BRANCH:
1201 | continue
1202 |
1203 | remove_centers.append(neighbour.index)
1204 |
1205 | if remove_centers:
1206 | self.remove_centers(remove_centers)
1207 | print("removed", len(remove_centers), 'points!')
1208 |
--------------------------------------------------------------------------------
/skeleton/center_type.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, unique
2 |
3 |
4 | @unique
5 | class CenterType(Enum):
6 | NON_BRANCH = 1,
7 | BRANCH = 2,
8 | BRIDGE = 3,
9 | REMOVED = 4,
10 |
--------------------------------------------------------------------------------
/skeleton/debug.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import random
4 | import sys
5 | import time
6 |
7 | import skeleton.center as sct
8 |
9 | from skeleton.center_type import CenterType
10 |
11 | from skeleton.params import get_density_weights
12 | from skeleton.utils import get_local_points
13 |
14 | import open3d as o3d
15 | from tqdm import tqdm
16 |
17 | import timeit
18 |
19 |
20 | class SkeletonBeforeAfterVisualizer:
21 | def __init__(self, skl: sct.Centers, enable=True):
22 | self.skl = skl
23 | self.enable = enable
24 |
25 | def __enter__(self):
26 | if not self.enable:
27 | return
28 |
29 | self.before_pts = self.skl.get_skeleton_points(copy=True)
30 |
31 | def __exit__(self, exc_type, exc_val, exc_tb):
32 | if not self.enable:
33 | return
34 |
35 | self.after_cts = self.skl.get_skeleton_points(copy=True)
36 | self._visualize_result()
37 |
38 | def _visualize_result(self):
39 | before_pcd = o3d.geometry.PointCloud()
40 | before_pcd.points = o3d.utility.Vector3dVector(self.before_pts)
41 | before_pcd.colors = o3d.utility.Vector3dVector([[0, 0.9, 0] for p in self.before_pts])
42 |
43 | after_pcd = o3d.geometry.PointCloud()
44 | after_pcd.points = o3d.utility.Vector3dVector([p for p in self.after_cts])
45 | after_pcd.colors = o3d.utility.Vector3dVector([[0, 0, 0.9] for p in self.after_cts])
46 |
47 | o3d.visualization.draw_geometries([before_pcd, after_pcd])
48 |
49 |
50 | class CodeTimer:
51 | def __init__(self, desc=None):
52 | self.t_start = 0
53 | self.t_end = 0
54 | self.desc = desc
55 |
56 | def __enter__(self):
57 | self.t_start = timeit.default_timer()
58 |
59 | def __exit__(self, exc_type, exc_val, exc_tb):
60 | self.t_end = timeit.default_timer()
61 | desc = self.desc
62 | if desc is None:
63 | desc = "Time: "
64 |
65 | print(desc, self.t_end - self.t_start)
66 |
--------------------------------------------------------------------------------
/skeleton/fit/ellipse.py:
--------------------------------------------------------------------------------
1 | from numpy.linalg import eig, inv, svd
2 | from math import atan2
3 | import numpy as np
4 |
5 |
6 | def __fit_ellipse(x, y):
7 | x, y = x[:, np.newaxis], y[:, np.newaxis]
8 | D = np.hstack((x * x, x * y, y * y, x, y, np.ones_like(x)))
9 | S, C = np.dot(D.T, D), np.zeros([6, 6])
10 | C[0, 2], C[2, 0], C[1, 1] = 2, 2, -1
11 | U, s, V = svd(np.dot(inv(S), C))
12 | a = U[:, 0]
13 | return a
14 |
15 |
16 | def ellipse_center(a):
17 | b, c, d, f, g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0]
18 | num = b * b - a * c
19 | x0 = (c * d - b * f) / num
20 | y0 = (a * f - b * d) / num
21 | return np.array([x0, y0])
22 |
23 |
24 | def ellipse_axis_length(a):
25 | b, c, d, f, g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0]
26 | up = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g)
27 | down1 = (b * b - a * c) * (
28 | (c - a) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a)
29 | )
30 | down2 = (b * b - a * c) * (
31 | (a - c) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a)
32 | )
33 | res1 = np.sqrt(up / down1)
34 | res2 = np.sqrt(up / down2)
35 | return np.array([res1, res2])
36 |
37 |
38 | def ellipse_angle_of_rotation(a):
39 | b, c, d, f, g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0]
40 | return atan2(2 * b, (a - c)) / 2
41 |
42 |
43 | def fit_ellipse(x, y):
44 | """@brief fit an ellipse to supplied data points: the 5 params
45 | returned are:
46 | M - major axis length
47 | m - minor axis length
48 | cx - ellipse centre (x coord.)
49 | cy - ellipse centre (y coord.)
50 | phi - rotation angle of ellipse bounding box
51 | @param x first coordinate of points to fit (array)
52 | @param y second coord. of points to fit (array)
53 | """
54 | a = __fit_ellipse(x, y)
55 | centre = ellipse_center(a)
56 | phi = ellipse_angle_of_rotation(a)
57 | M, m = ellipse_axis_length(a)
58 | # assert that the major axix M > minor axis m
59 | if m > M:
60 | M, m = m, M
61 | # ensure the angle is betwen 0 and 2*pi
62 | phi -= 2 * np.pi * int(phi / (2 * np.pi))
63 | return [M, m, centre[0], centre[1], phi]
64 |
--------------------------------------------------------------------------------
/skeleton/params.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 |
4 | from typing import Final
5 |
6 | SMALL_THRESHOLD: Final[float] = 1e-20
7 |
8 |
9 | def get_density_weights(points, hd, for_center=False, center=None):
10 | """
11 | INPUTS:
12 | x: 1x3 center we of interest, np.ndarray
13 | points: Nx3 array of all the points, np.ndarray
14 | h: size of local neighboorhood, float
15 | RETURNS:
16 | - np.array Nx1 of density weights assoiscated to each point
17 | """
18 | if center is None:
19 | center = [0, 0, 0]
20 |
21 | density_weights = []
22 |
23 | if for_center:
24 | r = points - center
25 | r2 = np.einsum('ij,ij->i', r, r)
26 | density_weights = 1 + np.einsum('i->', np.exp((-r2) / ((hd / 2) ** 2)))
27 | else:
28 | for point in points:
29 | r = point - points
30 | r2 = np.einsum('ij,ij->i', r, r)
31 | r2 = r2[r2 > SMALL_THRESHOLD]
32 |
33 | density_weight = 1 + np.einsum('i->', np.exp((-r2) / ((hd / 2.0) ** 2)))
34 | density_weights.append(density_weight)
35 |
36 | return np.array(density_weights)
37 |
38 |
39 | def get_term1(center: np.ndarray, points: np.ndarray, h: float, density_weights: np.ndarray):
40 | """
41 | :param center: 1x3 center we of interest, np.ndarray
42 | :param points: Nx3 array of all the points, np.ndarray
43 | :param h: size of local neighborhood, float
44 | :param density_weights:
45 | :return: term1 of the equation as float
46 | """
47 |
48 | r = points - center
49 | r2 = np.einsum('ij,ij->i', r, r)
50 |
51 | thetas = np.exp(-r2 / ((h / 2) ** 2))
52 |
53 | r2[r2 <= SMALL_THRESHOLD] = 1
54 | alphas = thetas # / np.sqrt(r2)
55 | alphas /= density_weights
56 |
57 | denom = np.einsum('i->', alphas)
58 | if denom > SMALL_THRESHOLD:
59 | # term1 = np.sum((points.T*alphas).T, axis = 0)/denom
60 | term1 = np.einsum('j,jk->k', alphas, points) / denom
61 | else:
62 | term1 = np.array(False)
63 |
64 | return term1
65 |
66 |
67 | def get_term2(center: np.ndarray, centers: np.ndarray, h: float):
68 | """
69 | :param center: 1x3 center we of interest, np.ndarray
70 | :param centers: Nx3 array of all the centers (excluding the current center), np.ndarray
71 | :param h: size of local neighborhood, float
72 | :return: term2 of the equation as float
73 | """
74 |
75 | x = center - centers
76 |
77 | r2 = np.einsum('ij,ij->i', x, x)
78 |
79 | indexes = r2 > SMALL_THRESHOLD
80 | r2 = r2[indexes]
81 | x = x[indexes]
82 |
83 | thetas = np.exp((-r2) / ((h / 2) ** 2))
84 |
85 | betas = thetas / np.sqrt(r2)
86 |
87 | denom = np.einsum('i->', betas)
88 |
89 | if denom > SMALL_THRESHOLD:
90 | num = np.einsum('j,jk->k', betas, x)
91 | term2 = num / denom
92 | else:
93 | term2 = np.array(False)
94 |
95 | return term2
96 |
97 |
98 | def get_sigma(center, centers, local_sigmas, h, k=5):
99 | # These are the weights
100 | r = centers - center
101 | r2 = np.einsum('ij,ij->i', r, r)
102 |
103 | indexes = r2 > SMALL_THRESHOLD
104 | r = r[indexes]
105 | r2 = r2[indexes]
106 |
107 | thetas = np.exp((-r2) / ((h / 2) ** 2))
108 |
109 | cov = np.einsum('j,jk,jl->kl', thetas, r, r)
110 |
111 | # Get eigenvalues and eigenvectors
112 | values, vectors = np.linalg.eig(cov)
113 |
114 | if np.iscomplex(values).any():
115 | values = np.real(values)
116 |
117 | vectors = np.real(vectors)
118 | vectors_norm = np.sqrt(np.einsum('ij,ij->i', vectors, vectors))
119 | vectors = vectors / vectors_norm
120 |
121 | # argsort always works from low --> to high so taking the negative values will give us high --> low indices
122 | sorted_indices = np.argsort(-values)
123 |
124 | sigma = np.max(values) / np.sum(values)
125 |
126 | if k > 0:
127 | knn = np.argsort(r2)
128 | knn_sigmas = np.sum(local_sigmas[knn[1:k]])
129 | sigma = (sigma + knn_sigmas) / k # smoothing sigma
130 |
131 | vectors_sorted = vectors[:, sorted_indices]
132 |
133 | if not np.isfinite(sigma):
134 | sigma = 0.9
135 |
136 | return sigma, vectors_sorted
137 |
--------------------------------------------------------------------------------
/skeleton/recentering.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import skeleton.utils as utils
4 | from skeleton.center_type import CenterType
5 |
6 | from skimage.measure import EllipseModel
7 |
8 | from skeleton.fit.ellipse import fit_ellipse
9 |
10 | import open3d as o3d
11 |
12 |
13 | def ellipse_center_skimage(projected):
14 | xy = projected[:, [0, 1]]
15 |
16 | ell = EllipseModel()
17 | if not ell.estimate(xy):
18 | return False, None
19 |
20 | xc, yc, _, _, _ = ell.params
21 | return True, np.array([xc, yc])
22 |
23 |
24 | def ellipse_center_svd(projected):
25 | x = projected[:, 0]
26 | y = projected[:, 1]
27 |
28 | _, _, xc, yc, _ = fit_ellipse(x, y)
29 | return True, np.array([xc, yc])
30 |
31 |
32 | def ellipse_center(projected, algorithm='svd'):
33 | if algorithm == 'svd':
34 | return ellipse_center_svd(projected)
35 | else:
36 | return ellipse_center_skimage(projected)
37 |
38 |
39 | def ellipse_center(projected):
40 | xy = projected[:, [0, 1]]
41 |
42 | ell = EllipseModel()
43 | if not ell.estimate(xy):
44 | return False, None
45 |
46 | xc, yc, _, _, _ = ell.params
47 | return True, np.array([xc, yc])
48 |
49 |
50 | def visualize_result(projected, neighbors, p):
51 | prj = o3d.geometry.PointCloud()
52 | prj.points = o3d.utility.Vector3dVector(projected)
53 | prj.colors = o3d.utility.Vector3dVector([[0, 0.9, 0] for p in projected])
54 |
55 | original = o3d.geometry.PointCloud()
56 | original.points = o3d.utility.Vector3dVector([p for p in neighbors])
57 | original.colors = o3d.utility.Vector3dVector([[0, 0, 0.9] for p in neighbors])
58 |
59 | cloud = o3d.geometry.PointCloud()
60 | cts = [p]
61 | cloud.points = o3d.utility.Vector3dVector(cts)
62 | cloud.colors = o3d.utility.Vector3dVector([[0.9, 0.0, 0.0] for _ in cts])
63 |
64 | o3d.visualization.draw_geometries([prj, original, cloud])
65 |
66 |
67 | def recenter_around(center, neighbors, max_dist_move):
68 | normal = center.normal_vector()
69 | # normal = utils.unit_vector(normal)
70 |
71 | if np.allclose(normal, np.zeros_like(normal)):
72 | return center
73 |
74 | if not np.isfinite(normal).all():
75 | return center
76 |
77 | p = center.center.copy()
78 |
79 | projected = np.array(
80 | [utils.project_one_point(q, p, normal) for q in neighbors if np.isfinite(q).all()])
81 |
82 | # visualize_result(projected, neighbors, p)
83 |
84 | success, cp = ellipse_center(projected)
85 | if not success:
86 | # center.set_label(CenterType.REMOVED)
87 | return center
88 |
89 | nxy = normal[[0, 1]]
90 | diff = p[[0, 1]] - cp
91 | pz = -np.dot(diff, nxy) / normal[2] + p[2]
92 |
93 | cp = np.append(cp, pz)
94 |
95 | move = cp - center.center
96 | l_move = np.linalg.norm(move)
97 | if l_move > max_dist_move:
98 | return center
99 |
100 | center.center = cp
101 | return center
102 |
--------------------------------------------------------------------------------
/skeleton/skeletonization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import random
4 | import sys
5 | import time
6 |
7 | import skeleton.center as sct
8 |
9 | from skeleton.center_type import CenterType
10 |
11 | from skeleton.params import get_density_weights
12 | from skeleton.utils import get_local_points
13 |
14 | from skeleton.debug import SkeletonBeforeAfterVisualizer
15 | from typing import Final
16 |
17 |
18 | def skeletonize(points, n_centers=1000,
19 | max_iterations=50,
20 | dh=2.0,
21 | sigma_smoothing_k=5,
22 | error_tolerance=1e-5,
23 | downsampling_rate=0.5,
24 | try_make_skeleton=True,
25 | recenter_knn=200,
26 | max_points=None):
27 | assert len(points) > n_centers
28 | assert len(points) > recenter_knn
29 |
30 | if max_points is not None and len(points) > max_points:
31 | print("Down sampling the original point cloud")
32 | random_indices = random.sample(range(0, len(points)), max_points)
33 | points = points[random_indices, :]
34 |
35 | # random.seed(int(time.time()))
36 | random.seed(3074)
37 |
38 | skl_centers = sct.Centers(points=points, center_count=n_centers, smoothing_k=sigma_smoothing_k)
39 |
40 | # for i in range(len(skl_centers.myCenters)):
41 | # skl_centers.myCenters[i].set_label(CenterType.BRANCH)
42 | # return skl_centers
43 |
44 | h = h0 = skl_centers.get_h0()
45 | print("h0:", h0)
46 |
47 | hd: Final[float] = h0 / 2
48 | density_weights = get_density_weights(points, hd)
49 |
50 | print("Max iterations: {}, Number points: {}, Number centers: {}".format(max_iterations, len(points),
51 | len(skl_centers.centers)))
52 |
53 | last_non_branch = len(skl_centers.centers)
54 | non_change_iters = 0
55 | for i in range(max_iterations):
56 | bridge_points = len([1 for c in skl_centers.myCenters if c.label == CenterType.BRIDGE])
57 | non_branch_points = len([1 for c in skl_centers.myCenters if c.label == CenterType.NON_BRANCH])
58 |
59 | print("\n\nIteration:{}, h:{}, bridge_points:{}\n\n".format(i, round(h, 3), bridge_points))
60 |
61 | last_error = 0
62 | for j in range(50): # magic number. do contracting at most 30 times
63 | error = skl_centers.contract(h, density_weights)
64 | skl_centers.update_properties()
65 |
66 | if np.abs(error - last_error) < error_tolerance:
67 | break
68 |
69 | last_error = error
70 |
71 | if try_make_skeleton:
72 | skl_centers.find_connections()
73 |
74 | print("Non-branch:", non_branch_points)
75 |
76 | if non_branch_points == last_non_branch:
77 | non_change_iters += 1
78 | elif non_branch_points < last_non_branch:
79 | non_change_iters = 0
80 |
81 | if non_change_iters >= 5:
82 | print("Cannot make more branch points")
83 | break
84 |
85 | if non_branch_points == 0:
86 | print("Found whole skeleton!")
87 | break
88 |
89 | last_non_branch = non_branch_points
90 |
91 | h = h + h0 / dh
92 |
93 | with SkeletonBeforeAfterVisualizer(skl_centers, enable=True):
94 | if recenter_knn > 0:
95 | skl_centers.recenter(downsampling_rate=downsampling_rate, knn=recenter_knn)
96 |
97 | return skl_centers
98 |
--------------------------------------------------------------------------------
/skeleton/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 | from skspatial.objects import Plane, Point
5 |
6 |
7 | def unit_vector(vector):
8 | return vector / np.linalg.norm(vector)
9 |
10 |
11 | def get_local_points_fast(points, centers, h, max_local_points=50000):
12 | # Get local_points points around this center point
13 | local_indices = []
14 | for center in centers:
15 | x, y, z = center
16 |
17 | # 1) first get the square around the center
18 | where_square = ((points[:, 0] >= (x - h)) & (points[:, 0] <= (x + h)) & (points[:, 1] >= (y - h)) &
19 | (points[:, 1] <= (y + h)) & (points[:, 2] >= (z - h)) & (points[:, 2] <= (z + h)))
20 |
21 | square = points[where_square]
22 | indices_square = np.where(where_square == True)[0]
23 |
24 | # Get points which comply to x^2, y^2, z^2 <= r^2
25 | square_squared = np.square(square - center)
26 | where_sphere = np.sum(square_squared, axis=1) <= h ** 2
27 | local_sphere_indices = indices_square[where_sphere]
28 |
29 | local_indices.append(local_sphere_indices)
30 |
31 | if len(local_indices) > max_local_points:
32 | return random.sample(local_indices, max_local_points)
33 |
34 | return local_indices
35 |
36 |
37 | def get_local_points(kdt, centers, h, max_local_points=50000):
38 | # Get local_points points around this center point
39 | local_indices = []
40 | for center in centers:
41 | k, idx, _ = kdt.search_radius_vector_3d(center, radius=h)
42 |
43 | indices = list(idx[1:])
44 | if len(indices) > max_local_points:
45 | return random.sample(indices, max_local_points)
46 |
47 | local_indices.append(indices)
48 |
49 | if len(local_indices) > max_local_points:
50 | return random.sample(local_indices, max_local_points)
51 | return local_indices
52 |
53 |
54 | def project_one_point(q, p, n):
55 | """
56 | :param q: a point
57 | :param p: the point on the plane
58 | :param n: the normal vector of the plane
59 | :return: the projected point
60 | """
61 | plane = Plane(point=p, normal=n)
62 | pt = Point(q)
63 | return plane.project_point(pt)
64 |
65 |
66 | def plane_dist(q, p, n):
67 | """
68 | :param q: a point
69 | :param p: a point on plane
70 | :param n: the normal vector of the plane
71 | :return:
72 | """
73 | pq = q - p
74 | dot = np.dot(pq, n)
75 | ret = dot / np.linalg.norm(n)
76 | if not np.isfinite([ret]).all():
77 | return 0x7fffffff
78 | return ret
79 |
--------------------------------------------------------------------------------