├── smart_tree ├── conf │ ├── __init__.py │ ├── pipeline.yaml │ ├── training.yaml │ └── training-split.json ├── model │ ├── __init__.py │ ├── weights │ │ ├── __init__.py │ │ ├── noble-elevator-58_model.pt │ │ ├── peach-forest-65_model.pt │ │ ├── peach-forest-65_model_weights.pt │ │ └── noble-elevator-58_model_weights.pt │ ├── render.py │ ├── tracker.py │ ├── helper.py │ ├── model.py │ ├── fp16.py │ ├── sparse.py │ ├── loss.py │ ├── model_inference.py │ ├── train.py │ └── model_blocks.py ├── util │ ├── __init__.py │ ├── misc.py │ ├── file.py │ ├── queries.py │ └── maths.py ├── data_types │ ├── __init__.py │ ├── graph.py │ ├── branch.py │ ├── tube.py │ ├── tree.py │ └── cloud.py ├── dataset │ ├── __init__.py │ ├── augmentations.py │ └── dataset.py ├── scripts │ ├── __init__.py │ ├── clean_up.sh │ ├── laz2ply.py │ ├── bench_dataloader.py │ ├── vis_dataloader.py │ ├── split-data.py │ └── view_npz.py ├── skeleton │ ├── __init__.py │ ├── filter.py │ ├── connection.py │ ├── shortest_path.py │ ├── skeletonize.py │ ├── graph.py │ └── path.py ├── o3d_abstractions │ ├── __init__.py │ ├── visualizer.py │ ├── camera.py │ └── geometries.py ├── __init__.py ├── tests │ ├── test-dataloader.yaml │ └── dataloader.py ├── cli.py └── pipeline.py ├── images ├── botanic-pcd.png ├── botanic-skeleton.png └── botanic-branch-mesh.png ├── tests └── type_checks.py ├── create-env.sh ├── pyproject.toml ├── LICENSE ├── .gitignore └── README.md /smart_tree/conf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/data_types/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/skeleton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/model/weights/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/o3d_abstractions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smart_tree/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | -------------------------------------------------------------------------------- /smart_tree/scripts/clean_up.sh: -------------------------------------------------------------------------------- 1 | isort . 2 | pycln . -------------------------------------------------------------------------------- /images/botanic-pcd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/images/botanic-pcd.png -------------------------------------------------------------------------------- /images/botanic-skeleton.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/images/botanic-skeleton.png -------------------------------------------------------------------------------- /images/botanic-branch-mesh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/images/botanic-branch-mesh.png -------------------------------------------------------------------------------- /smart_tree/model/weights/noble-elevator-58_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/smart_tree/model/weights/noble-elevator-58_model.pt -------------------------------------------------------------------------------- /smart_tree/model/weights/peach-forest-65_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/smart_tree/model/weights/peach-forest-65_model.pt -------------------------------------------------------------------------------- /smart_tree/model/weights/peach-forest-65_model_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/smart_tree/model/weights/peach-forest-65_model_weights.pt -------------------------------------------------------------------------------- /smart_tree/model/weights/noble-elevator-58_model_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/smart_tree/model/weights/noble-elevator-58_model_weights.pt -------------------------------------------------------------------------------- /smart_tree/skeleton/filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .graph import knn 4 | 5 | 6 | def outlier_removal(points, radii, nb_points=4): 7 | idxs, dists, _ = knn(points, points, K=nb_points, r=torch.max(radii).item()) 8 | 9 | keep = (dists < radii) & (idxs != -1) 10 | 11 | return keep.sum(1) == nb_points 12 | -------------------------------------------------------------------------------- /tests/type_checks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchtyping import patch_typeguard 3 | 4 | from smart_tree.data_types.branch import BranchSkeleton 5 | 6 | patch_typeguard() # use before @typechecked 7 | 8 | 9 | def test_branch_skeleton_type_error(): 10 | xyz = torch.rand(100) 11 | radii = torch.rand(100) 12 | 13 | BranchSkeleton(0, 0, xyz, radii) 14 | 15 | 16 | if __name__ == "__main__": 17 | test_branch_skeleton_type_error() 18 | -------------------------------------------------------------------------------- /smart_tree/scripts/laz2ply.py: -------------------------------------------------------------------------------- 1 | import laspy 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | 6 | def las_to_ply(input_las_file, output_ply_file): 7 | with laspy.open(input_las_file) as fh: 8 | las = fh.read() 9 | xyz = np.column_stack((las.x, las.y, las.z)) 10 | 11 | cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz)) 12 | o3d.io.write_point_cloud(output_ply_file, cloud) 13 | 14 | 15 | if __name__ == "__main__": 16 | input_las_file = "/csse/users/hdo27/Desktop/ChCh_Hovermap_tree.laz" 17 | output_ply_file = "output.ply" 18 | 19 | las_to_ply(input_las_file, output_ply_file) 20 | -------------------------------------------------------------------------------- /smart_tree/tests/test-dataloader.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | _target_: smart_tree.dataset.dataset.TreeDataset 3 | voxel_size: 0.01 4 | json_path: /local/uc-vision/smart-tree/smart_tree/conf/tree-split.json 5 | directory: /local/uc-vision/dataset/branches/ 6 | blocking: True 7 | transform: True 8 | block_size: 4 9 | buffer_size: 0.4 10 | 11 | data_loader: 12 | _target_: torch.utils.data.DataLoader 13 | batch_size: 1 14 | drop_last: True 15 | pin_memory: True 16 | num_workers: 0 17 | shuffle: True 18 | #prefetch_factor: None 19 | collate_fn: 20 | _target_: smart_tree.model.sparse.batch_collate 21 | _partial_: True -------------------------------------------------------------------------------- /smart_tree/tests/dataloader.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import hydra 4 | from hydra.utils import instantiate 5 | from omegaconf import DictConfig 6 | from tqdm import tqdm 7 | 8 | 9 | 10 | @hydra.main( 11 | version_base=None, 12 | config_path=".", 13 | config_name="test-dataloader", 14 | ) 15 | def main(cfg: DictConfig): 16 | train_dataloader = instantiate( 17 | cfg.data_loader, dataset=instantiate(cfg.dataset, mode="train") 18 | ) 19 | 20 | start_time = time.time() 21 | for data in tqdm(train_dataloader): 22 | pass 23 | 24 | print(time.time() - start_time) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /create-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | eval "$(conda shell.bash hook)" 3 | 4 | conda create -n smart-tree python=3.10 && 5 | conda run -n smart-tree conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit && 6 | conda run -n smart-tree conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia && 7 | conda run -n smart-tree pip install -e . && 8 | 9 | echo Installing FRNN 10 | git clone --recursive https://github.com/lxxue/FRNN.git 11 | conda run -n smart-tree pip install FRNN/external/prefix_sum/. && 12 | conda run -n smart-tree pip install -e FRNN/. && 13 | 14 | conda run -n smart-tree conda install -c rapidsai -c conda-forge -c nvidia cudf=23.02 cugraph=23.02 python=3.10 cuda-version=11.8 --solver=libmamba 15 | -------------------------------------------------------------------------------- /smart_tree/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | from hydra.utils import instantiate 6 | from omegaconf import DictConfig 7 | 8 | 9 | 10 | @hydra.main( 11 | version_base=None, 12 | config_path="conf", 13 | config_name="pipeline", 14 | ) 15 | def main(cfg: DictConfig): 16 | pipeline = instantiate(cfg.pipeline) 17 | 18 | if "path" in dict(cfg): 19 | pipeline.process_cloud(Path(cfg.path)) 20 | 21 | elif "directory" in dict(cfg): 22 | for p in os.listdir(cfg.directory): 23 | pipeline.process_cloud(Path(f"{cfg.directory}/{p}")) 24 | 25 | else: 26 | print("Please supply a path or directory to point clouds.") 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /smart_tree/model/render.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import wandb 5 | 6 | 7 | def render_cloud( 8 | renderer, 9 | labelled_cloud, 10 | camera_position=[1, 0, 0], 11 | camera_up=[0, 1, 0], 12 | ): 13 | segmented_img = renderer.capture( 14 | [labelled_cloud.to_o3d_seg_cld()], 15 | camera_position, 16 | camera_up, 17 | ) 18 | 19 | cld_img = renderer.capture( 20 | [labelled_cloud.to_o3d_cld()], 21 | camera_position, 22 | camera_up, 23 | ) 24 | 25 | projected_img = renderer.capture( 26 | [labelled_cloud.to_o3d_medial_vectors()], 27 | camera_position, 28 | camera_up, 29 | ) 30 | 31 | return [ 32 | np.asarray(cld_img), 33 | np.asarray(segmented_img), 34 | np.asarray(projected_img), 35 | ] 36 | 37 | 38 | def log_images(wandb_run, name, images, epoch): 39 | wandb_run.log({f"{name}": [wandb.Image(img) for img in images]}) 40 | -------------------------------------------------------------------------------- /smart_tree/scripts/bench_dataloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import hydra 5 | import torch 6 | from hydra.utils import instantiate 7 | from omegaconf import DictConfig 8 | from tqdm import tqdm 9 | 10 | from smart_tree.model.helper import get_batch 11 | 12 | 13 | @hydra.main( 14 | version_base=None, 15 | config_path="../conf", 16 | config_name="vine-dataset", 17 | ) 18 | def main(cfg: DictConfig): 19 | torch.manual_seed(42) 20 | torch.cuda.manual_seed_all(42) 21 | log = logging.getLogger(__name__) 22 | 23 | cfg = cfg.training 24 | torch.multiprocessing.set_start_method("spawn") 25 | 26 | train_loader = instantiate(cfg.train_data_loader) 27 | log.info(f"Train Dataset Size: {len(train_loader.dataset)}") 28 | 29 | while True: 30 | start = time.time() 31 | 32 | batches = get_batch(train_loader, device="cpu") 33 | for sparse_input, targets, mask, filenames in tqdm(batches): 34 | pass 35 | 36 | print(time.time() - start) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /smart_tree/o3d_abstractions/visualizer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass 2 | from typing import List, Sequence, Union 3 | 4 | import open3d as o3d 5 | 6 | 7 | 8 | @dataclass 9 | class ViewerItem: 10 | name: str 11 | geometry: o3d.geometry.Geometry 12 | is_visible: bool = True 13 | 14 | 15 | def o3d_viewer( 16 | items: Union[Sequence[ViewerItem], List[o3d.geometry.Geometry]], line_width=1 17 | ): 18 | mat = o3d.visualization.rendering.MaterialRecord() 19 | mat.shader = "defaultLit" 20 | 21 | line_mat = o3d.visualization.rendering.MaterialRecord() 22 | line_mat.shader = "unlitLine" 23 | line_mat.line_width = line_width 24 | 25 | if isinstance(items[0], o3d.geometry.Geometry): 26 | items = [ViewerItem(f"{i}", item) for i, item in enumerate(items)] 27 | 28 | def material(item): 29 | return line_mat if isinstance(item.geometry, o3d.geometry.LineSet) else mat 30 | 31 | geometries = [dict(**asdict(item), material=material(item)) for item in items] 32 | 33 | o3d.visualization.draw(geometries, line_width=line_width) 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "smart-tree" 7 | version = "1.0.0" 8 | authors = [ 9 | {name = "Harry Dobbs", email = "harrydobbs87@gmail.com"}, 10 | ] 11 | description = "Neural Network Point Cloud Tree Skeletonization" 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | license = {text = "MIT"} 15 | dependencies = [ 16 | 'numpy', 17 | 'open3d', 18 | 'hydra-core>=1.2.0', 19 | #'click', 20 | #'oauthlib', 21 | 'spconv-cu117', 22 | 'wandb', 23 | 'cmapy', 24 | #'plyfile', 25 | #'torch', 26 | 'py_structs', 27 | 'torchtyping', 28 | 'beartype', 29 | 'typeguard==2.11.1' 30 | ] 31 | 32 | [tool.setuptools.packages] 33 | find = {} # Scan the project directory with the default parameters 34 | 35 | [project.scripts] 36 | run-smart-tree = "smart_tree.cli:main" 37 | train-smart-tree = "smart_tree.model.train:main" 38 | view-npz = "smart_tree.scripts.view_npz:main" 39 | view-pcd = "smart_tree.scripts.view_pcd:main" 40 | 41 | #[project.scripts] 42 | 43 | 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 UC Vision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /smart_tree/conf/pipeline.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | _target_: smart_tree.pipeline.Pipeline 3 | 4 | preprocessing: 5 | _target_: smart_tree.dataset.augmentations.AugmentationPipeline 6 | augmentations: 7 | - _target_: smart_tree.dataset.augmentations.CentreCloud 8 | 9 | model_inference: 10 | _target_: smart_tree.model.model_inference.ModelInference 11 | model_path: smart_tree/model/weights/noble-elevator-58_model.pt 12 | weights_path: smart_tree/model/weights/noble-elevator-58_model_weights.pt 13 | voxel_size: 0.01 14 | block_size: 4 15 | buffer_size: 0.4 16 | num_workers : 8 17 | batch_size : 4 18 | 19 | skeletonizer: 20 | _target_: smart_tree.skeleton.skeletonize.Skeletonizer 21 | K: 16 22 | min_connection_length: 0.02 23 | minimum_graph_vertices: 32 24 | 25 | view_model_output : False 26 | view_skeletons : True 27 | save_path: / 28 | save_outputs : False 29 | 30 | branch_classes: [0] 31 | cmap: 32 | - [0.450, 0.325, 0.164] # Trunk 33 | - [0.541, 0.670, 0.164] # Foliage 34 | 35 | repair_skeletons : True 36 | smooth_skeletons : True 37 | smooth_kernel_size: 11 # Needs to be odd 38 | prune_skeletons : True 39 | min_skeleton_radius : 0.01 40 | min_skeleton_length : 0.02 -------------------------------------------------------------------------------- /smart_tree/model/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import wandb 4 | 5 | 6 | class Tracker: 7 | def __init__(self): 8 | self.running_epoch_radius_loss = [] 9 | self.running_epoch_direction_loss = [] 10 | self.running_epoch_class_loss = [] 11 | 12 | def update(self, loss_dict: dict): 13 | self.running_epoch_radius_loss.append(loss_dict["radius"].item()) 14 | self.running_epoch_direction_loss.append(loss_dict["direction"].item()) 15 | self.running_epoch_class_loss.append(loss_dict["class_l"].item()) 16 | 17 | @property 18 | def radius_loss(self): 19 | return np.mean(self.running_epoch_radius_loss) 20 | 21 | @property 22 | def direction_loss(self): 23 | return np.mean(self.running_epoch_direction_loss) 24 | 25 | @property 26 | def class_loss(self): 27 | return np.mean(self.running_epoch_class_loss) 28 | 29 | @property 30 | def total_loss(self): 31 | return self.radius_loss + self.direction_loss + self.class_loss 32 | 33 | def log(self, name, epoch): 34 | wandb.log( 35 | { 36 | f"{name} Total Loss": self.total_loss, 37 | f"{name} Radius Loss": self.radius_loss, 38 | f"{name} Direction Loss": self.direction_loss, 39 | f"{name} Class Loss": self.class_loss, 40 | }, 41 | epoch, 42 | ) 43 | -------------------------------------------------------------------------------- /smart_tree/skeleton/connection.py: -------------------------------------------------------------------------------- 1 | from smart_tree.data_types.tree import (DisjointTreeSkeleton, TreeSkeleton, 2 | connect_skeletons) 3 | from smart_tree.o3d_abstractions.visualizer import o3d_viewer 4 | 5 | if __name__ == "__main__": 6 | disjoint_skeleton = DisjointTreeSkeleton.from_pickle( 7 | "/local/smart-tree/test_data/disjoint_skeleton.pkl" 8 | ) 9 | 10 | # disjoint_skeleton.prune(min_radius=0.01, min_length=0.08).smooth(kernel_size=11) 11 | # Sort skeletons by total length 12 | skeletons_sorted = sorted( 13 | disjoint_skeleton.skeletons, 14 | key=lambda x: x.length, 15 | reverse=True, 16 | ) 17 | 18 | # skeletons_sorted[0].view() 19 | 20 | skel = connect_skeletons(skeletons_sorted[0], 0, 0, skeletons_sorted[1], 0, 0) 21 | 22 | skel.view() 23 | 24 | quit() 25 | 26 | final_skeleton = TreeSkeleton(0, skeletons_sorted[0].branches) 27 | 28 | for skeleton in skeletons_sorted[1:]: 29 | branch = skeleton.branches[skeleton.key_branch_with_biggest_radius] 30 | 31 | # get the point that has the biggest radius .... 32 | # get the closest point on the skeleton to that point ... 33 | # connect the two points ... 34 | 35 | print(skeleton.length) 36 | 37 | o3d_viewer( 38 | [final_skeleton.to_o3d_tube(), branch.to_o3d_tube(), skeleton.to_o3d_tube()] 39 | ) 40 | -------------------------------------------------------------------------------- /smart_tree/scripts/vis_dataloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | import numpy as np 5 | import open3d as o3d 6 | import torch 7 | from hydra.utils import instantiate 8 | from omegaconf import DictConfig 9 | from open3d_vis import render 10 | from tqdm import tqdm 11 | 12 | from smart_tree.model.helper import get_batch 13 | 14 | 15 | @hydra.main( 16 | version_base=None, 17 | config_path="../conf", 18 | config_name="vine-dataset", 19 | ) 20 | def main(cfg: DictConfig): 21 | torch.manual_seed(42) 22 | torch.cuda.manual_seed_all(42) 23 | log = logging.getLogger(__name__) 24 | 25 | cfg = cfg.training 26 | torch.multiprocessing.set_start_method("spawn") 27 | 28 | train_loader = instantiate(cfg.train_data_loader) 29 | log.info(f"Train Dataset Size: {len(train_loader.dataset)}") 30 | 31 | batches = get_batch(train_loader, device="cpu") 32 | cmap = torch.from_numpy(np.array(cfg.cmap)) 33 | 34 | for sparse_input, targets, mask, filenames in tqdm(batches): 35 | cloud_ids = sparse_input.indices[:, 0] 36 | coords = sparse_input.indices[:, 1:4] 37 | 38 | class_l = targets[:, -1].to(dtype=torch.long) 39 | 40 | for i, filename in enumerate(filenames): 41 | print("Filename:", filename) 42 | mask = cloud_ids == i 43 | labels = class_l[mask] 44 | 45 | xyz = coords[mask] 46 | class_colors = cmap[labels] 47 | 48 | boxes = render.boxes(xyz, xyz + 1, class_colors) 49 | o3d.visualization.draw([boxes]) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /smart_tree/data_types/graph.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import cugraph 5 | import cupy 6 | import open3d as o3d 7 | import torch 8 | from cudf import DataFrame 9 | from torchtyping import TensorType 10 | from tqdm import tqdm 11 | 12 | from ..o3d_abstractions.geometries import o3d_line_set 13 | 14 | 15 | @dataclass 16 | class Graph: 17 | vertices: TensorType["N", 3] 18 | edges: TensorType["N", 2] 19 | edge_weights: TensorType["N", 1] 20 | 21 | def to_o3d_lineset(self, colour=(1, 0, 0)) -> o3d.geometry.LineSet: 22 | graph_cpu = self.to_device(torch.device("cpu")) 23 | return o3d_line_set(graph_cpu.vertices, graph_cpu.edges, colour=colour) 24 | 25 | def to_device(self, device: torch.device): 26 | return Graph( 27 | self.vertices.to(device), 28 | self.edges.to(device), 29 | self.edge_weights.to(device), 30 | ) 31 | 32 | def connected_cugraph_components(self, minimum_vertices=10) -> List[cugraph.Graph]: 33 | 34 | g = cuda_graph(self.edges, self.edge_weights) 35 | 36 | df = cugraph.connected_components(g) 37 | 38 | components = [] 39 | for label in tqdm( 40 | df["labels"].unique().to_pandas(), 41 | desc="Finding Connected Components", 42 | leave=False, 43 | ): 44 | subgraph_vertices = df[df["labels"] == label]["vertex"] 45 | 46 | if subgraph_vertices.count() < minimum_vertices: 47 | continue 48 | 49 | components.append(cugraph.subgraph(g, subgraph_vertices)) 50 | 51 | return sorted(components, key=lambda graph: len(graph.nodes()), reverse=True) 52 | 53 | 54 | def cuda_graph(edges, edge_weights, renumber=False): 55 | 56 | edges = cupy.asarray(edges) 57 | edge_weights = cupy.asarray(edge_weights) 58 | 59 | d = DataFrame() 60 | d["source"] = edges[:, 0] 61 | d["destination"] = edges[:, 1] 62 | d["weights"] = edge_weights 63 | g = cugraph.Graph(directed=False) 64 | g.from_cudf_edgelist(d, edge_attr="weights", renumber=renumber) 65 | 66 | return g 67 | -------------------------------------------------------------------------------- /smart_tree/data_types/branch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional 5 | 6 | import numpy as np 7 | import open3d as o3d 8 | import torch 9 | from torchtyping import TensorType 10 | from typeguard import typechecked 11 | 12 | from smart_tree.o3d_abstractions.geometries import o3d_path, o3d_tube_mesh 13 | 14 | from .tube import Tube 15 | 16 | 17 | @typechecked 18 | @dataclass 19 | class BranchSkeleton: 20 | _id: int 21 | parent_id: int 22 | xyz: TensorType["N", 3] 23 | radii: TensorType["N", 1] 24 | child_id: Optional[int] = None 25 | 26 | def __post_init__(self) -> None: 27 | self.colour = np.random.rand(3) 28 | 29 | def __len__(self) -> np.array: 30 | return self.xyz.shape[0] 31 | 32 | def __str__(self): 33 | return f" ID: {self._id} \ 34 | Points: {self.xyz} \ 35 | Radii {self.radii}" 36 | 37 | def to_o3d_lineset(self, colour=(0, 0, 0)) -> o3d.geometry.LineSet: 38 | return o3d_path(self.xyz, colour) 39 | 40 | def to_o3d_tube(self) -> o3d.geometry.TriangleMesh: 41 | return o3d_tube_mesh(self.xyz.numpy(), self.radii.numpy(), self.colour) 42 | 43 | def to_tubes(self, colour=(1, 0, 0)) -> List[Tube]: 44 | a_, b_, r1_, r2_ = ( 45 | self.xyz[:-1], 46 | self.xyz[1:], 47 | self.radii[:-1], 48 | self.radii[1:], 49 | ) 50 | return [Tube(a, b, r1, r2) for a, b, r1, r2 in zip(a_, b_, r1_, r2_)] 51 | 52 | def filter(self, mask) -> BranchSkeleton: 53 | return BranchSkeleton( 54 | self._id, 55 | self.parent_id, 56 | self.xyz[mask], 57 | self.radii[mask], 58 | self.child_id, 59 | ) 60 | 61 | @property 62 | def length(self) -> TensorType[1]: 63 | return (self.xyz[1:] - self.xyz[:-1]).norm(dim=1).sum() 64 | 65 | @property 66 | def initial_radius(self) -> TensorType[1]: 67 | return torch.max(self.radii[0], self.radii[-1]) 68 | 69 | @property 70 | def biggest_radius_idx(self) -> TensorType[1]: 71 | return torch.argmax(self.radii) 72 | 73 | @property 74 | def biggest_radius(self) -> TensorType[1]: 75 | return torch.max(self.radii) 76 | -------------------------------------------------------------------------------- /smart_tree/skeleton/shortest_path.py: -------------------------------------------------------------------------------- 1 | import cugraph 2 | import cupy 3 | import numpy as np 4 | import torch 5 | from cudf import DataFrame 6 | 7 | 8 | def cudf_edgelist_to_numpy(edge_list): 9 | return np.vstack((edge_list["src"].to_numpy(), edge_list["dst"].to_numpy())).T 10 | 11 | 12 | def shortest_paths(root, edges, edge_weights, renumber=True): 13 | device = edges.device 14 | g = edge_graph(edges, edge_weights, renumber=renumber) 15 | r = cugraph.sssp(g, source=root) 16 | 17 | return ( 18 | torch.as_tensor(r["vertex"], device=device).long(), 19 | torch.as_tensor(r["predecessor"], device=device).long(), 20 | torch.as_tensor(r["distance"], device=device), 21 | ) 22 | 23 | 24 | def graph_shortest_paths(root, graph, device): 25 | # device = edges.device 26 | 27 | r = cugraph.sssp(graph, source=root) 28 | return ( 29 | torch.as_tensor(r["predecessor"], device=device).long(), 30 | torch.as_tensor(r["distance"], device=device), 31 | ) 32 | 33 | 34 | def pred_graph(preds, points): 35 | n = preds.shape[0] 36 | valid = preds >= 0 37 | 38 | dists = torch.norm(points - points[torch.clamp(preds, 0)], dim=1) 39 | dists[~valid] = 0.0 40 | 41 | edges = torch.stack([torch.arange(0, n, device=preds.device), preds], dim=1) 42 | edges[~valid, 1] = edges[~valid, 0] 43 | return edge_graph(edges, dists) 44 | 45 | 46 | def pred_graph(verts, preds, points): 47 | n = preds.shape[0] 48 | valid = preds >= 0 49 | 50 | dists = torch.norm(points[verts] - points[torch.clamp(preds, 0)], dim=1) 51 | dists[~valid] = 0.0 52 | 53 | edges = torch.stack([torch.arange(0, n, device=preds.device), preds], dim=1) 54 | edges[~valid, 1] = edges[~valid, 0] 55 | return edge_graph(edges, dists) 56 | 57 | 58 | def euclidean_distances(root, points, preds): 59 | g = pred_graph(preds, points) 60 | r = cugraph.sssp(g, source=root) 61 | return torch.as_tensor(r["distance"], device=points.device) 62 | 63 | 64 | def edge_graph(edges, edge_weights, renumber=False): 65 | d = DataFrame() 66 | edges = cupy.asarray(edges) 67 | 68 | d["source"] = edges[:, 0] 69 | d["destination"] = edges[:, 1] 70 | d["weights"] = cupy.asarray(edge_weights) 71 | 72 | g = cugraph.Graph(directed=False) 73 | g.from_cudf_edgelist(d, edge_attr="weights", renumber=renumber) 74 | return g 75 | -------------------------------------------------------------------------------- /smart_tree/model/helper.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | from typing import List 4 | 5 | import torch 6 | from py_structs.torch import map_tensors 7 | 8 | from smart_tree.data_types.cloud import Cloud 9 | 10 | from .sparse import sparse_from_batch 11 | 12 | 13 | def get_batch(dataloader, device, fp_16=False): 14 | for (feats, target_feats), coords, mask, filenames in dataloader: 15 | if fp_16: 16 | feats = feats.half() 17 | target_feats = target_feats.half() 18 | coords = coords.half() 19 | 20 | sparse_input = sparse_from_batch( 21 | feats, 22 | coords, 23 | device=device, 24 | ) 25 | targets = map_tensors( 26 | target_feats, 27 | partial( 28 | torch.Tensor.to, 29 | device=device, 30 | ), 31 | ) 32 | 33 | yield sparse_input, targets, mask, filenames 34 | 35 | 36 | def model_output_to_labelled_clds( 37 | sparse_input, 38 | model_output, 39 | cmap, 40 | filenames, 41 | ) -> List[Cloud]: 42 | return to_labelled_clds( 43 | sparse_input.indices[:, 0], 44 | sparse_input.features[:, :3], 45 | sparse_input.features[:, 3:6], 46 | model_output, 47 | cmap, 48 | filenames, 49 | ) 50 | 51 | 52 | def split_outputs(features, mask): 53 | radii = torch.exp(features["radius"][mask]) 54 | direction = features["direction"][mask] 55 | class_l = torch.argmax(features["class_l"], dim=1)[mask] 56 | 57 | return radii, direction, class_l 58 | 59 | 60 | def to_labelled_clds( 61 | cloud_ids, 62 | coords, 63 | rgb, 64 | model_output, 65 | cmap, 66 | filenames, 67 | ) -> List[Cloud]: 68 | num_clouds = cloud_ids.max() + 1 69 | clouds = [] 70 | 71 | # assert rgb.shape[1] > 0 72 | 73 | for i in range(num_clouds): 74 | mask = cloud_ids == i 75 | xyz = coords[mask] 76 | rgb = torch.rand(xyz.shape) # rgb[mask] 77 | 78 | radii, direction, class_l = split_outputs(model_output, mask) 79 | 80 | labelled_cloud = Cloud( 81 | xyz=xyz, 82 | rgb=rgb, 83 | medial_vector=radii * direction, 84 | class_l=class_l, 85 | filename=Path(filenames[i]), 86 | ) 87 | 88 | clouds.append(labelled_cloud.to_device(torch.device("cpu"))) 89 | 90 | return clouds 91 | -------------------------------------------------------------------------------- /smart_tree/model/model.py: -------------------------------------------------------------------------------- 1 | 2 | import spconv.pytorch as spconv 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from smart_tree.model.model_blocks import MLP, SubMConvBlock, UBlock 7 | 8 | 9 | 10 | class Smart_Tree(nn.Module): 11 | def __init__( 12 | self, 13 | input_channels, 14 | unet_planes, 15 | radius_fc_planes, 16 | direction_fc_planes, 17 | class_fc_planes, 18 | bias=False, 19 | algo=spconv.ConvAlgo.Native, 20 | ): 21 | super().__init__() 22 | 23 | norm_fn = nn.BatchNorm1d 24 | # functools.partial( 25 | # nn.BatchNorm1d, 26 | # eps=1e-4, 27 | # momentum=0.1, 28 | # ) 29 | activation_fn = nn.ReLU 30 | 31 | self.input_conv = SubMConvBlock( 32 | input_channels=input_channels, 33 | output_channels=unet_planes[0], 34 | kernel_size=1, 35 | padding=1, 36 | norm_fn=norm_fn, 37 | activation_fn=activation_fn, 38 | ) 39 | 40 | self.UNet = UBlock( 41 | unet_planes, 42 | norm_fn, 43 | activation_fn, 44 | key_id=1, 45 | algo=algo, 46 | ) 47 | 48 | # Three Heads... 49 | self.radius_head = MLP( 50 | radius_fc_planes, 51 | norm_fn, 52 | activation_fn, 53 | bias=True, 54 | ) 55 | self.direction_head = MLP( 56 | direction_fc_planes, 57 | norm_fn, 58 | activation_fn, 59 | bias=True, 60 | ) 61 | self.class_head = MLP( 62 | class_fc_planes, 63 | norm_fn, 64 | activation_fn, 65 | bias=True, 66 | ) 67 | 68 | self.apply(self.set_bn_init) 69 | 70 | @staticmethod 71 | def set_bn_init(m): 72 | classname = m.__class__.__name__ 73 | if classname.find("BatchNorm") != -1: 74 | m.weight.data.fill_(1.0) 75 | m.bias.data.fill_(0.0) 76 | 77 | def forward(self, input): 78 | predictions = {} 79 | 80 | x = self.input_conv(input) 81 | unet_out = self.UNet(x) 82 | 83 | predictions["radius"] = self.radius_head(unet_out).features 84 | predictions["direction"] = F.normalize(self.direction_head(unet_out).features) 85 | predictions["class_l"] = self.class_head(unet_out).features 86 | 87 | return predictions 88 | -------------------------------------------------------------------------------- /smart_tree/data_types/tube.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from torchtyping import TensorType 7 | 8 | 9 | @dataclass 10 | class Tube: 11 | a: TensorType["N", 3] # Start Point 3 12 | b: TensorType["N", 3] # End Point 3 13 | r1: float # Start Radius 14 | r2: float # End Radius 15 | 16 | def to_torch(self, device=torch.device("cuda:0")): 17 | self.a = torch.from_numpy(self.a).float().to(device) 18 | self.b = torch.from_numpy(self.b).float().to(device) 19 | self.r1 = torch.from_numpy(self.r1).float().to(device) 20 | self.r2 = torch.from_numpy(self.r2).float().to(device) 21 | 22 | def to_numpy(self): 23 | self.a = self.a.cpu().detach().numpy() 24 | self.b = self.b.cpu().detach().numpy() 25 | self.r1 = self.r1.cpu().detach().numpy() 26 | self.r2 = self.r2.cpu().detach().numpy() 27 | 28 | 29 | @dataclass 30 | class CollatedTube: 31 | a: TensorType["N", 3] # Nx3 32 | b: TensorType["N", 3] # Nx3 33 | r1: TensorType["N", 1] # N 34 | r2: TensorType["N", 1] # N 35 | 36 | def to_gpu(self, device=torch.device("cuda")): 37 | self.a = self.a.to(device) 38 | self.b = self.b.to(device) 39 | self.r1 = self.r1.to(device) 40 | self.r2 = self.r2.to(device) 41 | 42 | 43 | def collate_tubes(tubes: List[Tube]) -> CollatedTube: 44 | a = torch.cat([tube.a for tube in tubes]).reshape(-1, 3) 45 | b = torch.cat([tube.b for tube in tubes]).reshape(-1, 3) 46 | 47 | r1 = torch.cat([tube.r1 for tube in tubes]).reshape(1, -1) 48 | r2 = torch.cat([tube.r2 for tube in tubes]).reshape(1, -1) 49 | 50 | return CollatedTube(a, b, r1, r2) 51 | 52 | 53 | def sample_tubes(tubes: List[Tube], spacing): 54 | pts, radius = [], [] 55 | 56 | for i, tube in enumerate(tubes): 57 | v = tube.b - tube.a 58 | length = np.linalg.norm(v) 59 | 60 | direction = v / length 61 | num_points = np.ceil(length / spacing) 62 | 63 | if int(num_points) > 0.0: 64 | spaced_points = np.arange( 65 | 0, float(length), step=float(length / num_points) 66 | ).reshape(-1, 1) 67 | lin_radius = np.linspace( 68 | tube.r1, tube.r2, spaced_points.shape[0], dtype=float 69 | ) 70 | 71 | pts.append(tube.a + direction * spaced_points) 72 | radius.append(lin_radius) 73 | 74 | return np.concatenate(pts, axis=0), np.concatenate(radius, axis=0) 75 | -------------------------------------------------------------------------------- /smart_tree/scripts/split-data.py: -------------------------------------------------------------------------------- 1 | """ Script to split dataset into train and test sets 2 | 3 | usage: 4 | 5 | python smart_tree/scripts/split-data.py --read_directory=/speed-tree/speed-tree-outputs/processed_vines/ --json_save_path=/smart-tree/smart_tree/conf/vine-split.json --sample_type=random 6 | 7 | 8 | """ 9 | 10 | import json 11 | import os 12 | import random 13 | from pathlib import Path 14 | 15 | import click 16 | 17 | 18 | def flatten_list(l): 19 | return [item for sublist in l for item in sublist] 20 | 21 | 22 | def random_sample(read_dir, json_save_path): 23 | items = [str(path.name) for path in Path(read_dir).rglob("*.npz")] 24 | random.shuffle(items) 25 | 26 | data = {} 27 | 28 | data["train"] = sorted(items[: int(0.8 * len(items))]) 29 | data["test"] = sorted(items[int(0.8 * len(items)) : int(0.9 * len(items))]) 30 | data["validation"] = sorted(items[int(0.9 * len(items)) :]) 31 | 32 | with open(json_save_path, "w") as outfile: 33 | json.dump(data, outfile, indent=4, sort_keys=False) 34 | 35 | 36 | def strattified_sample(read_dir, json_save_path): 37 | dirs = os.listdir(read_dir) 38 | 39 | train_paths = [] 40 | test_paths = [] 41 | val_paths = [] 42 | 43 | for directory in dirs: 44 | items = [ 45 | str(path.resolve()) 46 | for path in Path(f"{read_dir}/{directory}").rglob("*.npz") 47 | ] 48 | random.shuffle(items) 49 | 50 | train_paths.append(items[: int(0.8 * len(items))]) 51 | test_paths.append( 52 | items[int(0.8 * len(items)) : int(0.8 * len(items) + int(0.1 * len(items)))] 53 | ) 54 | val_paths.append(items[int(0.8 * len(items)) + int(0.1 * len(items)) :]) 55 | 56 | data = {} 57 | 58 | data["train"] = sorted(flatten_list(train_paths)) 59 | data["test"] = sorted(flatten_list(test_paths)) 60 | data["validation"] = sorted(flatten_list(val_paths)) 61 | 62 | with open(json_save_path, "w") as outfile: 63 | json.dump(data, outfile, indent=4, sort_keys=False) 64 | 65 | 66 | @click.command() 67 | @click.option( 68 | "--read_directory", type=click.Path(exists=True), prompt="read directory?" 69 | ) 70 | @click.option("--json_save_path", prompt="json path?") 71 | @click.option("--sample_type", type=str, default=False, required=False) 72 | def main(read_directory, json_save_path, sample_type): 73 | if sample_type == "random": 74 | random_sample(read_directory, json_save_path) 75 | 76 | if sample_type == "strattified": 77 | strattified_sample(read_directory, json_save_path) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /smart_tree/util/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Union 3 | 4 | import cmapy 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def flatten_list(l): 10 | return [item for sublist in l for item in sublist] 11 | 12 | 13 | def at_least_2d(tensors: Union[List[torch.tensor], torch.tensor], expand_dim=1): 14 | if type(tensors) is list: 15 | return [at_least_2d(tensor) for tensor in tensors] 16 | else: 17 | if len(tensors.shape) == 1: 18 | return tensors.unsqueeze(expand_dim) 19 | else: 20 | return tensors 21 | 22 | 23 | def to_torch(numpy_arrays: List[np.array], device=torch.device("cpu")): 24 | return [torch.from_numpy(np_arr).float().to(device) for np_arr in numpy_arrays] 25 | 26 | 27 | def to_numpy(_torch: Union[List[torch.tensor], torch.tensor]): 28 | if type(_torch) is list: 29 | return [torch_arr.cpu().detach().numpy() for torch_arr in _torch] 30 | else: 31 | return _torch.cpu().detach().numpy() 32 | 33 | 34 | def concate_dict_of_tensors(tensors: dict, device=torch.device("cpu")): 35 | for key, values in tensors.items(): 36 | tensors[key] = torch.concatenate(values).to(device) 37 | return tensors 38 | 39 | 40 | def unique_n_colours(num_colours, cmap="hsv"): 41 | return ( 42 | np.asarray( 43 | [cmapy.color(cmap, i) for i in range(0, 255, math.ceil(255 / num_colours))] 44 | ).reshape(-1, 3) 45 | / 255 46 | ) 47 | 48 | 49 | def unique_n_random_colours(num_colours): 50 | return np.asarray([np.random.rand(3) for i in range(num_colours)]).reshape(-1, 3) 51 | 52 | 53 | def points_to_edges(points): 54 | points = points.reshape(-1, 3) 55 | parents = torch.arange(points.shape[0] - 1) 56 | children = torch.arange(1, points.shape[0]) 57 | 58 | return torch.column_stack((parents, children)) 59 | 60 | 61 | def voxel_downsample(xyz, voxel_size): 62 | xyz_quantized = ( 63 | xyz // voxel_size 64 | ) # torch.div(xyz + (voxel_size / 2), voxel_size, rounding_mode="floor") 65 | 66 | unique, idx, counts = torch.unique( 67 | xyz_quantized, 68 | dim=0, 69 | sorted=True, 70 | return_counts=True, 71 | return_inverse=True, 72 | ) 73 | 74 | _, ind_sorted = torch.sort(idx, stable=True) 75 | cum_sum = counts.cumsum(0) 76 | cum_sum = torch.cat((torch.tensor([0], device=cum_sum.device), cum_sum[:-1])) 77 | first_indicies = ind_sorted[cum_sum[1:]] 78 | 79 | return first_indicies 80 | 81 | 82 | def merge_dictionaries(dict1, dict2): 83 | merged_dict = {} 84 | 85 | for key, value in dict1.items(): 86 | merged_dict[key] = value 87 | 88 | i = 1 89 | for key, value in dict2.items(): 90 | new_key = key 91 | while new_key in merged_dict: 92 | new_key = i 93 | i += 1 94 | merged_dict[new_key] = value 95 | 96 | return merged_dict 97 | -------------------------------------------------------------------------------- /smart_tree/scripts/view_npz.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import List, Tuple 4 | 5 | 6 | from smart_tree.data_types.cloud import Cloud 7 | from smart_tree.data_types.tree import TreeSkeleton 8 | from smart_tree.o3d_abstractions.visualizer import ViewerItem, o3d_viewer 9 | from smart_tree.util.file import load_data_npz 10 | 11 | 12 | def view_synthetic_data(data: List[Tuple[Cloud, TreeSkeleton]], line_width=1): 13 | geometries = [] 14 | for i, item in enumerate(data): 15 | (cloud, skeleton), path = item 16 | 17 | tree_name = path.stem 18 | visible = i == 0 19 | 20 | geometries = [ 21 | ViewerItem( 22 | f"{tree_name}_cloud", 23 | cloud.to_o3d_cld(), 24 | is_visible=visible, 25 | ), 26 | ViewerItem( 27 | f"{tree_name}_labelled_cloud", 28 | cloud.to_o3d_seg_cld(), 29 | is_visible=visible, 30 | ), 31 | ViewerItem( 32 | f"{tree_name}_medial_vectors", 33 | cloud.to_o3d_medial_vectors(), 34 | is_visible=visible, 35 | ), 36 | # ViewerItem( 37 | # f"{tree_name}_skeleton", 38 | # skeleton.to_o3d_lineset(), 39 | # is_visible=visible, 40 | # ), 41 | # ViewerItem( 42 | # f"{tree_name}_skeleton_mesh", 43 | # skeleton.to_o3d_tubes(), 44 | # is_visible=visible, 45 | # ), 46 | ] 47 | 48 | o3d_viewer(geometries, line_width=line_width) 49 | 50 | 51 | def parse_args(): 52 | parser = argparse.ArgumentParser(description="Visualizer Arguments") 53 | 54 | parser.add_argument( 55 | "file_path", 56 | help="File Path of tree.npz", 57 | default=None, 58 | type=Path, 59 | ) 60 | 61 | parser.add_argument( 62 | "-lw", 63 | "--line_width", 64 | help="Width of visualizer lines", 65 | default=1, 66 | type=int, 67 | ) 68 | return parser.parse_args() 69 | 70 | 71 | def paths_from_args(args, glob="*.npz"): 72 | if not args.file_path.exists(): 73 | raise ValueError(f"File {args.file_path} does not exist") 74 | 75 | if args.file_path.is_file(): 76 | print(f"Loading data from file: {args.file_path}") 77 | return [args.file_path] 78 | 79 | if args.file_path.is_dir(): 80 | print(f"Loading data from directory: {args.file_path}") 81 | files = args.file_path.glob(glob) 82 | if files == []: 83 | raise ValueError(f"No npz files found in {args.file_path}") 84 | return files 85 | 86 | 87 | def main(): 88 | args = parse_args() 89 | 90 | data = [(load_data_npz(filename), filename) for filename in paths_from_args(args)] 91 | view_synthetic_data(data, args.line_width) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | .vscode/ 133 | 134 | outputs/ 135 | data/ 136 | multirun/ 137 | wandb/ 138 | *.out 139 | *.sl 140 | speed-tree-outputs/ 141 | smart_tree/conf/tree-split-test.json 142 | smart_tree/conf/apple-trellis-split-test.json 143 | smart_tree/conf/apple-trellis.yaml 144 | smart_tree/conf/apple-trellis-split.json 145 | FRNN/ 146 | 147 | 148 | smart_tree/conf 149 | test_data/ 150 | 151 | !smart_tree/conf/pipeline.yaml 152 | !smart_tree/conf/training.yaml 153 | !smart_tree/conf/training-split.json 154 | 155 | *.ply 156 | *.obj -------------------------------------------------------------------------------- /smart_tree/model/fp16.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from collections import abc 3 | from inspect import getfullargspec 4 | 5 | import spconv.pytorch as spconv 6 | import torch 7 | 8 | 9 | def cast_tensor_type(inputs, src_type, dst_type): 10 | if isinstance(inputs, torch.Tensor): 11 | return inputs.to(dst_type) if inputs.dtype == src_type else inputs 12 | elif isinstance(inputs, spconv.SparseConvTensor): 13 | if inputs.features.dtype == src_type: 14 | features = inputs.features.to(dst_type) 15 | inputs = inputs.replace_feature(features) 16 | return inputs 17 | elif isinstance(inputs, abc.Mapping): 18 | return type(inputs)( 19 | {k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items()} 20 | ) 21 | elif isinstance(inputs, abc.Iterable): 22 | return type(inputs)( 23 | cast_tensor_type(item, src_type, dst_type) for item in inputs 24 | ) 25 | else: 26 | return inputs 27 | 28 | 29 | def force_fp32(apply_to=None, out_fp16=False): 30 | def force_fp32_wrapper(old_func): 31 | @functools.wraps(old_func) 32 | def new_func(*args, **kwargs): 33 | if not isinstance(args[0], torch.nn.Module): 34 | raise TypeError( 35 | "@force_fp32 can only be used to decorate the " 36 | "method of nn.Module" 37 | ) 38 | # get the arg spec of the decorated method 39 | args_info = getfullargspec(old_func) 40 | # get the argument names to be casted 41 | args_to_cast = args_info.args if apply_to is None else apply_to 42 | # convert the args that need to be processed 43 | new_args = [] 44 | if args: 45 | arg_names = args_info.args[: len(args)] 46 | for i, arg_name in enumerate(arg_names): 47 | if arg_name in args_to_cast: 48 | new_args.append( 49 | cast_tensor_type(args[i], torch.half, torch.float) 50 | ) 51 | else: 52 | new_args.append(args[i]) 53 | # convert the kwargs that need to be processed 54 | new_kwargs = dict() 55 | if kwargs: 56 | for arg_name, arg_value in kwargs.items(): 57 | if arg_name in args_to_cast: 58 | new_kwargs[arg_name] = cast_tensor_type( 59 | arg_value, torch.half, torch.float 60 | ) 61 | else: 62 | new_kwargs[arg_name] = arg_value 63 | with torch.cuda.amp.autocast(enabled=False): 64 | output = old_func(*new_args, **new_kwargs) 65 | # cast the results back to fp32 if necessary 66 | if out_fp16: 67 | output = cast_tensor_type(output, torch.float, torch.half) 68 | return output 69 | 70 | return new_func 71 | 72 | return force_fp32_wrapper 73 | -------------------------------------------------------------------------------- /smart_tree/model/sparse.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | from typing import List, Tuple, Union 3 | 4 | import numpy as np 5 | import spconv.pytorch as spconv 6 | import torch 7 | 8 | 9 | def sparse_from_batch(features, coordinates, device): 10 | batch_size = features.shape[0] 11 | 12 | features = features.to(device) 13 | coordinates = coordinates.to(device) 14 | 15 | values, _ = torch.max(coordinates, 0) # BXYZ -> XYZ (Biggest Spatial Size) 16 | 17 | return spconv.SparseConvTensor( 18 | features, coordinates.int(), values[1:], batch_size=batch_size 19 | ) 20 | 21 | 22 | def split_sparse_list(indices, features): 23 | cloud_ids = indices[:, 0] 24 | 25 | num_clouds = cloud_ids.max() + 1 26 | return [ 27 | (indices[cloud_ids == i], features[cloud_ids == i]) for i in range(num_clouds) 28 | ] 29 | 30 | 31 | def split_sparse(sparse_tensor): 32 | cloud_ids = sparse_tensor.indices[:, 0] 33 | num_clouds = cloud_ids.max() + 1 34 | return [ 35 | (sparse_tensor.indices[cloud_ids == i], sparse_tensor.features[cloud_ids == i]) 36 | for i in range(num_clouds) 37 | ] 38 | 39 | 40 | def batch_collate(batch): 41 | """Custom Batch Collate Function for Sparse Data...""" 42 | 43 | batch_feats, batch_coords, batch_mask, fn = zip(*batch) 44 | 45 | for i, coords in enumerate(batch_coords): 46 | coords[:, 0] = torch.tensor([i], dtype=torch.float32) 47 | 48 | if isinstance(batch_feats[0], tuple): 49 | input_feats, target_feats = tuple(zip(*batch_feats)) 50 | 51 | input_feats, target_feats, coords, mask = [ 52 | torch.cat(x) for x in [input_feats, target_feats, batch_coords, batch_mask] 53 | ] 54 | 55 | return [(input_feats, target_feats), coords, mask, fn] 56 | 57 | feats, coords, mask = [ 58 | torch.cat(x) for x in [batch_feats, batch_coords, batch_mask] 59 | ] 60 | 61 | return [feats, coords, mask, fn] 62 | 63 | 64 | def ravel_hash(x: np.ndarray) -> np.ndarray: 65 | assert x.ndim == 2, x.shape 66 | 67 | x = x - np.min(x, axis=0) 68 | x = x.astype(np.uint64, copy=False) 69 | xmax = np.max(x, axis=0).astype(np.uint64) + 1 70 | 71 | h = np.zeros(x.shape[0], dtype=np.uint64) 72 | for k in range(x.shape[1] - 1): 73 | h += x[:, k] 74 | h *= xmax[k + 1] 75 | h += x[:, -1] 76 | return h 77 | 78 | 79 | def sparse_quantize( 80 | coords, 81 | voxel_size: Union[float, Tuple[float, ...]] = 1, 82 | *, 83 | return_index: bool = False, 84 | return_inverse: bool = False, 85 | ) -> List[np.ndarray]: 86 | if isinstance(voxel_size, (float, int)): 87 | voxel_size = tuple(repeat(voxel_size, 3)) 88 | assert isinstance(voxel_size, tuple) and len(voxel_size) == 3 89 | 90 | voxel_size = np.array(voxel_size) 91 | coords = np.floor(coords / voxel_size).astype(np.int32) 92 | 93 | _, indices, inverse_indices = np.unique( 94 | ravel_hash(coords), return_index=True, return_inverse=True 95 | ) 96 | coords = coords[indices] 97 | 98 | outputs = [coords] 99 | if return_index: 100 | outputs += [indices] 101 | if return_inverse: 102 | outputs += [inverse_indices] 103 | return outputs[0] if len(outputs) == 1 else outputs 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #
💡🧠🤔 Smart-Tree 🌳🌲🌴
2 | 3 | ## 📝 Description: 4 | 5 | This repository contains code from the paper "Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization".
6 | The code provided is a deep-learning-based skeletonization method for point clouds. 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
Input point cloud.Mesh output.Skeleton output.
20 | 21 | 22 | ## 💾 Data: 23 | 24 | Please follow instructions to download data from this link.
25 | 26 | ## 🔧 Installation: 27 | 28 | First, make sure you have Conda installed, aswell as mamba. 29 | This will ensure the enviroment is created within a resonable timeframe. 30 | 31 | To install smart-tree please use
`bash create-env.sh`
32 | Then activate the environment using:
`conda activate smart-tree` 33 | 34 | 35 | ## 📈 Training: 36 | 37 | To train the model open smart_tree/conf/training.yaml. 38 | 39 | You will need to update (alternatively these can be overwritten with hydra): 40 | 41 | - training.dataset.json_path to the location of where your smart_tree/conf/tree-split.json is stored. 42 | - training.dataset.directory to the location of where you downloaded the data (you can choose whether to train on the data with foliage or without based on the directory you supply). 43 | 44 | You can experiment with/adjust hyper-parameter settings too. 45 | 46 | The model will then train using the following: 47 | 48 | `train-smart-tree` 49 | 50 | The best model weights and model will be stored in the generated outputs directory. 51 | 52 | ## ▶️ Inference / ☠️ Skeletonization: 53 | 54 | We supply two different models with weights: 55 | * `noble-elevator-58` contains branch/foliage segmentation.
56 | * `peach-forest-65` is only trained on points from the branching structure.
57 | 58 | If you wish to run smart-tree using your own weights you will need to update the model paths in the `tree-dataset.yaml`.
59 | 60 | To run smart-tree use:
61 | `run-smart-tree +path=cloud_path`
62 | where `cloud_path` is the path of the point cloud you want to skeletonize.
63 | Skeletonization-specific parameters can be adjusted within the `smart_tree/conf/tree-dataset.yaml` config. 64 | 65 | ## 📜 Citation: 66 | Please use the following BibTeX entry to cite our work:
67 | 68 | ``` 69 | @inproceedings{dobbs2023smart, 70 | title={Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization}, 71 | author={Dobbs, Harry and Batchelor, Oliver and Green, Richard and Atlas, James}, 72 | booktitle={Iberian Conference on Pattern Recognition and Image Analysis}, 73 | pages={351--362}, 74 | year={2023}, 75 | organization={Springer} 76 | } 77 | ``` 78 | ## Star History 79 | 80 | [![Star History Chart](https://api.star-history.com/svg?repos=uc-vision/smart-tree&type=Date)](https://star-history.com/#uc-vision/smart-tree&Date) 81 | 82 | 83 | ## 📥 Contact 84 | 85 | Should you have any questions, comments or suggestions please use the following contact details: 86 | harry.dobbs@pg.canterbury.ac.nz 87 | -------------------------------------------------------------------------------- /smart_tree/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | def compute_loss( 8 | preds, 9 | targets, 10 | mask=None, 11 | radius_loss_fn=None, 12 | direction_loss_fn=None, 13 | class_loss_fn=None, 14 | target_radius_log=True, 15 | vector_class=None, 16 | ): 17 | predicted_radius = preds["radius"] # 18 | predicted_direction = preds["direction"] 19 | predicted_class = preds["class_l"] 20 | target_class = targets[:, [-1]].long() 21 | target_direction = targets[:, 1:-1] 22 | target_radius = targets[:, [0]] 23 | 24 | if mask is not None: 25 | predicted_radius = predicted_radius[mask] 26 | predicted_direction = predicted_direction[mask] 27 | predicted_class = predicted_class[mask] 28 | target_radius = target_radius[mask] 29 | target_direction = target_direction[mask] 30 | target_class = target_class[mask] 31 | 32 | # Only compute vector loss on branch points... 33 | if vector_class is not None: 34 | vector_mask = target_class == vector_class 35 | vector_mask = vector_mask.view(-1) 36 | predicted_radius = predicted_radius[vector_mask] 37 | predicted_direction = predicted_direction[vector_mask] 38 | target_radius = target_radius[vector_mask] 39 | target_direction = target_direction[vector_mask] 40 | 41 | if target_radius_log: 42 | target_radius = torch.log(target_radius) 43 | 44 | losses = {} 45 | 46 | losses["radius"] = radius_loss_fn(predicted_radius.view(-1), target_radius.view(-1)) 47 | losses["direction"] = direction_loss_fn(predicted_direction, target_direction) 48 | losses["class_l"] = class_loss_fn(predicted_class, target_class) 49 | 50 | return losses 51 | 52 | 53 | def L1Loss(outputs, targets): 54 | loss = nn.L1Loss() 55 | return loss(outputs, targets) 56 | 57 | 58 | def cosine_similarity_loss(outputs, targets): 59 | loss = nn.CosineSimilarity() 60 | return torch.mean(1 - loss(outputs, targets)) 61 | 62 | 63 | def dice_loss(outputs, targets): 64 | # https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08 65 | smooth = 1 66 | outputs = F.softmax(outputs, dim=1) 67 | targets = F.one_hot(targets) 68 | 69 | outputs = outputs.view(-1) 70 | targets = targets.view(-1) 71 | 72 | intersection = (outputs * targets).sum() 73 | 74 | return 1 - ( 75 | (2.0 * intersection + smooth) / (outputs.sum() + targets.sum() + smooth) 76 | ) 77 | 78 | 79 | def focal_loss(outputs, targets): 80 | # https://github.com/torrvision/focal_calibration/blob/main/Losses/focal_loss.py 81 | gamma = 2 82 | input = outputs 83 | 84 | if input.dim() > 2: 85 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 86 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 87 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 88 | targets = targets.view(-1, 1) 89 | 90 | logpt = F.log_softmax(input, dim=1) 91 | logpt = logpt.gather(1, targets) 92 | logpt = logpt.view(-1) 93 | pt = logpt.exp() 94 | loss = -1 * (1 - pt) ** gamma * logpt 95 | # return loss.sum() 96 | return loss.mean() 97 | 98 | 99 | def nll_loss(outputs, targets): 100 | return torch.tensor([0]).cuda() 101 | weights = targets.shape[0] / (torch.bincount(targets)) # Balance class weights 102 | return F.nll_loss(F.log_softmax(outputs, dim=1), targets, weight=weights) 103 | -------------------------------------------------------------------------------- /smart_tree/model/model_inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from smart_tree.data_types.cloud import Cloud 7 | from smart_tree.dataset.dataset import load_dataloader 8 | from smart_tree.model.sparse import sparse_from_batch 9 | 10 | 11 | def load_model(model_path, weights_path, device=torch.device("cuda:0")): 12 | model = torch.load(f"{model_path}", map_location=device) 13 | model.load_state_dict(torch.load(f"{weights_path}")) 14 | model.eval() 15 | 16 | return model 17 | 18 | 19 | """ Loads model and model weights, then returns the input, outputs and mask """ 20 | 21 | 22 | class ModelInference: 23 | def __init__( 24 | self, 25 | model_path: Path, 26 | weights_path: Path, 27 | voxel_size: float, 28 | block_size: float, 29 | buffer_size: float, 30 | num_workers=8, 31 | batch_size=4, 32 | device=torch.device("cuda:0"), 33 | verbose=False, 34 | ): 35 | self.device = device 36 | self.verbose = verbose 37 | self.voxel_size = voxel_size 38 | self.block_size = block_size 39 | self.buffer_size = buffer_size 40 | 41 | self.num_workers = num_workers 42 | self.batch_size = batch_size 43 | 44 | self.model = load_model(model_path, weights_path, self.device) 45 | 46 | if self.verbose: 47 | print("Model Loaded Succesfully") 48 | 49 | def forward(self, cloud: Cloud, return_masked=True): 50 | inputs, masks = [], [] 51 | radius, direction, class_l = [], [], [] 52 | 53 | dataloader = load_dataloader( 54 | cloud, 55 | self.voxel_size, 56 | self.block_size, 57 | self.buffer_size, 58 | self.num_workers, 59 | self.batch_size, 60 | ) 61 | 62 | for features, coordinates, mask, filename in tqdm( 63 | dataloader, desc="Inferring", leave=False 64 | ): 65 | sparse_input = sparse_from_batch( 66 | features[:, :3], 67 | coordinates, 68 | device=self.device, 69 | ) 70 | 71 | preds = self.model.forward(sparse_input) 72 | 73 | radius.append(preds["radius"].detach().cpu()) 74 | direction.append(preds["direction"].detach().cpu()) 75 | class_l.append(preds["class_l"].detach().cpu()) 76 | 77 | inputs.append(features.detach().cpu()) 78 | masks.append(mask.detach().cpu()) 79 | 80 | radius = torch.cat(radius) 81 | direction = torch.cat(direction) 82 | class_l = torch.cat(class_l) 83 | 84 | inputs = torch.cat(inputs) 85 | masks = torch.cat(masks) 86 | 87 | medial_vector = torch.exp(radius) * direction 88 | class_l = torch.argmax(class_l, dim=1, keepdim=True) 89 | 90 | lc = Cloud( 91 | xyz=inputs[:, :3], 92 | rgb=inputs[:, 3:6], 93 | medial_vector=medial_vector, 94 | class_l=class_l, 95 | ) 96 | 97 | if return_masked: 98 | return lc.filter(masks) 99 | 100 | return lc 101 | 102 | @staticmethod 103 | def from_cfg(cfg): 104 | return ModelInference( 105 | model_path=cfg.model_path, 106 | weights_path=cfg.weights_path, 107 | voxel_size=cfg.voxel_size, 108 | block_size=cfg.block_size, 109 | buffer_size=cfg.buffer_size, 110 | num_workers=cfg.num_workers, 111 | batch_size=cfg.batch_size, 112 | ) 113 | -------------------------------------------------------------------------------- /smart_tree/skeleton/skeletonize.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import cugraph 4 | import cupy 5 | import torch 6 | from cugraph import sssp 7 | from tqdm import tqdm 8 | 9 | from ..data_types.cloud import Cloud 10 | from ..data_types.graph import Graph 11 | from ..data_types.tree import DisjointTreeSkeleton, TreeSkeleton 12 | from .filter import outlier_removal 13 | from .graph import decompose_cuda_graph, nn_graph 14 | from .path import sample_tree 15 | from .shortest_path import pred_graph, shortest_paths 16 | 17 | 18 | class Skeletonizer: 19 | def __init__( 20 | self, 21 | K: int, 22 | min_connection_length: float, 23 | minimum_graph_vertices: int, 24 | device: torch.device = torch.device("cuda:0"), 25 | ): 26 | self.K = K 27 | self.min_connection_length = min_connection_length 28 | self.minimum_graph_vertices = minimum_graph_vertices 29 | self.device = device 30 | 31 | def forward(self, cloud: Cloud) -> DisjointTreeSkeleton: 32 | cloud.to_device(self.device) 33 | 34 | mask = outlier_removal(cloud.medial_pts, cloud.radius.unsqueeze(1), nb_points=8) 35 | cloud = cloud.filter(mask) 36 | 37 | graph: Graph = nn_graph( 38 | cloud.medial_pts, 39 | cloud.radius.clamp(min=self.min_connection_length), 40 | K=self.K, 41 | ) 42 | 43 | subgraphs: List[cugraph.Graph] = graph.connected_cugraph_components( 44 | minimum_vertices=self.minimum_graph_vertices 45 | ) 46 | 47 | skeletons = [] 48 | for subgraph_id, subgraph in enumerate( 49 | tqdm(subgraphs, desc="Processing Connected Components", leave=False) 50 | ): 51 | skeletons.append( 52 | self.process_subgraph(cloud, subgraph, skeleton_id=subgraph_id) 53 | ) 54 | 55 | return DisjointTreeSkeleton(skeletons) 56 | 57 | def process_subgraph(self, cloud, subgraph, skeleton_id=0) -> TreeSkeleton: 58 | """Extract skeleton for connected component""" 59 | 60 | subgraph_vertice_idx = torch.tensor( 61 | cupy.unique(subgraph.edges().values), 62 | device=self.device, 63 | ) 64 | 65 | subgraph_cloud = cloud.filter(subgraph_vertice_idx) 66 | 67 | edges, edge_weights = decompose_cuda_graph( 68 | subgraph, 69 | renumber_edges=True, 70 | device=self.device, 71 | ) 72 | 73 | verts, preds, distance = shortest_paths( 74 | subgraph_cloud.root_idx, 75 | edges, 76 | edge_weights, 77 | renumber=False, 78 | ) 79 | 80 | predecessor_graph = pred_graph(verts, preds, subgraph_cloud.medial_pts) 81 | 82 | distances = torch.as_tensor( 83 | sssp(predecessor_graph, source=subgraph_cloud.root_idx)["distance"], 84 | device=self.device, 85 | ) 86 | 87 | branches = sample_tree( 88 | subgraph_cloud.medial_pts, 89 | subgraph_cloud.radius.unsqueeze(1), 90 | preds, 91 | distances, 92 | subgraph_cloud.xyz, 93 | ) 94 | 95 | return TreeSkeleton(skeleton_id, branches) 96 | 97 | @staticmethod 98 | def from_cfg(cfg): 99 | return Skeletonizer( 100 | K=cfg.K, 101 | min_connection_length=cfg.min_connection_length, 102 | minimum_graph_vertices=cfg.minimum_graph_vertices, 103 | edge_non_linear=cfg.edge_non_linear, 104 | ) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /smart_tree/dataset/augmentations.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Sequence 3 | 4 | import torch 5 | from beartype import beartype 6 | 7 | from smart_tree.data_types.cloud import Cloud 8 | from smart_tree.util.maths import euler_angles_to_rotation 9 | 10 | 11 | class Augmentation(ABC): 12 | @abstractmethod 13 | def __call__(self, cloud: Cloud) -> Cloud: 14 | pass 15 | 16 | 17 | class Scale(Augmentation): 18 | def __init__(self, min_scale=0.9, max_scale=1.1): 19 | self.min_scale = min_scale 20 | self.max_scale = max_scale 21 | 22 | def __call__(self, cloud: Cloud) -> Cloud: 23 | t = torch.rand(1, device=cloud.xyz.device) * (self.max_scale - self.min_scale) 24 | return cloud.scale(t + self.min_scale) 25 | 26 | 27 | class FixedRotate(Augmentation): 28 | def __init__(self, xyz): 29 | self.xyz = xyz 30 | 31 | def __call__(self, cloud: Cloud) -> Cloud: 32 | self.rot_mat = euler_angles_to_rotation( 33 | torch.tensor(self.xyz, device=cloud.xyz.device) 34 | ).float() 35 | return cloud.rotate(self.rot_mat) 36 | 37 | 38 | class CentreCloud(Augmentation): 39 | def __call__(self, cloud: Cloud) -> Cloud: 40 | centre, (x, y, z) = cloud.bbox 41 | return cloud.translate(-centre + torch.tensor([0, y, 0], device=centre.device)) 42 | 43 | 44 | class VoxelDownsample(Augmentation): 45 | def __init__(self, voxel_size): 46 | self.voxel_size = voxel_size 47 | 48 | def __call__(self, cloud: Cloud) -> Cloud: 49 | return cloud.voxel_down_sample(self.voxel_size) 50 | 51 | 52 | class FixedTranslate(Augmentation): 53 | def __init__(self, xyz): 54 | self.xyz = torch.tensor(xyz) 55 | 56 | def __call__(self, cloud: Cloud) -> Cloud: 57 | return cloud.translate(self.xyz) 58 | 59 | 60 | class RandomCrop(Augmentation): 61 | def __init__(self, max_x, max_y, max_z): 62 | self.max_translation = torch.tensor([max_x, max_y, max_z]) 63 | 64 | def __call__(self, cloud): 65 | offset = ( 66 | torch.rand(3, device=cloud.xyz.device) - 0.5 67 | ) * self.max_translation.to(device=cloud.xyz.device) 68 | 69 | p = cloud.xyz + offset 70 | mask = torch.logical_and(p >= cloud.min_xyz, p <= cloud.max_xyz).all(dim=1) 71 | 72 | return cloud.filter(mask) 73 | 74 | 75 | class RandomCubicCrop(Augmentation): 76 | def __init__(self, size): 77 | self.size = size 78 | 79 | def __call__(self, cloud): 80 | random_pt = cloud.xyz[torch.randint(0, cloud.xyz.shape[0], (1,))] 81 | min_corner = random_pt - self.size / 2 82 | max_corner = random_pt + self.size / 2 83 | 84 | mask = torch.logical_and( 85 | cloud.xyz >= min_corner, 86 | cloud.xyz <= max_corner, 87 | ).all(dim=1) 88 | 89 | return cloud.filter(mask) 90 | 91 | 92 | class RandomDropout(Augmentation): 93 | def __init__(self, max_drop_out): 94 | self.max_drop_out = max_drop_out 95 | 96 | def __call__(self, cloud: Cloud) -> Cloud: 97 | num_indices = int( 98 | (1.0 - (self.max_drop_out * torch.rand(1, device=cloud.xyz.device))) 99 | * cloud.xyz.shape[0] 100 | ) 101 | 102 | indices = torch.randint( 103 | high=cloud.xyz.shape[0], size=(num_indices, 1), device=cloud.xyz.device 104 | ).squeeze(1) 105 | return cloud.filter(indices) 106 | 107 | 108 | class AugmentationPipeline(Augmentation): 109 | @beartype 110 | def __init__(self, augmentations: Sequence[Augmentation]): 111 | self.augmentations = augmentations 112 | 113 | def __call__(self, cloud): 114 | for augmentation in self.augmentations: 115 | cloud = augmentation(cloud) 116 | return cloud 117 | -------------------------------------------------------------------------------- /smart_tree/o3d_abstractions/camera.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import open3d.visualization.rendering as rendering 4 | 5 | 6 | def create_camera(width, height, fx=575, fy=575): 7 | camera_parameters = o3d.camera.PinholeCameraParameters() 8 | camera_parameters.extrinsic = np.eye(4) 9 | camera_parameters.intrinsic.set_intrinsics( 10 | width=width, 11 | height=height, 12 | fx=fx, 13 | fy=fy, 14 | cx=(width / 2.0 - 0.5), 15 | cy=(height / 2.0 - 0.5), 16 | ) 17 | return camera_parameters 18 | 19 | 20 | def update_camera_position( 21 | camera, camera_position, camera_target=[0, 0, 0], up=np.asarray([0, 1, 0]) 22 | ): 23 | camera_direction = (camera_target - camera_position) / np.linalg.norm( 24 | camera_target - camera_position 25 | ) 26 | camera_right = np.cross(camera_direction, up) 27 | camera_right = camera_right / np.linalg.norm(camera_right) 28 | camera_up = np.cross(camera_direction, camera_right) 29 | position_matrix = np.eye(4) 30 | position_matrix[:3, 3] = -camera_position 31 | camera_look_at = np.eye(4) 32 | camera_look_at[:3, :3] = np.vstack((camera_right, camera_up, camera_direction)) 33 | camera_look_at = np.matmul(camera_look_at, position_matrix) 34 | camera.extrinsic = camera_look_at 35 | return camera 36 | 37 | 38 | def o3d_headless_render(geoms, camera_position, camera_up): 39 | width, height = 1920, 1080 40 | camera = create_camera(width, height) 41 | 42 | # Setup a Offscreen Renderer 43 | render = rendering.OffscreenRenderer(width=width, height=height) 44 | render.scene.set_background([1.0, 1.0, 1.0, 1.0]) # RGBA 45 | render.setup_camera(camera.intrinsic, camera.extrinsic) 46 | render.scene.scene.set_sun_light( 47 | geoms[0].get_center() + np.asarray(camera_position), [1.0, 1.0, 1.0], 75000 48 | ) 49 | render.scene.scene.enable_sun_light(True) 50 | render.scene.scene.enable_indirect_light(True) 51 | render.scene.scene.set_indirect_light_intensity(0.3) 52 | # render.scene.set_lighting(render.scene.LightingProfile.NO_SHADOWS, (0, 0, 0)) 53 | 54 | mtl = o3d.visualization.rendering.MaterialRecord() 55 | mtl.shader = "defaultUnlit" 56 | 57 | for i, item in enumerate(geoms): 58 | render.scene.add_geometry(f"{i}", item, mtl) 59 | 60 | camera = update_camera_position( 61 | camera, 62 | geoms[0].get_center() + np.asarray(camera_position), 63 | camera_target=geoms[0].get_center(), 64 | up=np.asarray(camera_up), 65 | ) 66 | render.setup_camera(camera.intrinsic, camera.extrinsic) 67 | 68 | return render.render_to_image() 69 | 70 | 71 | class Renderer: 72 | def __init__(self, width, height): 73 | self.camera = create_camera(width, height) 74 | self.render = rendering.OffscreenRenderer(width=width, height=height) 75 | self.render.scene.set_background([1.0, 1.0, 1.0, 1.0]) # RGBA 76 | self.render.setup_camera(self.camera.intrinsic, self.camera.extrinsic) 77 | 78 | self.render.scene.scene.enable_sun_light(True) 79 | self.render.scene.scene.enable_indirect_light(True) 80 | self.render.scene.scene.set_indirect_light_intensity(0.3) 81 | self.mtl = o3d.visualization.rendering.MaterialRecord() 82 | self.mtl.shader = "defaultUnlit" 83 | 84 | def capture(self, geoms, camera_position, camera_up): 85 | self.render.scene.scene.set_sun_light( 86 | geoms[0].get_center() + np.asarray(camera_position), [1.0, 1.0, 1.0], 75000 87 | ) 88 | for i, item in enumerate(geoms): 89 | self.render.scene.add_geometry(f"{i}", item, self.mtl) 90 | 91 | camera = update_camera_position( 92 | self.camera, 93 | geoms[0].get_center() + np.asarray(camera_position), 94 | camera_target=geoms[0].get_center(), 95 | up=np.asarray(camera_up), 96 | ) 97 | self.render.setup_camera(camera.intrinsic, camera.extrinsic) 98 | img = self.render.render_to_image() 99 | self.render.scene.clear_geometry() 100 | 101 | return img 102 | -------------------------------------------------------------------------------- /smart_tree/conf/training.yaml: -------------------------------------------------------------------------------- 1 | # @package: _global_ 2 | 3 | wandb: 4 | project: tree 5 | entity: harry1576 6 | mode: online 7 | 8 | fp16: True 9 | num_epoch: 1 10 | lr_decay: True 11 | lr: 0.1 12 | early_stop_epoch: 20 13 | early_stop: True 14 | 15 | batch_size: 8 16 | directory: /local/smart-tree/data/branches 17 | json_path: smart_tree/conf/training-split.json 18 | voxel_size: 0.01 19 | 20 | cmap: 21 | - [0.450, 0.325, 0.164] # Trunk 22 | - [0.541, 0.670, 0.164] # Foliage 23 | 24 | 25 | capture_output: 1 26 | 27 | input_features: 28 | - xyz 29 | 30 | target_features: 31 | - radius 32 | - direction 33 | - class_l 34 | 35 | train_dataset: 36 | _target_: smart_tree.dataset.dataset.TreeDataset 37 | mode: train 38 | voxel_size: ${voxel_size} 39 | directory: ${directory} 40 | json_path: ${json_path} 41 | input_features: ${input_features} 42 | target_features: ${target_features} 43 | augmentation: 44 | _target_: smart_tree.dataset.augmentations.AugmentationPipeline 45 | augmentations: 46 | - _target_: smart_tree.dataset.augmentations.RandomCubicCrop 47 | size: 4.0 48 | 49 | test_dataset: 50 | _target_: smart_tree.dataset.dataset.TreeDataset 51 | mode: test 52 | voxel_size: ${voxel_size} 53 | directory: ${directory} 54 | json_path: ${json_path} 55 | input_features: ${input_features} 56 | target_features: ${target_features} 57 | augmentation: 58 | _target_: smart_tree.dataset.augmentations.AugmentationPipeline 59 | augmentations: 60 | - _target_: smart_tree.dataset.augmentations.RandomCubicCrop 61 | size: 4.0 62 | 63 | validation_dataset: 64 | _target_: smart_tree.dataset.dataset.TreeDataset 65 | mode: validation 66 | voxel_size: ${voxel_size} 67 | directory: ${directory} 68 | json_path: ${json_path} 69 | input_features: ${input_features} 70 | target_features: ${target_features} 71 | cache: True 72 | augmentation: 73 | _target_: smart_tree.dataset.augmentations.AugmentationPipeline 74 | augmentations: 75 | - _target_: smart_tree.dataset.augmentations.RandomCubicCrop 76 | size: 4.0 77 | 78 | 79 | train_data_loader: 80 | _target_: torch.utils.data.DataLoader 81 | batch_size: ${batch_size} 82 | drop_last: False 83 | pin_memory: False 84 | num_workers: 0 85 | shuffle: False 86 | # sampler: 87 | # _target_: torch.utils.data.RandomSampler 88 | # replacement: True 89 | # num_samples: 8 90 | # data_source: ${train_dataset} 91 | collate_fn: 92 | _target_: smart_tree.model.sparse.batch_collate 93 | _partial_: True 94 | dataset: ${train_dataset} 95 | 96 | validation_data_loader: 97 | _target_: torch.utils.data.DataLoader 98 | batch_size: ${batch_size} 99 | drop_last: False 100 | pin_memory: False 101 | num_workers: 0 102 | shuffle: False 103 | collate_fn: 104 | _target_: smart_tree.model.sparse.batch_collate 105 | _partial_: True 106 | dataset: ${validation_dataset} 107 | 108 | test_data_loader: 109 | _target_: torch.utils.data.DataLoader 110 | batch_size: ${batch_size} 111 | drop_last: False 112 | pin_memory: False 113 | num_workers: 0 114 | shuffle: False 115 | collate_fn: 116 | _target_: smart_tree.model.sparse.batch_collate 117 | _partial_: True 118 | dataset: ${test_dataset} 119 | 120 | 121 | model: 122 | _target_: smart_tree.model.model.Smart_Tree 123 | input_channels: 3 124 | unet_planes: [8, 16, 32] 125 | radius_fc_planes: [8, 8, 4, 1] 126 | direction_fc_planes: [8, 8, 4, 3] 127 | class_fc_planes: [8, 8, 4, 2] 128 | 129 | optimizer: 130 | _target_: torch.optim.Adam 131 | lr: ${lr} 132 | 133 | scheduler: 134 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 135 | mode: "min" 136 | 137 | loss_fn: 138 | _target_: smart_tree.model.loss.compute_loss 139 | _partial_: True 140 | vector_class: 0 141 | target_radius_log: True 142 | 143 | radius_loss_fn: 144 | _target_ : smart_tree.model.loss.L1Loss 145 | _partial_: True 146 | 147 | direction_loss_fn: 148 | _target_: smart_tree.model.loss.cosine_similarity_loss 149 | _partial_: True 150 | 151 | class_loss_fn: 152 | _target_: smart_tree.model.loss.focal_loss 153 | _partial_: True 154 | -------------------------------------------------------------------------------- /smart_tree/skeleton/graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import cugraph 4 | import frnn 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from ..data_types.graph import Graph 10 | 11 | 12 | def knn(src, dest, K=50, r=1.0, grid=None): 13 | src_lengths = src.new_tensor([src.shape[0]], dtype=torch.long) 14 | dest_lengths = src.new_tensor([dest.shape[0]], dtype=torch.long) 15 | dists, idxs, grid, _ = frnn.frnn_grid_points( 16 | src.unsqueeze(0), 17 | dest.unsqueeze(0), 18 | src_lengths, 19 | dest_lengths, 20 | K, 21 | r, 22 | return_nn=False, 23 | return_sorted=True, 24 | ) 25 | 26 | return idxs.squeeze(0), dists.sqrt().squeeze(0), grid 27 | 28 | 29 | def nn(src, dest, r=1.0, grid=None): 30 | idx, dist, grid = knn(src, dest, K=1, r=r, grid=grid) 31 | idx, dist = idx.squeeze(1), dist.squeeze(1) 32 | 33 | return idx, dist, grid 34 | 35 | 36 | def nn_graph(points: torch.Tensor, radii, K=40): 37 | idxs, dists, _ = knn(points, points, K=K, r=radii.max().item()) 38 | idxs[dists > radii.unsqueeze(1)] = -1 39 | edges, edge_weights = make_edges(dists, idxs) 40 | return Graph(points, edges, edge_weights) 41 | 42 | 43 | def medial_nn_graph(points: torch.Tensor, radii, medial_dist, K=40): 44 | # edges weighted based on distance to medial axis 45 | idxs, dists, _ = knn(points, points, K=K, r=radii.max().item() / 4.0) 46 | dists_ = dists + medial_dist[idxs] # Add medial distance to distance graph... 47 | idxs[dists > radii.unsqueeze(1)] = -1 48 | 49 | return make_edges(dists_, idxs) 50 | 51 | 52 | def make_edges(dists, idxs): 53 | n = dists.shape[0] 54 | K = dists.shape[1] 55 | 56 | parent = torch.arange(n, device=dists.device).unsqueeze(1).expand(n, K) 57 | edges = torch.stack([parent, idxs], dim=2) 58 | 59 | valid = idxs.view(-1) > 0 60 | return edges.view(-1, 2)[valid], dists.view(-1)[valid] 61 | 62 | 63 | def nn_flat(points, K=50, r=1.0, device=torch.device("cuda")): 64 | idxs, dists = nn(points, K, r, device) 65 | return idxs.reshape(-1), dists.reshape(-1) 66 | 67 | 68 | def pcd_nn(points, radii, K=20): 69 | idxs, dists = nn_flat(np.asarray(points), K=K, r=float(radii.max() / 4.0)) 70 | 71 | parent_vertex = np.repeat(np.arange(points.shape[0]), K) 72 | parent_vertex = parent_vertex.reshape(-1) 73 | 74 | edges = np.vstack((parent_vertex, idxs.reshape(-1))).T 75 | 76 | valid = idxs != -1 77 | return edges[valid], dists[valid] 78 | 79 | 80 | def decompose_cuda_graph(cuda_graph, renumber_edges=False, device=torch.device("cuda")): 81 | pdf = cugraph.to_pandas_edgelist(cuda_graph) 82 | 83 | edges = torch.stack((torch.tensor(pdf["src"]), torch.tensor(pdf["dst"])), dim=1) 84 | edge_weights = torch.tensor(pdf["weights"]) 85 | 86 | edges, edge_weights = edges.long().to(device), edge_weights.to(device) 87 | 88 | if renumber_edges: 89 | edges = remap_edges(edges) 90 | 91 | return edges, edge_weights 92 | 93 | 94 | def remap_edges(edges): 95 | # Find unique node IDs and their corresponding indices 96 | unique_nodes, node_indices = torch.unique(edges, return_inverse=True) 97 | 98 | # Create a mapping from old node IDs to new node IDs 99 | mapping = torch.arange(unique_nodes.size(0), device=edges.device) 100 | 101 | # Map the old node IDs to new node IDs using the indices 102 | renumbered_edges = mapping[node_indices].reshape(edges.shape) 103 | 104 | return renumbered_edges 105 | 106 | 107 | def connected_components(edges, edge_weights, minimum_vertices=0, max_components=10): 108 | g = cuda_graph(edges, edge_weights) 109 | 110 | connected_components = cugraph.connected_components(g) 111 | 112 | num_labels = connected_components["labels"].to_pandas().value_counts() 113 | valid_labels = num_labels[num_labels > minimum_vertices].index 114 | 115 | graphs = [] 116 | 117 | for label in tqdm( 118 | valid_labels[:max_components], desc="Getting Connected Componenets", leave=False 119 | ): 120 | graphs.append( 121 | cugraph.subgraph( 122 | g, connected_components.query(f"labels == {label}")["vertex"] 123 | ) 124 | ) 125 | 126 | return graphs 127 | -------------------------------------------------------------------------------- /smart_tree/skeleton/path.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from ..data_types.branch import BranchSkeleton 6 | from .graph import nn 7 | 8 | 9 | def trace_route(preds, idx, termination_pts): 10 | path = [] 11 | 12 | while idx >= 0 and idx not in termination_pts: 13 | path.append(idx) 14 | idx = preds[idx] 15 | 16 | return preds.new_tensor(path, dtype=torch.long).flip(0), idx 17 | 18 | 19 | def select_path_points( 20 | points: torch.tensor, path_verts: torch.tensor, radii: torch.tensor 21 | ): 22 | """ 23 | Finds points nearest to a path (specified by points with radii). 24 | points: (N, 3) 3d points of point cloud 25 | path_verts: (M, 3) 3d points of path 26 | radii: (M, 1) radii of path 27 | returns: (X, 2) index tuples of (path, point) for X points falling within the path, ordered by path index 28 | """ 29 | 30 | point_path, dists, _ = nn( 31 | points, 32 | path_verts, 33 | r=radii.max().item(), 34 | ) # nearest path idx for each point 35 | valid = dists[point_path >= 0] < radii[point_path[point_path >= 0]].squeeze( 36 | 1 37 | ) # where the path idx is less than the distance to the point 38 | 39 | on_path = point_path.new_zeros(point_path.shape, dtype=torch.bool) 40 | on_path[point_path >= 0] = valid # points that are on the path. 41 | 42 | idx_point = on_path.nonzero().squeeze(1) 43 | idx_path = point_path[idx_point] 44 | 45 | order = torch.argsort(idx_path) 46 | return idx_point[order], idx_path[order] 47 | 48 | 49 | def sample_tree( 50 | medial_pts, 51 | medial_radii, 52 | preds, 53 | distances, 54 | all_points, 55 | root_idx=0, 56 | visualize=False, 57 | pbar=None, 58 | ): 59 | """ 60 | Medial Points: NN estimated medial points 61 | Medial Radii: NN estimated radii of points 62 | Preds: Predecessor of each medial point (on path to root node) 63 | Distance: Distance from root node to medial points 64 | Surface Points: The point the medial pts got projected from.. 65 | """ 66 | 67 | branch_id = 0 68 | 69 | branches = {} 70 | 71 | selection_mask = preds > 0 72 | distances[~selection_mask] = -1 73 | 74 | termination_pts = torch.tensor([], device=torch.device("cuda")) 75 | branch_ids = torch.full( 76 | (medial_pts.shape[0],), 77 | -1, 78 | device=torch.device("cuda"), 79 | dtype=int, 80 | ) 81 | 82 | pbar = tqdm( 83 | total=distances.shape[0], 84 | leave=False, 85 | desc="Allocating Points", 86 | ) 87 | 88 | while True: 89 | pbar.update(n=((distances < 0).sum().item() - pbar.n)) 90 | pbar.refresh() 91 | 92 | farthest = distances.argmax().item() 93 | 94 | if distances[farthest] <= 0: 95 | break 96 | 97 | """ Traces the path of the futhrest point until it converges with allocated points """ 98 | path_vertices_idx, termination_idx = trace_route( 99 | preds, 100 | farthest, 101 | termination_pts, 102 | ) 103 | 104 | """ Gets the points around that path (and which path indexs they are close to) """ 105 | idx_points, idx_path = select_path_points( 106 | medial_pts, 107 | medial_pts[path_vertices_idx], 108 | medial_radii[path_vertices_idx], 109 | ) 110 | 111 | """ Mark this points as allocated and as termination points """ 112 | distances[idx_points] = -1 113 | distances[path_vertices_idx] = -1 114 | termination_pts = torch.unique( 115 | torch.cat( 116 | ( 117 | termination_pts, 118 | idx_points, 119 | path_vertices_idx, 120 | ) 121 | ) 122 | ) 123 | 124 | """ If the path has at least two points, save it as a branch """ 125 | if len(path_vertices_idx) < 2: 126 | continue 127 | 128 | branches[branch_id] = BranchSkeleton( 129 | branch_id, 130 | xyz=medial_pts[path_vertices_idx].cpu(), 131 | radii=medial_radii[path_vertices_idx].cpu(), 132 | parent_id=int(branch_ids[termination_idx].item()), 133 | ) 134 | 135 | branch_ids[path_vertices_idx] = branch_id 136 | branch_ids[idx_points] = branch_id 137 | 138 | branch_id += 1 139 | 140 | return branches 141 | -------------------------------------------------------------------------------- /smart_tree/pipeline.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .data_types.cloud import Cloud 7 | from .data_types.tree import DisjointTreeSkeleton 8 | from .o3d_abstractions.visualizer import o3d_viewer 9 | from .util.file import (load_cloud, save_o3d_cloud, 10 | save_o3d_lineset, save_o3d_mesh) 11 | 12 | 13 | class Pipeline: 14 | def __init__( 15 | self, 16 | preprocessing, 17 | model_inference, 18 | skeletonizer, 19 | repair_skeletons=False, 20 | smooth_skeletons=False, 21 | smooth_kernel_size=0, 22 | prune_skeletons=False, 23 | min_skeleton_radius=0.0, 24 | min_skeleton_length=1000, 25 | view_model_output=False, 26 | view_skeletons=False, 27 | save_outputs=False, 28 | save_path="/", 29 | branch_classes=[0], 30 | cmap=[[1, 0, 0], [0, 1, 0]], 31 | device=torch.device("cuda:0"), 32 | ): 33 | self.preprocessing = preprocessing 34 | self.model_inference = model_inference 35 | self.skeletonizer = skeletonizer 36 | 37 | self.repair_skeletons = repair_skeletons 38 | self.smooth_skeletons = smooth_skeletons 39 | self.smooth_kernel_size = smooth_kernel_size 40 | self.prune_skeletons = prune_skeletons 41 | 42 | self.min_skeleton_radius = min_skeleton_radius 43 | self.min_skeleton_length = min_skeleton_length 44 | 45 | self.view_model_output = view_model_output 46 | self.view_skeletons = view_skeletons 47 | 48 | self.save_outputs = save_outputs 49 | self.save_path = save_path 50 | 51 | self.branch_classes = branch_classes 52 | self.cmap = np.asarray(cmap) 53 | self.device = device 54 | 55 | def process_cloud(self, path: Path =None, cloud: Cloud=None): 56 | # Load point cloud 57 | cloud: Cloud = load_cloud(path) if path != None else cloud 58 | 59 | cloud = cloud.to_device(self.device) 60 | cloud = self.preprocessing(cloud) 61 | 62 | # Run point cloud through model to predict class, radius, direction 63 | lc: Cloud = self.model_inference.forward(cloud).to_device(self.device) 64 | if self.view_model_output: 65 | lc.view(self.cmap) 66 | 67 | # Filter only the branch points for skeletonizaiton 68 | branch_cloud: Cloud = lc.filter_by_class(self.branch_classes) 69 | 70 | # Run the branch cloud through skeletonization algorithm, then post process 71 | skeleton: DisjointTreeSkeleton = self.skeletonizer.forward(branch_cloud) 72 | 73 | self.post_process(skeleton) 74 | 75 | # View skeletonization results 76 | if self.view_skeletons: 77 | o3d_viewer( 78 | [ 79 | skeleton.to_o3d_tube(), 80 | skeleton.to_o3d_lineset(), 81 | skeleton.to_o3d_tube(colour=False), 82 | cloud.to_o3d_cld(), 83 | ], 84 | line_width=5, 85 | ) 86 | 87 | if self.save_outputs: 88 | print("Saving Outputs") 89 | sp = self.save_path 90 | save_o3d_lineset(f"{sp}/skeleton.ply", skeleton.to_o3d_lineset()) 91 | save_o3d_mesh(f"{sp}/mesh.ply", skeleton.to_o3d_tube()) 92 | save_o3d_cloud(f"{sp}/cloud.ply", lc.to_o3d_cld()) 93 | save_o3d_cloud(f"{sp}/seg_cld.ply", lc.to_o3d_seg_cld(self.cmap)) 94 | 95 | def post_process(self, skeleton: DisjointTreeSkeleton): 96 | if self.prune_skeletons: 97 | skeleton.prune( 98 | min_length=self.min_skeleton_length, 99 | min_radius=self.min_skeleton_radius, 100 | ) 101 | 102 | if self.repair_skeletons: 103 | skeleton.repair() 104 | 105 | if self.smooth_skeletons: 106 | skeleton.smooth(self.smooth_kernel_size) 107 | 108 | @staticmethod 109 | def from_cfg(inferer, skeletonizer, cfg): 110 | return Pipeline( 111 | inferer, 112 | skeletonizer, 113 | preprocessing_cfg=cfg.preprocessing, 114 | repair_skeletons=cfg.repair_skeletons, 115 | smooth_skeletons=cfg.smooth_skeletons, 116 | smooth_kernel_size=cfg.smooth_kernel_size, 117 | prune_skeletons=cfg.prune_skeletons, 118 | min_skeleton_radius=cfg.min_skeleton_radius, 119 | min_skeleton_length=cfg.min_skeleton_length, 120 | view_model_output=cfg.view_model_output, 121 | view_skeletons=cfg.view_skeletons, 122 | save_outputs=cfg.save_outputs, 123 | branch_classes=cfg.branch_classes, 124 | cmap=cfg.cmap, 125 | ) 126 | -------------------------------------------------------------------------------- /smart_tree/util/file.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | import yaml 8 | 9 | from smart_tree.data_types.branch import BranchSkeleton 10 | from smart_tree.data_types.cloud import Cloud 11 | from smart_tree.data_types.tree import TreeSkeleton 12 | 13 | 14 | def unpackage_data(data: dict) -> Tuple[Cloud, TreeSkeleton]: 15 | tree_id = data["tree_id"] 16 | branch_id = data["branch_id"] 17 | branch_parent_id = data["branch_parent_id"] 18 | skeleton_xyz = data["skeleton_xyz"] 19 | skeleton_radii = data["skeleton_radii"] 20 | sizes = data["branch_num_elements"] 21 | 22 | cld = Cloud.from_numpy( 23 | xyz=data["xyz"], 24 | rgb=data["rgb"], 25 | vector=data["vector"], 26 | class_l=data["class_l"], 27 | ) 28 | 29 | return cld, None 30 | 31 | offsets = np.cumsum(np.append([0], sizes)) 32 | 33 | branch_idx = [np.arange(size) + offset for size, offset in zip(sizes, offsets)] 34 | branches = {} 35 | 36 | for idx, _id, parent_id in zip(branch_idx, branch_id, branch_parent_id): 37 | branches[_id] = BranchSkeleton( 38 | _id, parent_id, skeleton_xyz[idx], skeleton_radii[idx] 39 | ) 40 | 41 | return cld, TreeSkeleton(tree_id, branches) 42 | 43 | 44 | def package_data(skeleton: TreeSkeleton, pointcloud: Cloud) -> dict: 45 | data = {} 46 | 47 | data["tree_id"] = skeleton._id 48 | 49 | data["xyz"] = pointcloud.xyz 50 | data["rgb"] = pointcloud.rgb 51 | data["vector"] = pointcloud.vector 52 | data["class_l"] = pointcloud.class_l 53 | 54 | data["skeleton_xyz"] = np.concatenate( 55 | [branch.xyz for branch in skeleton.branches.values()] 56 | ) 57 | data["skeleton_radii"] = np.concatenate( 58 | [branch.radii for branch in skeleton.branches.values()] 59 | )[..., np.newaxis] 60 | data["branch_id"] = np.asarray( 61 | [branch._id for branch in skeleton.branches.values()] 62 | ) 63 | data["branch_parent_id"] = np.asarray( 64 | [branch.parent_id for branch in skeleton.branches.values()] 65 | ) 66 | data["branch_num_elements"] = np.asarray( 67 | [len(branch) for branch in skeleton.branches.values()] 68 | ) 69 | 70 | return data 71 | 72 | 73 | def save_skeleton(skeleton: TreeSkeleton, save_location): 74 | data = {} 75 | data["tree_id"] = skeleton._id 76 | 77 | data["skeleton_xyz"] = np.concatenate( 78 | [branch.xyz for branch in skeleton.branches.values()] 79 | ) 80 | data["skeleton_radii"] = np.concatenate( 81 | [branch.radii for branch in skeleton.branches.values()] 82 | )[..., np.newaxis] 83 | data["branch_id"] = np.asarray( 84 | [branch._id for branch in skeleton.branches.values()] 85 | ) 86 | data["branch_parent_id"] = np.asarray( 87 | [branch.parent_id for branch in skeleton.branches.values()] 88 | ) 89 | data["branch_num_elements"] = np.asarray( 90 | [len(branch) for branch in skeleton.branches.values()] 91 | ) 92 | 93 | np.savez(save_location, **data) 94 | 95 | 96 | def load_skeleton(path): 97 | data = np.load(path) 98 | 99 | # tree_id = data["tree_id"] 100 | branch_id = data["branch_id"] 101 | branch_parent_id = data["branch_parent_id"] 102 | skeleton_xyz = data["skeleton_xyz"] 103 | skeleton_radii = data["skeleton_radii"] 104 | sizes = data["branch_num_elements"] 105 | 106 | offsets = np.cumsum(np.append([0], sizes)) 107 | 108 | branch_idx = [np.arange(size) + offset for size, offset in zip(sizes, offsets)] 109 | branches = {} 110 | 111 | for idx, _id, parent_id in zip(branch_idx, branch_id, branch_parent_id): 112 | branches[_id] = BranchSkeleton( 113 | _id, parent_id, skeleton_xyz[idx], skeleton_radii[idx] 114 | ) 115 | 116 | return TreeSkeleton(0, branches) 117 | 118 | 119 | def save_data_npz(path: Path, skeleton: TreeSkeleton, pointcloud: Cloud): 120 | np.savez(path, **package_data(skeleton, pointcloud)) 121 | 122 | 123 | def load_data_npz(path: Path) -> Tuple[Cloud, TreeSkeleton]: 124 | with np.load(path) as data: 125 | return unpackage_data(data) 126 | 127 | 128 | def load_json(json_path): 129 | return json.load(open(json_path)) 130 | 131 | 132 | def save_o3d_cloud(filename: Path, cld: Cloud): 133 | o3d.io.write_point_cloud(str(filename), cld) 134 | 135 | 136 | def save_o3d_lineset(path: Path, ls): 137 | return o3d.io.write_line_set(path, ls) 138 | 139 | 140 | def save_o3d_mesh(path: Path, mesh): 141 | return o3d.io.write_triangle_mesh(path, mesh) 142 | 143 | 144 | def load_o3d_cloud(path: Path): 145 | return o3d.io.read_point_cloud(path) 146 | 147 | 148 | def load_o3d_lineset(path: Path): 149 | return o3d.io.read_line_set(path) 150 | 151 | 152 | def load_o3d_mesh(path: Path): 153 | return o3d.io.read_triangle_model(path) 154 | 155 | 156 | def load_cloud(path: Path): 157 | if path.suffix == ".npz": 158 | return Cloud.from_numpy(**np.load(path)) 159 | 160 | data = o3d.io.read_point_cloud(str(path)) 161 | xyz = np.asarray(data.points) 162 | rgb = ( 163 | np.asarray(data.colors) 164 | if np.asarray(data.colors).shape[0] != 0 165 | else np.zeros_like(xyz) 166 | ) 167 | return Cloud.from_numpy(xyz=xyz, rgb=rgb) 168 | 169 | 170 | def load_yaml(path: Path): 171 | with open(f"{path}") as f: 172 | config = yaml.load(f, Loader=yaml.FullLoader) 173 | return config 174 | -------------------------------------------------------------------------------- /smart_tree/util/queries.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from smart_tree.data_types.tube import CollatedTube, Tube, collate_tubes 10 | 11 | """ 12 | For the following : 13 | N : number of pts 14 | M : number of tubes 15 | """ 16 | 17 | 18 | def points_to_collated_tube_projections( 19 | pts: np.array, collated_tube: CollatedTube, eps=1e-12 20 | ): # N x 3, M x 2 21 | ab = collated_tube.b - collated_tube.a # M x 3 22 | 23 | ap = pts[:, np.newaxis] - collated_tube.a[np.newaxis, ...] # N x M x 3 24 | 25 | t = np.clip( 26 | np.einsum("nmd,md->nm", ap, ab) / (np.einsum("md,md->m", ab, ab) + eps), 27 | 0.0, 28 | 1.0, 29 | ) # N x M 30 | proj = collated_tube.a[np.newaxis, ...] + np.einsum( 31 | "nm,md->nmd", t, ab 32 | ) # N x M x 3 33 | return proj, t 34 | 35 | 36 | def projection_to_distance_matrix(projections, pts): # N x M x 3 37 | return np.sqrt(np.sum(np.square(projections - pts[:, np.newaxis, :]), 2)) # N x M 38 | 39 | 40 | def pts_to_nearest_tube(pts: np.array, tubes: List[Tube]): 41 | """Vectors from pt to the nearest tube""" 42 | 43 | collated_tube = collate_tubes(tubes) 44 | projections, t = points_to_collated_tube_projections( 45 | pts, collated_tube 46 | ) # N x M x 3 47 | 48 | 49 | def pts_to_nearest_tube_keops(pts: np.array, tubes: List[Tube]): 50 | """Vectors from pt to the nearest tube""" 51 | 52 | distances, idx, r = points_to_tube_distance_keops(pts, tubes) # N x M x 3 53 | 54 | return distances.reshape(-1), idx, rce_matrix(projections, pts) # N x M 55 | 56 | distances = distances - r 57 | idx = np.argmin(distances, 1) # N 58 | 59 | # assert idx.shape[0] == pts.shape[0] 60 | 61 | return ( 62 | projections[np.arange(pts.shape[0]), idx] - pts, 63 | idx, 64 | r[np.arange(pts.shape[0]), idx], 65 | ) # vector, idx , radius 66 | 67 | 68 | def pairwise_pts_to_nearest_tube(pts: np.array, tubes: List[Tube]): 69 | collated_tube = collate_tubes(tubes) 70 | 71 | ab = collated_tube.b - collated_tube.a # M x 3 72 | ap = ( 73 | pts - collated_tube.a 74 | ) # N def pts_to_nearest_tube_keops(pts: np.array, tubes: List[Tube]): 75 | """Vectors from pt to the nearest tube""" 76 | 77 | distances, idx, r = points_to_tube_distance_keops(pts, tubes) # N x M x 3 78 | 79 | return distances.reshape(-1), idx, r + t * collated_tube.r2 80 | 81 | distances = np.sqrt(np.sum(np.square(proj - pts), 1)) 82 | 83 | distances = distances - r 84 | 85 | return distances, r # vector, idx , radius 86 | 87 | 88 | # GPU 89 | def points_to_collated_tube_projections_gpu( 90 | pts: np.array, collated_tube: CollatedTube, device=torch.device("cuda") 91 | ): 92 | ab = collated_tube.b - collated_tube.a # M x 3 93 | 94 | ap = pts.unsqueeze(1) - collated_tube.a.unsqueeze(0) # N x M x 3 95 | 96 | t = (torch.einsum("nmd,md->nm", ap, ab) / torch.einsum("md,md->m", ab, ab)).clip( 97 | 0.0, 1.0 98 | ) # N x M 99 | proj = collated_tube.a.unsqueeze(0) + torch.einsum("nm,md->nmd", t, ab) # N x M x 3 100 | return proj, t 101 | 102 | 103 | def projection_to_distance_matrix_gpu(projections, pts): # N x M x 3 104 | return (projections - pts.unsqueeze(1)).square().sum(2).sqrt() 105 | 106 | 107 | def pts_to_nearest_tube_gpu( 108 | pts: torch.tensor, tubes: List[Tube], device=torch.device("cuda") 109 | ): 110 | """Vectors from pt to the nearest tube""" 111 | 112 | collated_tube_gpu = collate_tubes(tubes) 113 | collated_tube_gpu.to_gpu() 114 | 115 | pts = pts.float().to(device) 116 | 117 | projections, t = points_to_collated_tube_projections_gpu( 118 | pts, collated_tube_gpu, device=torch.device("cuda") 119 | ) # N x M x 3 120 | r = (1 - t) * collated_tube_gpu.r1 + t * collated_tube_gpu.r2 121 | 122 | distances = projection_to_distance_matrix_gpu(projections, pts) # N x M 123 | 124 | distances = torch.abs(distances - r) 125 | idx = torch.argmin(distances, 1) # N 126 | 127 | assert idx.shape[0] == pts.shape[0] 128 | 129 | return ( 130 | projections[torch.arange(pts.shape[0]), idx] - pts, 131 | idx, 132 | r[torch.arange(pts.shape[0]), idx], 133 | ) 134 | 135 | 136 | def projection_to_distance_matrix_keops(projections, pts): # N x M x 3 137 | return np.sqrt(np.sum(np.square(projections - pts[:, np.newaxis, :]), 2)) # N x M 138 | 139 | 140 | def skeleton_to_points(pcd, skeleton, chunk_size=4096, device="gpu"): 141 | distances = [] 142 | radii = [] 143 | vectors_ = [] 144 | 145 | tubes = skeleton.to_tubes() 146 | pts_chunks = np.array_split(pcd.xyz, np.ceil(pcd.xyz.shape[0] / chunk_size)) 147 | 148 | for pts in tqdm(pts_chunks, desc="Labelling Chunks", leave=False): 149 | if device == "gpu": 150 | vectors, idxs, radiuses = pts_to_nearest_tube_gpu( 151 | pts, tubes 152 | ) # vector to nearest skeleton... 153 | else: 154 | vectors, idxs, radiuses = pts_to_nearest_tube( 155 | pts, tubes 156 | ) # vector to nearest skeleton... 157 | 158 | distances.append( 159 | np.sqrt(np.einsum("ij,ij->i", vectors, vectors)) 160 | ) # could do on gpu but meh 161 | radii.append([radius for radius in radiuses]) 162 | vectors_.append([v for v in vectors]) 163 | 164 | distances = np.concatenate(distances) 165 | radii = np.concatenate(radii) 166 | vectors_ = np.concatenate(vectors_) 167 | 168 | return distances, radii, vectors_ 169 | -------------------------------------------------------------------------------- /smart_tree/util/maths.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def np_normalized(a: np.array, axis=-1, order=2) -> np.array: 9 | """Normalizes a numpy array of points""" 10 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 11 | l2[l2 == 0] = 1e-13 12 | return a / np.expand_dims(l2, axis), l2 13 | 14 | 15 | def torch_normalized(v): 16 | return F.normalize(v), v.pow(2).sum(1).sqrt().unsqueeze(1) 17 | 18 | 19 | def euler_angles_to_rotation(xyz: List) -> torch.tensor: 20 | x, y, z = xyz 21 | 22 | R_X = torch.tensor( 23 | [ 24 | [1.0, 0.0, 0.0], 25 | [0.0, torch.cos(x), -torch.sin(x)], 26 | [0.0, torch.sin(x), torch.cos(x)], 27 | ] 28 | ) 29 | 30 | R_Y = torch.tensor( 31 | [ 32 | [torch.cos(y), 0.0, torch.sin(y)], 33 | [0.0, 1.0, 0.0], 34 | [-torch.sin(y), 0.0, torch.cos(y)], 35 | ] 36 | ) 37 | 38 | R_Z = torch.tensor( 39 | [ 40 | [torch.cos(z), -torch.sin(z), 0.0], 41 | [torch.sin(z), torch.cos(z), 0.0], 42 | [0.0, 0.0, 1.0], 43 | ] 44 | ) 45 | 46 | return torch.mm(R_Z, torch.mm(R_Y, R_X)) 47 | 48 | 49 | def rotation_matrix_from_vectors_np(vec1: np.array, vec2: np.array) -> np.array: 50 | """Find the rotation matrix that aligns vec1 to vec2 51 | :param vec1: A 3d "source" vector 52 | :param vec2: A 3d "destination" vector 53 | :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2. 54 | """ 55 | a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), ( 56 | vec2 / np.linalg.norm(vec2) 57 | ).reshape(3) 58 | v = np.cross(a, b) 59 | c = np.dot(a, b) 60 | s = np.linalg.norm(v) 61 | 62 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 63 | rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2)) 64 | return rotation_matrix 65 | 66 | 67 | def rotation_matrix_from_vectors_torch(vec1, vec2): 68 | a, b = F.normalize(vec1, dim=0), F.normalize(vec2, dim=0) 69 | v = torch.cross(a, b) 70 | c = torch.dot(a, b) 71 | s = torch.linalg.norm(v) 72 | 73 | kmat = torch.tensor([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 74 | rotation_matrix = torch.eye(3) + kmat + kmat.matmul(kmat) * ((1 - c) / (s**2)) 75 | 76 | return rotation_matrix 77 | 78 | 79 | def make_transformation_matrix(rotation, translation): 80 | return torch.vstack( 81 | (torch.hstack((rotation, translation)), torch.tensor([0.0, 0.0, 0.0, 1.0])) 82 | ) 83 | 84 | 85 | def bb_filter( 86 | points, 87 | min_x=-np.inf, 88 | max_x=np.inf, 89 | min_y=-np.inf, 90 | max_y=np.inf, 91 | min_z=-np.inf, 92 | max_z=np.inf, 93 | ): 94 | bound_x = np.logical_and(points[:, 0] >= min_x, points[:, 0] < max_x) 95 | bound_y = np.logical_and(points[:, 1] >= min_y, points[:, 1] < max_y) 96 | bound_z = np.logical_and(points[:, 2] >= min_z, points[:, 2] < max_z) 97 | 98 | bb_filter = np.logical_and(np.logical_and(bound_x, bound_y), bound_z) 99 | 100 | return bb_filter 101 | 102 | 103 | """ 104 | def cube_filter(points, center, cube_size): 105 | 106 | min_x = center[0] - cube_size / 2 107 | max_x = center[0] + cube_size / 2 108 | min_y = center[1] - cube_size / 2 109 | max_y = center[1] + cube_size / 2 110 | min_z = center[2] - cube_size / 2 111 | max_z = center[2] + cube_size / 2 112 | 113 | return bb_filter(points, min_x, max_x, min_y, max_y, min_z, max_z) 114 | """ 115 | 116 | 117 | def np_bb_filter( 118 | points, 119 | min_x=-np.inf, 120 | max_x=np.inf, 121 | min_y=-np.inf, 122 | max_y=np.inf, 123 | min_z=-np.inf, 124 | max_z=np.inf, 125 | ): 126 | bound_x = np.logical_and(points[:, 0] >= min_x, points[:, 0] < max_x) 127 | bound_y = np.logical_and(points[:, 1] >= min_y, points[:, 1] < max_y) 128 | bound_z = np.logical_and(points[:, 2] >= min_z, points[:, 2] < max_z) 129 | 130 | bb_filter = np.logical_and(np.logical_and(bound_x, bound_y), bound_z) 131 | 132 | return bb_filter 133 | 134 | 135 | def torch_bb_filter(points, min_x, max_x, min_y, max_y, min_z, max_z): 136 | bound_x = torch.logical_and(points[:, 0] >= min_x, points[:, 0] < max_x) 137 | bound_y = torch.logical_and(points[:, 1] >= min_y, points[:, 1] < max_y) 138 | bound_z = torch.logical_and(points[:, 2] >= min_z, points[:, 2] < max_z) 139 | 140 | bb_filter = torch.logical_and(torch.logical_and(bound_x, bound_y), bound_z) 141 | 142 | return bb_filter 143 | 144 | 145 | def cube_filter(points, center, cube_size): 146 | min = center - (cube_size / 2) 147 | max = center + (cube_size / 2) 148 | 149 | if type(center) == np.array: 150 | return np_bb_filter(points, min[0], max[0], min[1], max[1], min[2], max[2]) 151 | 152 | max = max.to(points.device) 153 | min = min.to(points.device) 154 | 155 | return torch_bb_filter(points, min[0], max[0], min[1], max[1], min[2], max[2]) 156 | 157 | 158 | def vertex_dirs(points): 159 | d = points[1:] - points[:-1] 160 | d = d / np.linalg.norm(d) 161 | 162 | smooth = (d[1:] + d[:-1]) * 0.5 163 | dirs = np.concatenate([np.array(d[0:1]), smooth, np.array(d[-2:-1])]) 164 | 165 | return dirs / np.linalg.norm(dirs, axis=1, keepdims=True) 166 | 167 | 168 | def random_unit(dtype=np.float32): 169 | x = np.random.randn(3).astype(dtype) 170 | return x / np.linalg.norm(x) 171 | 172 | 173 | def make_tangent(d, n): 174 | t = np.cross(d, n) 175 | t /= np.linalg.norm(t, axis=-1, keepdims=True) 176 | return np.cross(t, d) 177 | 178 | 179 | def gen_tangents(dirs, t): 180 | tangents = [] 181 | 182 | for dir in dirs: 183 | t = make_tangent(dir, t) 184 | tangents.append(t) 185 | 186 | return np.stack(tangents) 187 | -------------------------------------------------------------------------------- /smart_tree/o3d_abstractions/geometries.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import torch 4 | 5 | # from o3d_abstractions.viewimport o3d_view_geometries 6 | from smart_tree.util.maths import (gen_tangents, random_unit, 7 | vertex_dirs) 8 | 9 | 10 | def o3d_mesh(verts, tris): 11 | return o3d.geometry.TriangleMesh( 12 | o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(tris) 13 | ).compute_triangle_normals() 14 | 15 | 16 | def o3d_merge_meshes(meshes, colourize=False): 17 | if colourize: 18 | for m in meshes: 19 | m.paint_uniform_color(np.random.rand(3)) 20 | 21 | mesh = meshes[0] 22 | for m in meshes[1:]: 23 | mesh += m 24 | return mesh 25 | 26 | 27 | def o3d_merge_linesets(line_sets, colour=[0, 0, 0]): 28 | sizes = [np.asarray(ls.points).shape[0] for ls in line_sets] 29 | offsets = np.cumsum([0] + sizes) 30 | 31 | points = np.concatenate([ls.points for ls in line_sets]) 32 | idxs = np.concatenate([ls.lines + offset for ls, offset in zip(line_sets, offsets)]) 33 | 34 | return o3d.geometry.LineSet( 35 | o3d.utility.Vector3dVector(points), o3d.utility.Vector2iVector(idxs) 36 | ).paint_uniform_color(colour) 37 | 38 | 39 | def points_to_edge_idx(points): 40 | idx = torch.arange(points.shape[0] - 1) 41 | return torch.column_stack((idx, idx + 1)) 42 | 43 | 44 | def o3d_sphere(xyz, radius, colour=(1, 0, 0)): 45 | return ( 46 | o3d.geometry.TriangleMesh.create_sphere(radius) 47 | .translate(np.asarray(xyz)) 48 | .paint_uniform_color(colour) 49 | ) 50 | 51 | 52 | def o3d_spheres(xyzs, radii, colour=None, colours=None): 53 | spheres = [o3d_sphere(xyz, r) for xyz, r in zip(xyzs, radii)] 54 | return paint_o3d_geoms(spheres, colour, colours) 55 | 56 | 57 | def o3d_line_set(vertices, edges, colour=None): 58 | line_set = o3d.geometry.LineSet( 59 | o3d.utility.Vector3dVector(vertices), o3d.utility.Vector2iVector(edges) 60 | ) 61 | if colour is not None: 62 | return line_set.paint_uniform_color(colour) 63 | return line_set 64 | 65 | 66 | def o3d_line_sets(vertices, edges): 67 | return [o3d_line_set(v, e) for v, e in zip(vertices, edges)] 68 | 69 | 70 | def o3d_path(vertices, colour=None): 71 | edge_idx = points_to_edge_idx(vertices) 72 | if colour is not None: 73 | return o3d_line_set(vertices, edge_idx).paint_uniform_color(colour) 74 | return o3d_line_set(vertices, edge_idx) 75 | 76 | 77 | def o3d_paths(vertices): 78 | return [o3d_path(v) for v in vertices] 79 | 80 | 81 | def o3d_merge_clouds(points_clds): 82 | points = np.concatenate([np.asarray(pcd.points) for pcd in points_clds]) 83 | colors = np.concatenate([np.asarray(pcd.colors) for pcd in points_clds]) 84 | 85 | return o3d_cloud(points, colours=colors) 86 | 87 | 88 | def o3d_cloud(points, colour=None, colours=None, normals=None): 89 | cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points)) 90 | if normals is not None: 91 | cloud.normals = o3d.utility.Vector3dVector(normals) 92 | if colour is not None: 93 | return cloud.paint_uniform_color(colour) 94 | elif colours is not None: 95 | cloud.colors = o3d.utility.Vector3dVector(colours) 96 | return cloud 97 | 98 | return cloud.paint_uniform_color([1, 0, 0]) 99 | 100 | 101 | def class_label_o3d_cloud(cloud, class_labels, cmap=[]): 102 | cloud.colors = o3d.utility.Vector3dVector(np.asarray(cmap)[class_labels]) 103 | return cloud 104 | 105 | 106 | def o3d_clouds(batch_points, colour=None, colours=None, p_colours=None): 107 | # KN# 108 | # colour -> paint all the same colour 109 | # colours -> colour for each point cloud 110 | # p_colours -> points for each point in each cloud 111 | if colour is not None: 112 | return [o3d_cloud(p, colour) for p in zip(batch_points)] 113 | 114 | if colours is not None: 115 | return [o3d_cloud(p, c) for p, c in zip(batch_points, colours)] 116 | 117 | if p_colours is not None: 118 | return [o3d_cloud(p, colours=c) for p, c in zip(batch_points, p_colours)] 119 | 120 | return [o3d_cloud(p, n) for p, n in zip(batch_points)] 121 | 122 | 123 | def o3d_voxel_grid( 124 | width: float, 125 | depth: float, 126 | height: float, 127 | voxel_size: float, 128 | origin=np.asarray([0, 0, 0]), 129 | colour=np.asarray([1, 1, 0]), 130 | ): 131 | return o3d.geometry.VoxelGrid.create_dense( 132 | origin, colour, voxel_size, width, depth, height 133 | ) 134 | 135 | 136 | def o3d_cylinder(radius, length, colour=(1, 0, 0)): 137 | return o3d.geometry.TriangleMesh.create_cylinder( 138 | radius, length 139 | ).paint_uniform_color(colour) 140 | 141 | 142 | def o3d_cylinders(radii, length, colour=None, colours=None): 143 | cylinders = [o3d_cylinder(r, l, colour) for r, l in zip(radii, length)] 144 | return paint_o3d_geoms(cylinders, colour, colours) 145 | 146 | 147 | def paint_o3d_geoms(geometries, colour=None, colours=None): 148 | if colours is not None: 149 | return [ 150 | geom.paint_uniform_color(colours[i]) for i, geom in enumerate(geometries) 151 | ] 152 | elif colour is not None: 153 | return [geom.paint_uniform_color(colour) for geom in geometries] 154 | return geometries 155 | 156 | 157 | def unit_circle(n): 158 | a = np.linspace(0, 2 * np.pi, n + 1)[:-1] 159 | return np.stack([np.sin(a), np.cos(a)], axis=1) 160 | 161 | 162 | def cylinder_triangles(m, n): 163 | tri1 = np.array([0, 1, 2]) 164 | tri2 = np.array([2, 3, 0]) 165 | 166 | v0 = np.arange(m) 167 | v1 = (v0 + 1) % m 168 | v2 = v1 + m 169 | v3 = v0 + m 170 | 171 | edges = np.stack([v0, v1, v2, v3], axis=1) 172 | 173 | segments = np.arange(n - 1) * m 174 | edges = edges.reshape(1, *edges.shape) + segments.reshape(n - 1, 1, 1) 175 | 176 | edges = edges.reshape(-1, 4) 177 | return np.concatenate([edges[:, tri1], edges[:, tri2]]) 178 | 179 | 180 | def tube_vertices(points, radii, n=10): 181 | circle = unit_circle(n).astype(np.float32) 182 | 183 | dirs = vertex_dirs(points) 184 | t = gen_tangents(dirs, random_unit()) 185 | 186 | b = np.stack([t, np.cross(t, dirs)], axis=1) 187 | b = b * radii.reshape(-1, 1, 1) 188 | 189 | return np.einsum("bdx,md->bmx", b, circle) + points.reshape(points.shape[0], 1, 3) 190 | 191 | 192 | def o3d_lines_between_clouds(cld1, cld2): 193 | pts1 = np.asarray(cld1.points) 194 | pts2 = np.asarray(cld2.points) 195 | 196 | interweaved = np.hstack((pts1, pts2)).reshape(-1, 3) 197 | return o3d_line_set( 198 | interweaved, np.arange(0, min(pts1.shape[0], pts2.shape[0]) * 2).reshape(-1, 2) 199 | ) 200 | 201 | 202 | def o3d_tube_mesh(points, radii, colour=(1, 0, 0), n=10): 203 | points = tube_vertices(points, radii, n) 204 | 205 | n, m, _ = points.shape 206 | indexes = cylinder_triangles(m, n) 207 | 208 | mesh = o3d_mesh(points.reshape(-1, 3), indexes) 209 | mesh.compute_vertex_normals() 210 | 211 | return mesh.paint_uniform_color(colour) 212 | 213 | 214 | def sample_o3d_lineset(lineset, sample_rate): 215 | edges = np.asarray(lineset.lines) 216 | xyz = np.asarray(lineset.points) 217 | 218 | pts, radius = [], [] 219 | 220 | for i, edge in enumerate(edges): 221 | start = xyz[edge[0]] 222 | end = xyz[edge[1]] 223 | 224 | v = end - start 225 | length = np.linalg.norm(v) 226 | direction = v / length 227 | num_points = np.ceil(length / sample_rate) 228 | 229 | if int(num_points) > 0.0: 230 | spaced_points = np.arange( 231 | 0, float(length), step=float(length / num_points) 232 | ).reshape(-1, 1) 233 | pts.append(start + direction * spaced_points) 234 | 235 | return np.concatenate(pts, axis=0) 236 | -------------------------------------------------------------------------------- /smart_tree/model/train.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import math 4 | import os 5 | from pathlib import Path 6 | from typing import List 7 | 8 | import hydra 9 | import numpy as np 10 | import torch 11 | from hydra.core.hydra_config import HydraConfig 12 | from hydra.utils import instantiate 13 | from omegaconf import DictConfig, OmegaConf 14 | from tqdm import tqdm 15 | 16 | import wandb 17 | from smart_tree.data_types.cloud import Cloud 18 | from smart_tree.o3d_abstractions.camera import Renderer 19 | 20 | from .helper import get_batch, model_output_to_labelled_clds 21 | from .tracker import Tracker 22 | 23 | 24 | def train_epoch( 25 | data_loader, 26 | model, 27 | optimizer, 28 | loss_fn, 29 | fp16=False, 30 | scaler=None, 31 | device=torch.device("cuda"), 32 | ): 33 | tracker = Tracker() 34 | 35 | for sp_input, targets, mask, fn in tqdm( 36 | get_batch(data_loader, device, fp16), 37 | desc="Batch", 38 | leave=False, 39 | total=math.ceil(len(data_loader) / data_loader.batch_size), 40 | ): 41 | preds = model.forward(sp_input) 42 | 43 | loss = loss_fn(preds, targets, mask) 44 | 45 | if fp16: 46 | assert sum(loss.values()).dtype is torch.float32 47 | scaler.scale(sum(loss.values())).backward() 48 | scaler.step(optimizer) 49 | scaler.update() 50 | scale = scaler.get_scale() 51 | else: 52 | (sum(loss.values())).backward() 53 | optimizer.step() 54 | 55 | optimizer.zero_grad() 56 | tracker.update(loss) 57 | 58 | return tracker, scaler 59 | 60 | 61 | @torch.no_grad() 62 | def eval_epoch( 63 | data_loader, 64 | model, 65 | loss_fn, 66 | fp16=False, 67 | device=torch.device("cuda"), 68 | ): 69 | tracker = Tracker() 70 | model.eval() 71 | 72 | for sp_input, targets, mask, fn in tqdm( 73 | get_batch(data_loader, device, fp16), 74 | desc="Evaluating", 75 | leave=False, 76 | total=math.ceil(len(data_loader) / data_loader.batch_size), 77 | ): 78 | preds = model.forward(sp_input) 79 | loss = loss_fn(preds, targets, mask) 80 | tracker.update(loss) 81 | model.train() 82 | 83 | return tracker 84 | 85 | 86 | @torch.no_grad() 87 | def capture_epoch( 88 | data_loader, 89 | model, 90 | cmap, 91 | fp16=False, 92 | device=torch.device("cuda"), 93 | ): 94 | model.eval() 95 | captures = [] 96 | 97 | for sp_input, targets, mask, fn in tqdm( 98 | get_batch(data_loader, device, fp16), 99 | desc="Capturing Outputs", 100 | leave=False, 101 | total=math.ceil(len(data_loader) / data_loader.batch_size), 102 | ): 103 | model_output = model.forward(sp_input) 104 | 105 | labelled_clouds = model_output_to_labelled_clds( 106 | sp_input, 107 | model_output, 108 | cmap, 109 | fn, 110 | ) 111 | 112 | model.train() 113 | return labelled_clouds 114 | 115 | 116 | @torch.no_grad() 117 | def capture_clouds( 118 | data_loader, 119 | model, 120 | cmap, 121 | fp16=False, 122 | device=torch.device("cuda"), 123 | ) -> List[Cloud]: 124 | model.eval() 125 | clouds = [] 126 | 127 | for sp_input, targets, mask, filenames in tqdm( 128 | get_batch(data_loader, device, fp16), 129 | desc="Capturing Outputs", 130 | leave=False, 131 | total=math.ceil(len(data_loader) / data_loader.batch_size), 132 | ): 133 | model_output = model.forward(sp_input) 134 | clouds.extend( 135 | model_output_to_labelled_clds( 136 | sp_input, 137 | model_output, 138 | cmap, 139 | filenames, 140 | ) 141 | ) 142 | 143 | model.train() 144 | return clouds 145 | 146 | 147 | def capture_and_log(loader, model, epoch, wandb_run, cfg): 148 | clouds = capture_clouds( 149 | loader, 150 | model, 151 | cfg.cmap, 152 | fp16=cfg.fp16, 153 | ) 154 | 155 | for cloud in tqdm(clouds, desc="Uploading Clouds", leave=False): 156 | seg_cloud = cloud.to_o3d_seg_cld(np.asarray(cfg.cmap)) 157 | xyz_rgb = np.concatenate( 158 | (np.asarray(seg_cloud.points), np.asarray(seg_cloud.colors) * 255), -1 159 | ) 160 | wandb_run.log( 161 | {f"{Path(cloud.filename).stem}": wandb.Object3D(xyz_rgb)}, 162 | step=epoch, 163 | ) 164 | 165 | 166 | @hydra.main( 167 | version_base=None, 168 | config_path="../conf", 169 | config_name="training.yaml", 170 | ) 171 | def main(cfg: DictConfig): 172 | torch.manual_seed(42) 173 | torch.cuda.manual_seed_all(42) 174 | log = logging.getLogger(__name__) 175 | 176 | wandb.init( 177 | project=cfg.wandb.project, 178 | entity=cfg.wandb.entity, 179 | mode=cfg.wandb.mode, 180 | config=OmegaConf.to_container(cfg, resolve=[True | False]), 181 | ) 182 | run_dir = HydraConfig.get().runtime.output_dir 183 | run_name = wandb.run.name 184 | log.info(f"Directory : {run_dir}") 185 | log.info(f"Machine: {os.uname()[1]}") 186 | 187 | renderer = Renderer(960, 540) 188 | 189 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 190 | 191 | train_loader = instantiate(cfg.train_data_loader) 192 | val_loader = instantiate(cfg.validation_data_loader) 193 | test_loader = instantiate(cfg.test_data_loader) 194 | 195 | log.info(f"Train Dataset Size: {len(train_loader.dataset)}") 196 | log.info(f"Validation Dataset Size: {len(val_loader.dataset)}") 197 | log.info(f"Test Dataset Size: {len(test_loader.dataset)}") 198 | 199 | # Model 200 | model = instantiate(cfg.model).to(device).train() 201 | torch.save(model, f"{run_dir}/{run_name}_model.pt") 202 | 203 | # Optimizer / Scheduler 204 | optimizer = instantiate(cfg.optimizer, params=model.parameters()) 205 | scheduler = instantiate(cfg.scheduler, optimizer=optimizer) 206 | loss_fn = instantiate(cfg.loss_fn) 207 | 208 | # FP-16 209 | amp_ctx = torch.cuda.amp.autocast() if cfg.fp16 else contextlib.nullcontext() 210 | scaler = torch.cuda.amp.grad_scaler.GradScaler() 211 | 212 | epochs_no_improve = 0 213 | best_val_loss = torch.inf 214 | 215 | # Epochs 216 | for epoch in tqdm(range(0, cfg.num_epoch), leave=True, desc="Epoch"): 217 | with amp_ctx: 218 | training_tracker, scaler = train_epoch( 219 | train_loader, 220 | model, 221 | optimizer, 222 | loss_fn, 223 | scaler=scaler, 224 | fp16=cfg.fp16, 225 | ) 226 | 227 | val_tracker = eval_epoch( 228 | val_loader, 229 | model, 230 | loss_fn, 231 | fp16=cfg.fp16, 232 | ) 233 | 234 | test_tracker = eval_epoch( 235 | test_loader, 236 | model, 237 | loss_fn, 238 | fp16=cfg.fp16, 239 | ) 240 | 241 | if (epoch + 1) % cfg.capture_output == 0: 242 | capture_and_log(test_loader, model, epoch, wandb.run, cfg) 243 | capture_and_log(val_loader, model, epoch, wandb.run, cfg) 244 | 245 | scheduler.step(val_tracker.total_loss) if cfg.lr_decay else None 246 | 247 | # Save Best Model 248 | if val_tracker.total_loss < best_val_loss: 249 | epochs_no_improve = 0 250 | best_val_loss = val_tracker.total_loss 251 | wandb.run.summary["Best Test Loss"] = best_val_loss 252 | torch.save(model.state_dict(), f"{run_dir}/{run_name}_model_weights.pt") 253 | log.info(f"Weights Saved at epoch: {epoch}") 254 | else: 255 | epochs_no_improve += 1 256 | 257 | if epochs_no_improve == cfg.early_stop_epoch and cfg.early_stop: 258 | log.info("Training Ended (Evaluation Test Score Not Improving)") 259 | break 260 | 261 | # log onto wandb... 262 | training_tracker.log("Training", epoch) 263 | val_tracker.log("Validation", epoch) 264 | 265 | 266 | if __name__ == "__main__": 267 | main() 268 | -------------------------------------------------------------------------------- /smart_tree/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List 4 | 5 | import torch 6 | import torch.utils.data 7 | from spconv.pytorch.utils import PointToVoxel 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from ..data_types.cloud import Cloud 12 | from ..model.sparse import batch_collate 13 | from ..util.file import load_cloud 14 | from ..util.maths import cube_filter 15 | from ..util.misc import at_least_2d 16 | 17 | 18 | class TreeDataset: 19 | def __init__( 20 | self, 21 | voxel_size: int, 22 | json_path: Path, 23 | directory: Path, 24 | mode: str, 25 | input_features: List[str], 26 | target_features: List[str], 27 | augmentation=None, 28 | cache: bool = False, 29 | device=torch.device("cuda:0"), 30 | ): 31 | self.voxel_size = voxel_size 32 | self.mode = mode 33 | self.augmentation = augmentation 34 | self.directory = directory 35 | self.device = device 36 | 37 | self.input_features = input_features 38 | self.target_features = target_features 39 | 40 | assert Path( 41 | json_path 42 | ).is_file(), f"json metadata does not exist at '{json_path}'" 43 | json_data = json.load(open(json_path)) 44 | 45 | if self.mode == "train": 46 | self.tree_paths = json_data["train"] 47 | elif self.mode == "validation": 48 | self.tree_paths = json_data["validation"] 49 | elif self.mode == "test": 50 | self.tree_paths = json_data["test"] 51 | 52 | missing = [ 53 | path 54 | for path in self.tree_paths 55 | if not Path(f"{self.directory}/{path}").is_file() 56 | ] 57 | 58 | assert len(missing) == 0, f"Missing {len(missing)} files: {missing}" 59 | 60 | self.cache = {} if cache else None 61 | self.load_cloud = load_cloud if self.mode != "test" else load_cloud 62 | 63 | def load(self, filename) -> Cloud: 64 | if self.cache is None: 65 | return self.load_cloud(filename) 66 | 67 | if filename not in self.cache: 68 | self.cache[filename] = self.load_cloud(filename).pin_memory() 69 | 70 | return self.cache[filename] 71 | 72 | def __getitem__(self, idx): 73 | filename = Path(f"{self.directory}/{self.tree_paths[idx]}") 74 | cld = self.load(filename) 75 | 76 | try: 77 | return self.process_cloud(cld, self.tree_paths[idx]) 78 | except Exception: 79 | print(f"Exception processing {filename} with {len(cld)} points") 80 | raise 81 | 82 | def process_cloud(self, cld: Cloud, filename: str): 83 | cld = cld.to_device(self.device) 84 | 85 | if self.augmentation != None: 86 | cld = self.augmentation(cld) 87 | 88 | xyzmin, _ = torch.min(cld.xyz, axis=0) 89 | xyzmax, _ = torch.max(cld.xyz, axis=0) 90 | 91 | # data = cld.cat() 92 | input_features = torch.cat( 93 | [at_least_2d(getattr(cld, attr)) for attr in self.input_features], dim=1 94 | ) 95 | 96 | target_features = torch.cat( 97 | [at_least_2d(getattr(cld, attr)) for attr in self.target_features], dim=1 98 | ) 99 | 100 | data = torch.cat([input_features, target_features], dim=1) 101 | 102 | assert ( 103 | data.shape[0] > 0 104 | ), f"Empty cloud after augmentation: {self.tree_paths[idx]}" 105 | 106 | surface_voxel_generator = PointToVoxel( 107 | vsize_xyz=[self.voxel_size] * 3, 108 | coors_range_xyz=[ 109 | xyzmin[0], 110 | xyzmin[1], 111 | xyzmin[2], 112 | xyzmax[0], 113 | xyzmax[1], 114 | xyzmax[2], 115 | ], 116 | num_point_features=data.shape[1], 117 | max_num_voxels=data.shape[0], 118 | max_num_points_per_voxel=1, 119 | device=data.device, 120 | ) 121 | 122 | feats, coords, _, _ = surface_voxel_generator.generate_voxel_with_id(data) 123 | 124 | indice = torch.zeros( 125 | (coords.shape[0], 1), 126 | dtype=coords.dtype, 127 | device=feats.device, 128 | ) 129 | coords = torch.cat((indice, coords), dim=1) 130 | 131 | feats = feats.squeeze(1) 132 | coords = coords.squeeze(1) 133 | loss_mask = torch.ones(feats.shape[0], dtype=torch.bool, device=feats.device) 134 | 135 | input_feats = feats[:, : input_features.shape[1]] 136 | target_feats = feats[:, input_features.shape[1] :] 137 | 138 | return (input_feats, target_feats), coords, loss_mask, filename 139 | 140 | def __len__(self): 141 | return len(self.tree_paths) 142 | 143 | 144 | class SingleTreeInference: 145 | def __init__( 146 | self, 147 | cloud: Cloud, 148 | voxel_size: float, 149 | block_size: float = 4, 150 | buffer_size: float = 0.4, 151 | min_points=20, 152 | file_name=None, 153 | device=torch.device("cuda:0"), 154 | ): 155 | self.cloud = cloud 156 | 157 | self.voxel_size = voxel_size 158 | self.block_size = block_size 159 | self.buffer_size = buffer_size 160 | self.min_points = min_points 161 | self.device = device 162 | self.file_name = file_name 163 | 164 | self.compute_blocks() 165 | 166 | def compute_blocks(self): 167 | self.xyz_quantized = torch.div( 168 | self.cloud.xyz, self.block_size, rounding_mode="floor" 169 | ) 170 | self.block_ids, pnt_counts = torch.unique( 171 | self.xyz_quantized, return_counts=True, dim=0 172 | ) 173 | 174 | # Remove blocks that have less than specified amount of points... 175 | self.block_ids = self.block_ids[pnt_counts > self.min_points] 176 | self.block_centres = (self.block_ids * self.block_size) + (self.block_size / 2) 177 | 178 | self.clouds: List[Cloud] = [] 179 | 180 | for centre in tqdm(self.block_centres, desc="Computing blocks...", leave=False): 181 | mask = cube_filter( 182 | self.cloud.xyz, 183 | centre, 184 | self.block_size + (self.buffer_size * 2), 185 | ) 186 | block_cloud = self.cloud.filter(mask).to_device(torch.device("cpu")) 187 | 188 | self.clouds.append(block_cloud) 189 | 190 | self.block_centres = self.block_centres.to(torch.device("cpu")) 191 | 192 | def __getitem__(self, idx): 193 | block_centre = self.block_centres[idx] 194 | cloud: Cloud = self.clouds[idx] 195 | 196 | xyzmin, _ = torch.min(cloud.xyz, axis=0) 197 | xyzmax, _ = torch.max(cloud.xyz, axis=0) 198 | 199 | surface_voxel_generator = PointToVoxel( 200 | vsize_xyz=[self.voxel_size] * 3, 201 | coors_range_xyz=[ 202 | xyzmin[0], 203 | xyzmin[1], 204 | xyzmin[2], 205 | xyzmax[0], 206 | xyzmax[1], 207 | xyzmax[2], 208 | ], 209 | num_point_features=6, 210 | max_num_voxels=len(cloud), 211 | max_num_points_per_voxel=1, 212 | ) 213 | 214 | feats, coords, _, voxel_id_tv = surface_voxel_generator.generate_voxel_with_id( 215 | torch.cat((cloud.xyz, cloud.rgb), dim=1).contiguous() 216 | ) # 217 | 218 | indice = torch.zeros((coords.shape[0], 1), dtype=torch.int32) 219 | coords = torch.cat((indice, coords), dim=1) 220 | 221 | feats = feats.squeeze(1) 222 | coords = coords.squeeze(1) 223 | 224 | mask = cube_filter(feats[:, :3], block_centre, self.block_size) 225 | 226 | return feats, coords, mask, self.file_name 227 | 228 | def __len__(self): 229 | return len(self.clouds) 230 | 231 | 232 | def load_dataloader( 233 | cloud: Cloud, 234 | voxel_size: float, 235 | block_size: float, 236 | buffer_size: float, 237 | num_workers: float, 238 | batch_size: float, 239 | ): 240 | dataset = SingleTreeInference(cloud, voxel_size, block_size, buffer_size) 241 | 242 | return DataLoader(dataset, batch_size, num_workers, collate_fn=batch_collate) 243 | -------------------------------------------------------------------------------- /smart_tree/data_types/tree.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from dataclasses import dataclass 3 | from typing import Dict, List 4 | 5 | import open3d as o3d 6 | import torch 7 | import torch.nn.functional as F 8 | from torchtyping import TensorType 9 | from tqdm import tqdm 10 | 11 | from ..o3d_abstractions.geometries import (o3d_merge_linesets, 12 | o3d_merge_meshes) 13 | from ..o3d_abstractions.visualizer import o3d_viewer 14 | from ..util.misc import flatten_list, merge_dictionaries 15 | from ..util.queries import pts_to_nearest_tube_gpu 16 | from .branch import BranchSkeleton 17 | from .tube import Tube 18 | 19 | 20 | @dataclass 21 | class TreeSkeleton: 22 | _id: int 23 | branches: Dict[int, BranchSkeleton] 24 | 25 | def __post_init__(self): 26 | self.colour = torch.rand(3) 27 | 28 | def __len__(self) -> int: 29 | return len(self.branches) 30 | 31 | def __str__(self) -> str: 32 | return f"Tree Skeleton ({self._id}) has {len(self)} branches..." 33 | 34 | def to_o3d_linesets(self) -> List[o3d.geometry.LineSet]: 35 | return [b.to_o3d_lineset() for b in self.branches.values()] 36 | 37 | def to_o3d_lineset(self, colour=(0, 0, 0)) -> o3d.geometry.LineSet: 38 | return o3d_merge_linesets(self.to_o3d_linesets(), colour=colour) 39 | 40 | def to_o3d_tubes(self) -> List[o3d.geometry.TriangleMesh]: 41 | return [b.to_o3d_tube() for b in self.branches.values()] 42 | 43 | def to_o3d_tube(self) -> o3d.geometry.TriangleMesh: 44 | return o3d_merge_meshes(self.to_o3d_tubes()) 45 | 46 | def view(self): 47 | o3d_viewer([self.to_o3d_lineset(), self.to_o3d_tube()]) 48 | 49 | def to_tubes(self) -> List[Tube]: 50 | return flatten_list([b.to_tubes() for b in self.branches.values()]) 51 | 52 | def sample_skeleton(self, spacing): 53 | return sample_tubes(self.to_tubes(), spacing) 54 | 55 | def to_device(self, device): 56 | for branch_id, branch in self.branches(): 57 | return BranchSkeleton( 58 | _id, 59 | parent_id, 60 | self.xyz.to(device), 61 | self.radii.to(device), 62 | child_id, 63 | ) 64 | 65 | def point_to_skeleton(self): 66 | tubes = self.to_tubes() 67 | 68 | def point_to_tube(pt): 69 | return pts_to_nearest_tube_gpu(pt.reshape(-1, 3), tubes) 70 | 71 | return point_to_tube # v, idx, _ 72 | 73 | def repair(self): 74 | """skeletons are not connected between branches. 75 | this function connects the branches to their parent branches by finding 76 | the nearest point on the parent branch.""" 77 | 78 | branch_ids = [branch._id for branch in self.branches.values()] 79 | 80 | for branch in self.branches.values(): 81 | if branch.parent_id not in branch_ids: 82 | continue 83 | 84 | parent_branch = self.branches[branch.parent_id] 85 | tubes = parent_branch.to_tubes() 86 | 87 | v, idx, _ = pts_to_nearest_tube_gpu(branch.xyz[0].reshape(-1, 3), tubes) 88 | 89 | connection_pt = branch.xyz[0].reshape(-1, 3).cpu() + v[0].cpu() 90 | 91 | branch.xyz = torch.cat((connection_pt, branch.xyz)) 92 | branch.radii = torch.cat((branch.radii[[0]], branch.radii)) 93 | 94 | def prune(self, min_radius: float, min_length: float, root_id=None): 95 | """ 96 | If a branch doesn't meet the initial radius threshold or length threshold we want to remove it and all 97 | it's predecessors... 98 | minimum_radius: some point of the branch must be above this to not remove the branch 99 | length: the total lenght of the branch must be greater than this point 100 | """ 101 | root_id = min(self.branches.keys()) if root_id == None else root_id 102 | 103 | keep = {root_id: self.branches[root_id]} 104 | remove = {} 105 | 106 | for branch_id, branch in tqdm( 107 | self.branches.items(), 108 | leave=False, 109 | desc="Pruning Branches", 110 | ): 111 | if branch.parent_id not in keep and branch._id != root_id: 112 | remove[branch_id] = branch 113 | elif branch.length < min_length: 114 | remove[branch_id] = branch 115 | elif branch.initial_radius < min_radius: 116 | remove[branch_id] = branch 117 | else: 118 | keep[branch_id] = branch 119 | 120 | self.branches = keep 121 | return TreeSkeleton(0, remove) 122 | 123 | def smooth(self, kernel_size=5): 124 | """ 125 | Smooths the skeleton radius. 126 | """ 127 | kernel = torch.ones(1, 1, kernel_size) / kernel_size 128 | for branch in self.branches.values(): 129 | if branch.radii.shape[0] > kernel_size: 130 | branch.radii = F.conv1d( 131 | branch.radii.reshape(1, 1, -1), 132 | kernel, 133 | padding="same", 134 | ).reshape(-1) 135 | 136 | @property 137 | def length(self) -> TensorType[1]: 138 | return torch.sum(torch.tensor([b.length for b in self.branches.values()])) 139 | 140 | @property 141 | def biggest_radius_idx(self) -> TensorType[1]: 142 | return torch.argmax(self.radii) 143 | 144 | @property 145 | def key_branch_with_biggest_radius(self) -> TensorType[1]: 146 | """Returns the key of the branch with the biggest radius""" 147 | biggest_branch_radius = 0 148 | for key, branch in self.branches.items(): 149 | if branch.biggest_radius > biggest_branch_radius: 150 | biggest_branch_radius = branch.biggest_radius 151 | biggest_branch_radius_key = key 152 | 153 | return biggest_branch_radius_key 154 | 155 | @property 156 | def max_branch_id(self): 157 | return max(self.branches.keys()) 158 | 159 | 160 | @dataclass 161 | class DisjointTreeSkeleton: 162 | skeletons: List[TreeSkeleton] 163 | 164 | def prune(self, min_radius, min_length): 165 | self.skeletons[0].prune( 166 | min_radius=min_radius, 167 | min_length=min_length, 168 | ) # Can only prune the first skeleton as we don't know the root points for all the other skeletons... 169 | 170 | def repair(self): 171 | for skeleton in self.skeletons: 172 | skeleton.repair() 173 | 174 | def smooth(self, kernel_size=7): 175 | for skeleton in self.skeletons: 176 | skeleton.smooth(kernel_size=kernel_size) 177 | 178 | def to_o3d_lineset(self): 179 | return o3d_merge_linesets( 180 | [s.to_o3d_lineset().paint_uniform_color(s.colour) for s in self.skeletons] 181 | ) 182 | 183 | def to_o3d_tube(self, colour=True): 184 | if colour: 185 | skeleton_tubes = [ 186 | skel.to_o3d_tube().paint_uniform_color(skel.colour) 187 | for skel in self.skeletons 188 | ] 189 | else: 190 | skeleton_tubes = [s.to_o3d_tube() for s in self.skeletons] 191 | 192 | return o3d_merge_meshes(skeleton_tubes) 193 | 194 | def view(self): 195 | o3d_viewer([self.to_o3d_lineset(), self.to_o3d_tube()]) 196 | 197 | def to_pickle(self, path): 198 | with open("disjoint_skeleton.pkl", "wb") as pickle_file: 199 | pickle.dump(self, pickle_file) 200 | 201 | @staticmethod 202 | def from_pickle(path): 203 | with open(f"{path}", "rb") as pickle_file: 204 | return pickle.load(pickle_file) 205 | 206 | 207 | def connect( 208 | skeleton_1: TreeSkeleton, 209 | skeleton_1_child_branch_key: int, 210 | skeleton_1_child_vert_idx: int, 211 | skeleton_2: TreeSkeleton, 212 | skeleton_2_child_branch_key: int, 213 | skeleton_2_child_vert_idx: int, 214 | ) -> TreeSkeleton: 215 | # This is bundy, only visually gives appearance of connection... 216 | # Need to do some more processing to actually connect the skeletons... 217 | 218 | parent_branch = skeleton_1.branches[skeleton_1_parent_branch_key] 219 | child_branch = skeleton_2.branches[skeleton_2_child_branch_key] 220 | 221 | child_branch.parent_id = skeleton_1_parent_branch_key 222 | connection_pt = parent_branch.xyz[skeleton_1_parent_vert_idx] 223 | 224 | child_branch.xyz = torch.cat((connection_pt.unsqueeze(0), child_branch.xyz)) 225 | child_branch.radii = torch.cat((child_branch.radii[[0]], child_branch.radii)) 226 | 227 | for key, branch in skeleton_2.branches.items(): 228 | branch._id += skeleton_1.max_branch_id 229 | 230 | if branch.parent_id != -1: 231 | branch.parent_id += skeleton_1.max_branch_id 232 | 233 | return TreeSkeleton(0, merge_dictionaries(skeleton_1.branches, skeleton_2.branches)) 234 | -------------------------------------------------------------------------------- /smart_tree/model/model_blocks.py: -------------------------------------------------------------------------------- 1 | import spconv.pytorch as spconv 2 | import torch 3 | import torch.cuda.amp 4 | import torch.nn as nn 5 | from spconv.pytorch import SparseModule 6 | 7 | 8 | class SubMConvBlock(SparseModule): 9 | def __init__( 10 | self, 11 | input_channels, 12 | output_channels, 13 | kernel_size, 14 | norm_fn, 15 | activation_fn, 16 | stride=1, 17 | padding=1, 18 | algo=spconv.ConvAlgo.Native, 19 | bias=False, 20 | ): 21 | super().__init__() 22 | 23 | self.sequence = spconv.SparseSequential( 24 | spconv.SubMConv3d( 25 | in_channels=input_channels, 26 | out_channels=output_channels, 27 | kernel_size=kernel_size, 28 | stride=stride, 29 | padding=padding, 30 | bias=bias, 31 | algo=algo, 32 | ), 33 | norm_fn(output_channels), 34 | activation_fn(), 35 | ) 36 | 37 | def forward(self, input): 38 | return self.sequence(input) 39 | 40 | 41 | class EncoderBlock(SparseModule): 42 | def __init__( 43 | self, 44 | input_channels, 45 | output_channels, 46 | kernel_size, 47 | norm_fn, 48 | activation_fn, 49 | stride=2, 50 | padding=1, 51 | key=None, 52 | algo=spconv.ConvAlgo.Native, 53 | bias=False, 54 | ): 55 | super().__init__() 56 | 57 | self.sequence = spconv.SparseSequential( 58 | spconv.SparseConv3d( 59 | input_channels, 60 | output_channels, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | indice_key=key, 64 | algo=algo, 65 | bias=bias, 66 | padding=padding, 67 | ), 68 | norm_fn(output_channels), 69 | activation_fn(), 70 | ) 71 | 72 | def forward(self, input): 73 | return self.sequence(input) 74 | 75 | 76 | class DecoderBlock(SparseModule): 77 | def __init__( 78 | self, 79 | input_channels, 80 | output_channels, 81 | kernel_size, 82 | norm_fn, 83 | activation_fn, 84 | key=None, 85 | algo=spconv.ConvAlgo.Native, 86 | bias=False, 87 | ): 88 | super().__init__() 89 | 90 | self.sequence = spconv.SparseSequential( 91 | spconv.SparseInverseConv3d( 92 | input_channels, 93 | output_channels, 94 | kernel_size, 95 | indice_key=key, 96 | algo=algo, 97 | bias=bias, 98 | ), 99 | norm_fn(output_channels), 100 | activation_fn(), 101 | ) 102 | 103 | def forward(self, input): 104 | return self.sequence(input) 105 | 106 | 107 | class ResBlock(nn.Module): 108 | def __init__( 109 | self, 110 | input_channels, 111 | output_channels, 112 | kernel_size, 113 | norm_fn, 114 | activation_fn, 115 | algo=spconv.ConvAlgo.Native, 116 | bias=False, 117 | ): 118 | super().__init__() 119 | 120 | if input_channels == output_channels: 121 | self.identity = spconv.SparseSequential(nn.Identity()) 122 | else: 123 | self.identity = spconv.SparseSequential( 124 | spconv.SubMConv3d( 125 | input_channels, 126 | output_channels, 127 | kernel_size=1, 128 | padding=1, 129 | bias=False, 130 | algo=algo, 131 | ) 132 | ) 133 | 134 | self.sequence = spconv.SparseSequential( 135 | spconv.SubMConv3d( 136 | input_channels, output_channels, kernel_size, bias=False, algo=algo 137 | ), 138 | norm_fn(output_channels), 139 | activation_fn(), 140 | spconv.SubMConv3d( 141 | output_channels, output_channels, kernel_size, bias=False, algo=algo 142 | ), 143 | norm_fn(output_channels), 144 | ) 145 | 146 | self.activation_fn = spconv.SparseSequential(activation_fn()) 147 | 148 | def forward(self, input): 149 | identity = spconv.SparseConvTensor( 150 | input.features, input.indices, input.spatial_shape, input.batch_size 151 | ) 152 | output = self.sequence(input) 153 | output = output.replace_feature( 154 | output.features + self.identity(identity).features 155 | ) 156 | return self.activation_fn(output) 157 | 158 | 159 | class UBlock(nn.Module): 160 | def __init__( 161 | self, 162 | n_planes, 163 | norm_fn, 164 | activation_fn, 165 | kernel_size=3, 166 | key_id=1, 167 | algo=spconv.ConvAlgo.Native, 168 | bias=False, 169 | ): 170 | super().__init__() 171 | 172 | self.n_planes = n_planes 173 | self.Head = ResBlock( 174 | n_planes[0], 175 | n_planes[0], 176 | kernel_size, 177 | norm_fn, 178 | activation_fn, 179 | algo=algo, 180 | bias=bias, 181 | ) 182 | 183 | if len(n_planes) > 1: 184 | self.Encode = EncoderBlock( 185 | n_planes[0], 186 | n_planes[1], 187 | kernel_size, 188 | norm_fn, 189 | activation_fn, 190 | stride=2, 191 | key=key_id, 192 | algo=algo, 193 | bias=bias, 194 | ) 195 | self.U = UBlock( 196 | n_planes[1:], 197 | norm_fn, 198 | activation_fn, 199 | kernel_size, 200 | key_id + 1, 201 | algo, 202 | bias, 203 | ) 204 | self.Decode = DecoderBlock( 205 | n_planes[1], 206 | n_planes[0], 207 | kernel_size, 208 | norm_fn, 209 | activation_fn, 210 | key=key_id, 211 | algo=algo, 212 | bias=bias, 213 | ) 214 | self.Tail = ResBlock( 215 | n_planes[0] * 2, 216 | n_planes[0], 217 | kernel_size, 218 | norm_fn, 219 | activation_fn, 220 | algo=algo, 221 | bias=bias, 222 | ) 223 | 224 | def forward(self, input): 225 | output = self.Head(input) 226 | 227 | identity = spconv.SparseConvTensor( 228 | output.features, 229 | output.indices, 230 | output.spatial_shape, 231 | output.batch_size, 232 | ) 233 | 234 | if len(self.n_planes) > 1: 235 | output = self.Encode(output) 236 | output = self.U(output) 237 | output = self.Decode(output) 238 | output = output.replace_feature( 239 | torch.cat((identity.features, output.features), dim=1) 240 | ) 241 | output = self.Tail(output) 242 | 243 | return output 244 | 245 | 246 | class SparseFC(nn.Module): 247 | def __init__( 248 | self, 249 | n_planes, 250 | norm_fn, 251 | activation_fn=None, 252 | kernel_size=1, 253 | algo=spconv.ConvAlgo.Native, 254 | bias=False, 255 | ): 256 | super().__init__() 257 | 258 | self.sequence = spconv.SparseSequential() 259 | for i in range(len(n_planes) - 2): 260 | self.sequence.add( 261 | spconv.SubMConv3d( 262 | n_planes[i], 263 | n_planes[i + 1], 264 | kernel_size=kernel_size, 265 | bias=False, 266 | algo=algo, 267 | padding=0, 268 | ) 269 | ) 270 | self.sequence.add(norm_fn(n_planes[i + 1])) 271 | self.sequence.add(activation_fn()) 272 | 273 | self.sequence.add( 274 | spconv.SubMConv3d( 275 | n_planes[-2], 276 | n_planes[-1], 277 | kernel_size=kernel_size, 278 | bias=False, 279 | algo=algo, 280 | padding=0, 281 | ) 282 | ) 283 | 284 | def forward(self, input): 285 | return self.sequence(input) 286 | 287 | 288 | class MLP(nn.Module): 289 | def __init__( 290 | self, 291 | n_planes, 292 | norm_fn, 293 | activation_fn=None, 294 | bias=False, 295 | ): 296 | super().__init__() 297 | 298 | self.sequence = spconv.SparseSequential() 299 | 300 | for i in range(len(n_planes) - 2): 301 | self.sequence.add( 302 | nn.Linear( 303 | n_planes[i], 304 | n_planes[i + 1], 305 | bias=bias, 306 | ) 307 | ) 308 | self.sequence.add(norm_fn(n_planes[i + 1])) 309 | self.sequence.add(activation_fn()) 310 | 311 | self.sequence.add( 312 | nn.Linear( 313 | n_planes[-2], 314 | n_planes[-1], 315 | bias=bias, 316 | ) 317 | ) 318 | 319 | def forward(self, input): 320 | return self.sequence(input) 321 | -------------------------------------------------------------------------------- /smart_tree/data_types/cloud.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torchtyping import TensorType 11 | from typeguard import typechecked 12 | 13 | from ..o3d_abstractions.geometries import o3d_cloud, o3d_lines_between_clouds 14 | from ..o3d_abstractions.visualizer import o3d_viewer 15 | from ..util.misc import voxel_downsample 16 | from ..util.queries import skeleton_to_points 17 | 18 | 19 | @typechecked 20 | @dataclass 21 | class Cloud: 22 | xyz: TensorType["N", 3] 23 | rgb: Optional[TensorType["N", 3]] = None 24 | medial_vector: Optional[TensorType["N", 3]] = None 25 | branch_direction: Optional[TensorType["N", 3]] = None 26 | branch_ids: Optional[TensorType["N", 1]] = None 27 | class_l: Optional[TensorType["N", 1]] = None 28 | filename: Optional[Path] = None 29 | 30 | def __len__(self): 31 | return self.xyz.shape[0] 32 | 33 | def __str__(self): 34 | return f"{'*' * 80}\nCloud with {self.xyz.shape[0]} Points.\nMin: {torch.min(self.xyz, 0)[0]}\nMax: {torch.max(self.xyz, 0)[0]}\nDevice:{self.xyz.device}\n{'*' * 80}\n" 35 | 36 | def paint(self, colour=[1, 0, 0]): 37 | self.rgb = torch.tensor([colour]).expand(self.__len__, -1) 38 | 39 | def to_o3d_cld(self): 40 | cpu_cld = self.cpu() 41 | if not hasattr(cpu_cld, "rgb"): 42 | cpu_cld.paint() 43 | return o3d_cloud(cpu_cld.xyz, colours=cpu_cld.rgb) 44 | 45 | def to_o3d_seg_cld(self, cmap: np.ndarray = np.array([[1, 0, 0], [0, 1, 0]])): 46 | cpu_cld = self.cpu() 47 | colours = cmap[cpu_cld.class_l.view(-1).int()] 48 | return o3d_cloud(cpu_cld.xyz, colours=colours) 49 | 50 | def to_o3d_trunk_cld(self): 51 | cpu_cld = self.cpu() 52 | min_branch_id = cpu_cld.branch_ids[0] 53 | return cpu_cld.filter(cpu_cld.branch_ids == min_branch_id).to_o3d_cld() 54 | 55 | def to_o3d_branch_cld(self): 56 | cpu_cld = self.cpu() 57 | min_branch_id = cpu_cld.branch_ids[0] 58 | return cpu_cld.filter(cpu_cld.branch_ids != min_branch_id).to_o3d_cld() 59 | 60 | def to_o3d_medial_vectors(self, cmap=None): 61 | cpu_cld = self.cpu() 62 | medial_cloud = o3d_cloud(cpu_cld.xyz + cpu_cld.medial_vector) 63 | return o3d_lines_between_clouds(cpu_cld.to_o3d_cld(), medial_cloud) 64 | 65 | def to_o3d_branch_directions(self, scale=0.1, cmap=None): 66 | cpu_cld = self.cpu() 67 | branch_dir_cloud = o3d_cloud(cpu_cld.xyz + (cpu_cld.branch_direction * scale)) 68 | return o3d_lines_between_clouds(cpu_cld.to_o3d_cld(), branch_dir_cloud) 69 | 70 | # def to_wandb_seg_cld(self, cmap: np.ndarray = np.array([[1, 0, 0], [0, 1, 0]])): 71 | 72 | def filter(cloud: Cloud, mask) -> Cloud: 73 | mask = mask.to(cloud.xyz.device) 74 | xyz = cloud.xyz[mask] 75 | rgb = cloud.rgb[mask] if cloud.rgb is not None else None 76 | 77 | medial_vector = ( 78 | cloud.medial_vector[mask] if cloud.medial_vector is not None else None 79 | ) 80 | branch_direction = ( 81 | cloud.branch_direction[mask] if cloud.branch_direction is not None else None 82 | ) 83 | class_l = cloud.class_l[mask] if cloud.class_l is not None else None 84 | branch_ids = cloud.branch_ids[mask] if cloud.branch_ids is not None else None 85 | filename = cloud.filename if cloud.filename is not None else None 86 | 87 | return Cloud( 88 | xyz=xyz, 89 | rgb=rgb, 90 | medial_vector=medial_vector, 91 | branch_direction=branch_direction, 92 | class_l=class_l, 93 | branch_ids=branch_ids, 94 | filename=filename, 95 | ) 96 | 97 | def filter_by_class(self, classes): 98 | classes = torch.tensor(classes, device=self.class_l.device) 99 | mask = torch.isin( 100 | self.class_l, 101 | classes, 102 | ) 103 | return self.filter(mask.view(-1)) 104 | 105 | def filter_by_skeleton(skeleton: TreeSkeleton, cloud: Cloud, threshold=1.1): 106 | distances, radii, vectors_ = skeleton_to_points(cloud, skeleton, chunk_size=512) 107 | mask = distances < radii * threshold 108 | return filter(cloud, mask) 109 | 110 | def pin_memory(self): 111 | xyz = self.xyz.pin_memory() 112 | rgb = self.rgb.pin_memory() if self.rgb is not None else None 113 | medial_vector = ( 114 | self.medial_vector.pin_memory() if self.medial_vector is not None else None 115 | ) 116 | branch_direction = ( 117 | self.branch_direction.pin_memory() 118 | if self.branch_direction is not None 119 | else None 120 | ) 121 | class_l = self.class_l.pin_memory() if self.class_l is not None else None 122 | branch_ids = ( 123 | self.branch_ids.pin_memory() if self.branch_ids is not None else None 124 | ) 125 | 126 | return Cloud( 127 | xyz=xyz, 128 | rgb=rgb, 129 | medial_vector=medial_vector, 130 | branch_direction=branch_direction, 131 | class_l=class_l, 132 | branch_ids=branch_ids, 133 | ) 134 | 135 | def cpu(self): 136 | return self.to_device(torch.device("cpu")) 137 | 138 | def to_device(self, device): 139 | xyz = self.xyz.to(device) 140 | rgb = self.rgb.to(device) if self.rgb is not None else None 141 | medial_vector = ( 142 | self.medial_vector.to(device) if self.medial_vector is not None else None 143 | ) 144 | branch_direction = ( 145 | self.branch_direction.to(device) 146 | if self.branch_direction is not None 147 | else None 148 | ) 149 | class_l = self.class_l.to(device) if self.class_l is not None else None 150 | branch_ids = self.branch_ids.to(device) if self.branch_ids is not None else None 151 | filename = self.filename if self.filename is not None else None 152 | 153 | return Cloud( 154 | xyz=xyz, 155 | rgb=rgb, 156 | medial_vector=medial_vector, 157 | branch_direction=branch_direction, 158 | class_l=class_l, 159 | branch_ids=branch_ids, 160 | filename=filename, 161 | ) 162 | 163 | def cat(self): 164 | return torch.cat( 165 | ( 166 | self.xyz, 167 | self.rgb, 168 | ), 169 | 1, 170 | ) 171 | 172 | def view(self, cmap=[]): 173 | if cmap == []: 174 | cmap = np.random.rand(self.number_classes, 3) 175 | 176 | cpu_cld = self.cpu() 177 | geoms = [] 178 | 179 | geoms.append(cpu_cld.to_o3d_cld()) 180 | if cpu_cld.class_l != None: 181 | geoms.append(cpu_cld.to_o3d_seg_cld(cmap)) 182 | 183 | if cpu_cld.medial_vector != None: 184 | projected = o3d_cloud(cpu_cld.xyz + cpu_cld.medial_vector, colour=(1, 0, 0)) 185 | geoms.append(projected) 186 | geoms.append(o3d_lines_between_clouds(cpu_cld.to_o3d_cld(), projected)) 187 | 188 | o3d_viewer(geoms) 189 | 190 | def voxel_down_sample(self, voxel_size): 191 | idx = voxel_downsample(self.xyz, voxel_size) 192 | return self.filter(idx) 193 | 194 | def scale(self, factor): 195 | return Cloud(self.xyz * factor, self.rgb) 196 | 197 | def translate(self, xyz): 198 | return Cloud(self.xyz + xyz.to(self.xyz.device), self.rgb) 199 | 200 | def rotate(self, rot_mat): 201 | rot_mat = rot_mat.to(self.xyz.dtype) 202 | return Cloud(torch.matmul(self.xyz, rot_mat.to(self.xyz.device)), self.rgb) 203 | 204 | @property 205 | def root_idx(self) -> int: 206 | return torch.argmin(self.xyz[:, 1]).item() 207 | 208 | @property 209 | def number_classes(self) -> int: 210 | if not hasattr(self, "class_l"): 211 | return 1 212 | return torch.max(self.class_l).item() + 1 213 | 214 | @property 215 | def max_xyz(self) -> torch.Tensor: 216 | return torch.max(self.xyz, 0)[0] 217 | 218 | @property 219 | def min_xyz(self) -> torch.Tensor: 220 | return torch.min(self.xyz, 0)[0] 221 | 222 | @property 223 | def bbox(self) -> tuple[torch.Tensor, torch.Tensor]: 224 | # defined by centre coordinate, x/2, y/2, z/2 225 | dimensions = (self.max_xyz - self.min_xyz) / 2 226 | centre = self.min_xyz + dimensions 227 | return centre, dimensions 228 | 229 | @property 230 | def medial_pts(self) -> torch.Tensor: 231 | return self.xyz + self.medial_vector 232 | 233 | @staticmethod 234 | def from_numpy(**kwargs) -> "Cloud": 235 | torch_kwargs = {} 236 | 237 | for key, value in kwargs.items(): 238 | if key in [ 239 | "xyz", 240 | "rgb", 241 | "medial_vector", 242 | "branch_direction", 243 | "branch_ids", 244 | "class_l", 245 | ]: 246 | torch_kwargs[key] = torch.tensor(value).float() 247 | 248 | """ SUPPORT LEGACY NPZ -> Remove in Future...""" 249 | if key in ["vector"]: 250 | torch_kwargs["medial_vector"] = torch.tensor(value) 251 | 252 | return Cloud(**torch_kwargs) 253 | 254 | @property 255 | def radius(self) -> torch.Tensor: 256 | return self.medial_vector.pow(2).sum(1).sqrt() 257 | 258 | @property 259 | def direction(self) -> torch.Tensor: 260 | return F.normalize(self.medial_vector) 261 | 262 | @staticmethod 263 | def from_o3d_cld(cld) -> Cloud: 264 | return Cloud.from_numpy(xyz=np.asarray(cld.points), rgb=np.asarray(cld.colors)) 265 | -------------------------------------------------------------------------------- /smart_tree/conf/training-split.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | "/cherry/cherry_21.npz", 4 | "/cherry/cherry_41.npz", 5 | "/cherry/cherry_67.npz", 6 | "/cherry/cherry_12.npz", 7 | "/cherry/cherry_92.npz", 8 | "/cherry/cherry_33.npz", 9 | "/cherry/cherry_29.npz", 10 | "/cherry/cherry_11.npz", 11 | "/cherry/cherry_84.npz", 12 | "/cherry/cherry_24.npz", 13 | "/cherry/cherry_5.npz", 14 | "/cherry/cherry_10.npz", 15 | "/cherry/cherry_66.npz", 16 | "/cherry/cherry_38.npz", 17 | "/cherry/cherry_61.npz", 18 | "/cherry/cherry_48.npz", 19 | "/cherry/cherry_26.npz", 20 | "/cherry/cherry_1.npz", 21 | "/cherry/cherry_40.npz", 22 | "/cherry/cherry_51.npz", 23 | "/cherry/cherry_70.npz", 24 | "/cherry/cherry_31.npz", 25 | "/cherry/cherry_81.npz", 26 | "/cherry/cherry_97.npz", 27 | "/cherry/cherry_62.npz", 28 | "/cherry/cherry_72.npz", 29 | "/cherry/cherry_86.npz", 30 | "/cherry/cherry_39.npz", 31 | "/cherry/cherry_79.npz", 32 | "/cherry/cherry_100.npz", 33 | "/cherry/cherry_83.npz", 34 | "/cherry/cherry_35.npz", 35 | "/cherry/cherry_87.npz", 36 | "/cherry/cherry_17.npz", 37 | "/cherry/cherry_57.npz", 38 | "/cherry/cherry_68.npz", 39 | "/cherry/cherry_77.npz", 40 | "/cherry/cherry_60.npz", 41 | "/cherry/cherry_3.npz", 42 | "/cherry/cherry_22.npz", 43 | "/cherry/cherry_80.npz", 44 | "/cherry/cherry_14.npz", 45 | "/cherry/cherry_23.npz", 46 | "/cherry/cherry_2.npz", 47 | "/cherry/cherry_53.npz", 48 | "/cherry/cherry_8.npz", 49 | "/cherry/cherry_32.npz", 50 | "/cherry/cherry_54.npz", 51 | "/cherry/cherry_63.npz", 52 | "/cherry/cherry_52.npz", 53 | "/cherry/cherry_16.npz", 54 | "/cherry/cherry_71.npz", 55 | "/cherry/cherry_44.npz", 56 | "/cherry/cherry_43.npz", 57 | "/cherry/cherry_59.npz", 58 | "/cherry/cherry_47.npz", 59 | "/cherry/cherry_69.npz", 60 | "/cherry/cherry_9.npz", 61 | "/cherry/cherry_95.npz", 62 | "/cherry/cherry_76.npz", 63 | "/cherry/cherry_93.npz", 64 | "/cherry/cherry_78.npz", 65 | "/cherry/cherry_99.npz", 66 | "/cherry/cherry_36.npz", 67 | "/cherry/cherry_27.npz", 68 | "/cherry/cherry_55.npz", 69 | "/cherry/cherry_45.npz", 70 | "/cherry/cherry_91.npz", 71 | "/cherry/cherry_46.npz", 72 | "/cherry/cherry_65.npz", 73 | "/cherry/cherry_85.npz", 74 | "/cherry/cherry_88.npz", 75 | "/cherry/cherry_58.npz", 76 | "/cherry/cherry_30.npz", 77 | "/cherry/cherry_37.npz", 78 | "/cherry/cherry_50.npz", 79 | "/cherry/cherry_25.npz", 80 | "/cherry/cherry_20.npz", 81 | "/cherry/cherry_64.npz", 82 | "/cherry/cherry_74.npz", 83 | "/apple/apple_57.npz", 84 | "/apple/apple_82.npz", 85 | "/apple/apple_38.npz", 86 | "/apple/apple_94.npz", 87 | "/apple/apple_49.npz", 88 | "/apple/apple_70.npz", 89 | "/apple/apple_71.npz", 90 | "/apple/apple_90.npz", 91 | "/apple/apple_41.npz", 92 | "/apple/apple_8.npz", 93 | "/apple/apple_75.npz", 94 | "/apple/apple_55.npz", 95 | "/apple/apple_19.npz", 96 | "/apple/apple_30.npz", 97 | "/apple/apple_88.npz", 98 | "/apple/apple_15.npz", 99 | "/apple/apple_65.npz", 100 | "/apple/apple_45.npz", 101 | "/apple/apple_78.npz", 102 | "/apple/apple_18.npz", 103 | "/apple/apple_26.npz", 104 | "/apple/apple_25.npz", 105 | "/apple/apple_80.npz", 106 | "/apple/apple_54.npz", 107 | "/apple/apple_72.npz", 108 | "/apple/apple_50.npz", 109 | "/apple/apple_1.npz", 110 | "/apple/apple_93.npz", 111 | "/apple/apple_99.npz", 112 | "/apple/apple_74.npz", 113 | "/apple/apple_33.npz", 114 | "/apple/apple_13.npz", 115 | "/apple/apple_34.npz", 116 | "/apple/apple_21.npz", 117 | "/apple/apple_59.npz", 118 | "/apple/apple_7.npz", 119 | "/apple/apple_84.npz", 120 | "/apple/apple_81.npz", 121 | "/apple/apple_85.npz", 122 | "/apple/apple_79.npz", 123 | "/apple/apple_29.npz", 124 | "/apple/apple_43.npz", 125 | "/apple/apple_40.npz", 126 | "/apple/apple_61.npz", 127 | "/apple/apple_68.npz", 128 | "/apple/apple_3.npz", 129 | "/apple/apple_24.npz", 130 | "/apple/apple_67.npz", 131 | "/apple/apple_32.npz", 132 | "/apple/apple_95.npz", 133 | "/apple/apple_4.npz", 134 | "/apple/apple_91.npz", 135 | "/apple/apple_46.npz", 136 | "/apple/apple_20.npz", 137 | "/apple/apple_89.npz", 138 | "/apple/apple_51.npz", 139 | "/apple/apple_10.npz", 140 | "/apple/apple_98.npz", 141 | "/apple/apple_47.npz", 142 | "/apple/apple_62.npz", 143 | "/apple/apple_52.npz", 144 | "/apple/apple_2.npz", 145 | "/apple/apple_31.npz", 146 | "/apple/apple_58.npz", 147 | "/apple/apple_22.npz", 148 | "/apple/apple_37.npz", 149 | "/apple/apple_9.npz", 150 | "/apple/apple_28.npz", 151 | "/apple/apple_69.npz", 152 | "/apple/apple_76.npz", 153 | "/apple/apple_100.npz", 154 | "/apple/apple_14.npz", 155 | "/apple/apple_44.npz", 156 | "/apple/apple_6.npz", 157 | "/apple/apple_66.npz", 158 | "/apple/apple_35.npz", 159 | "/apple/apple_27.npz", 160 | "/apple/apple_87.npz", 161 | "/apple/apple_48.npz", 162 | "/apple/apple_60.npz", 163 | "/ginkgo/ginkgo_75.npz", 164 | "/ginkgo/ginkgo_44.npz", 165 | "/ginkgo/ginkgo_65.npz", 166 | "/ginkgo/ginkgo_32.npz", 167 | "/ginkgo/ginkgo_82.npz", 168 | "/ginkgo/ginkgo_39.npz", 169 | "/ginkgo/ginkgo_38.npz", 170 | "/ginkgo/ginkgo_95.npz", 171 | "/ginkgo/ginkgo_4.npz", 172 | "/ginkgo/ginkgo_70.npz", 173 | "/ginkgo/ginkgo_1.npz", 174 | "/ginkgo/ginkgo_29.npz", 175 | "/ginkgo/ginkgo_90.npz", 176 | "/ginkgo/ginkgo_13.npz", 177 | "/ginkgo/ginkgo_2.npz", 178 | "/ginkgo/ginkgo_55.npz", 179 | "/ginkgo/ginkgo_68.npz", 180 | "/ginkgo/ginkgo_49.npz", 181 | "/ginkgo/ginkgo_63.npz", 182 | "/ginkgo/ginkgo_47.npz", 183 | "/ginkgo/ginkgo_66.npz", 184 | "/ginkgo/ginkgo_40.npz", 185 | "/ginkgo/ginkgo_97.npz", 186 | "/ginkgo/ginkgo_81.npz", 187 | "/ginkgo/ginkgo_41.npz", 188 | "/ginkgo/ginkgo_26.npz", 189 | "/ginkgo/ginkgo_42.npz", 190 | "/ginkgo/ginkgo_35.npz", 191 | "/ginkgo/ginkgo_10.npz", 192 | "/ginkgo/ginkgo_53.npz", 193 | "/ginkgo/ginkgo_24.npz", 194 | "/ginkgo/ginkgo_31.npz", 195 | "/ginkgo/ginkgo_45.npz", 196 | "/ginkgo/ginkgo_8.npz", 197 | "/ginkgo/ginkgo_87.npz", 198 | "/ginkgo/ginkgo_12.npz", 199 | "/ginkgo/ginkgo_92.npz", 200 | "/ginkgo/ginkgo_59.npz", 201 | "/ginkgo/ginkgo_22.npz", 202 | "/ginkgo/ginkgo_48.npz", 203 | "/ginkgo/ginkgo_76.npz", 204 | "/ginkgo/ginkgo_50.npz", 205 | "/ginkgo/ginkgo_84.npz", 206 | "/ginkgo/ginkgo_30.npz", 207 | "/ginkgo/ginkgo_80.npz", 208 | "/ginkgo/ginkgo_5.npz", 209 | "/ginkgo/ginkgo_28.npz", 210 | "/ginkgo/ginkgo_57.npz", 211 | "/ginkgo/ginkgo_79.npz", 212 | "/ginkgo/ginkgo_89.npz", 213 | "/ginkgo/ginkgo_17.npz", 214 | "/ginkgo/ginkgo_19.npz", 215 | "/ginkgo/ginkgo_98.npz", 216 | "/ginkgo/ginkgo_67.npz", 217 | "/ginkgo/ginkgo_69.npz", 218 | "/ginkgo/ginkgo_58.npz", 219 | "/ginkgo/ginkgo_33.npz", 220 | "/ginkgo/ginkgo_85.npz", 221 | "/ginkgo/ginkgo_3.npz", 222 | "/ginkgo/ginkgo_14.npz", 223 | "/ginkgo/ginkgo_51.npz", 224 | "/ginkgo/ginkgo_15.npz", 225 | "/ginkgo/ginkgo_72.npz", 226 | "/ginkgo/ginkgo_46.npz", 227 | "/ginkgo/ginkgo_16.npz", 228 | "/ginkgo/ginkgo_52.npz", 229 | "/ginkgo/ginkgo_34.npz", 230 | "/ginkgo/ginkgo_100.npz", 231 | "/ginkgo/ginkgo_74.npz", 232 | "/ginkgo/ginkgo_91.npz", 233 | "/ginkgo/ginkgo_78.npz", 234 | "/ginkgo/ginkgo_83.npz", 235 | "/ginkgo/ginkgo_6.npz", 236 | "/ginkgo/ginkgo_73.npz", 237 | "/ginkgo/ginkgo_20.npz", 238 | "/ginkgo/ginkgo_27.npz", 239 | "/ginkgo/ginkgo_25.npz", 240 | "/ginkgo/ginkgo_96.npz", 241 | "/ginkgo/ginkgo_9.npz", 242 | "/ginkgo/ginkgo_60.npz", 243 | "/walnut/walnut_59.npz", 244 | "/walnut/walnut_32.npz", 245 | "/walnut/walnut_10.npz", 246 | "/walnut/walnut_16.npz", 247 | "/walnut/walnut_48.npz", 248 | "/walnut/walnut_60.npz", 249 | "/walnut/walnut_23.npz", 250 | "/walnut/walnut_55.npz", 251 | "/walnut/walnut_1.npz", 252 | "/walnut/walnut_19.npz", 253 | "/walnut/walnut_35.npz", 254 | "/walnut/walnut_6.npz", 255 | "/walnut/walnut_49.npz", 256 | "/walnut/walnut_42.npz", 257 | "/walnut/walnut_62.npz", 258 | "/walnut/walnut_76.npz", 259 | "/walnut/walnut_21.npz", 260 | "/walnut/walnut_73.npz", 261 | "/walnut/walnut_58.npz", 262 | "/walnut/walnut_63.npz", 263 | "/walnut/walnut_85.npz", 264 | "/walnut/walnut_72.npz", 265 | "/walnut/walnut_39.npz", 266 | "/walnut/walnut_69.npz", 267 | "/walnut/walnut_78.npz", 268 | "/walnut/walnut_41.npz", 269 | "/walnut/walnut_3.npz", 270 | "/walnut/walnut_18.npz", 271 | "/walnut/walnut_99.npz", 272 | "/walnut/walnut_29.npz", 273 | "/walnut/walnut_31.npz", 274 | "/walnut/walnut_33.npz", 275 | "/walnut/walnut_89.npz", 276 | "/walnut/walnut_90.npz", 277 | "/walnut/walnut_83.npz", 278 | "/walnut/walnut_4.npz", 279 | "/walnut/walnut_24.npz", 280 | "/walnut/walnut_52.npz", 281 | "/walnut/walnut_38.npz", 282 | "/walnut/walnut_28.npz", 283 | "/walnut/walnut_51.npz", 284 | "/walnut/walnut_71.npz", 285 | "/walnut/walnut_100.npz", 286 | "/walnut/walnut_30.npz", 287 | "/walnut/walnut_64.npz", 288 | "/walnut/walnut_57.npz", 289 | "/walnut/walnut_75.npz", 290 | "/walnut/walnut_77.npz", 291 | "/walnut/walnut_81.npz", 292 | "/walnut/walnut_22.npz", 293 | "/walnut/walnut_13.npz", 294 | "/walnut/walnut_53.npz", 295 | "/walnut/walnut_56.npz", 296 | "/walnut/walnut_11.npz", 297 | "/walnut/walnut_8.npz", 298 | "/walnut/walnut_20.npz", 299 | "/walnut/walnut_65.npz", 300 | "/walnut/walnut_88.npz", 301 | "/walnut/walnut_54.npz", 302 | "/walnut/walnut_40.npz", 303 | "/walnut/walnut_95.npz", 304 | "/walnut/walnut_70.npz", 305 | "/walnut/walnut_37.npz", 306 | "/walnut/walnut_17.npz", 307 | "/walnut/walnut_44.npz", 308 | "/walnut/walnut_5.npz", 309 | "/walnut/walnut_36.npz", 310 | "/walnut/walnut_9.npz", 311 | "/walnut/walnut_86.npz", 312 | "/walnut/walnut_50.npz", 313 | "/walnut/walnut_46.npz", 314 | "/walnut/walnut_34.npz", 315 | "/walnut/walnut_82.npz", 316 | "/walnut/walnut_61.npz", 317 | "/walnut/walnut_96.npz", 318 | "/walnut/walnut_25.npz", 319 | "/walnut/walnut_80.npz", 320 | "/walnut/walnut_45.npz", 321 | "/walnut/walnut_94.npz", 322 | "/walnut/walnut_87.npz", 323 | "/pine/pine_47.npz", 324 | "/pine/pine_88.npz", 325 | "/pine/pine_58.npz", 326 | "/pine/pine_63.npz", 327 | "/pine/pine_32.npz", 328 | "/pine/pine_61.npz", 329 | "/pine/pine_81.npz", 330 | "/pine/pine_16.npz", 331 | "/pine/pine_50.npz", 332 | "/pine/pine_48.npz", 333 | "/pine/pine_44.npz", 334 | "/pine/pine_98.npz", 335 | "/pine/pine_17.npz", 336 | "/pine/pine_1.npz", 337 | "/pine/pine_36.npz", 338 | "/pine/pine_31.npz", 339 | "/pine/pine_100.npz", 340 | "/pine/pine_69.npz", 341 | "/pine/pine_49.npz", 342 | "/pine/pine_24.npz", 343 | "/pine/pine_78.npz", 344 | "/pine/pine_90.npz", 345 | "/pine/pine_77.npz", 346 | "/pine/pine_57.npz", 347 | "/pine/pine_41.npz", 348 | "/pine/pine_19.npz", 349 | "/pine/pine_56.npz", 350 | "/pine/pine_53.npz", 351 | "/pine/pine_60.npz", 352 | "/pine/pine_38.npz", 353 | "/pine/pine_25.npz", 354 | "/pine/pine_20.npz", 355 | "/pine/pine_89.npz", 356 | "/pine/pine_6.npz", 357 | "/pine/pine_82.npz", 358 | "/pine/pine_3.npz", 359 | "/pine/pine_97.npz", 360 | "/pine/pine_34.npz", 361 | "/pine/pine_73.npz", 362 | "/pine/pine_26.npz", 363 | "/pine/pine_51.npz", 364 | "/pine/pine_5.npz", 365 | "/pine/pine_91.npz", 366 | "/pine/pine_75.npz", 367 | "/pine/pine_66.npz", 368 | "/pine/pine_12.npz", 369 | "/pine/pine_96.npz", 370 | "/pine/pine_85.npz", 371 | "/pine/pine_86.npz", 372 | "/pine/pine_23.npz", 373 | "/pine/pine_64.npz", 374 | "/pine/pine_79.npz", 375 | "/pine/pine_59.npz", 376 | "/pine/pine_39.npz", 377 | "/pine/pine_94.npz", 378 | "/pine/pine_42.npz", 379 | "/pine/pine_54.npz", 380 | "/pine/pine_15.npz", 381 | "/pine/pine_72.npz", 382 | "/pine/pine_4.npz", 383 | "/pine/pine_28.npz", 384 | "/pine/pine_93.npz", 385 | "/pine/pine_87.npz", 386 | "/pine/pine_74.npz", 387 | "/pine/pine_18.npz", 388 | "/pine/pine_71.npz", 389 | "/pine/pine_92.npz", 390 | "/pine/pine_55.npz", 391 | "/pine/pine_62.npz", 392 | "/pine/pine_80.npz", 393 | "/pine/pine_68.npz", 394 | "/pine/pine_70.npz", 395 | "/pine/pine_14.npz", 396 | "/pine/pine_22.npz", 397 | "/pine/pine_95.npz", 398 | "/pine/pine_40.npz", 399 | "/pine/pine_8.npz", 400 | "/pine/pine_65.npz", 401 | "/pine/pine_37.npz", 402 | "/pine/pine_29.npz", 403 | "/eucalyptus/eucalyptus_79.npz", 404 | "/eucalyptus/eucalyptus_90.npz", 405 | "/eucalyptus/eucalyptus_23.npz", 406 | "/eucalyptus/eucalyptus_2.npz", 407 | "/eucalyptus/eucalyptus_17.npz", 408 | "/eucalyptus/eucalyptus_74.npz", 409 | "/eucalyptus/eucalyptus_93.npz", 410 | "/eucalyptus/eucalyptus_85.npz", 411 | "/eucalyptus/eucalyptus_24.npz", 412 | "/eucalyptus/eucalyptus_86.npz", 413 | "/eucalyptus/eucalyptus_82.npz", 414 | "/eucalyptus/eucalyptus_15.npz", 415 | "/eucalyptus/eucalyptus_64.npz", 416 | "/eucalyptus/eucalyptus_97.npz", 417 | "/eucalyptus/eucalyptus_43.npz", 418 | "/eucalyptus/eucalyptus_6.npz", 419 | "/eucalyptus/eucalyptus_78.npz", 420 | "/eucalyptus/eucalyptus_31.npz", 421 | "/eucalyptus/eucalyptus_29.npz", 422 | "/eucalyptus/eucalyptus_75.npz", 423 | "/eucalyptus/eucalyptus_84.npz", 424 | "/eucalyptus/eucalyptus_37.npz", 425 | "/eucalyptus/eucalyptus_18.npz", 426 | "/eucalyptus/eucalyptus_100.npz", 427 | "/eucalyptus/eucalyptus_20.npz", 428 | "/eucalyptus/eucalyptus_99.npz", 429 | "/eucalyptus/eucalyptus_62.npz", 430 | "/eucalyptus/eucalyptus_4.npz", 431 | "/eucalyptus/eucalyptus_38.npz", 432 | "/eucalyptus/eucalyptus_42.npz", 433 | "/eucalyptus/eucalyptus_72.npz", 434 | "/eucalyptus/eucalyptus_1.npz", 435 | "/eucalyptus/eucalyptus_51.npz", 436 | "/eucalyptus/eucalyptus_87.npz", 437 | "/eucalyptus/eucalyptus_80.npz", 438 | "/eucalyptus/eucalyptus_98.npz", 439 | "/eucalyptus/eucalyptus_91.npz", 440 | "/eucalyptus/eucalyptus_27.npz", 441 | "/eucalyptus/eucalyptus_11.npz", 442 | "/eucalyptus/eucalyptus_66.npz", 443 | "/eucalyptus/eucalyptus_63.npz", 444 | "/eucalyptus/eucalyptus_69.npz", 445 | "/eucalyptus/eucalyptus_59.npz", 446 | "/eucalyptus/eucalyptus_61.npz", 447 | "/eucalyptus/eucalyptus_13.npz", 448 | "/eucalyptus/eucalyptus_7.npz", 449 | "/eucalyptus/eucalyptus_41.npz", 450 | "/eucalyptus/eucalyptus_67.npz", 451 | "/eucalyptus/eucalyptus_77.npz", 452 | "/eucalyptus/eucalyptus_92.npz", 453 | "/eucalyptus/eucalyptus_26.npz", 454 | "/eucalyptus/eucalyptus_83.npz", 455 | "/eucalyptus/eucalyptus_95.npz", 456 | "/eucalyptus/eucalyptus_45.npz", 457 | "/eucalyptus/eucalyptus_60.npz", 458 | "/eucalyptus/eucalyptus_19.npz", 459 | "/eucalyptus/eucalyptus_49.npz", 460 | "/eucalyptus/eucalyptus_9.npz", 461 | "/eucalyptus/eucalyptus_46.npz", 462 | "/eucalyptus/eucalyptus_3.npz", 463 | "/eucalyptus/eucalyptus_22.npz", 464 | "/eucalyptus/eucalyptus_10.npz", 465 | "/eucalyptus/eucalyptus_39.npz", 466 | "/eucalyptus/eucalyptus_56.npz", 467 | "/eucalyptus/eucalyptus_71.npz", 468 | "/eucalyptus/eucalyptus_48.npz", 469 | "/eucalyptus/eucalyptus_36.npz", 470 | "/eucalyptus/eucalyptus_68.npz", 471 | "/eucalyptus/eucalyptus_30.npz", 472 | "/eucalyptus/eucalyptus_44.npz", 473 | "/eucalyptus/eucalyptus_34.npz", 474 | "/eucalyptus/eucalyptus_55.npz", 475 | "/eucalyptus/eucalyptus_8.npz", 476 | "/eucalyptus/eucalyptus_25.npz", 477 | "/eucalyptus/eucalyptus_65.npz", 478 | "/eucalyptus/eucalyptus_81.npz", 479 | "/eucalyptus/eucalyptus_32.npz", 480 | "/eucalyptus/eucalyptus_16.npz", 481 | "/eucalyptus/eucalyptus_96.npz", 482 | "/eucalyptus/eucalyptus_40.npz" 483 | ], 484 | "test": [ 485 | "/cherry/cherry_15.npz", 486 | "/cherry/cherry_90.npz", 487 | "/cherry/cherry_75.npz", 488 | "/cherry/cherry_82.npz", 489 | "/cherry/cherry_34.npz", 490 | "/cherry/cherry_4.npz", 491 | "/cherry/cherry_49.npz", 492 | "/cherry/cherry_28.npz", 493 | "/cherry/cherry_19.npz", 494 | "/cherry/cherry_56.npz", 495 | "/apple/apple_97.npz", 496 | "/apple/apple_53.npz", 497 | "/apple/apple_63.npz", 498 | "/apple/apple_23.npz", 499 | "/apple/apple_83.npz", 500 | "/apple/apple_92.npz", 501 | "/apple/apple_16.npz", 502 | "/apple/apple_86.npz", 503 | "/apple/apple_36.npz", 504 | "/apple/apple_12.npz", 505 | "/ginkgo/ginkgo_21.npz", 506 | "/ginkgo/ginkgo_99.npz", 507 | "/ginkgo/ginkgo_88.npz", 508 | "/ginkgo/ginkgo_77.npz", 509 | "/ginkgo/ginkgo_11.npz", 510 | "/ginkgo/ginkgo_37.npz", 511 | "/ginkgo/ginkgo_18.npz", 512 | "/ginkgo/ginkgo_7.npz", 513 | "/ginkgo/ginkgo_71.npz", 514 | "/ginkgo/ginkgo_93.npz", 515 | "/walnut/walnut_67.npz", 516 | "/walnut/walnut_7.npz", 517 | "/walnut/walnut_92.npz", 518 | "/walnut/walnut_2.npz", 519 | "/walnut/walnut_12.npz", 520 | "/walnut/walnut_27.npz", 521 | "/walnut/walnut_43.npz", 522 | "/walnut/walnut_93.npz", 523 | "/walnut/walnut_66.npz", 524 | "/walnut/walnut_26.npz", 525 | "/pine/pine_13.npz", 526 | "/pine/pine_52.npz", 527 | "/pine/pine_2.npz", 528 | "/pine/pine_27.npz", 529 | "/pine/pine_46.npz", 530 | "/pine/pine_76.npz", 531 | "/pine/pine_30.npz", 532 | "/pine/pine_67.npz", 533 | "/pine/pine_21.npz", 534 | "/pine/pine_11.npz", 535 | "/eucalyptus/eucalyptus_28.npz", 536 | "/eucalyptus/eucalyptus_5.npz", 537 | "/eucalyptus/eucalyptus_94.npz", 538 | "/eucalyptus/eucalyptus_89.npz", 539 | "/eucalyptus/eucalyptus_54.npz", 540 | "/eucalyptus/eucalyptus_53.npz", 541 | "/eucalyptus/eucalyptus_58.npz", 542 | "/eucalyptus/eucalyptus_73.npz", 543 | "/eucalyptus/eucalyptus_35.npz", 544 | "/eucalyptus/eucalyptus_70.npz" 545 | ], 546 | "validation": [ 547 | "/cherry/cherry_42.npz", 548 | "/cherry/cherry_13.npz", 549 | "/cherry/cherry_89.npz", 550 | "/cherry/cherry_18.npz", 551 | "/cherry/cherry_7.npz", 552 | "/cherry/cherry_73.npz", 553 | "/cherry/cherry_98.npz", 554 | "/cherry/cherry_96.npz", 555 | "/cherry/cherry_6.npz", 556 | "/cherry/cherry_94.npz", 557 | "/apple/apple_39.npz", 558 | "/apple/apple_56.npz", 559 | "/apple/apple_96.npz", 560 | "/apple/apple_17.npz", 561 | "/apple/apple_11.npz", 562 | "/apple/apple_5.npz", 563 | "/apple/apple_64.npz", 564 | "/apple/apple_42.npz", 565 | "/apple/apple_77.npz", 566 | "/apple/apple_73.npz", 567 | "/ginkgo/ginkgo_64.npz", 568 | "/ginkgo/ginkgo_54.npz", 569 | "/ginkgo/ginkgo_23.npz", 570 | "/ginkgo/ginkgo_94.npz", 571 | "/ginkgo/ginkgo_36.npz", 572 | "/ginkgo/ginkgo_56.npz", 573 | "/ginkgo/ginkgo_86.npz", 574 | "/ginkgo/ginkgo_61.npz", 575 | "/ginkgo/ginkgo_43.npz", 576 | "/ginkgo/ginkgo_62.npz", 577 | "/walnut/walnut_14.npz", 578 | "/walnut/walnut_97.npz", 579 | "/walnut/walnut_74.npz", 580 | "/walnut/walnut_84.npz", 581 | "/walnut/walnut_47.npz", 582 | "/walnut/walnut_15.npz", 583 | "/walnut/walnut_79.npz", 584 | "/walnut/walnut_68.npz", 585 | "/walnut/walnut_98.npz", 586 | "/walnut/walnut_91.npz", 587 | "/pine/pine_84.npz", 588 | "/pine/pine_7.npz", 589 | "/pine/pine_83.npz", 590 | "/pine/pine_43.npz", 591 | "/pine/pine_45.npz", 592 | "/pine/pine_35.npz", 593 | "/pine/pine_99.npz", 594 | "/pine/pine_9.npz", 595 | "/pine/pine_33.npz", 596 | "/pine/pine_10.npz", 597 | "/eucalyptus/eucalyptus_12.npz", 598 | "/eucalyptus/eucalyptus_14.npz", 599 | "/eucalyptus/eucalyptus_76.npz", 600 | "/eucalyptus/eucalyptus_33.npz", 601 | "/eucalyptus/eucalyptus_50.npz", 602 | "/eucalyptus/eucalyptus_47.npz", 603 | "/eucalyptus/eucalyptus_57.npz", 604 | "/eucalyptus/eucalyptus_21.npz", 605 | "/eucalyptus/eucalyptus_88.npz", 606 | "/eucalyptus/eucalyptus_52.npz" 607 | ] 608 | } --------------------------------------------------------------------------------