├── .github
└── workflows
│ └── python-app.yml
├── .gitignore
├── CHANGELOG.rst
├── LICENSE
├── README.md
├── images
├── apple-pcd.png
├── apple-skeleton.png
├── cherry-pcd.png
├── cherry-skeleton.png
├── pine-pcd.png
└── pine-skeleton.png
├── pyproject.toml
├── synthetic_trees
├── .gitignore
├── __init__.py
├── data_types
│ ├── __init__.py
│ ├── branch.py
│ ├── cloud.py
│ ├── tree.py
│ └── tube.py
├── download.py
├── evaluate.py
├── evaluation
│ ├── __init__.py
│ ├── metrics.py
│ └── results.py
├── extract_cloud.py
├── process_results.py
├── scripts
│ └── batch_rotate.py
├── test_contraction.py
├── test_contraction2.py
├── util
│ ├── __init__.py
│ ├── file.py
│ ├── math.py
│ ├── misc.py
│ ├── o3d_abstractions.py
│ ├── operations.py
│ └── queries.py
├── view.py
└── view_clouds.py
└── tests
└── test_skeleton.py
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | build:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v3
22 | - name: Set up Python 3.8.16
23 | uses: actions/setup-python@v3
24 | with:
25 | python-version: "3.8.16"
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install .
30 |
31 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | dataset/
132 | outputs/
133 |
134 | *.csv
135 | dataset.zip
136 | data/
--------------------------------------------------------------------------------
/CHANGELOG.rst:
--------------------------------------------------------------------------------
1 | =========
2 | Changelog
3 | =========
4 |
5 | Version 0.1
6 | ===========
7 |
8 | -
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 UC Vision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
🌳🌲🌴 Synthetic-Trees 🌴🌲🌳
2 |
3 | ## 📝 Description
4 |
5 | This repository offers a synthetic point cloud dataset with ground truth skeletons of multiple species. Our library enables you to open, visualize, and assess the accuracy of the skeletons. To gain insight into how we created the data and the evaluation metrics employed, please refer to our published paper, available at this link. Our dataset consists of two parts: one with point clouds that feature foliage, which is particularly useful for training models that can handle real-world data that includes leaves; the other contains only the branching structure and is less affected to occlusion.
6 |
7 |
8 |
9 |
10 |  |
11 |  |
12 |  |
13 |
14 |
15 | Sapling Cherry Point Cloud. |
16 | Apple Tree Point Cloud. |
17 | Pine Tree Point Cloud. |
18 |
19 |
20 |
21 |  |
22 |  |
23 |  |
24 |
25 |
26 | Sapling Cherry Ground Truth Skeleton. |
27 | Apple Tree Ground Truth Skeleton. |
28 | Pine Tree Ground Truth Skeleton. |
29 |
30 |
31 |
32 |
33 | ## 🔍 Usage
34 |
35 | #### 💾 Downloading
36 |
37 | You can download the data by following this link. The dataset includes synthetic point clouds and ground truth skeletons, along with a JSON file that specifies the training, validation, and test sets. For evaluation purposes, we have provided "cleaned" point clouds and skeletons in the evaluation folder, which are suitable for assessment.
38 |
39 |
40 | #### 🖥️ Installation
41 | To install:
42 | Create a conda environment:
43 |
44 | `conda create -n synthetic-trees python=3.8`
45 |
46 | then:
47 |
48 | `pip install .`
49 |
50 | #### 🕵️♂️ Visualizing
51 | To visualize the data, use the `visualize.py` script. You can call it using either:
52 |
53 | ``` view-synthetic-trees -p=file_path -lw=linewidth ```
54 | or
55 | ``` view-synthetic-trees -d=directory -lw=linewidth ```
56 |
57 | where:
58 | - `file_path` is the path of the `.npz` file of a single tree.
59 | - `directory` is the path of the folder containing `.npz` files.
60 | - `linewidth` is the width of the skeleton lines in the visualizer.
61 |
62 | #### 📊 Evaluation
63 | To evaluate your method against the ground truth data, use the `evaluate.py` script. You can call it using:
64 |
65 | ``` evaluate-synthetic-trees -d_gt=ground_truth_directory -d_o=output_directory -r_o=results_save_path ```
66 |
67 | where:
68 | - `ground_truth_directory` is the directory of the ground truth `.npz` files.
69 | - `output_directory` is the directory of the folder containing your skeleton outputs (in `.ply` format).
70 | - `results_save_path` is the path of the `.csv` file to save your results to.
71 |
72 | #### 📋 Processing Results
73 | After running the evaluation, you can use the `process_results.py` script to post-process the raw results and obtain metrics across the dataset. Call it using:
74 |
75 | ``` process-synthetic-trees-results -p=path ```
76 |
77 | where: `path` is the path of the results `.csv` file from the evaluation step.
78 |
79 | ## 📜 Citation
80 | Please use the following BibTeX entry to cite our work:
81 | ```
82 | @inproceedings{dobbs2023smart,
83 | title={Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization},
84 | author={Dobbs, Harry and Batchelor, Oliver and Green, Richard and Atlas, James},
85 | booktitle={Iberian Conference on Pattern Recognition and Image Analysis},
86 | pages={351--362},
87 | year={2023},
88 | organization={Springer}
89 | }
90 |
91 | ```
92 |
93 | ## 📥 Contact
94 |
95 | Should you have any questions, comments or suggestions please use the following contact details:
96 | harry.dobbs@pg.canterbury.ac.nz
97 |
--------------------------------------------------------------------------------
/images/apple-pcd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/apple-pcd.png
--------------------------------------------------------------------------------
/images/apple-skeleton.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/apple-skeleton.png
--------------------------------------------------------------------------------
/images/cherry-pcd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/cherry-pcd.png
--------------------------------------------------------------------------------
/images/cherry-skeleton.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/cherry-skeleton.png
--------------------------------------------------------------------------------
/images/pine-pcd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/pine-pcd.png
--------------------------------------------------------------------------------
/images/pine-skeleton.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/pine-skeleton.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "synthetic_trees"
7 | authors = [
8 | {name = "Harry Dobbs", email = "harrydobbs87@gmail.com"},
9 | ]
10 | description = "Tools for synthetic tree point cloud dataset."
11 | readme = "README.rst"
12 | requires-python = ">=3.7"
13 | license = {text = "MIT"}
14 | dependencies = [
15 | 'numpy',
16 | 'open3d==0.17',
17 | 'pykeops',
18 | 'torch']
19 | dynamic = ["version"]
20 |
21 | [tool.setuptools.packages]
22 | find = {} # Scan the project directory with the default parameters
23 |
24 | [project.scripts]
25 | view-synthetic-trees = "synthetic_trees.view:main"
26 | test-contraction = "synthetic_trees.test_contraction:main"
27 |
28 | download-synthetic-trees = "synthetic_trees.download:main"
29 | evaluate-synthetic-trees = "synthetic_trees.evaluate:main"
30 | process-synthetic-trees-results = "synthetic_trees.process_results:main"
31 | view-pointclouds = "synthetic_trees.view_clouds:main"
32 |
33 |
34 |
--------------------------------------------------------------------------------
/synthetic_trees/.gitignore:
--------------------------------------------------------------------------------
1 | # Temporary and binary files
2 | *~
3 | *.py[cod]
4 | *.so
5 | *.cfg
6 | !.isort.cfg
7 | !setup.cfg
8 | *.orig
9 | *.log
10 | *.pot
11 | __pycache__/*
12 | .cache/*
13 | .*.swp
14 | */.ipynb_checkpoints/*
15 | .DS_Store
16 |
17 | # Project files
18 | .ropeproject
19 | .project
20 | .pydevproject
21 | .settings
22 | .idea
23 | .vscode
24 | tags
25 |
26 | # Package files
27 | *.egg
28 | *.eggs/
29 | .installed.cfg
30 | *.egg-info
31 |
32 | # Unittest and coverage
33 | htmlcov/*
34 | .coverage
35 | .coverage.*
36 | .tox
37 | junit*.xml
38 | coverage.xml
39 | .pytest_cache/
40 |
41 | # Build and docs folder/files
42 | build/*
43 | dist/*
44 | sdist/*
45 | docs/api/*
46 | docs/_rst/*
47 | docs/_build/*
48 | cover/*
49 | MANIFEST
50 |
51 | # Per-project virtualenvs
52 | .venv*/
53 | .conda*/
54 | .python-version
55 |
--------------------------------------------------------------------------------
/synthetic_trees/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | if sys.version_info[:2] >= (3, 8):
4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover
6 | else:
7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover
8 |
9 | try:
10 | # Change here if project is renamed and does not equal the package name
11 | dist_name = __name__
12 | __version__ = version(dist_name)
13 | except PackageNotFoundError: # pragma: no cover
14 | __version__ = "unknown"
15 | finally:
16 | del version, PackageNotFoundError
17 |
--------------------------------------------------------------------------------
/synthetic_trees/data_types/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | if sys.version_info[:2] >= (3, 8):
4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover
6 | else:
7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover
8 |
9 | try:
10 | # Change here if project is renamed and does not equal the package name
11 | dist_name = __name__
12 | __version__ = version(dist_name)
13 | except PackageNotFoundError: # pragma: no cover
14 | __version__ = "unknown"
15 | finally:
16 | del version, PackageNotFoundError
17 |
--------------------------------------------------------------------------------
/synthetic_trees/data_types/branch.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import open3d as o3d
4 |
5 | from dataclasses import dataclass
6 | from typing import List, Dict
7 |
8 | from .tube import Tube
9 | from ..util.o3d_abstractions import o3d_path, o3d_tube_mesh
10 |
11 | from ..util.queries import pts_on_nearest_tube
12 |
13 |
14 | @dataclass
15 | class BranchSkeleton:
16 | _id: int
17 | parent_id: int
18 | xyz: np.array
19 | radii: np.array
20 | child_id: int = -1
21 |
22 | @property
23 | def length(self):
24 | return np.sum(np.sqrt(np.sum(np.diff(self.xyz, axis=0)**2, axis=1)))
25 |
26 | def __len__(self):
27 | return self.xyz.shape[0]
28 |
29 | def __str__(self):
30 | return f"Branch {self._id} with {self.xyz} points. \
31 | and {self.radii} radii"
32 |
33 | def to_tubes(self) -> List[Tube]:
34 | a_, b_, r1_, r2_ = (
35 | self.xyz[:-1], self.xyz[1:], self.radii[:-1], self.radii[1:])
36 |
37 | return [Tube(a, b, r1, r2) for a, b, r1, r2 in zip(a_, b_, r1_, r2_)]
38 |
39 | def closest_pt(self, pt: np.array): # closest point on skeleton to query point
40 | return pts_on_nearest_tube(pt, self.to_tubes())
41 |
42 | def to_o3d_lineset(self, colour=(0, 0, 0)) -> o3d.cuda.pybind.geometry.LineSet:
43 | return o3d_path(self.xyz, colour)
44 |
45 | def to_o3d_tube(self):
46 | return o3d_tube_mesh(self.xyz, self.radii)
47 |
--------------------------------------------------------------------------------
/synthetic_trees/data_types/cloud.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | from dataclasses import dataclass
6 | from ..util.o3d_abstractions import o3d_cloud, o3d_lines_between_clouds
7 |
8 |
9 | @dataclass
10 | class Cloud:
11 | xyz: np.array
12 | rgb: np.array
13 | class_l: np.array = None
14 | medial_vector: np.array = None
15 |
16 | def __len__(self):
17 | return self.xyz.shape[0]
18 |
19 | def __str__(self):
20 | return f"Cloud with {self.xyz.shape[0]} points "
21 |
22 | def to_o3d_cloud(self):
23 | return o3d_cloud(self.xyz, colours=self.rgb)
24 |
25 | def to_o3d_cloud_labelled(self, cmap=None):
26 | if cmap is None:
27 | cmap = np.random.rand(self.number_classes, 3)
28 |
29 | return o3d_cloud(self.xyz, colours=cmap[self.class_l])
30 |
31 | def to_o3d_medial_vectors(self, cmap=None):
32 | medial_cloud = o3d_cloud(self.xyz + self.medial_vector)
33 | return o3d_lines_between_clouds(self.to_o3d_cloud(), medial_cloud)
34 |
35 | def to_device(self, device):
36 | if self.xyz is not None:
37 | self.xyz = (
38 | torch.from_numpy(self.xyz).to(device)
39 | if isinstance(self.xyz, np.ndarray)
40 | else self.xyz.to(device)
41 | )
42 |
43 | if self.rgb is not None:
44 | self.rgb = (
45 | torch.from_numpy(self.rgb).to(device)
46 | if isinstance(self.rgb, np.ndarray)
47 | else self.rgb.to(device)
48 | )
49 |
50 | if self.class_l is not None:
51 | self.class_l = (
52 | torch.from_numpy(self.class_l).to(device)
53 | if isinstance(self.class_l, np.ndarray)
54 | else self.class_l.to(device)
55 | )
56 |
57 | if self.medial_vector is not None:
58 | self.medial_vector = (
59 | torch.from_numpy(self.medial_vector).to(device)
60 | if isinstance(self.medial_vector, np.ndarray)
61 | else self.medial_vector.to(device)
62 | )
63 |
64 | @property
65 | def number_classes(self):
66 | return torch.max(self.class_l) + 1
67 |
--------------------------------------------------------------------------------
/synthetic_trees/data_types/tree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 |
4 | from copy import deepcopy
5 | from dataclasses import dataclass
6 | from typing import List, Dict
7 |
8 | from ..util.o3d_abstractions import o3d_merge_linesets, o3d_merge_meshes
9 | from ..util.misc import flatten_list
10 |
11 | from ..util.operations import sample_tubes
12 | from .branch import BranchSkeleton
13 | from .tube import Tube
14 |
15 |
16 | @dataclass
17 | class TreeSkeleton:
18 | _id: int
19 | branches: Dict[int, BranchSkeleton]
20 |
21 | def __len__(self):
22 | return len(self.branches)
23 |
24 | def __str__(self):
25 | return (f"Tree Skeleton ({self._id}) has {len(self)} branches...")
26 |
27 | def to_tubes(self) -> List[Tube]:
28 | return flatten_list([branch.to_tubes() for branch in self.branches.values()])
29 |
30 | def to_o3d_tubes(self) -> o3d.cuda.pybind.geometry.TriangleMesh:
31 | return o3d_merge_meshes([branch.to_o3d_tube() for branch in self.branches.values()])
32 |
33 | def to_o3d_lineset(self) -> o3d.cuda.pybind.geometry.LineSet:
34 | return o3d_merge_linesets([branch.to_o3d_lineset() for branch in self.branches.values()])
35 |
36 | def point_sample(self, sample_rate=0.01) -> o3d.cuda.pybind.geometry.PointCloud:
37 | return sample_tubes(self.to_tubes(), sample_rate)
38 |
39 |
40 | def repair_skeleton(skeleton: TreeSkeleton):
41 | """ By default the skeletons are not connected between branches.
42 | this function connects the branches to their parent branches by finding
43 | the nearest point on the parent branch - relative to radius.
44 | It returns a new skeleton with no reference to the original.
45 | """
46 | skeleton = deepcopy(skeleton)
47 |
48 | for branch in list(skeleton.branches.values()):
49 |
50 | if branch.parent_id == -1 or branch.parent_id == 0:
51 | continue
52 |
53 | parent_branch = skeleton.branches[branch.parent_id]
54 |
55 | connection_pt, connection_rad = parent_branch.closest_pt(
56 | pt=branch.xyz[[0]])
57 |
58 | branch.xyz = np.insert(branch.xyz, 0, connection_pt, axis=0)
59 | branch.radii = np.insert(branch.radii, 0, connection_rad, axis=0)
60 |
61 | return skeleton
62 |
63 |
64 | def prune_skeleton(skeleton: TreeSkeleton, min_radius_threshold=0.01, length_threshold=0.02, root_id=1):
65 | """ In the skeleton format we are using each branch only knows it's parent
66 | but not it's child (could work this out by doing a traversal). If a branch doesn't
67 | meet the initial radius threshold or length threshold we want to remove it and all
68 | it's predecessors...
69 | Because of the way the skeleton is initalized however we know that earlier branches
70 | are guaranteed to be of lower order.
71 | minimum_radius_threshold: some point of the branch must be above this to not remove the branch
72 | length_threshold: the total length of the branch must be greater than this point
73 | """
74 | branches_to_keep = {root_id: skeleton.branches[root_id]}
75 |
76 | for branch_id, branch in skeleton.branches.items():
77 |
78 | if branch.parent_id == -1:
79 | continue
80 |
81 | if branch.parent_id in branches_to_keep:
82 | if branch.length > length_threshold and branch.radii[0] > min_radius_threshold:
83 | branches_to_keep[branch_id] = branch
84 |
85 | return TreeSkeleton(skeleton._id, branches_to_keep)
86 |
--------------------------------------------------------------------------------
/synthetic_trees/data_types/tube.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from typing import List
4 | from dataclasses import dataclass
5 |
6 |
7 | @dataclass
8 | class Tube:
9 | a: np.array # Start Point 3
10 | b: np.array # End Point 3
11 | r1: float # Start Radius
12 | r2: float # End Radius
13 |
14 |
15 | @dataclass
16 | class CollatedTube:
17 | a: np.array # Nx3
18 | b: np.array # Nx3
19 | r1: np.array # N
20 | r2: np.array # N
21 |
22 |
23 | def collate_tubes(tubes: List[Tube]) -> CollatedTube:
24 |
25 | a = np.concatenate([tube.a for tube in tubes]).reshape(-1, 3)
26 | b = np.concatenate([tube.b for tube in tubes]).reshape(-1, 3)
27 |
28 | r1 = np.asarray([tube.r1 for tube in tubes]).reshape(1, -1)
29 | r2 = np.asarray([tube.r2 for tube in tubes]).reshape(1, -1)
30 |
31 | return CollatedTube(a, b, r1, r2)
32 |
--------------------------------------------------------------------------------
/synthetic_trees/download.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import argparse
4 |
5 |
6 | def parse_args():
7 |
8 | parser = argparse.ArgumentParser(description="Downloader Arguments")
9 |
10 | parser.add_argument("-d", "--directory",
11 | help="Tree download directory.",
12 | required=False,
13 | default="dataset", type=str)
14 |
15 | return parser.parse_args()
16 |
17 |
18 | def main():
19 |
20 | args = parse_args()
21 |
22 | if not os.path.isdir(args.directory):
23 | os.mkdir(args.directory)
24 |
25 | # download to that directory
26 |
27 |
28 |
29 |
30 |
31 |
32 | if __name__ == "__main__":
33 | main()
34 |
--------------------------------------------------------------------------------
/synthetic_trees/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from typing import List, Tuple
5 | from pathlib import Path
6 | import argparse
7 |
8 | import torch
9 | import open3d as o3d
10 |
11 | from tqdm import tqdm
12 |
13 | from data_types.tree import TreeSkeleton, repair_skeleton
14 | from data_types.cloud import Cloud
15 |
16 | from util.file import load_data_npz
17 | from util.o3d_abstractions import o3d_load_lineset
18 | from evaluation.results import save_results
19 |
20 | from util.operations import sample_o3d_lineset
21 | from util.misc import to_torch
22 | from evaluation.metrics import recall, precision
23 |
24 |
25 | def evaluate_one(gt_skeleton: TreeSkeleton, output_skeleton: o3d.cuda.pybind.geometry.LineSet, thresholds=np.linspace(0,1,100), sample_rate=0.001):
26 |
27 | results = {}
28 |
29 | skeleton = repair_skeleton(gt_skeleton)
30 |
31 | gt_xyzs, gt_radii = skeleton.point_sample(sample_rate)
32 |
33 | output_pts = sample_o3d_lineset(output_skeleton, sample_rate)
34 |
35 | #o3d_viewer([o3d_cloud(gt_xyzs, colour=(0,1,0)), o3d_cloud(output_pts, colour=(1,0,0)), skeleton.to_o3d_lineset()])
36 |
37 | gt_xyzs_c, gt_radii_c, output_pts_c = to_torch([gt_xyzs, gt_radii, output_pts], device=torch.device("cuda"))
38 |
39 | results["recall"] = recall(gt_xyzs_c, output_pts_c, gt_radii_c.reshape(-1), thresholds=thresholds)
40 | results["precision"] = precision(gt_xyzs_c, output_pts_c, gt_radii_c.reshape(-1), thresholds=thresholds)
41 | results['thresholds'] = thresholds
42 |
43 | return results
44 |
45 |
46 | def gt_skeleton_generator(paths: List[Path]) -> Tuple[Cloud, TreeSkeleton]:
47 | for path in paths:
48 | yield load_data_npz(path)[1]
49 |
50 | def output_skeleton_generator(paths: List[Path]) -> o3d.cuda.pybind.geometry.LineSet:
51 | for path in paths:
52 | yield o3d_load_lineset(str(path))
53 |
54 |
55 | def parse_args():
56 |
57 | parser = argparse.ArgumentParser(description="Visualizer Arguments")
58 |
59 | parser.add_argument("-d_gt", "--ground_truth_dir",
60 | help="Directory of folder of tree.npz(s) *.npz", default=None, type=str)
61 |
62 | parser.add_argument("-d_o", "--output_dir",
63 | help="Directory of folder of skeleton outputs *.ply", default=None, type=str)
64 |
65 | parser.add_argument("-r_o", "--results_save_path",
66 | help="Path to save results csv to", default="results.csv", type=str, required=False)
67 |
68 | return parser.parse_args()
69 |
70 |
71 | def main():
72 |
73 | args = parse_args()
74 |
75 | ground_truth_paths = list(Path(args.ground_truth_dir).glob("*.npz"))
76 | output_paths = list(Path(args.output_dir).glob("*.ply"))
77 |
78 | ground_truth_tree_names = [path.stem for path in ground_truth_paths]
79 | output_tree_names = [path.stem for path in output_paths]
80 |
81 | tree_names = list(set(ground_truth_tree_names).intersection(set(output_tree_names)))
82 |
83 | gt_paths = sorted([p for p in ground_truth_paths if p.stem in tree_names]) #[30:]
84 | output_paths = sorted([p for p in output_paths if p.stem in tree_names]) #[30:]
85 |
86 | results = {}
87 |
88 | for gt_skeleton, output_skeleton, tree_name in tqdm(zip(gt_skeleton_generator(gt_paths), output_skeleton_generator(output_paths), tree_names)):
89 |
90 | results[f"{tree_name}"] = evaluate_one(gt_skeleton, output_skeleton, sample_rate=0.001)
91 |
92 | save_results(results, args.results_save_path)
93 |
94 | if __name__ == "__main__":
95 | main()
--------------------------------------------------------------------------------
/synthetic_trees/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | if sys.version_info[:2] >= (3, 8):
4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover
6 | else:
7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover
8 |
9 | try:
10 | # Change here if project is renamed and does not equal the package name
11 | dist_name = __name__
12 | __version__ = version(dist_name)
13 | except PackageNotFoundError: # pragma: no cover
14 | __version__ = "unknown"
15 | finally:
16 | del version, PackageNotFoundError
17 |
--------------------------------------------------------------------------------
/synthetic_trees/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import frnn
3 |
4 | from util.queries import nn_frnn, nn_keops
5 |
6 |
7 | def recall(gt_points, test_points, gt_radii, thresholds=[0.1]): # recall (completeness)
8 |
9 | results = []
10 | dist, idx = nn_keops(gt_points, test_points)
11 | idx = idx.reshape(-1)
12 | dist = dist.reshape(-1)
13 |
14 | for t in thresholds:
15 |
16 | mask = dist < (gt_radii * t)
17 |
18 | valid_percentage = (torch.sum(mask) / gt_points.shape[0]) * 100
19 |
20 | results.append(valid_percentage.item())
21 |
22 | return results
23 |
24 |
25 | def precision(gt_points, test_points, gt_radii, thresholds=[0.1]): # precision (how close)
26 |
27 | results = []
28 | dist, idx = nn_keops(test_points, gt_points)
29 | idx = idx.reshape(-1)
30 | dist = dist.reshape(-1)
31 |
32 | for t in thresholds:
33 |
34 | mask = dist < (gt_radii[idx] * t)
35 |
36 | valid_percentage = (torch.sum(mask) / test_points.shape[0]) * 100
37 |
38 | results.append(valid_percentage.item())
39 |
40 | return results
41 |
42 |
43 | # def recall(gt_points, test_points, gt_radii, threshold=1.0):
44 | # idxs, dists, _ = nn(gt_points, test_points, r=gt_radii.max().item())
45 | # valid_idx = idxs[idxs != -1]
46 | # valid = (dists[valid_idx] < gt_radii[valid_idx] * threshold)
47 | # return (torch.sum(valid).cpu().item() / gt_points.shape[0]) * 100
48 |
49 |
50 | # def precision(test_points, gt_points, gt_radii, threshold=1.0):
51 | # idxs, dists, _ = nn(test_points, gt_points, r=gt_radii.max().item())
52 | # valid_idx = idxs[idxs != -1]
53 | # valid = dists[valid_idx] < gt_radii[valid_idx] * threshold
54 |
55 | # print(torch.sum(torch.isnan(valid.cpu())))
56 | # return (torch.sum(valid.cpu()).item() / test_points.shape[0]) * 100
--------------------------------------------------------------------------------
/synthetic_trees/evaluation/results.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from typing import Dict
4 |
5 |
6 | def save_results(results_dict: Dict, save_path: str ="results.csv"):
7 |
8 | rows = []
9 |
10 | for tree_name, results in results_dict.items():
11 |
12 | species = tree_name.split("_")[0]
13 |
14 | for threshold, recall, precision in zip(results["thresholds"], results["recall"], results["precision"]):
15 |
16 | rows.append([tree_name, species, threshold, recall, precision])
17 |
18 | df = pd.DataFrame(rows, columns=["tree_name", "species", "threshold", "recall", "precision"])
19 | df.to_csv(f"{save_path}", index=False)
20 |
--------------------------------------------------------------------------------
/synthetic_trees/extract_cloud.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from synthetic_trees.util.file import load_data_npz
5 | import open3d as o3d
6 |
7 |
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(description="Visualizer Arguments")
12 |
13 | parser.add_argument("file_path",
14 | help="File Path of tree.npz",type=Path)
15 | parser.add_argument("output_file",
16 | help="File path to write cloud", type=Path)
17 | return parser.parse_args()
18 |
19 |
20 |
21 | def main():
22 |
23 | args = parse_args()
24 | assert args.file_path.exists(), f"File {args.file_path} does not exist"
25 |
26 | cloud, skeleton = load_data_npz(args.file_path)
27 | o3d.t.io.write_point_cloud(str(args.output_file), cloud.to_tensor_cloud())
28 |
29 |
30 |
31 | if __name__ == "__main__":
32 | main()
33 |
--------------------------------------------------------------------------------
/synthetic_trees/process_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 |
4 |
5 | import argparse
6 |
7 | import matplotlib.pyplot as plt
8 |
9 | from util.math import calculate_AuC
10 |
11 |
12 | def parse_args():
13 |
14 | parser = argparse.ArgumentParser(description="Results directory.")
15 |
16 | parser.add_argument("-p", "--path",
17 | help="Results path.",
18 | required=True,
19 | default="dataset",
20 | type=str)
21 |
22 | return parser.parse_args()
23 |
24 |
25 | def main():
26 |
27 | args = parse_args()
28 |
29 | df = pd.read_csv(args.path)
30 |
31 | df["f1"] = (2 * ((df["recall"] * df["precision"]) / (df["recall"] + df["precision"]))).fillna(0)
32 |
33 | df = df.round(2)
34 |
35 | df = df.groupby("threshold").agg("mean", numeric_only=True)
36 |
37 | print(f"F1 AUC: {calculate_AuC(df['f1'])}")
38 | print(f"Recall AUC: {calculate_AuC(df['recall'])}")
39 | print(f"Precision AUC: {calculate_AuC(df['precision'])}")
40 |
41 | #download to that directory
42 | ax = df.plot()
43 | ax.set_ylim(0, 100)
44 | ax.set_xlim(0, 1)
45 |
46 | plt.show()
47 |
48 | if __name__ == "__main__":
49 | main()
--------------------------------------------------------------------------------
/synthetic_trees/scripts/batch_rotate.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/synthetic_trees/scripts/batch_rotate.py
--------------------------------------------------------------------------------
/synthetic_trees/test_contraction.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from dataclasses import asdict, dataclass
3 |
4 | from pathlib import Path
5 | import torch
6 | from synthetic_trees.data_types.tree import repair_skeleton, prune_skeleton
7 |
8 | from synthetic_trees.data_types.tube import collate_tubes
9 | from synthetic_trees.util.file import load_data_npz
10 |
11 | import geometry_grid.torch_geometry as torch_geom
12 | from geometry_grid.taichi_geometry.grid import Grid, morton_sort
13 | from geometry_grid.taichi_geometry import Tube
14 |
15 | from geometry_grid.taichi_geometry.dynamic_grid import DynamicGrid
16 | from geometry_grid.taichi_geometry.counted_grid import CountedGrid
17 |
18 |
19 | from geometry_grid.functional.distance import batch_point_distances
20 | from geometry_grid.taichi_geometry.attract_query import attract_query
21 | from geometry_grid.taichi_geometry.min_query import min_query
22 |
23 |
24 | from geometry_grid.render_util import display_distances
25 | from open3d_vis import render
26 | import open3d as o3d
27 | import time
28 |
29 | import taichi as ti
30 | import taichi.math as tm
31 |
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser(description="Visualizer Arguments")
35 |
36 | parser.add_argument("file_path", help="File Path of tree.npz", type=Path)
37 | parser.add_argument("--debug", action="store_true", help="Enable taichi debug mode")
38 |
39 | parser.add_argument("--device", default="cuda:0", help="Device to run on")
40 |
41 | return parser.parse_args()
42 |
43 |
44 | def display_vectors(points, v, point_size=3):
45 | o3d.visualization.draw(
46 | [
47 | render.segments(points, points + v, color=(1, 0, 0)),
48 | render.point_cloud(points, color=(0, 0, 1)),
49 | ],
50 | point_size=point_size,
51 | )
52 |
53 |
54 | @ti.func
55 | def relative_distance(tube: Tube, p: tm.vec3, query_radius: ti.f32):
56 | t, dist_sq = tube.segment.point_dist_sq(p)
57 | dist = ti.sqrt(dist_sq)
58 | r = tube.radius_at(t)
59 |
60 | return ti.select((dist - r) < query_radius, ti.sqrt(dist_sq) / r, torch.inf)
61 |
62 |
63 | def nearest_branch(grid, points: torch.Tensor, query_radius: float):
64 | return min_query(grid, points, query_radius, relative_distance)
65 |
66 |
67 | def to_medial_axis(segments: torch.Tensor, points: torch.Tensor):
68 | points = points.clone().requires_grad_(True)
69 | dist = batch_point_distances(segments, points)
70 |
71 | err = dist.pow(2).sum() * 0.5
72 | err.backward()
73 |
74 | return -points.grad
75 |
76 |
77 | def main():
78 | args = parse_args()
79 |
80 | ti.init(arch=ti.gpu, debug=args.debug, log_level=ti.INFO)
81 |
82 | cloud, skeleton = load_data_npz(args.file_path)
83 | # view_synthetic_data([(data, args.file_path)])
84 |
85 | skeleton = repair_skeleton(skeleton)
86 |
87 | device = torch.device(args.device)
88 | np_tubes = collate_tubes(skeleton.to_tubes())
89 |
90 | tubes = {
91 | k: torch.from_numpy(x).to(dtype=torch.float32, device=device)
92 | for k, x in asdict(np_tubes).items()
93 | }
94 |
95 | segments = torch_geom.Segment(tubes["a"], tubes["b"])
96 | radii = torch.stack((tubes["r1"], tubes["r2"]), -1).squeeze(0)
97 |
98 | tubes = torch_geom.Tube(segments, radii)
99 | bounds = tubes.bounds.union_all()
100 |
101 | points = torch.from_numpy(cloud.xyz).to(dtype=torch.float32, device=device)
102 | points = morton_sort(points, n=256)
103 |
104 | print("Generate grid...")
105 | start_time = time.time()
106 | tube_grid = CountedGrid.from_torch(Grid.fixed_size(bounds, (512, 512, 512)), tubes)
107 | print(f"Grid Made... {time.time() - start_time}")
108 |
109 | point_grid = DynamicGrid.from_torch(
110 | Grid.fixed_size(bounds, (512, 512, 512)), torch_geom.Point(points)
111 | )
112 | print(f"Grid Made... {time.time() - start_time}")
113 |
114 | vis = o3d.visualization.VisualizerWithKeyCallback()
115 | vis.create_window("test", width=800, height=600)
116 |
117 | pcd = render.point_cloud(points, color=(0, 0, 1))
118 | vis.add_geometry(pcd)
119 |
120 | def update_points(vis):
121 | _, idx = min_query(tube_grid, points, 0.2, relative_distance)
122 |
123 | point_grid.update_objects(torch_geom.Point(points))
124 | # forces to regularize points and make them spread out
125 | forces = attract_query(point_grid.index, points, sigma=0.01, query_radius=0.05)
126 |
127 | # # project forces along segment direction only
128 | # dirs = segments.unit_dir[idx]
129 | # forces = torch_geom.dot(dirs, forces).unsqueeze(1) * dirs
130 |
131 | points.add_(forces * 0.5)
132 |
133 | # to_axis = to_medial_axis(segments[idx], points)
134 | # points.add_(to_axis * 0.2 + torch.randn_like(points) * 0.0001)
135 |
136 | pcd.points = o3d.utility.Vector3dVector(points.cpu().numpy())
137 |
138 | vis.update_geometry(pcd)
139 |
140 | vis.register_key_callback(ord(" "), update_points)
141 |
142 | while True:
143 | vis.poll_events()
144 |
145 | # display_vectors(points, -points.grad)
146 | # o3d.visualization.draw(
147 | # [
148 | # render.point_cloud(points, color=(0, 0, 1)),
149 | # # render.point_cloud(points - points.grad, color=(0, 1, 0))
150 | # ],
151 | # point_size=6,
152 | # )
153 |
154 | # print("Grid size: ", seg_grid.grid.size)
155 | # cells, counts = seg_grid.active_cells()
156 |
157 | # max_dist = dist[torch.isfinite(dist)].max()
158 |
159 | # display_distances(tubes, seg_grid.grid.get_boxes(cells),
160 | # points, dist / max_dist )
161 |
162 |
163 | if __name__ == "__main__":
164 | main()
165 |
--------------------------------------------------------------------------------
/synthetic_trees/test_contraction2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import torch
5 | import time
6 |
7 | from dataclasses import dataclass
8 | from typing import List, Dict, Tuple
9 | from pathlib import Path
10 |
11 |
12 | def flatten_list(l):
13 | return [item for sublist in l for item in sublist]
14 |
15 |
16 | @dataclass
17 | class Tube:
18 | a: np.array
19 | b: np.array # Nx3
20 | r1: float
21 | r2: float
22 |
23 |
24 | @dataclass
25 | class CollatedTube:
26 | a: np.array # Nx3
27 | b: np.array # Nx3
28 | r1: np.array # N
29 | r2: np.array # N
30 |
31 | def _to_torch(self, device=torch.device("cuda")):
32 | a = torch.tensor(self.a, device=device)
33 | b = torch.tensor(self.b, device=device)
34 | r1 = torch.tensor(self.r1, device=device)
35 | r2 = torch.tensor(self.r2, device=device)
36 | return CollatedTube(a, b, r1, r2)
37 |
38 |
39 | def collate_tubes(tubes: List[Tube]) -> CollatedTube:
40 | a = np.concatenate([tube.a for tube in tubes]).reshape(-1, 3)
41 | b = np.concatenate([tube.b for tube in tubes]).reshape(-1, 3)
42 |
43 | r1 = np.asarray([tube.r1 for tube in tubes]).reshape(1, -1)
44 | r2 = np.asarray([tube.r2 for tube in tubes]).reshape(1, -1)
45 |
46 | return CollatedTube(a, b, r1, r2)
47 |
48 |
49 | @dataclass
50 | class TreeSkeleton:
51 | _id: int
52 | branches: Dict[int, BranchSkeleton]
53 |
54 | def to_tubes(self) -> List[Tube]:
55 | return flatten_list([branch.to_tubes() for branch in self.branches.values()])
56 |
57 |
58 | @dataclass
59 | class Cloud:
60 | xyz: np.array
61 | rgb: np.array
62 | class_l: np.array = None
63 | vector: np.array = None
64 |
65 |
66 | @dataclass
67 | class BranchSkeleton:
68 | _id: int
69 | parent_id: int
70 | xyz: np.array
71 | radii: np.array
72 | child_id: int = -1
73 |
74 | def to_tubes(self) -> List[Tube]:
75 | a_, b_, r1_, r2_ = (
76 | self.xyz[:-1],
77 | self.xyz[1:],
78 | self.radii[:-1],
79 | self.radii[1:],
80 | )
81 |
82 | return [Tube(a, b, r1, r2) for a, b, r1, r2 in zip(a_, b_, r1_, r2_)]
83 |
84 |
85 | def points_to_collated_tube_projections_gpu(
86 | pts: np.array,
87 | collated_tube: CollatedTube,
88 | device=torch.device("cuda"),
89 | eps=1e-12,
90 | ):
91 | ab = collated_tube.b - collated_tube.a # M x 3 -> tube direction
92 |
93 | ap = pts.unsqueeze(1) - collated_tube.a.unsqueeze(0) # N x M x 3
94 |
95 | t = (
96 | torch.einsum("nmd,md->nm", ap, ab) / (torch.einsum("md,md->m", ab, ab) + eps)
97 | ).clip(
98 | 0.0, 1.0
99 | ) # N x M
100 | proj = collated_tube.a.unsqueeze(0) + torch.einsum("nm,md->nmd", t, ab) # N x M x 3
101 |
102 | return proj, t
103 |
104 |
105 | def projection_to_distance_matrix_gpu(projections, pts): # N x M x 3
106 | return (projections - pts.unsqueeze(1)).square().sum(2).sqrt()
107 |
108 |
109 | def pts_to_nearest_tube_gpu(
110 | pts: np.array, collated_tube: CollatedTube, device=torch.device("cuda")
111 | ):
112 | """Vectors from pt to the nearest tube"""
113 |
114 | # collated_tube = collate_tubes(tubes)
115 |
116 | projections, t = points_to_collated_tube_projections_gpu(
117 | pts, collated_tube, device=torch.device("cuda")
118 | ) # N x M x 3
119 | r = (1 - t) * collated_tube.r1 + t * collated_tube.r2
120 |
121 | distances = projection_to_distance_matrix_gpu(projections, pts) # N x M
122 |
123 | distances = distances - r
124 | idx = torch.argmin(distances, 1) # N
125 |
126 | return (
127 | (projections[torch.arange(pts.shape[0]), idx] - pts),
128 | (idx),
129 | (r[torch.arange(pts.shape[0]), idx]),
130 | )
131 |
132 |
133 | def unpackage_data(data: dict) -> Tuple[Cloud, TreeSkeleton]:
134 | tree_id = data["tree_id"]
135 | branch_id = data["branch_id"]
136 | branch_parent_id = data["branch_parent_id"]
137 | skeleton_xyz = data["skeleton_xyz"]
138 | skeleton_radii = data["skeleton_radii"]
139 | sizes = data["branch_num_elements"]
140 |
141 | cld = Cloud(
142 | xyz=data["xyz"],
143 | rgb=data["rgb"],
144 | class_l=data["class_l"],
145 | vector=data["vector"],
146 | )
147 |
148 | offsets = np.cumsum(np.append([0], sizes))
149 |
150 | branch_idx = [np.arange(size) + offset for size, offset in zip(sizes, offsets)]
151 | branches = {}
152 |
153 | for idx, _id, parent_id in zip(branch_idx, branch_id, branch_parent_id):
154 | branches[_id] = BranchSkeleton(
155 | _id, parent_id, skeleton_xyz[idx], skeleton_radii[idx]
156 | )
157 |
158 | return cld, TreeSkeleton(tree_id, branches)
159 |
160 |
161 | def load_data_npz(path: Path) -> Tuple[Cloud, TreeSkeleton]:
162 | return unpackage_data(np.load(str(path)))
163 |
164 |
165 | if __name__ == "__main__":
166 | cloud, skeleton = load_data_npz(
167 | "/local/point_cloud_datasets/synthetic-trees/tree_dataset/branches/pine/pine_15.npz"
168 | )
169 |
170 | collated_tube = collate_tubes(skeleton.to_tubes())._to_torch()
171 | pts = torch.from_numpy(cloud.xyz).float().to(torch.device("cuda"))
172 |
173 | start_time = time.time()
174 | torch.cuda.synchronize()
175 |
176 | for i in range(0, len(pts), 8000):
177 | batch_pts = pts[i : i + 8000]
178 | vector, idx, radius = pts_to_nearest_tube_gpu(batch_pts, collated_tube)
179 |
180 | torch.cuda.synchronize()
181 |
182 | print(f"Done {time.time() - start_time}")
183 |
184 | print(vector.shape)
185 |
--------------------------------------------------------------------------------
/synthetic_trees/util/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | if sys.version_info[:2] >= (3, 8):
4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover
6 | else:
7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover
8 |
9 | try:
10 | # Change here if project is renamed and does not equal the package name
11 | dist_name = __name__
12 | __version__ = version(dist_name)
13 | except PackageNotFoundError: # pragma: no cover
14 | __version__ = "unknown"
15 | finally:
16 | del version, PackageNotFoundError
17 |
--------------------------------------------------------------------------------
/synthetic_trees/util/file.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 |
4 | from pathlib import Path
5 | from typing import Tuple
6 |
7 |
8 | from ..data_types.cloud import Cloud
9 | from ..data_types.tree import TreeSkeleton
10 | from ..data_types.branch import BranchSkeleton
11 |
12 |
13 | def unpackage_data(data: dict) -> Tuple[Cloud, TreeSkeleton]:
14 | tree_id = data["tree_id"]
15 | branch_id = data["branch_id"]
16 | branch_parent_id = data["branch_parent_id"]
17 | skeleton_xyz = data["skeleton_xyz"]
18 | skeleton_radii = data["skeleton_radii"]
19 | sizes = data["branch_num_elements"]
20 |
21 | medial_vector = data.get("medial_vector", data.get("vector", None))
22 |
23 | cld = Cloud(
24 | xyz=data["xyz"],
25 | rgb=data["rgb"],
26 | class_l=data["class_l"],
27 | medial_vector=medial_vector,
28 | )
29 |
30 | offsets = np.cumsum(np.append([0], sizes))
31 |
32 | branch_idx = [np.arange(size) + offset for size, offset in zip(sizes, offsets)]
33 | branches = {}
34 |
35 | for idx, _id, parent_id in zip(branch_idx, branch_id, branch_parent_id):
36 | branches[_id] = BranchSkeleton(
37 | _id, parent_id, skeleton_xyz[idx], skeleton_radii[idx]
38 | )
39 |
40 | return cld, TreeSkeleton(tree_id, branches)
41 |
42 |
43 | def load_data_npz(path: Path) -> Tuple[Cloud, TreeSkeleton]:
44 | return unpackage_data(np.load(str(path)))
45 |
46 |
47 | def load_o3d_cloud(path: Path):
48 | return o3d.io.read_point_cloud(str(path))
49 |
--------------------------------------------------------------------------------
/synthetic_trees/util/math.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy import trapz
3 |
4 |
5 | def make_tangent(d, n):
6 | t = np.cross(d, n)
7 | t /= np.linalg.norm(t, axis=-1, keepdims=True)
8 | return np.cross(t, d)
9 |
10 |
11 | def unit_circle(n):
12 | a = np.linspace(0, 2 * np.pi, n + 1)[:-1]
13 | return np.stack( [np.sin(a), np.cos(a)], axis=1)
14 |
15 |
16 | def vertex_dirs(points):
17 | d = points[1:] - points[:-1]
18 | d = d / np.linalg.norm(d)
19 |
20 | smooth = (d[1:] + d[:-1]) * 0.5
21 | dirs = np.concatenate([
22 | np.array(d[0:1]), smooth, np.array(d[-2:-1])
23 | ])
24 |
25 | return dirs / np.linalg.norm(dirs, axis=1, keepdims=True)
26 |
27 |
28 | def gen_tangents(dirs, t):
29 | tangents = []
30 |
31 | for dir in dirs:
32 | t = make_tangent(dir, t)
33 | tangents.append(t)
34 |
35 | return np.stack(tangents)
36 |
37 |
38 | def random_unit(dtype=np.float32):
39 | x = np.random.randn(3).astype(dtype)
40 | return x / np.linalg.norm(x)
41 |
42 |
43 | def calculate_AuC(y, dx=0.01):
44 | return trapz(y=y, dx=dx)
--------------------------------------------------------------------------------
/synthetic_trees/util/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from typing import List
5 |
6 | def flatten_list(l):
7 | return [item for sublist in l for item in sublist]
8 |
9 | def to_torch(numpy_arrays: List[np.array], device=torch.device("cpu")):
10 | return [torch.from_numpy(np_arr).float().to(device) for np_arr in numpy_arrays]
--------------------------------------------------------------------------------
/synthetic_trees/util/o3d_abstractions.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 | from typing import Sequence
3 |
4 | import open3d as o3d
5 |
6 | import numpy as np
7 |
8 | from .math import unit_circle, vertex_dirs, gen_tangents, random_unit
9 |
10 |
11 | def o3d_cloud(points, colour=None, colours=None, normals=None):
12 | cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))
13 |
14 | if normals is not None:
15 | cloud.normals = o3d.utility.Vector3dVector(normals)
16 | if colour is not None:
17 | return cloud.paint_uniform_color(colour)
18 | elif colours is not None:
19 | cloud.colors = o3d.utility.Vector3dVector(colours)
20 |
21 | return cloud
22 |
23 |
24 | def o3d_merge_linesets(line_sets, colour=(0, 0, 0)):
25 | sizes = [np.asarray(ls.points).shape[0] for ls in line_sets]
26 | offsets = np.cumsum([0] + sizes)
27 |
28 | points = np.concatenate([ls.points for ls in line_sets])
29 | idxs = np.concatenate([ls.lines + offset for ls, offset in zip(line_sets, offsets)])
30 |
31 | return o3d_line_set(points, idxs).paint_uniform_color(colour)
32 |
33 |
34 | def o3d_line_set(vertices, edges, colour=None):
35 | ls = o3d.geometry.LineSet(
36 | o3d.utility.Vector3dVector(vertices), o3d.utility.Vector2iVector(edges)
37 | )
38 | if colour is not None:
39 | return ls.paint_uniform_color(colour)
40 | return ls
41 |
42 |
43 | def o3d_path(vertices, colour=None):
44 | idx = np.arange(vertices.shape[0] - 1)
45 | edge_idx = np.column_stack((idx, idx + 1))
46 | if colour is not None:
47 | return o3d_line_set(vertices, edge_idx, colour)
48 | return o3d_line_set(vertices, edge_idx)
49 |
50 |
51 | def o3d_merge_meshes(meshes):
52 | sizes = [np.asarray(mesh.vertices).shape[0] for mesh in meshes]
53 | offsets = np.cumsum([0] + sizes)
54 |
55 | part_indexes = np.repeat(np.arange(0, len(meshes)), sizes)
56 |
57 | triangles = np.concatenate(
58 | [mesh.triangles + offset for offset, mesh in zip(offsets, meshes)]
59 | )
60 | vertices = np.concatenate([mesh.vertices for mesh in meshes])
61 |
62 | mesh = o3d_mesh(vertices, triangles)
63 | mesh.vertex_colors = o3d.utility.Vector3dVector(
64 | np.concatenate([np.asarray(mesh.vertex_colors) for mesh in meshes])
65 | )
66 | return mesh
67 |
68 |
69 | def o3d_mesh(verts, tris):
70 | return o3d.geometry.TriangleMesh(
71 | o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(tris)
72 | ).compute_triangle_normals()
73 |
74 |
75 | def o3d_lines_between_clouds(cld1, cld2):
76 | pts1 = np.asarray(cld1.points)
77 | pts2 = np.asarray(cld2.points)
78 |
79 | interweaved = np.hstack((pts1, pts2)).reshape(-1, 3)
80 | return o3d_line_set(
81 | interweaved, np.arange(0, min(pts1.shape[0], pts2.shape[0]) * 2).reshape(-1, 2)
82 | )
83 |
84 |
85 | def cylinder_triangles(m, n):
86 | tri1 = np.array([0, 1, 2])
87 | tri2 = np.array([2, 3, 0])
88 |
89 | v0 = np.arange(m)
90 | v1 = (v0 + 1) % m
91 | v2 = v1 + m
92 | v3 = v0 + m
93 |
94 | edges = np.stack([v0, v1, v2, v3], axis=1)
95 |
96 | segments = np.arange(n - 1) * m
97 | edges = edges.reshape(1, *edges.shape) + segments.reshape(n - 1, 1, 1)
98 |
99 | edges = edges.reshape(-1, 4)
100 | return np.concatenate([edges[:, tri1], edges[:, tri2]])
101 |
102 |
103 | def tube_vertices(points, radii, n=10):
104 | circle = unit_circle(n).astype(np.float32)
105 |
106 | dirs = vertex_dirs(points)
107 | t = gen_tangents(dirs, random_unit())
108 |
109 | b = np.stack([t, np.cross(t, dirs)], axis=1)
110 | b = b * radii.reshape(-1, 1, 1)
111 |
112 | return np.einsum("bdx,md->bmx", b, circle) + points.reshape(points.shape[0], 1, 3)
113 |
114 |
115 | def o3d_tube_mesh(points, radii, colour=(1, 0, 0), n=10):
116 | points = tube_vertices(points, radii, n)
117 |
118 | n, m, _ = points.shape
119 | indexes = cylinder_triangles(m, n)
120 |
121 | mesh = o3d_mesh(points.reshape(-1, 3), indexes)
122 | mesh.compute_vertex_normals()
123 |
124 | return mesh.paint_uniform_color(colour)
125 |
126 |
127 | def o3d_load_lineset(path, colour=[0, 0, 0]):
128 | return o3d.io.read_line_set(path).paint_uniform_color(colour)
129 |
130 |
131 | @dataclass
132 | class ViewerItem:
133 | name: str
134 | geometry: o3d.geometry.Geometry
135 | is_visible: bool = True
136 |
137 |
138 | def o3d_viewer(items: Sequence[ViewerItem], line_width=1):
139 | mat = o3d.visualization.rendering.MaterialRecord()
140 | mat.shader = "defaultLit"
141 |
142 | line_mat = o3d.visualization.rendering.MaterialRecord()
143 | line_mat.shader = "unlitLine"
144 | line_mat.line_width = line_width
145 |
146 | def material(item):
147 | return line_mat if isinstance(item.geometry, o3d.geometry.LineSet) else mat
148 |
149 | geometries = [dict(**asdict(item), material=material(item)) for item in items]
150 | o3d.visualization.draw(geometries, line_width=line_width)
151 |
--------------------------------------------------------------------------------
/synthetic_trees/util/operations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from typing import List
4 |
5 | from ..data_types.tube import Tube
6 |
7 |
8 | def sample_tubes(tubes: List[Tube], sample_rate):
9 |
10 | pts, radius = [], []
11 |
12 | for i, tube in enumerate(tubes):
13 |
14 | start = tube.a
15 | end = tube.b
16 |
17 | start_rad = tube.r1
18 | end_rad = tube.r2
19 |
20 | v = end - start
21 | length = np.linalg.norm(v)
22 | direction = v / length
23 | num_pts = int(np.round(length / sample_rate))
24 |
25 | if num_pts > 0:
26 |
27 | # np.arange(0, float(length), step=float(sample_rate)).reshape(-1,1)
28 | spaced_points = np.linspace(0, length, num_pts).reshape(-1, 1)
29 |
30 | lin_radius = np.linspace(
31 | start_rad, end_rad, spaced_points.shape[0], dtype=float)
32 |
33 | pts.append(start + direction * spaced_points)
34 | radius.append(lin_radius)
35 |
36 | return np.concatenate(pts, axis=0), np.concatenate(radius, axis=0)
37 |
38 |
39 | def sample_o3d_lineset(ls, sample_rate):
40 |
41 | edges = np.asarray(ls.lines)
42 | xyz = np.asarray(ls.points)
43 |
44 | pts, radius = [], []
45 |
46 | for i, edge in enumerate(edges):
47 |
48 | start = xyz[edge[0]]
49 | end = xyz[edge[1]]
50 |
51 | v = end - start
52 | length = np.linalg.norm(v)
53 | direction = v / length
54 | num_pts = int(np.round(length / sample_rate))
55 |
56 | if num_pts > 0:
57 |
58 | # np.arange(0, float(length), step=float(sample_rate)).reshape(-1,1)
59 | spaced_points = np.linspace(0, length, num_pts).reshape(-1, 1)
60 | pts.append(start + direction * spaced_points)
61 |
62 | return np.concatenate(pts, axis=0)
63 |
--------------------------------------------------------------------------------
/synthetic_trees/util/queries.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from typing import List
5 |
6 | from ..data_types.tube import Tube, CollatedTube, collate_tubes
7 |
8 |
9 | from pykeops.torch import LazyTensor
10 |
11 |
12 | """
13 | For the following :
14 | N : number of pts
15 | M : number of tubes
16 | """
17 |
18 |
19 | # N x 3, M x 2
20 | def points_to_collated_tube_projections(pts: np.array, collated_tube: CollatedTube, eps=1e-12):
21 |
22 | ab = collated_tube.b - collated_tube.a # M x 3
23 |
24 | ap = pts[:, np.newaxis] - collated_tube.a[np.newaxis, ...] # N x M x 3
25 |
26 | t = np.clip(np.einsum('nmd,md->nm', ap, ab) /
27 | (np.einsum('md,md->m', ab, ab) + eps), 0.0, 1.0) # N x M
28 | proj = collated_tube.a[np.newaxis, ...] + \
29 | np.einsum('nm,md->nmd', t, ab) # N x M x 3
30 | return proj, t
31 |
32 |
33 | def projection_to_distance_matrix(projections, pts): # N x M x 3
34 | # N x M
35 | return np.sqrt(np.sum(np.square(projections - pts[:, np.newaxis, :]), 2))
36 |
37 |
38 | def pts_to_nearest_tube(pts: np.array, tubes: List[Tube]):
39 | """ Vectors from pt to the nearest tube """
40 |
41 | collated_tube = collate_tubes(tubes)
42 | projections, t = points_to_collated_tube_projections(
43 | pts, collated_tube) # N x M x 3
44 |
45 | r = (1 - t) * collated_tube.r1 + t * collated_tube.r2
46 |
47 | distances = projection_to_distance_matrix(projections, pts) # N x M
48 |
49 | distances = (distances - r)
50 | idx = np.argmin(distances, 1) # N
51 |
52 | # assert idx.shape[0] == pts.shape[0]
53 |
54 | # vector, idx , radius
55 | return projections[np.arange(pts.shape[0]), idx] - pts, idx, r[np.arange(pts.shape[0]), idx]
56 |
57 |
58 | def pts_on_nearest_tube(pts: np.array, tubes: List[Tube]):
59 |
60 | vectors, index, radius = pts_to_nearest_tube(pts, tubes)
61 | return pts + vectors, radius
62 |
63 |
64 | # def knn(src, dest, K=50, r=1.0, grid=None):
65 | # src_lengths = src.new_tensor([src.shape[0]], dtype=torch.long)
66 | # dest_lengths = src.new_tensor([dest.shape[0]], dtype=torch.long)
67 | # dists, idxs, grid, _ = frnn.frnn_grid_points(
68 | # src.unsqueeze(0), dest.unsqueeze(0),
69 | # src_lengths, dest_lengths,
70 | # K, r,return_nn=False, return_sorted=True)
71 | # return idxs.squeeze(0), dists.sqrt().squeeze(0), grid
72 |
73 |
74 | # def nn_frnn(src, dest, r=1.0, grid=None):
75 | # idx, dist, grid = knn(src, dest, K=1, r=r, grid=grid)
76 | # idx, dist = idx.squeeze(1), dist.squeeze(1)
77 | # return idx, dist, grid
78 |
79 |
80 | def distance_matrix_keops(pts1, pts2, device=torch.device("cuda")):
81 |
82 | pts1 = pts1.clone().to(device).float()
83 | pts2 = pts2.clone().to(device).float()
84 |
85 | x_i = LazyTensor(pts1.reshape(-1, 1, 3).float())
86 | y_j = LazyTensor(pts2.view(1, -1, 3).float())
87 |
88 | return (x_i - y_j).square().sum(dim=2).sqrt()
89 |
90 |
91 | def nn_keops(pts1, pts2):
92 |
93 | D_ij = distance_matrix_keops(pts1, pts2)
94 |
95 | return D_ij.min(1), D_ij.argmin(1).flatten() # distance, idx
96 |
--------------------------------------------------------------------------------
/synthetic_trees/view.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 |
5 | from typing import List, Tuple
6 |
7 | from pathlib import Path
8 |
9 | from synthetic_trees.data_types.tree import TreeSkeleton
10 | from synthetic_trees.data_types.cloud import Cloud
11 | from synthetic_trees.util.file import load_data_npz
12 | from synthetic_trees.util.o3d_abstractions import ViewerItem, o3d_viewer
13 |
14 |
15 | def view_synthetic_data(data: List[Tuple[Cloud, TreeSkeleton]], line_width=1):
16 | geometries = []
17 | for i, item in enumerate(data):
18 | (cloud, skeleton), path = item
19 |
20 | tree_name = path.stem
21 | visible = i == 0
22 |
23 | geometries = [
24 | ViewerItem(
25 | f"{tree_name}_cloud",
26 | cloud.to_o3d_cloud(),
27 | is_visible=visible,
28 | ),
29 | ViewerItem(
30 | f"{tree_name}_labelled_cloud",
31 | cloud.to_o3d_cloud_labelled(),
32 | is_visible=visible,
33 | ),
34 | ViewerItem(
35 | f"{tree_name}_medial_vectors",
36 | cloud.to_o3d_medial_vectors(),
37 | is_visible=visible,
38 | ),
39 | ViewerItem(
40 | f"{tree_name}_skeleton",
41 | skeleton.to_o3d_lineset(),
42 | is_visible=visible,
43 | ),
44 | ViewerItem(
45 | f"{tree_name}_skeleton_mesh",
46 | skeleton.to_o3d_tubes(),
47 | is_visible=visible,
48 | ),
49 | ]
50 |
51 | o3d_viewer(geometries, line_width=line_width)
52 |
53 |
54 | def parse_args():
55 | parser = argparse.ArgumentParser(description="Visualizer Arguments")
56 |
57 | parser.add_argument(
58 | "file_path",
59 | help="File Path of tree.npz",
60 | default=None,
61 | type=Path,
62 | )
63 |
64 | parser.add_argument(
65 | "-lw",
66 | "--line_width",
67 | help="Width of visualizer lines",
68 | default=1,
69 | type=int,
70 | )
71 | return parser.parse_args()
72 |
73 |
74 | def paths_from_args(args, glob="*.npz"):
75 | if not args.file_path.exists():
76 | raise ValueError(f"File {args.file_path} does not exist")
77 |
78 | if args.file_path.is_file():
79 | print(f"Loading data from file: {args.file_path}")
80 | return [args.file_path]
81 |
82 | if args.file_path.is_dir():
83 | print(f"Loading data from directory: {args.file_path}")
84 | files = args.file_path.glob(glob)
85 | if files == []:
86 | raise ValueError(f"No npz files found in {args.file_path}")
87 | return files
88 |
89 |
90 | def main():
91 | args = parse_args()
92 |
93 | data = [(load_data_npz(filename), filename) for filename in paths_from_args(args)]
94 | view_synthetic_data(data, args.line_width)
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
99 |
--------------------------------------------------------------------------------
/synthetic_trees/view_clouds.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from synthetic_trees.view import parse_args
4 |
5 |
6 | from synthetic_trees.util.file import load_o3d_cloud
7 | from synthetic_trees.util.o3d_abstractions import ViewerItem, o3d_viewer
8 | from synthetic_trees.view import paths_from_args
9 |
10 |
11 | def view_clouds(data: list):
12 | geometries = [ViewerItem(f"{path.stem}_cloud", cloud, visible=i == 0)
13 | for i, (cloud, path) in enumerate(data)]
14 |
15 | o3d_viewer(geometries)
16 |
17 |
18 | def main():
19 | args = parse_args()
20 | data = [(load_o3d_cloud(filename), filename)
21 | for filename in paths_from_args(args, '*.ply')]
22 |
23 | view_clouds(data)
24 |
25 |
26 | if __name__ == "__main__":
27 | main()
28 |
--------------------------------------------------------------------------------
/tests/test_skeleton.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from synthetic_trees import download
4 |
5 | __author__ = "Harry Dobbs"
6 | __copyright__ = "Harry Dobbs"
7 | __license__ = "MIT"
8 |
9 |
10 |
--------------------------------------------------------------------------------