├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── CHANGELOG.rst ├── LICENSE ├── README.md ├── images ├── apple-pcd.png ├── apple-skeleton.png ├── cherry-pcd.png ├── cherry-skeleton.png ├── pine-pcd.png └── pine-skeleton.png ├── pyproject.toml ├── synthetic_trees ├── .gitignore ├── __init__.py ├── data_types │ ├── __init__.py │ ├── branch.py │ ├── cloud.py │ ├── tree.py │ └── tube.py ├── download.py ├── evaluate.py ├── evaluation │ ├── __init__.py │ ├── metrics.py │ └── results.py ├── extract_cloud.py ├── process_results.py ├── scripts │ └── batch_rotate.py ├── test_contraction.py ├── test_contraction2.py ├── util │ ├── __init__.py │ ├── file.py │ ├── math.py │ ├── misc.py │ ├── o3d_abstractions.py │ ├── operations.py │ └── queries.py ├── view.py └── view_clouds.py └── tests └── test_skeleton.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.8.16 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.8.16" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install . 30 | 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | dataset/ 132 | outputs/ 133 | 134 | *.csv 135 | dataset.zip 136 | data/ -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | Version 0.1 6 | =========== 7 | 8 | - 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 UC Vision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #
🌳🌲🌴 Synthetic-Trees 🌴🌲🌳
2 | 3 | ## 📝 Description 4 | 5 | This repository offers a synthetic point cloud dataset with ground truth skeletons of multiple species. Our library enables you to open, visualize, and assess the accuracy of the skeletons. To gain insight into how we created the data and the evaluation metrics employed, please refer to our published paper, available at this link. Our dataset consists of two parts: one with point clouds that feature foliage, which is particularly useful for training models that can handle real-world data that includes leaves; the other contains only the branching structure and is less affected to occlusion. 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
Sapling Cherry Point Cloud.Apple Tree Point Cloud.Pine Tree Point Cloud.
Sapling Cherry Ground Truth Skeleton.Apple Tree Ground Truth Skeleton.Pine Tree Ground Truth Skeleton.
32 | 33 | ## 🔍 Usage 34 | 35 | #### 💾 Downloading 36 | 37 | You can download the data by following this link. The dataset includes synthetic point clouds and ground truth skeletons, along with a JSON file that specifies the training, validation, and test sets. For evaluation purposes, we have provided "cleaned" point clouds and skeletons in the evaluation folder, which are suitable for assessment. 38 | 39 | 40 | #### 🖥️ Installation 41 | To install: 42 | Create a conda environment: 43 | 44 | `conda create -n synthetic-trees python=3.8` 45 | 46 | then: 47 | 48 | `pip install .` 49 | 50 | #### 🕵️‍♂️ Visualizing 51 | To visualize the data, use the `visualize.py` script. You can call it using either: 52 | 53 | ``` view-synthetic-trees -p=file_path -lw=linewidth ``` 54 | or 55 | ``` view-synthetic-trees -d=directory -lw=linewidth ``` 56 | 57 | where: 58 | - `file_path` is the path of the `.npz` file of a single tree. 59 | - `directory` is the path of the folder containing `.npz` files. 60 | - `linewidth` is the width of the skeleton lines in the visualizer. 61 | 62 | #### 📊 Evaluation 63 | To evaluate your method against the ground truth data, use the `evaluate.py` script. You can call it using: 64 | 65 | ``` evaluate-synthetic-trees -d_gt=ground_truth_directory -d_o=output_directory -r_o=results_save_path ``` 66 | 67 | where: 68 | - `ground_truth_directory` is the directory of the ground truth `.npz` files. 69 | - `output_directory` is the directory of the folder containing your skeleton outputs (in `.ply` format). 70 | - `results_save_path` is the path of the `.csv` file to save your results to. 71 | 72 | #### 📋 Processing Results 73 | After running the evaluation, you can use the `process_results.py` script to post-process the raw results and obtain metrics across the dataset. Call it using: 74 | 75 | ``` process-synthetic-trees-results -p=path ``` 76 | 77 | where: `path` is the path of the results `.csv` file from the evaluation step. 78 | 79 | ## 📜 Citation 80 | Please use the following BibTeX entry to cite our work:
81 | ``` 82 | @inproceedings{dobbs2023smart, 83 | title={Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization}, 84 | author={Dobbs, Harry and Batchelor, Oliver and Green, Richard and Atlas, James}, 85 | booktitle={Iberian Conference on Pattern Recognition and Image Analysis}, 86 | pages={351--362}, 87 | year={2023}, 88 | organization={Springer} 89 | } 90 | 91 | ``` 92 | 93 | ## 📥 Contact 94 | 95 | Should you have any questions, comments or suggestions please use the following contact details: 96 | harry.dobbs@pg.canterbury.ac.nz 97 | -------------------------------------------------------------------------------- /images/apple-pcd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/apple-pcd.png -------------------------------------------------------------------------------- /images/apple-skeleton.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/apple-skeleton.png -------------------------------------------------------------------------------- /images/cherry-pcd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/cherry-pcd.png -------------------------------------------------------------------------------- /images/cherry-skeleton.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/cherry-skeleton.png -------------------------------------------------------------------------------- /images/pine-pcd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/pine-pcd.png -------------------------------------------------------------------------------- /images/pine-skeleton.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/images/pine-skeleton.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "synthetic_trees" 7 | authors = [ 8 | {name = "Harry Dobbs", email = "harrydobbs87@gmail.com"}, 9 | ] 10 | description = "Tools for synthetic tree point cloud dataset." 11 | readme = "README.rst" 12 | requires-python = ">=3.7" 13 | license = {text = "MIT"} 14 | dependencies = [ 15 | 'numpy', 16 | 'open3d==0.17', 17 | 'pykeops', 18 | 'torch'] 19 | dynamic = ["version"] 20 | 21 | [tool.setuptools.packages] 22 | find = {} # Scan the project directory with the default parameters 23 | 24 | [project.scripts] 25 | view-synthetic-trees = "synthetic_trees.view:main" 26 | test-contraction = "synthetic_trees.test_contraction:main" 27 | 28 | download-synthetic-trees = "synthetic_trees.download:main" 29 | evaluate-synthetic-trees = "synthetic_trees.evaluate:main" 30 | process-synthetic-trees-results = "synthetic_trees.process_results:main" 31 | view-pointclouds = "synthetic_trees.view_clouds:main" 32 | 33 | 34 | -------------------------------------------------------------------------------- /synthetic_trees/.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | 17 | # Project files 18 | .ropeproject 19 | .project 20 | .pydevproject 21 | .settings 22 | .idea 23 | .vscode 24 | tags 25 | 26 | # Package files 27 | *.egg 28 | *.eggs/ 29 | .installed.cfg 30 | *.egg-info 31 | 32 | # Unittest and coverage 33 | htmlcov/* 34 | .coverage 35 | .coverage.* 36 | .tox 37 | junit*.xml 38 | coverage.xml 39 | .pytest_cache/ 40 | 41 | # Build and docs folder/files 42 | build/* 43 | dist/* 44 | sdist/* 45 | docs/api/* 46 | docs/_rst/* 47 | docs/_build/* 48 | cover/* 49 | MANIFEST 50 | 51 | # Per-project virtualenvs 52 | .venv*/ 53 | .conda*/ 54 | .python-version 55 | -------------------------------------------------------------------------------- /synthetic_trees/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 6 | else: 7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 8 | 9 | try: 10 | # Change here if project is renamed and does not equal the package name 11 | dist_name = __name__ 12 | __version__ = version(dist_name) 13 | except PackageNotFoundError: # pragma: no cover 14 | __version__ = "unknown" 15 | finally: 16 | del version, PackageNotFoundError 17 | -------------------------------------------------------------------------------- /synthetic_trees/data_types/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 6 | else: 7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 8 | 9 | try: 10 | # Change here if project is renamed and does not equal the package name 11 | dist_name = __name__ 12 | __version__ = version(dist_name) 13 | except PackageNotFoundError: # pragma: no cover 14 | __version__ = "unknown" 15 | finally: 16 | del version, PackageNotFoundError 17 | -------------------------------------------------------------------------------- /synthetic_trees/data_types/branch.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | from dataclasses import dataclass 6 | from typing import List, Dict 7 | 8 | from .tube import Tube 9 | from ..util.o3d_abstractions import o3d_path, o3d_tube_mesh 10 | 11 | from ..util.queries import pts_on_nearest_tube 12 | 13 | 14 | @dataclass 15 | class BranchSkeleton: 16 | _id: int 17 | parent_id: int 18 | xyz: np.array 19 | radii: np.array 20 | child_id: int = -1 21 | 22 | @property 23 | def length(self): 24 | return np.sum(np.sqrt(np.sum(np.diff(self.xyz, axis=0)**2, axis=1))) 25 | 26 | def __len__(self): 27 | return self.xyz.shape[0] 28 | 29 | def __str__(self): 30 | return f"Branch {self._id} with {self.xyz} points. \ 31 | and {self.radii} radii" 32 | 33 | def to_tubes(self) -> List[Tube]: 34 | a_, b_, r1_, r2_ = ( 35 | self.xyz[:-1], self.xyz[1:], self.radii[:-1], self.radii[1:]) 36 | 37 | return [Tube(a, b, r1, r2) for a, b, r1, r2 in zip(a_, b_, r1_, r2_)] 38 | 39 | def closest_pt(self, pt: np.array): # closest point on skeleton to query point 40 | return pts_on_nearest_tube(pt, self.to_tubes()) 41 | 42 | def to_o3d_lineset(self, colour=(0, 0, 0)) -> o3d.cuda.pybind.geometry.LineSet: 43 | return o3d_path(self.xyz, colour) 44 | 45 | def to_o3d_tube(self): 46 | return o3d_tube_mesh(self.xyz, self.radii) 47 | -------------------------------------------------------------------------------- /synthetic_trees/data_types/cloud.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from dataclasses import dataclass 6 | from ..util.o3d_abstractions import o3d_cloud, o3d_lines_between_clouds 7 | 8 | 9 | @dataclass 10 | class Cloud: 11 | xyz: np.array 12 | rgb: np.array 13 | class_l: np.array = None 14 | medial_vector: np.array = None 15 | 16 | def __len__(self): 17 | return self.xyz.shape[0] 18 | 19 | def __str__(self): 20 | return f"Cloud with {self.xyz.shape[0]} points " 21 | 22 | def to_o3d_cloud(self): 23 | return o3d_cloud(self.xyz, colours=self.rgb) 24 | 25 | def to_o3d_cloud_labelled(self, cmap=None): 26 | if cmap is None: 27 | cmap = np.random.rand(self.number_classes, 3) 28 | 29 | return o3d_cloud(self.xyz, colours=cmap[self.class_l]) 30 | 31 | def to_o3d_medial_vectors(self, cmap=None): 32 | medial_cloud = o3d_cloud(self.xyz + self.medial_vector) 33 | return o3d_lines_between_clouds(self.to_o3d_cloud(), medial_cloud) 34 | 35 | def to_device(self, device): 36 | if self.xyz is not None: 37 | self.xyz = ( 38 | torch.from_numpy(self.xyz).to(device) 39 | if isinstance(self.xyz, np.ndarray) 40 | else self.xyz.to(device) 41 | ) 42 | 43 | if self.rgb is not None: 44 | self.rgb = ( 45 | torch.from_numpy(self.rgb).to(device) 46 | if isinstance(self.rgb, np.ndarray) 47 | else self.rgb.to(device) 48 | ) 49 | 50 | if self.class_l is not None: 51 | self.class_l = ( 52 | torch.from_numpy(self.class_l).to(device) 53 | if isinstance(self.class_l, np.ndarray) 54 | else self.class_l.to(device) 55 | ) 56 | 57 | if self.medial_vector is not None: 58 | self.medial_vector = ( 59 | torch.from_numpy(self.medial_vector).to(device) 60 | if isinstance(self.medial_vector, np.ndarray) 61 | else self.medial_vector.to(device) 62 | ) 63 | 64 | @property 65 | def number_classes(self): 66 | return torch.max(self.class_l) + 1 67 | -------------------------------------------------------------------------------- /synthetic_trees/data_types/tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | 4 | from copy import deepcopy 5 | from dataclasses import dataclass 6 | from typing import List, Dict 7 | 8 | from ..util.o3d_abstractions import o3d_merge_linesets, o3d_merge_meshes 9 | from ..util.misc import flatten_list 10 | 11 | from ..util.operations import sample_tubes 12 | from .branch import BranchSkeleton 13 | from .tube import Tube 14 | 15 | 16 | @dataclass 17 | class TreeSkeleton: 18 | _id: int 19 | branches: Dict[int, BranchSkeleton] 20 | 21 | def __len__(self): 22 | return len(self.branches) 23 | 24 | def __str__(self): 25 | return (f"Tree Skeleton ({self._id}) has {len(self)} branches...") 26 | 27 | def to_tubes(self) -> List[Tube]: 28 | return flatten_list([branch.to_tubes() for branch in self.branches.values()]) 29 | 30 | def to_o3d_tubes(self) -> o3d.cuda.pybind.geometry.TriangleMesh: 31 | return o3d_merge_meshes([branch.to_o3d_tube() for branch in self.branches.values()]) 32 | 33 | def to_o3d_lineset(self) -> o3d.cuda.pybind.geometry.LineSet: 34 | return o3d_merge_linesets([branch.to_o3d_lineset() for branch in self.branches.values()]) 35 | 36 | def point_sample(self, sample_rate=0.01) -> o3d.cuda.pybind.geometry.PointCloud: 37 | return sample_tubes(self.to_tubes(), sample_rate) 38 | 39 | 40 | def repair_skeleton(skeleton: TreeSkeleton): 41 | """ By default the skeletons are not connected between branches. 42 | this function connects the branches to their parent branches by finding 43 | the nearest point on the parent branch - relative to radius. 44 | It returns a new skeleton with no reference to the original. 45 | """ 46 | skeleton = deepcopy(skeleton) 47 | 48 | for branch in list(skeleton.branches.values()): 49 | 50 | if branch.parent_id == -1 or branch.parent_id == 0: 51 | continue 52 | 53 | parent_branch = skeleton.branches[branch.parent_id] 54 | 55 | connection_pt, connection_rad = parent_branch.closest_pt( 56 | pt=branch.xyz[[0]]) 57 | 58 | branch.xyz = np.insert(branch.xyz, 0, connection_pt, axis=0) 59 | branch.radii = np.insert(branch.radii, 0, connection_rad, axis=0) 60 | 61 | return skeleton 62 | 63 | 64 | def prune_skeleton(skeleton: TreeSkeleton, min_radius_threshold=0.01, length_threshold=0.02, root_id=1): 65 | """ In the skeleton format we are using each branch only knows it's parent 66 | but not it's child (could work this out by doing a traversal). If a branch doesn't 67 | meet the initial radius threshold or length threshold we want to remove it and all 68 | it's predecessors... 69 | Because of the way the skeleton is initalized however we know that earlier branches 70 | are guaranteed to be of lower order. 71 | minimum_radius_threshold: some point of the branch must be above this to not remove the branch 72 | length_threshold: the total length of the branch must be greater than this point 73 | """ 74 | branches_to_keep = {root_id: skeleton.branches[root_id]} 75 | 76 | for branch_id, branch in skeleton.branches.items(): 77 | 78 | if branch.parent_id == -1: 79 | continue 80 | 81 | if branch.parent_id in branches_to_keep: 82 | if branch.length > length_threshold and branch.radii[0] > min_radius_threshold: 83 | branches_to_keep[branch_id] = branch 84 | 85 | return TreeSkeleton(skeleton._id, branches_to_keep) 86 | -------------------------------------------------------------------------------- /synthetic_trees/data_types/tube.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import List 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class Tube: 9 | a: np.array # Start Point 3 10 | b: np.array # End Point 3 11 | r1: float # Start Radius 12 | r2: float # End Radius 13 | 14 | 15 | @dataclass 16 | class CollatedTube: 17 | a: np.array # Nx3 18 | b: np.array # Nx3 19 | r1: np.array # N 20 | r2: np.array # N 21 | 22 | 23 | def collate_tubes(tubes: List[Tube]) -> CollatedTube: 24 | 25 | a = np.concatenate([tube.a for tube in tubes]).reshape(-1, 3) 26 | b = np.concatenate([tube.b for tube in tubes]).reshape(-1, 3) 27 | 28 | r1 = np.asarray([tube.r1 for tube in tubes]).reshape(1, -1) 29 | r2 = np.asarray([tube.r2 for tube in tubes]).reshape(1, -1) 30 | 31 | return CollatedTube(a, b, r1, r2) 32 | -------------------------------------------------------------------------------- /synthetic_trees/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import argparse 4 | 5 | 6 | def parse_args(): 7 | 8 | parser = argparse.ArgumentParser(description="Downloader Arguments") 9 | 10 | parser.add_argument("-d", "--directory", 11 | help="Tree download directory.", 12 | required=False, 13 | default="dataset", type=str) 14 | 15 | return parser.parse_args() 16 | 17 | 18 | def main(): 19 | 20 | args = parse_args() 21 | 22 | if not os.path.isdir(args.directory): 23 | os.mkdir(args.directory) 24 | 25 | # download to that directory 26 | 27 | 28 | 29 | 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /synthetic_trees/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from typing import List, Tuple 5 | from pathlib import Path 6 | import argparse 7 | 8 | import torch 9 | import open3d as o3d 10 | 11 | from tqdm import tqdm 12 | 13 | from data_types.tree import TreeSkeleton, repair_skeleton 14 | from data_types.cloud import Cloud 15 | 16 | from util.file import load_data_npz 17 | from util.o3d_abstractions import o3d_load_lineset 18 | from evaluation.results import save_results 19 | 20 | from util.operations import sample_o3d_lineset 21 | from util.misc import to_torch 22 | from evaluation.metrics import recall, precision 23 | 24 | 25 | def evaluate_one(gt_skeleton: TreeSkeleton, output_skeleton: o3d.cuda.pybind.geometry.LineSet, thresholds=np.linspace(0,1,100), sample_rate=0.001): 26 | 27 | results = {} 28 | 29 | skeleton = repair_skeleton(gt_skeleton) 30 | 31 | gt_xyzs, gt_radii = skeleton.point_sample(sample_rate) 32 | 33 | output_pts = sample_o3d_lineset(output_skeleton, sample_rate) 34 | 35 | #o3d_viewer([o3d_cloud(gt_xyzs, colour=(0,1,0)), o3d_cloud(output_pts, colour=(1,0,0)), skeleton.to_o3d_lineset()]) 36 | 37 | gt_xyzs_c, gt_radii_c, output_pts_c = to_torch([gt_xyzs, gt_radii, output_pts], device=torch.device("cuda")) 38 | 39 | results["recall"] = recall(gt_xyzs_c, output_pts_c, gt_radii_c.reshape(-1), thresholds=thresholds) 40 | results["precision"] = precision(gt_xyzs_c, output_pts_c, gt_radii_c.reshape(-1), thresholds=thresholds) 41 | results['thresholds'] = thresholds 42 | 43 | return results 44 | 45 | 46 | def gt_skeleton_generator(paths: List[Path]) -> Tuple[Cloud, TreeSkeleton]: 47 | for path in paths: 48 | yield load_data_npz(path)[1] 49 | 50 | def output_skeleton_generator(paths: List[Path]) -> o3d.cuda.pybind.geometry.LineSet: 51 | for path in paths: 52 | yield o3d_load_lineset(str(path)) 53 | 54 | 55 | def parse_args(): 56 | 57 | parser = argparse.ArgumentParser(description="Visualizer Arguments") 58 | 59 | parser.add_argument("-d_gt", "--ground_truth_dir", 60 | help="Directory of folder of tree.npz(s) *.npz", default=None, type=str) 61 | 62 | parser.add_argument("-d_o", "--output_dir", 63 | help="Directory of folder of skeleton outputs *.ply", default=None, type=str) 64 | 65 | parser.add_argument("-r_o", "--results_save_path", 66 | help="Path to save results csv to", default="results.csv", type=str, required=False) 67 | 68 | return parser.parse_args() 69 | 70 | 71 | def main(): 72 | 73 | args = parse_args() 74 | 75 | ground_truth_paths = list(Path(args.ground_truth_dir).glob("*.npz")) 76 | output_paths = list(Path(args.output_dir).glob("*.ply")) 77 | 78 | ground_truth_tree_names = [path.stem for path in ground_truth_paths] 79 | output_tree_names = [path.stem for path in output_paths] 80 | 81 | tree_names = list(set(ground_truth_tree_names).intersection(set(output_tree_names))) 82 | 83 | gt_paths = sorted([p for p in ground_truth_paths if p.stem in tree_names]) #[30:] 84 | output_paths = sorted([p for p in output_paths if p.stem in tree_names]) #[30:] 85 | 86 | results = {} 87 | 88 | for gt_skeleton, output_skeleton, tree_name in tqdm(zip(gt_skeleton_generator(gt_paths), output_skeleton_generator(output_paths), tree_names)): 89 | 90 | results[f"{tree_name}"] = evaluate_one(gt_skeleton, output_skeleton, sample_rate=0.001) 91 | 92 | save_results(results, args.results_save_path) 93 | 94 | if __name__ == "__main__": 95 | main() -------------------------------------------------------------------------------- /synthetic_trees/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 6 | else: 7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 8 | 9 | try: 10 | # Change here if project is renamed and does not equal the package name 11 | dist_name = __name__ 12 | __version__ = version(dist_name) 13 | except PackageNotFoundError: # pragma: no cover 14 | __version__ = "unknown" 15 | finally: 16 | del version, PackageNotFoundError 17 | -------------------------------------------------------------------------------- /synthetic_trees/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import frnn 3 | 4 | from util.queries import nn_frnn, nn_keops 5 | 6 | 7 | def recall(gt_points, test_points, gt_radii, thresholds=[0.1]): # recall (completeness) 8 | 9 | results = [] 10 | dist, idx = nn_keops(gt_points, test_points) 11 | idx = idx.reshape(-1) 12 | dist = dist.reshape(-1) 13 | 14 | for t in thresholds: 15 | 16 | mask = dist < (gt_radii * t) 17 | 18 | valid_percentage = (torch.sum(mask) / gt_points.shape[0]) * 100 19 | 20 | results.append(valid_percentage.item()) 21 | 22 | return results 23 | 24 | 25 | def precision(gt_points, test_points, gt_radii, thresholds=[0.1]): # precision (how close) 26 | 27 | results = [] 28 | dist, idx = nn_keops(test_points, gt_points) 29 | idx = idx.reshape(-1) 30 | dist = dist.reshape(-1) 31 | 32 | for t in thresholds: 33 | 34 | mask = dist < (gt_radii[idx] * t) 35 | 36 | valid_percentage = (torch.sum(mask) / test_points.shape[0]) * 100 37 | 38 | results.append(valid_percentage.item()) 39 | 40 | return results 41 | 42 | 43 | # def recall(gt_points, test_points, gt_radii, threshold=1.0): 44 | # idxs, dists, _ = nn(gt_points, test_points, r=gt_radii.max().item()) 45 | # valid_idx = idxs[idxs != -1] 46 | # valid = (dists[valid_idx] < gt_radii[valid_idx] * threshold) 47 | # return (torch.sum(valid).cpu().item() / gt_points.shape[0]) * 100 48 | 49 | 50 | # def precision(test_points, gt_points, gt_radii, threshold=1.0): 51 | # idxs, dists, _ = nn(test_points, gt_points, r=gt_radii.max().item()) 52 | # valid_idx = idxs[idxs != -1] 53 | # valid = dists[valid_idx] < gt_radii[valid_idx] * threshold 54 | 55 | # print(torch.sum(torch.isnan(valid.cpu()))) 56 | # return (torch.sum(valid.cpu()).item() / test_points.shape[0]) * 100 -------------------------------------------------------------------------------- /synthetic_trees/evaluation/results.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from typing import Dict 4 | 5 | 6 | def save_results(results_dict: Dict, save_path: str ="results.csv"): 7 | 8 | rows = [] 9 | 10 | for tree_name, results in results_dict.items(): 11 | 12 | species = tree_name.split("_")[0] 13 | 14 | for threshold, recall, precision in zip(results["thresholds"], results["recall"], results["precision"]): 15 | 16 | rows.append([tree_name, species, threshold, recall, precision]) 17 | 18 | df = pd.DataFrame(rows, columns=["tree_name", "species", "threshold", "recall", "precision"]) 19 | df.to_csv(f"{save_path}", index=False) 20 | -------------------------------------------------------------------------------- /synthetic_trees/extract_cloud.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from synthetic_trees.util.file import load_data_npz 5 | import open3d as o3d 6 | 7 | 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="Visualizer Arguments") 12 | 13 | parser.add_argument("file_path", 14 | help="File Path of tree.npz",type=Path) 15 | parser.add_argument("output_file", 16 | help="File path to write cloud", type=Path) 17 | return parser.parse_args() 18 | 19 | 20 | 21 | def main(): 22 | 23 | args = parse_args() 24 | assert args.file_path.exists(), f"File {args.file_path} does not exist" 25 | 26 | cloud, skeleton = load_data_npz(args.file_path) 27 | o3d.t.io.write_point_cloud(str(args.output_file), cloud.to_tensor_cloud()) 28 | 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /synthetic_trees/process_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | 5 | import argparse 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | from util.math import calculate_AuC 10 | 11 | 12 | def parse_args(): 13 | 14 | parser = argparse.ArgumentParser(description="Results directory.") 15 | 16 | parser.add_argument("-p", "--path", 17 | help="Results path.", 18 | required=True, 19 | default="dataset", 20 | type=str) 21 | 22 | return parser.parse_args() 23 | 24 | 25 | def main(): 26 | 27 | args = parse_args() 28 | 29 | df = pd.read_csv(args.path) 30 | 31 | df["f1"] = (2 * ((df["recall"] * df["precision"]) / (df["recall"] + df["precision"]))).fillna(0) 32 | 33 | df = df.round(2) 34 | 35 | df = df.groupby("threshold").agg("mean", numeric_only=True) 36 | 37 | print(f"F1 AUC: {calculate_AuC(df['f1'])}") 38 | print(f"Recall AUC: {calculate_AuC(df['recall'])}") 39 | print(f"Precision AUC: {calculate_AuC(df['precision'])}") 40 | 41 | #download to that directory 42 | ax = df.plot() 43 | ax.set_ylim(0, 100) 44 | ax.set_xlim(0, 1) 45 | 46 | plt.show() 47 | 48 | if __name__ == "__main__": 49 | main() -------------------------------------------------------------------------------- /synthetic_trees/scripts/batch_rotate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/synthetic-trees/2e3277b942bbb789b4056eee26e07ae6c6215af8/synthetic_trees/scripts/batch_rotate.py -------------------------------------------------------------------------------- /synthetic_trees/test_contraction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import asdict, dataclass 3 | 4 | from pathlib import Path 5 | import torch 6 | from synthetic_trees.data_types.tree import repair_skeleton, prune_skeleton 7 | 8 | from synthetic_trees.data_types.tube import collate_tubes 9 | from synthetic_trees.util.file import load_data_npz 10 | 11 | import geometry_grid.torch_geometry as torch_geom 12 | from geometry_grid.taichi_geometry.grid import Grid, morton_sort 13 | from geometry_grid.taichi_geometry import Tube 14 | 15 | from geometry_grid.taichi_geometry.dynamic_grid import DynamicGrid 16 | from geometry_grid.taichi_geometry.counted_grid import CountedGrid 17 | 18 | 19 | from geometry_grid.functional.distance import batch_point_distances 20 | from geometry_grid.taichi_geometry.attract_query import attract_query 21 | from geometry_grid.taichi_geometry.min_query import min_query 22 | 23 | 24 | from geometry_grid.render_util import display_distances 25 | from open3d_vis import render 26 | import open3d as o3d 27 | import time 28 | 29 | import taichi as ti 30 | import taichi.math as tm 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser(description="Visualizer Arguments") 35 | 36 | parser.add_argument("file_path", help="File Path of tree.npz", type=Path) 37 | parser.add_argument("--debug", action="store_true", help="Enable taichi debug mode") 38 | 39 | parser.add_argument("--device", default="cuda:0", help="Device to run on") 40 | 41 | return parser.parse_args() 42 | 43 | 44 | def display_vectors(points, v, point_size=3): 45 | o3d.visualization.draw( 46 | [ 47 | render.segments(points, points + v, color=(1, 0, 0)), 48 | render.point_cloud(points, color=(0, 0, 1)), 49 | ], 50 | point_size=point_size, 51 | ) 52 | 53 | 54 | @ti.func 55 | def relative_distance(tube: Tube, p: tm.vec3, query_radius: ti.f32): 56 | t, dist_sq = tube.segment.point_dist_sq(p) 57 | dist = ti.sqrt(dist_sq) 58 | r = tube.radius_at(t) 59 | 60 | return ti.select((dist - r) < query_radius, ti.sqrt(dist_sq) / r, torch.inf) 61 | 62 | 63 | def nearest_branch(grid, points: torch.Tensor, query_radius: float): 64 | return min_query(grid, points, query_radius, relative_distance) 65 | 66 | 67 | def to_medial_axis(segments: torch.Tensor, points: torch.Tensor): 68 | points = points.clone().requires_grad_(True) 69 | dist = batch_point_distances(segments, points) 70 | 71 | err = dist.pow(2).sum() * 0.5 72 | err.backward() 73 | 74 | return -points.grad 75 | 76 | 77 | def main(): 78 | args = parse_args() 79 | 80 | ti.init(arch=ti.gpu, debug=args.debug, log_level=ti.INFO) 81 | 82 | cloud, skeleton = load_data_npz(args.file_path) 83 | # view_synthetic_data([(data, args.file_path)]) 84 | 85 | skeleton = repair_skeleton(skeleton) 86 | 87 | device = torch.device(args.device) 88 | np_tubes = collate_tubes(skeleton.to_tubes()) 89 | 90 | tubes = { 91 | k: torch.from_numpy(x).to(dtype=torch.float32, device=device) 92 | for k, x in asdict(np_tubes).items() 93 | } 94 | 95 | segments = torch_geom.Segment(tubes["a"], tubes["b"]) 96 | radii = torch.stack((tubes["r1"], tubes["r2"]), -1).squeeze(0) 97 | 98 | tubes = torch_geom.Tube(segments, radii) 99 | bounds = tubes.bounds.union_all() 100 | 101 | points = torch.from_numpy(cloud.xyz).to(dtype=torch.float32, device=device) 102 | points = morton_sort(points, n=256) 103 | 104 | print("Generate grid...") 105 | start_time = time.time() 106 | tube_grid = CountedGrid.from_torch(Grid.fixed_size(bounds, (512, 512, 512)), tubes) 107 | print(f"Grid Made... {time.time() - start_time}") 108 | 109 | point_grid = DynamicGrid.from_torch( 110 | Grid.fixed_size(bounds, (512, 512, 512)), torch_geom.Point(points) 111 | ) 112 | print(f"Grid Made... {time.time() - start_time}") 113 | 114 | vis = o3d.visualization.VisualizerWithKeyCallback() 115 | vis.create_window("test", width=800, height=600) 116 | 117 | pcd = render.point_cloud(points, color=(0, 0, 1)) 118 | vis.add_geometry(pcd) 119 | 120 | def update_points(vis): 121 | _, idx = min_query(tube_grid, points, 0.2, relative_distance) 122 | 123 | point_grid.update_objects(torch_geom.Point(points)) 124 | # forces to regularize points and make them spread out 125 | forces = attract_query(point_grid.index, points, sigma=0.01, query_radius=0.05) 126 | 127 | # # project forces along segment direction only 128 | # dirs = segments.unit_dir[idx] 129 | # forces = torch_geom.dot(dirs, forces).unsqueeze(1) * dirs 130 | 131 | points.add_(forces * 0.5) 132 | 133 | # to_axis = to_medial_axis(segments[idx], points) 134 | # points.add_(to_axis * 0.2 + torch.randn_like(points) * 0.0001) 135 | 136 | pcd.points = o3d.utility.Vector3dVector(points.cpu().numpy()) 137 | 138 | vis.update_geometry(pcd) 139 | 140 | vis.register_key_callback(ord(" "), update_points) 141 | 142 | while True: 143 | vis.poll_events() 144 | 145 | # display_vectors(points, -points.grad) 146 | # o3d.visualization.draw( 147 | # [ 148 | # render.point_cloud(points, color=(0, 0, 1)), 149 | # # render.point_cloud(points - points.grad, color=(0, 1, 0)) 150 | # ], 151 | # point_size=6, 152 | # ) 153 | 154 | # print("Grid size: ", seg_grid.grid.size) 155 | # cells, counts = seg_grid.active_cells() 156 | 157 | # max_dist = dist[torch.isfinite(dist)].max() 158 | 159 | # display_distances(tubes, seg_grid.grid.get_boxes(cells), 160 | # points, dist / max_dist ) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /synthetic_trees/test_contraction2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import torch 5 | import time 6 | 7 | from dataclasses import dataclass 8 | from typing import List, Dict, Tuple 9 | from pathlib import Path 10 | 11 | 12 | def flatten_list(l): 13 | return [item for sublist in l for item in sublist] 14 | 15 | 16 | @dataclass 17 | class Tube: 18 | a: np.array 19 | b: np.array # Nx3 20 | r1: float 21 | r2: float 22 | 23 | 24 | @dataclass 25 | class CollatedTube: 26 | a: np.array # Nx3 27 | b: np.array # Nx3 28 | r1: np.array # N 29 | r2: np.array # N 30 | 31 | def _to_torch(self, device=torch.device("cuda")): 32 | a = torch.tensor(self.a, device=device) 33 | b = torch.tensor(self.b, device=device) 34 | r1 = torch.tensor(self.r1, device=device) 35 | r2 = torch.tensor(self.r2, device=device) 36 | return CollatedTube(a, b, r1, r2) 37 | 38 | 39 | def collate_tubes(tubes: List[Tube]) -> CollatedTube: 40 | a = np.concatenate([tube.a for tube in tubes]).reshape(-1, 3) 41 | b = np.concatenate([tube.b for tube in tubes]).reshape(-1, 3) 42 | 43 | r1 = np.asarray([tube.r1 for tube in tubes]).reshape(1, -1) 44 | r2 = np.asarray([tube.r2 for tube in tubes]).reshape(1, -1) 45 | 46 | return CollatedTube(a, b, r1, r2) 47 | 48 | 49 | @dataclass 50 | class TreeSkeleton: 51 | _id: int 52 | branches: Dict[int, BranchSkeleton] 53 | 54 | def to_tubes(self) -> List[Tube]: 55 | return flatten_list([branch.to_tubes() for branch in self.branches.values()]) 56 | 57 | 58 | @dataclass 59 | class Cloud: 60 | xyz: np.array 61 | rgb: np.array 62 | class_l: np.array = None 63 | vector: np.array = None 64 | 65 | 66 | @dataclass 67 | class BranchSkeleton: 68 | _id: int 69 | parent_id: int 70 | xyz: np.array 71 | radii: np.array 72 | child_id: int = -1 73 | 74 | def to_tubes(self) -> List[Tube]: 75 | a_, b_, r1_, r2_ = ( 76 | self.xyz[:-1], 77 | self.xyz[1:], 78 | self.radii[:-1], 79 | self.radii[1:], 80 | ) 81 | 82 | return [Tube(a, b, r1, r2) for a, b, r1, r2 in zip(a_, b_, r1_, r2_)] 83 | 84 | 85 | def points_to_collated_tube_projections_gpu( 86 | pts: np.array, 87 | collated_tube: CollatedTube, 88 | device=torch.device("cuda"), 89 | eps=1e-12, 90 | ): 91 | ab = collated_tube.b - collated_tube.a # M x 3 -> tube direction 92 | 93 | ap = pts.unsqueeze(1) - collated_tube.a.unsqueeze(0) # N x M x 3 94 | 95 | t = ( 96 | torch.einsum("nmd,md->nm", ap, ab) / (torch.einsum("md,md->m", ab, ab) + eps) 97 | ).clip( 98 | 0.0, 1.0 99 | ) # N x M 100 | proj = collated_tube.a.unsqueeze(0) + torch.einsum("nm,md->nmd", t, ab) # N x M x 3 101 | 102 | return proj, t 103 | 104 | 105 | def projection_to_distance_matrix_gpu(projections, pts): # N x M x 3 106 | return (projections - pts.unsqueeze(1)).square().sum(2).sqrt() 107 | 108 | 109 | def pts_to_nearest_tube_gpu( 110 | pts: np.array, collated_tube: CollatedTube, device=torch.device("cuda") 111 | ): 112 | """Vectors from pt to the nearest tube""" 113 | 114 | # collated_tube = collate_tubes(tubes) 115 | 116 | projections, t = points_to_collated_tube_projections_gpu( 117 | pts, collated_tube, device=torch.device("cuda") 118 | ) # N x M x 3 119 | r = (1 - t) * collated_tube.r1 + t * collated_tube.r2 120 | 121 | distances = projection_to_distance_matrix_gpu(projections, pts) # N x M 122 | 123 | distances = distances - r 124 | idx = torch.argmin(distances, 1) # N 125 | 126 | return ( 127 | (projections[torch.arange(pts.shape[0]), idx] - pts), 128 | (idx), 129 | (r[torch.arange(pts.shape[0]), idx]), 130 | ) 131 | 132 | 133 | def unpackage_data(data: dict) -> Tuple[Cloud, TreeSkeleton]: 134 | tree_id = data["tree_id"] 135 | branch_id = data["branch_id"] 136 | branch_parent_id = data["branch_parent_id"] 137 | skeleton_xyz = data["skeleton_xyz"] 138 | skeleton_radii = data["skeleton_radii"] 139 | sizes = data["branch_num_elements"] 140 | 141 | cld = Cloud( 142 | xyz=data["xyz"], 143 | rgb=data["rgb"], 144 | class_l=data["class_l"], 145 | vector=data["vector"], 146 | ) 147 | 148 | offsets = np.cumsum(np.append([0], sizes)) 149 | 150 | branch_idx = [np.arange(size) + offset for size, offset in zip(sizes, offsets)] 151 | branches = {} 152 | 153 | for idx, _id, parent_id in zip(branch_idx, branch_id, branch_parent_id): 154 | branches[_id] = BranchSkeleton( 155 | _id, parent_id, skeleton_xyz[idx], skeleton_radii[idx] 156 | ) 157 | 158 | return cld, TreeSkeleton(tree_id, branches) 159 | 160 | 161 | def load_data_npz(path: Path) -> Tuple[Cloud, TreeSkeleton]: 162 | return unpackage_data(np.load(str(path))) 163 | 164 | 165 | if __name__ == "__main__": 166 | cloud, skeleton = load_data_npz( 167 | "/local/point_cloud_datasets/synthetic-trees/tree_dataset/branches/pine/pine_15.npz" 168 | ) 169 | 170 | collated_tube = collate_tubes(skeleton.to_tubes())._to_torch() 171 | pts = torch.from_numpy(cloud.xyz).float().to(torch.device("cuda")) 172 | 173 | start_time = time.time() 174 | torch.cuda.synchronize() 175 | 176 | for i in range(0, len(pts), 8000): 177 | batch_pts = pts[i : i + 8000] 178 | vector, idx, radius = pts_to_nearest_tube_gpu(batch_pts, collated_tube) 179 | 180 | torch.cuda.synchronize() 181 | 182 | print(f"Done {time.time() - start_time}") 183 | 184 | print(vector.shape) 185 | -------------------------------------------------------------------------------- /synthetic_trees/util/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 6 | else: 7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 8 | 9 | try: 10 | # Change here if project is renamed and does not equal the package name 11 | dist_name = __name__ 12 | __version__ = version(dist_name) 13 | except PackageNotFoundError: # pragma: no cover 14 | __version__ = "unknown" 15 | finally: 16 | del version, PackageNotFoundError 17 | -------------------------------------------------------------------------------- /synthetic_trees/util/file.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | 4 | from pathlib import Path 5 | from typing import Tuple 6 | 7 | 8 | from ..data_types.cloud import Cloud 9 | from ..data_types.tree import TreeSkeleton 10 | from ..data_types.branch import BranchSkeleton 11 | 12 | 13 | def unpackage_data(data: dict) -> Tuple[Cloud, TreeSkeleton]: 14 | tree_id = data["tree_id"] 15 | branch_id = data["branch_id"] 16 | branch_parent_id = data["branch_parent_id"] 17 | skeleton_xyz = data["skeleton_xyz"] 18 | skeleton_radii = data["skeleton_radii"] 19 | sizes = data["branch_num_elements"] 20 | 21 | medial_vector = data.get("medial_vector", data.get("vector", None)) 22 | 23 | cld = Cloud( 24 | xyz=data["xyz"], 25 | rgb=data["rgb"], 26 | class_l=data["class_l"], 27 | medial_vector=medial_vector, 28 | ) 29 | 30 | offsets = np.cumsum(np.append([0], sizes)) 31 | 32 | branch_idx = [np.arange(size) + offset for size, offset in zip(sizes, offsets)] 33 | branches = {} 34 | 35 | for idx, _id, parent_id in zip(branch_idx, branch_id, branch_parent_id): 36 | branches[_id] = BranchSkeleton( 37 | _id, parent_id, skeleton_xyz[idx], skeleton_radii[idx] 38 | ) 39 | 40 | return cld, TreeSkeleton(tree_id, branches) 41 | 42 | 43 | def load_data_npz(path: Path) -> Tuple[Cloud, TreeSkeleton]: 44 | return unpackage_data(np.load(str(path))) 45 | 46 | 47 | def load_o3d_cloud(path: Path): 48 | return o3d.io.read_point_cloud(str(path)) 49 | -------------------------------------------------------------------------------- /synthetic_trees/util/math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import trapz 3 | 4 | 5 | def make_tangent(d, n): 6 | t = np.cross(d, n) 7 | t /= np.linalg.norm(t, axis=-1, keepdims=True) 8 | return np.cross(t, d) 9 | 10 | 11 | def unit_circle(n): 12 | a = np.linspace(0, 2 * np.pi, n + 1)[:-1] 13 | return np.stack( [np.sin(a), np.cos(a)], axis=1) 14 | 15 | 16 | def vertex_dirs(points): 17 | d = points[1:] - points[:-1] 18 | d = d / np.linalg.norm(d) 19 | 20 | smooth = (d[1:] + d[:-1]) * 0.5 21 | dirs = np.concatenate([ 22 | np.array(d[0:1]), smooth, np.array(d[-2:-1]) 23 | ]) 24 | 25 | return dirs / np.linalg.norm(dirs, axis=1, keepdims=True) 26 | 27 | 28 | def gen_tangents(dirs, t): 29 | tangents = [] 30 | 31 | for dir in dirs: 32 | t = make_tangent(dir, t) 33 | tangents.append(t) 34 | 35 | return np.stack(tangents) 36 | 37 | 38 | def random_unit(dtype=np.float32): 39 | x = np.random.randn(3).astype(dtype) 40 | return x / np.linalg.norm(x) 41 | 42 | 43 | def calculate_AuC(y, dx=0.01): 44 | return trapz(y=y, dx=dx) -------------------------------------------------------------------------------- /synthetic_trees/util/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from typing import List 5 | 6 | def flatten_list(l): 7 | return [item for sublist in l for item in sublist] 8 | 9 | def to_torch(numpy_arrays: List[np.array], device=torch.device("cpu")): 10 | return [torch.from_numpy(np_arr).float().to(device) for np_arr in numpy_arrays] -------------------------------------------------------------------------------- /synthetic_trees/util/o3d_abstractions.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | from typing import Sequence 3 | 4 | import open3d as o3d 5 | 6 | import numpy as np 7 | 8 | from .math import unit_circle, vertex_dirs, gen_tangents, random_unit 9 | 10 | 11 | def o3d_cloud(points, colour=None, colours=None, normals=None): 12 | cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points)) 13 | 14 | if normals is not None: 15 | cloud.normals = o3d.utility.Vector3dVector(normals) 16 | if colour is not None: 17 | return cloud.paint_uniform_color(colour) 18 | elif colours is not None: 19 | cloud.colors = o3d.utility.Vector3dVector(colours) 20 | 21 | return cloud 22 | 23 | 24 | def o3d_merge_linesets(line_sets, colour=(0, 0, 0)): 25 | sizes = [np.asarray(ls.points).shape[0] for ls in line_sets] 26 | offsets = np.cumsum([0] + sizes) 27 | 28 | points = np.concatenate([ls.points for ls in line_sets]) 29 | idxs = np.concatenate([ls.lines + offset for ls, offset in zip(line_sets, offsets)]) 30 | 31 | return o3d_line_set(points, idxs).paint_uniform_color(colour) 32 | 33 | 34 | def o3d_line_set(vertices, edges, colour=None): 35 | ls = o3d.geometry.LineSet( 36 | o3d.utility.Vector3dVector(vertices), o3d.utility.Vector2iVector(edges) 37 | ) 38 | if colour is not None: 39 | return ls.paint_uniform_color(colour) 40 | return ls 41 | 42 | 43 | def o3d_path(vertices, colour=None): 44 | idx = np.arange(vertices.shape[0] - 1) 45 | edge_idx = np.column_stack((idx, idx + 1)) 46 | if colour is not None: 47 | return o3d_line_set(vertices, edge_idx, colour) 48 | return o3d_line_set(vertices, edge_idx) 49 | 50 | 51 | def o3d_merge_meshes(meshes): 52 | sizes = [np.asarray(mesh.vertices).shape[0] for mesh in meshes] 53 | offsets = np.cumsum([0] + sizes) 54 | 55 | part_indexes = np.repeat(np.arange(0, len(meshes)), sizes) 56 | 57 | triangles = np.concatenate( 58 | [mesh.triangles + offset for offset, mesh in zip(offsets, meshes)] 59 | ) 60 | vertices = np.concatenate([mesh.vertices for mesh in meshes]) 61 | 62 | mesh = o3d_mesh(vertices, triangles) 63 | mesh.vertex_colors = o3d.utility.Vector3dVector( 64 | np.concatenate([np.asarray(mesh.vertex_colors) for mesh in meshes]) 65 | ) 66 | return mesh 67 | 68 | 69 | def o3d_mesh(verts, tris): 70 | return o3d.geometry.TriangleMesh( 71 | o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(tris) 72 | ).compute_triangle_normals() 73 | 74 | 75 | def o3d_lines_between_clouds(cld1, cld2): 76 | pts1 = np.asarray(cld1.points) 77 | pts2 = np.asarray(cld2.points) 78 | 79 | interweaved = np.hstack((pts1, pts2)).reshape(-1, 3) 80 | return o3d_line_set( 81 | interweaved, np.arange(0, min(pts1.shape[0], pts2.shape[0]) * 2).reshape(-1, 2) 82 | ) 83 | 84 | 85 | def cylinder_triangles(m, n): 86 | tri1 = np.array([0, 1, 2]) 87 | tri2 = np.array([2, 3, 0]) 88 | 89 | v0 = np.arange(m) 90 | v1 = (v0 + 1) % m 91 | v2 = v1 + m 92 | v3 = v0 + m 93 | 94 | edges = np.stack([v0, v1, v2, v3], axis=1) 95 | 96 | segments = np.arange(n - 1) * m 97 | edges = edges.reshape(1, *edges.shape) + segments.reshape(n - 1, 1, 1) 98 | 99 | edges = edges.reshape(-1, 4) 100 | return np.concatenate([edges[:, tri1], edges[:, tri2]]) 101 | 102 | 103 | def tube_vertices(points, radii, n=10): 104 | circle = unit_circle(n).astype(np.float32) 105 | 106 | dirs = vertex_dirs(points) 107 | t = gen_tangents(dirs, random_unit()) 108 | 109 | b = np.stack([t, np.cross(t, dirs)], axis=1) 110 | b = b * radii.reshape(-1, 1, 1) 111 | 112 | return np.einsum("bdx,md->bmx", b, circle) + points.reshape(points.shape[0], 1, 3) 113 | 114 | 115 | def o3d_tube_mesh(points, radii, colour=(1, 0, 0), n=10): 116 | points = tube_vertices(points, radii, n) 117 | 118 | n, m, _ = points.shape 119 | indexes = cylinder_triangles(m, n) 120 | 121 | mesh = o3d_mesh(points.reshape(-1, 3), indexes) 122 | mesh.compute_vertex_normals() 123 | 124 | return mesh.paint_uniform_color(colour) 125 | 126 | 127 | def o3d_load_lineset(path, colour=[0, 0, 0]): 128 | return o3d.io.read_line_set(path).paint_uniform_color(colour) 129 | 130 | 131 | @dataclass 132 | class ViewerItem: 133 | name: str 134 | geometry: o3d.geometry.Geometry 135 | is_visible: bool = True 136 | 137 | 138 | def o3d_viewer(items: Sequence[ViewerItem], line_width=1): 139 | mat = o3d.visualization.rendering.MaterialRecord() 140 | mat.shader = "defaultLit" 141 | 142 | line_mat = o3d.visualization.rendering.MaterialRecord() 143 | line_mat.shader = "unlitLine" 144 | line_mat.line_width = line_width 145 | 146 | def material(item): 147 | return line_mat if isinstance(item.geometry, o3d.geometry.LineSet) else mat 148 | 149 | geometries = [dict(**asdict(item), material=material(item)) for item in items] 150 | o3d.visualization.draw(geometries, line_width=line_width) 151 | -------------------------------------------------------------------------------- /synthetic_trees/util/operations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import List 4 | 5 | from ..data_types.tube import Tube 6 | 7 | 8 | def sample_tubes(tubes: List[Tube], sample_rate): 9 | 10 | pts, radius = [], [] 11 | 12 | for i, tube in enumerate(tubes): 13 | 14 | start = tube.a 15 | end = tube.b 16 | 17 | start_rad = tube.r1 18 | end_rad = tube.r2 19 | 20 | v = end - start 21 | length = np.linalg.norm(v) 22 | direction = v / length 23 | num_pts = int(np.round(length / sample_rate)) 24 | 25 | if num_pts > 0: 26 | 27 | # np.arange(0, float(length), step=float(sample_rate)).reshape(-1,1) 28 | spaced_points = np.linspace(0, length, num_pts).reshape(-1, 1) 29 | 30 | lin_radius = np.linspace( 31 | start_rad, end_rad, spaced_points.shape[0], dtype=float) 32 | 33 | pts.append(start + direction * spaced_points) 34 | radius.append(lin_radius) 35 | 36 | return np.concatenate(pts, axis=0), np.concatenate(radius, axis=0) 37 | 38 | 39 | def sample_o3d_lineset(ls, sample_rate): 40 | 41 | edges = np.asarray(ls.lines) 42 | xyz = np.asarray(ls.points) 43 | 44 | pts, radius = [], [] 45 | 46 | for i, edge in enumerate(edges): 47 | 48 | start = xyz[edge[0]] 49 | end = xyz[edge[1]] 50 | 51 | v = end - start 52 | length = np.linalg.norm(v) 53 | direction = v / length 54 | num_pts = int(np.round(length / sample_rate)) 55 | 56 | if num_pts > 0: 57 | 58 | # np.arange(0, float(length), step=float(sample_rate)).reshape(-1,1) 59 | spaced_points = np.linspace(0, length, num_pts).reshape(-1, 1) 60 | pts.append(start + direction * spaced_points) 61 | 62 | return np.concatenate(pts, axis=0) 63 | -------------------------------------------------------------------------------- /synthetic_trees/util/queries.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from typing import List 5 | 6 | from ..data_types.tube import Tube, CollatedTube, collate_tubes 7 | 8 | 9 | from pykeops.torch import LazyTensor 10 | 11 | 12 | """ 13 | For the following : 14 | N : number of pts 15 | M : number of tubes 16 | """ 17 | 18 | 19 | # N x 3, M x 2 20 | def points_to_collated_tube_projections(pts: np.array, collated_tube: CollatedTube, eps=1e-12): 21 | 22 | ab = collated_tube.b - collated_tube.a # M x 3 23 | 24 | ap = pts[:, np.newaxis] - collated_tube.a[np.newaxis, ...] # N x M x 3 25 | 26 | t = np.clip(np.einsum('nmd,md->nm', ap, ab) / 27 | (np.einsum('md,md->m', ab, ab) + eps), 0.0, 1.0) # N x M 28 | proj = collated_tube.a[np.newaxis, ...] + \ 29 | np.einsum('nm,md->nmd', t, ab) # N x M x 3 30 | return proj, t 31 | 32 | 33 | def projection_to_distance_matrix(projections, pts): # N x M x 3 34 | # N x M 35 | return np.sqrt(np.sum(np.square(projections - pts[:, np.newaxis, :]), 2)) 36 | 37 | 38 | def pts_to_nearest_tube(pts: np.array, tubes: List[Tube]): 39 | """ Vectors from pt to the nearest tube """ 40 | 41 | collated_tube = collate_tubes(tubes) 42 | projections, t = points_to_collated_tube_projections( 43 | pts, collated_tube) # N x M x 3 44 | 45 | r = (1 - t) * collated_tube.r1 + t * collated_tube.r2 46 | 47 | distances = projection_to_distance_matrix(projections, pts) # N x M 48 | 49 | distances = (distances - r) 50 | idx = np.argmin(distances, 1) # N 51 | 52 | # assert idx.shape[0] == pts.shape[0] 53 | 54 | # vector, idx , radius 55 | return projections[np.arange(pts.shape[0]), idx] - pts, idx, r[np.arange(pts.shape[0]), idx] 56 | 57 | 58 | def pts_on_nearest_tube(pts: np.array, tubes: List[Tube]): 59 | 60 | vectors, index, radius = pts_to_nearest_tube(pts, tubes) 61 | return pts + vectors, radius 62 | 63 | 64 | # def knn(src, dest, K=50, r=1.0, grid=None): 65 | # src_lengths = src.new_tensor([src.shape[0]], dtype=torch.long) 66 | # dest_lengths = src.new_tensor([dest.shape[0]], dtype=torch.long) 67 | # dists, idxs, grid, _ = frnn.frnn_grid_points( 68 | # src.unsqueeze(0), dest.unsqueeze(0), 69 | # src_lengths, dest_lengths, 70 | # K, r,return_nn=False, return_sorted=True) 71 | # return idxs.squeeze(0), dists.sqrt().squeeze(0), grid 72 | 73 | 74 | # def nn_frnn(src, dest, r=1.0, grid=None): 75 | # idx, dist, grid = knn(src, dest, K=1, r=r, grid=grid) 76 | # idx, dist = idx.squeeze(1), dist.squeeze(1) 77 | # return idx, dist, grid 78 | 79 | 80 | def distance_matrix_keops(pts1, pts2, device=torch.device("cuda")): 81 | 82 | pts1 = pts1.clone().to(device).float() 83 | pts2 = pts2.clone().to(device).float() 84 | 85 | x_i = LazyTensor(pts1.reshape(-1, 1, 3).float()) 86 | y_j = LazyTensor(pts2.view(1, -1, 3).float()) 87 | 88 | return (x_i - y_j).square().sum(dim=2).sqrt() 89 | 90 | 91 | def nn_keops(pts1, pts2): 92 | 93 | D_ij = distance_matrix_keops(pts1, pts2) 94 | 95 | return D_ij.min(1), D_ij.argmin(1).flatten() # distance, idx 96 | -------------------------------------------------------------------------------- /synthetic_trees/view.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | from typing import List, Tuple 6 | 7 | from pathlib import Path 8 | 9 | from synthetic_trees.data_types.tree import TreeSkeleton 10 | from synthetic_trees.data_types.cloud import Cloud 11 | from synthetic_trees.util.file import load_data_npz 12 | from synthetic_trees.util.o3d_abstractions import ViewerItem, o3d_viewer 13 | 14 | 15 | def view_synthetic_data(data: List[Tuple[Cloud, TreeSkeleton]], line_width=1): 16 | geometries = [] 17 | for i, item in enumerate(data): 18 | (cloud, skeleton), path = item 19 | 20 | tree_name = path.stem 21 | visible = i == 0 22 | 23 | geometries = [ 24 | ViewerItem( 25 | f"{tree_name}_cloud", 26 | cloud.to_o3d_cloud(), 27 | is_visible=visible, 28 | ), 29 | ViewerItem( 30 | f"{tree_name}_labelled_cloud", 31 | cloud.to_o3d_cloud_labelled(), 32 | is_visible=visible, 33 | ), 34 | ViewerItem( 35 | f"{tree_name}_medial_vectors", 36 | cloud.to_o3d_medial_vectors(), 37 | is_visible=visible, 38 | ), 39 | ViewerItem( 40 | f"{tree_name}_skeleton", 41 | skeleton.to_o3d_lineset(), 42 | is_visible=visible, 43 | ), 44 | ViewerItem( 45 | f"{tree_name}_skeleton_mesh", 46 | skeleton.to_o3d_tubes(), 47 | is_visible=visible, 48 | ), 49 | ] 50 | 51 | o3d_viewer(geometries, line_width=line_width) 52 | 53 | 54 | def parse_args(): 55 | parser = argparse.ArgumentParser(description="Visualizer Arguments") 56 | 57 | parser.add_argument( 58 | "file_path", 59 | help="File Path of tree.npz", 60 | default=None, 61 | type=Path, 62 | ) 63 | 64 | parser.add_argument( 65 | "-lw", 66 | "--line_width", 67 | help="Width of visualizer lines", 68 | default=1, 69 | type=int, 70 | ) 71 | return parser.parse_args() 72 | 73 | 74 | def paths_from_args(args, glob="*.npz"): 75 | if not args.file_path.exists(): 76 | raise ValueError(f"File {args.file_path} does not exist") 77 | 78 | if args.file_path.is_file(): 79 | print(f"Loading data from file: {args.file_path}") 80 | return [args.file_path] 81 | 82 | if args.file_path.is_dir(): 83 | print(f"Loading data from directory: {args.file_path}") 84 | files = args.file_path.glob(glob) 85 | if files == []: 86 | raise ValueError(f"No npz files found in {args.file_path}") 87 | return files 88 | 89 | 90 | def main(): 91 | args = parse_args() 92 | 93 | data = [(load_data_npz(filename), filename) for filename in paths_from_args(args)] 94 | view_synthetic_data(data, args.line_width) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /synthetic_trees/view_clouds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from synthetic_trees.view import parse_args 4 | 5 | 6 | from synthetic_trees.util.file import load_o3d_cloud 7 | from synthetic_trees.util.o3d_abstractions import ViewerItem, o3d_viewer 8 | from synthetic_trees.view import paths_from_args 9 | 10 | 11 | def view_clouds(data: list): 12 | geometries = [ViewerItem(f"{path.stem}_cloud", cloud, visible=i == 0) 13 | for i, (cloud, path) in enumerate(data)] 14 | 15 | o3d_viewer(geometries) 16 | 17 | 18 | def main(): 19 | args = parse_args() 20 | data = [(load_o3d_cloud(filename), filename) 21 | for filename in paths_from_args(args, '*.ply')] 22 | 23 | view_clouds(data) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /tests/test_skeleton.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from synthetic_trees import download 4 | 5 | __author__ = "Harry Dobbs" 6 | __copyright__ = "Harry Dobbs" 7 | __license__ = "MIT" 8 | 9 | 10 | --------------------------------------------------------------------------------