├── smart_tree
├── conf
│ ├── __init__.py
│ ├── pipeline.yaml
│ ├── training.yaml
│ └── training-split.json
├── model
│ ├── __init__.py
│ ├── weights
│ │ ├── __init__.py
│ │ ├── noble-elevator-58_model.pt
│ │ ├── peach-forest-65_model.pt
│ │ ├── peach-forest-65_model_weights.pt
│ │ └── noble-elevator-58_model_weights.pt
│ ├── render.py
│ ├── tracker.py
│ ├── helper.py
│ ├── model.py
│ ├── fp16.py
│ ├── sparse.py
│ ├── loss.py
│ ├── model_inference.py
│ ├── train.py
│ └── model_blocks.py
├── util
│ ├── __init__.py
│ ├── misc.py
│ ├── file.py
│ ├── queries.py
│ └── maths.py
├── data_types
│ ├── __init__.py
│ ├── graph.py
│ ├── branch.py
│ ├── tube.py
│ ├── tree.py
│ └── cloud.py
├── dataset
│ ├── __init__.py
│ ├── augmentations.py
│ └── dataset.py
├── scripts
│ ├── __init__.py
│ ├── clean_up.sh
│ ├── laz2ply.py
│ ├── bench_dataloader.py
│ ├── vis_dataloader.py
│ ├── split-data.py
│ └── view_npz.py
├── skeleton
│ ├── __init__.py
│ ├── filter.py
│ ├── connection.py
│ ├── shortest_path.py
│ ├── skeletonize.py
│ ├── graph.py
│ └── path.py
├── o3d_abstractions
│ ├── __init__.py
│ ├── visualizer.py
│ ├── camera.py
│ └── geometries.py
├── __init__.py
├── tests
│ ├── test-dataloader.yaml
│ └── dataloader.py
├── cli.py
└── pipeline.py
├── images
├── botanic-pcd.png
├── botanic-skeleton.png
└── botanic-branch-mesh.png
├── tests
└── type_checks.py
├── create-env.sh
├── pyproject.toml
├── LICENSE
├── .gitignore
└── README.md
/smart_tree/conf/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/data_types/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/skeleton/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/model/weights/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/o3d_abstractions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smart_tree/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
--------------------------------------------------------------------------------
/smart_tree/scripts/clean_up.sh:
--------------------------------------------------------------------------------
1 | isort .
2 | pycln .
--------------------------------------------------------------------------------
/images/botanic-pcd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/images/botanic-pcd.png
--------------------------------------------------------------------------------
/images/botanic-skeleton.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/images/botanic-skeleton.png
--------------------------------------------------------------------------------
/images/botanic-branch-mesh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/images/botanic-branch-mesh.png
--------------------------------------------------------------------------------
/smart_tree/model/weights/noble-elevator-58_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/smart_tree/model/weights/noble-elevator-58_model.pt
--------------------------------------------------------------------------------
/smart_tree/model/weights/peach-forest-65_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/smart_tree/model/weights/peach-forest-65_model.pt
--------------------------------------------------------------------------------
/smart_tree/model/weights/peach-forest-65_model_weights.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/smart_tree/model/weights/peach-forest-65_model_weights.pt
--------------------------------------------------------------------------------
/smart_tree/model/weights/noble-elevator-58_model_weights.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uc-vision/smart-tree/HEAD/smart_tree/model/weights/noble-elevator-58_model_weights.pt
--------------------------------------------------------------------------------
/smart_tree/skeleton/filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .graph import knn
4 |
5 |
6 | def outlier_removal(points, radii, nb_points=4):
7 | idxs, dists, _ = knn(points, points, K=nb_points, r=torch.max(radii).item())
8 |
9 | keep = (dists < radii) & (idxs != -1)
10 |
11 | return keep.sum(1) == nb_points
12 |
--------------------------------------------------------------------------------
/tests/type_checks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchtyping import patch_typeguard
3 |
4 | from smart_tree.data_types.branch import BranchSkeleton
5 |
6 | patch_typeguard() # use before @typechecked
7 |
8 |
9 | def test_branch_skeleton_type_error():
10 | xyz = torch.rand(100)
11 | radii = torch.rand(100)
12 |
13 | BranchSkeleton(0, 0, xyz, radii)
14 |
15 |
16 | if __name__ == "__main__":
17 | test_branch_skeleton_type_error()
18 |
--------------------------------------------------------------------------------
/smart_tree/scripts/laz2ply.py:
--------------------------------------------------------------------------------
1 | import laspy
2 | import numpy as np
3 | import open3d as o3d
4 |
5 |
6 | def las_to_ply(input_las_file, output_ply_file):
7 | with laspy.open(input_las_file) as fh:
8 | las = fh.read()
9 | xyz = np.column_stack((las.x, las.y, las.z))
10 |
11 | cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz))
12 | o3d.io.write_point_cloud(output_ply_file, cloud)
13 |
14 |
15 | if __name__ == "__main__":
16 | input_las_file = "/csse/users/hdo27/Desktop/ChCh_Hovermap_tree.laz"
17 | output_ply_file = "output.ply"
18 |
19 | las_to_ply(input_las_file, output_ply_file)
20 |
--------------------------------------------------------------------------------
/smart_tree/tests/test-dataloader.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | _target_: smart_tree.dataset.dataset.TreeDataset
3 | voxel_size: 0.01
4 | json_path: /local/uc-vision/smart-tree/smart_tree/conf/tree-split.json
5 | directory: /local/uc-vision/dataset/branches/
6 | blocking: True
7 | transform: True
8 | block_size: 4
9 | buffer_size: 0.4
10 |
11 | data_loader:
12 | _target_: torch.utils.data.DataLoader
13 | batch_size: 1
14 | drop_last: True
15 | pin_memory: True
16 | num_workers: 0
17 | shuffle: True
18 | #prefetch_factor: None
19 | collate_fn:
20 | _target_: smart_tree.model.sparse.batch_collate
21 | _partial_: True
--------------------------------------------------------------------------------
/smart_tree/tests/dataloader.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import hydra
4 | from hydra.utils import instantiate
5 | from omegaconf import DictConfig
6 | from tqdm import tqdm
7 |
8 |
9 |
10 | @hydra.main(
11 | version_base=None,
12 | config_path=".",
13 | config_name="test-dataloader",
14 | )
15 | def main(cfg: DictConfig):
16 | train_dataloader = instantiate(
17 | cfg.data_loader, dataset=instantiate(cfg.dataset, mode="train")
18 | )
19 |
20 | start_time = time.time()
21 | for data in tqdm(train_dataloader):
22 | pass
23 |
24 | print(time.time() - start_time)
25 |
26 |
27 | if __name__ == "__main__":
28 | main()
29 |
--------------------------------------------------------------------------------
/create-env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | eval "$(conda shell.bash hook)"
3 |
4 | conda create -n smart-tree python=3.10 &&
5 | conda run -n smart-tree conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit &&
6 | conda run -n smart-tree conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia &&
7 | conda run -n smart-tree pip install -e . &&
8 |
9 | echo Installing FRNN
10 | git clone --recursive https://github.com/lxxue/FRNN.git
11 | conda run -n smart-tree pip install FRNN/external/prefix_sum/. &&
12 | conda run -n smart-tree pip install -e FRNN/. &&
13 |
14 | conda run -n smart-tree conda install -c rapidsai -c conda-forge -c nvidia cudf=23.02 cugraph=23.02 python=3.10 cuda-version=11.8 --solver=libmamba
15 |
--------------------------------------------------------------------------------
/smart_tree/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import hydra
5 | from hydra.utils import instantiate
6 | from omegaconf import DictConfig
7 |
8 |
9 |
10 | @hydra.main(
11 | version_base=None,
12 | config_path="conf",
13 | config_name="pipeline",
14 | )
15 | def main(cfg: DictConfig):
16 | pipeline = instantiate(cfg.pipeline)
17 |
18 | if "path" in dict(cfg):
19 | pipeline.process_cloud(Path(cfg.path))
20 |
21 | elif "directory" in dict(cfg):
22 | for p in os.listdir(cfg.directory):
23 | pipeline.process_cloud(Path(f"{cfg.directory}/{p}"))
24 |
25 | else:
26 | print("Please supply a path or directory to point clouds.")
27 |
28 |
29 | if __name__ == "__main__":
30 | main()
31 |
--------------------------------------------------------------------------------
/smart_tree/model/render.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 |
4 | import wandb
5 |
6 |
7 | def render_cloud(
8 | renderer,
9 | labelled_cloud,
10 | camera_position=[1, 0, 0],
11 | camera_up=[0, 1, 0],
12 | ):
13 | segmented_img = renderer.capture(
14 | [labelled_cloud.to_o3d_seg_cld()],
15 | camera_position,
16 | camera_up,
17 | )
18 |
19 | cld_img = renderer.capture(
20 | [labelled_cloud.to_o3d_cld()],
21 | camera_position,
22 | camera_up,
23 | )
24 |
25 | projected_img = renderer.capture(
26 | [labelled_cloud.to_o3d_medial_vectors()],
27 | camera_position,
28 | camera_up,
29 | )
30 |
31 | return [
32 | np.asarray(cld_img),
33 | np.asarray(segmented_img),
34 | np.asarray(projected_img),
35 | ]
36 |
37 |
38 | def log_images(wandb_run, name, images, epoch):
39 | wandb_run.log({f"{name}": [wandb.Image(img) for img in images]})
40 |
--------------------------------------------------------------------------------
/smart_tree/scripts/bench_dataloader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import hydra
5 | import torch
6 | from hydra.utils import instantiate
7 | from omegaconf import DictConfig
8 | from tqdm import tqdm
9 |
10 | from smart_tree.model.helper import get_batch
11 |
12 |
13 | @hydra.main(
14 | version_base=None,
15 | config_path="../conf",
16 | config_name="vine-dataset",
17 | )
18 | def main(cfg: DictConfig):
19 | torch.manual_seed(42)
20 | torch.cuda.manual_seed_all(42)
21 | log = logging.getLogger(__name__)
22 |
23 | cfg = cfg.training
24 | torch.multiprocessing.set_start_method("spawn")
25 |
26 | train_loader = instantiate(cfg.train_data_loader)
27 | log.info(f"Train Dataset Size: {len(train_loader.dataset)}")
28 |
29 | while True:
30 | start = time.time()
31 |
32 | batches = get_batch(train_loader, device="cpu")
33 | for sparse_input, targets, mask, filenames in tqdm(batches):
34 | pass
35 |
36 | print(time.time() - start)
37 |
38 |
39 | if __name__ == "__main__":
40 | main()
41 |
--------------------------------------------------------------------------------
/smart_tree/o3d_abstractions/visualizer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, dataclass
2 | from typing import List, Sequence, Union
3 |
4 | import open3d as o3d
5 |
6 |
7 |
8 | @dataclass
9 | class ViewerItem:
10 | name: str
11 | geometry: o3d.geometry.Geometry
12 | is_visible: bool = True
13 |
14 |
15 | def o3d_viewer(
16 | items: Union[Sequence[ViewerItem], List[o3d.geometry.Geometry]], line_width=1
17 | ):
18 | mat = o3d.visualization.rendering.MaterialRecord()
19 | mat.shader = "defaultLit"
20 |
21 | line_mat = o3d.visualization.rendering.MaterialRecord()
22 | line_mat.shader = "unlitLine"
23 | line_mat.line_width = line_width
24 |
25 | if isinstance(items[0], o3d.geometry.Geometry):
26 | items = [ViewerItem(f"{i}", item) for i, item in enumerate(items)]
27 |
28 | def material(item):
29 | return line_mat if isinstance(item.geometry, o3d.geometry.LineSet) else mat
30 |
31 | geometries = [dict(**asdict(item), material=material(item)) for item in items]
32 |
33 | o3d.visualization.draw(geometries, line_width=line_width)
34 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "smart-tree"
7 | version = "1.0.0"
8 | authors = [
9 | {name = "Harry Dobbs", email = "harrydobbs87@gmail.com"},
10 | ]
11 | description = "Neural Network Point Cloud Tree Skeletonization"
12 | readme = "README.md"
13 | requires-python = ">=3.10"
14 | license = {text = "MIT"}
15 | dependencies = [
16 | 'numpy',
17 | 'open3d',
18 | 'hydra-core>=1.2.0',
19 | #'click',
20 | #'oauthlib',
21 | 'spconv-cu117',
22 | 'wandb',
23 | 'cmapy',
24 | #'plyfile',
25 | #'torch',
26 | 'py_structs',
27 | 'torchtyping',
28 | 'beartype',
29 | 'typeguard==2.11.1'
30 | ]
31 |
32 | [tool.setuptools.packages]
33 | find = {} # Scan the project directory with the default parameters
34 |
35 | [project.scripts]
36 | run-smart-tree = "smart_tree.cli:main"
37 | train-smart-tree = "smart_tree.model.train:main"
38 | view-npz = "smart_tree.scripts.view_npz:main"
39 | view-pcd = "smart_tree.scripts.view_pcd:main"
40 |
41 | #[project.scripts]
42 |
43 |
44 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 UC Vision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/smart_tree/conf/pipeline.yaml:
--------------------------------------------------------------------------------
1 | pipeline:
2 | _target_: smart_tree.pipeline.Pipeline
3 |
4 | preprocessing:
5 | _target_: smart_tree.dataset.augmentations.AugmentationPipeline
6 | augmentations:
7 | - _target_: smart_tree.dataset.augmentations.CentreCloud
8 |
9 | model_inference:
10 | _target_: smart_tree.model.model_inference.ModelInference
11 | model_path: smart_tree/model/weights/noble-elevator-58_model.pt
12 | weights_path: smart_tree/model/weights/noble-elevator-58_model_weights.pt
13 | voxel_size: 0.01
14 | block_size: 4
15 | buffer_size: 0.4
16 | num_workers : 8
17 | batch_size : 4
18 |
19 | skeletonizer:
20 | _target_: smart_tree.skeleton.skeletonize.Skeletonizer
21 | K: 16
22 | min_connection_length: 0.02
23 | minimum_graph_vertices: 32
24 |
25 | view_model_output : False
26 | view_skeletons : True
27 | save_path: /
28 | save_outputs : False
29 |
30 | branch_classes: [0]
31 | cmap:
32 | - [0.450, 0.325, 0.164] # Trunk
33 | - [0.541, 0.670, 0.164] # Foliage
34 |
35 | repair_skeletons : True
36 | smooth_skeletons : True
37 | smooth_kernel_size: 11 # Needs to be odd
38 | prune_skeletons : True
39 | min_skeleton_radius : 0.01
40 | min_skeleton_length : 0.02
--------------------------------------------------------------------------------
/smart_tree/model/tracker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import wandb
4 |
5 |
6 | class Tracker:
7 | def __init__(self):
8 | self.running_epoch_radius_loss = []
9 | self.running_epoch_direction_loss = []
10 | self.running_epoch_class_loss = []
11 |
12 | def update(self, loss_dict: dict):
13 | self.running_epoch_radius_loss.append(loss_dict["radius"].item())
14 | self.running_epoch_direction_loss.append(loss_dict["direction"].item())
15 | self.running_epoch_class_loss.append(loss_dict["class_l"].item())
16 |
17 | @property
18 | def radius_loss(self):
19 | return np.mean(self.running_epoch_radius_loss)
20 |
21 | @property
22 | def direction_loss(self):
23 | return np.mean(self.running_epoch_direction_loss)
24 |
25 | @property
26 | def class_loss(self):
27 | return np.mean(self.running_epoch_class_loss)
28 |
29 | @property
30 | def total_loss(self):
31 | return self.radius_loss + self.direction_loss + self.class_loss
32 |
33 | def log(self, name, epoch):
34 | wandb.log(
35 | {
36 | f"{name} Total Loss": self.total_loss,
37 | f"{name} Radius Loss": self.radius_loss,
38 | f"{name} Direction Loss": self.direction_loss,
39 | f"{name} Class Loss": self.class_loss,
40 | },
41 | epoch,
42 | )
43 |
--------------------------------------------------------------------------------
/smart_tree/skeleton/connection.py:
--------------------------------------------------------------------------------
1 | from smart_tree.data_types.tree import (DisjointTreeSkeleton, TreeSkeleton,
2 | connect_skeletons)
3 | from smart_tree.o3d_abstractions.visualizer import o3d_viewer
4 |
5 | if __name__ == "__main__":
6 | disjoint_skeleton = DisjointTreeSkeleton.from_pickle(
7 | "/local/smart-tree/test_data/disjoint_skeleton.pkl"
8 | )
9 |
10 | # disjoint_skeleton.prune(min_radius=0.01, min_length=0.08).smooth(kernel_size=11)
11 | # Sort skeletons by total length
12 | skeletons_sorted = sorted(
13 | disjoint_skeleton.skeletons,
14 | key=lambda x: x.length,
15 | reverse=True,
16 | )
17 |
18 | # skeletons_sorted[0].view()
19 |
20 | skel = connect_skeletons(skeletons_sorted[0], 0, 0, skeletons_sorted[1], 0, 0)
21 |
22 | skel.view()
23 |
24 | quit()
25 |
26 | final_skeleton = TreeSkeleton(0, skeletons_sorted[0].branches)
27 |
28 | for skeleton in skeletons_sorted[1:]:
29 | branch = skeleton.branches[skeleton.key_branch_with_biggest_radius]
30 |
31 | # get the point that has the biggest radius ....
32 | # get the closest point on the skeleton to that point ...
33 | # connect the two points ...
34 |
35 | print(skeleton.length)
36 |
37 | o3d_viewer(
38 | [final_skeleton.to_o3d_tube(), branch.to_o3d_tube(), skeleton.to_o3d_tube()]
39 | )
40 |
--------------------------------------------------------------------------------
/smart_tree/scripts/vis_dataloader.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import hydra
4 | import numpy as np
5 | import open3d as o3d
6 | import torch
7 | from hydra.utils import instantiate
8 | from omegaconf import DictConfig
9 | from open3d_vis import render
10 | from tqdm import tqdm
11 |
12 | from smart_tree.model.helper import get_batch
13 |
14 |
15 | @hydra.main(
16 | version_base=None,
17 | config_path="../conf",
18 | config_name="vine-dataset",
19 | )
20 | def main(cfg: DictConfig):
21 | torch.manual_seed(42)
22 | torch.cuda.manual_seed_all(42)
23 | log = logging.getLogger(__name__)
24 |
25 | cfg = cfg.training
26 | torch.multiprocessing.set_start_method("spawn")
27 |
28 | train_loader = instantiate(cfg.train_data_loader)
29 | log.info(f"Train Dataset Size: {len(train_loader.dataset)}")
30 |
31 | batches = get_batch(train_loader, device="cpu")
32 | cmap = torch.from_numpy(np.array(cfg.cmap))
33 |
34 | for sparse_input, targets, mask, filenames in tqdm(batches):
35 | cloud_ids = sparse_input.indices[:, 0]
36 | coords = sparse_input.indices[:, 1:4]
37 |
38 | class_l = targets[:, -1].to(dtype=torch.long)
39 |
40 | for i, filename in enumerate(filenames):
41 | print("Filename:", filename)
42 | mask = cloud_ids == i
43 | labels = class_l[mask]
44 |
45 | xyz = coords[mask]
46 | class_colors = cmap[labels]
47 |
48 | boxes = render.boxes(xyz, xyz + 1, class_colors)
49 | o3d.visualization.draw([boxes])
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/smart_tree/data_types/graph.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 | import cugraph
5 | import cupy
6 | import open3d as o3d
7 | import torch
8 | from cudf import DataFrame
9 | from torchtyping import TensorType
10 | from tqdm import tqdm
11 |
12 | from ..o3d_abstractions.geometries import o3d_line_set
13 |
14 |
15 | @dataclass
16 | class Graph:
17 | vertices: TensorType["N", 3]
18 | edges: TensorType["N", 2]
19 | edge_weights: TensorType["N", 1]
20 |
21 | def to_o3d_lineset(self, colour=(1, 0, 0)) -> o3d.geometry.LineSet:
22 | graph_cpu = self.to_device(torch.device("cpu"))
23 | return o3d_line_set(graph_cpu.vertices, graph_cpu.edges, colour=colour)
24 |
25 | def to_device(self, device: torch.device):
26 | return Graph(
27 | self.vertices.to(device),
28 | self.edges.to(device),
29 | self.edge_weights.to(device),
30 | )
31 |
32 | def connected_cugraph_components(self, minimum_vertices=10) -> List[cugraph.Graph]:
33 |
34 | g = cuda_graph(self.edges, self.edge_weights)
35 |
36 | df = cugraph.connected_components(g)
37 |
38 | components = []
39 | for label in tqdm(
40 | df["labels"].unique().to_pandas(),
41 | desc="Finding Connected Components",
42 | leave=False,
43 | ):
44 | subgraph_vertices = df[df["labels"] == label]["vertex"]
45 |
46 | if subgraph_vertices.count() < minimum_vertices:
47 | continue
48 |
49 | components.append(cugraph.subgraph(g, subgraph_vertices))
50 |
51 | return sorted(components, key=lambda graph: len(graph.nodes()), reverse=True)
52 |
53 |
54 | def cuda_graph(edges, edge_weights, renumber=False):
55 |
56 | edges = cupy.asarray(edges)
57 | edge_weights = cupy.asarray(edge_weights)
58 |
59 | d = DataFrame()
60 | d["source"] = edges[:, 0]
61 | d["destination"] = edges[:, 1]
62 | d["weights"] = edge_weights
63 | g = cugraph.Graph(directed=False)
64 | g.from_cudf_edgelist(d, edge_attr="weights", renumber=renumber)
65 |
66 | return g
67 |
--------------------------------------------------------------------------------
/smart_tree/data_types/branch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Optional
5 |
6 | import numpy as np
7 | import open3d as o3d
8 | import torch
9 | from torchtyping import TensorType
10 | from typeguard import typechecked
11 |
12 | from smart_tree.o3d_abstractions.geometries import o3d_path, o3d_tube_mesh
13 |
14 | from .tube import Tube
15 |
16 |
17 | @typechecked
18 | @dataclass
19 | class BranchSkeleton:
20 | _id: int
21 | parent_id: int
22 | xyz: TensorType["N", 3]
23 | radii: TensorType["N", 1]
24 | child_id: Optional[int] = None
25 |
26 | def __post_init__(self) -> None:
27 | self.colour = np.random.rand(3)
28 |
29 | def __len__(self) -> np.array:
30 | return self.xyz.shape[0]
31 |
32 | def __str__(self):
33 | return f" ID: {self._id} \
34 | Points: {self.xyz} \
35 | Radii {self.radii}"
36 |
37 | def to_o3d_lineset(self, colour=(0, 0, 0)) -> o3d.geometry.LineSet:
38 | return o3d_path(self.xyz, colour)
39 |
40 | def to_o3d_tube(self) -> o3d.geometry.TriangleMesh:
41 | return o3d_tube_mesh(self.xyz.numpy(), self.radii.numpy(), self.colour)
42 |
43 | def to_tubes(self, colour=(1, 0, 0)) -> List[Tube]:
44 | a_, b_, r1_, r2_ = (
45 | self.xyz[:-1],
46 | self.xyz[1:],
47 | self.radii[:-1],
48 | self.radii[1:],
49 | )
50 | return [Tube(a, b, r1, r2) for a, b, r1, r2 in zip(a_, b_, r1_, r2_)]
51 |
52 | def filter(self, mask) -> BranchSkeleton:
53 | return BranchSkeleton(
54 | self._id,
55 | self.parent_id,
56 | self.xyz[mask],
57 | self.radii[mask],
58 | self.child_id,
59 | )
60 |
61 | @property
62 | def length(self) -> TensorType[1]:
63 | return (self.xyz[1:] - self.xyz[:-1]).norm(dim=1).sum()
64 |
65 | @property
66 | def initial_radius(self) -> TensorType[1]:
67 | return torch.max(self.radii[0], self.radii[-1])
68 |
69 | @property
70 | def biggest_radius_idx(self) -> TensorType[1]:
71 | return torch.argmax(self.radii)
72 |
73 | @property
74 | def biggest_radius(self) -> TensorType[1]:
75 | return torch.max(self.radii)
76 |
--------------------------------------------------------------------------------
/smart_tree/skeleton/shortest_path.py:
--------------------------------------------------------------------------------
1 | import cugraph
2 | import cupy
3 | import numpy as np
4 | import torch
5 | from cudf import DataFrame
6 |
7 |
8 | def cudf_edgelist_to_numpy(edge_list):
9 | return np.vstack((edge_list["src"].to_numpy(), edge_list["dst"].to_numpy())).T
10 |
11 |
12 | def shortest_paths(root, edges, edge_weights, renumber=True):
13 | device = edges.device
14 | g = edge_graph(edges, edge_weights, renumber=renumber)
15 | r = cugraph.sssp(g, source=root)
16 |
17 | return (
18 | torch.as_tensor(r["vertex"], device=device).long(),
19 | torch.as_tensor(r["predecessor"], device=device).long(),
20 | torch.as_tensor(r["distance"], device=device),
21 | )
22 |
23 |
24 | def graph_shortest_paths(root, graph, device):
25 | # device = edges.device
26 |
27 | r = cugraph.sssp(graph, source=root)
28 | return (
29 | torch.as_tensor(r["predecessor"], device=device).long(),
30 | torch.as_tensor(r["distance"], device=device),
31 | )
32 |
33 |
34 | def pred_graph(preds, points):
35 | n = preds.shape[0]
36 | valid = preds >= 0
37 |
38 | dists = torch.norm(points - points[torch.clamp(preds, 0)], dim=1)
39 | dists[~valid] = 0.0
40 |
41 | edges = torch.stack([torch.arange(0, n, device=preds.device), preds], dim=1)
42 | edges[~valid, 1] = edges[~valid, 0]
43 | return edge_graph(edges, dists)
44 |
45 |
46 | def pred_graph(verts, preds, points):
47 | n = preds.shape[0]
48 | valid = preds >= 0
49 |
50 | dists = torch.norm(points[verts] - points[torch.clamp(preds, 0)], dim=1)
51 | dists[~valid] = 0.0
52 |
53 | edges = torch.stack([torch.arange(0, n, device=preds.device), preds], dim=1)
54 | edges[~valid, 1] = edges[~valid, 0]
55 | return edge_graph(edges, dists)
56 |
57 |
58 | def euclidean_distances(root, points, preds):
59 | g = pred_graph(preds, points)
60 | r = cugraph.sssp(g, source=root)
61 | return torch.as_tensor(r["distance"], device=points.device)
62 |
63 |
64 | def edge_graph(edges, edge_weights, renumber=False):
65 | d = DataFrame()
66 | edges = cupy.asarray(edges)
67 |
68 | d["source"] = edges[:, 0]
69 | d["destination"] = edges[:, 1]
70 | d["weights"] = cupy.asarray(edge_weights)
71 |
72 | g = cugraph.Graph(directed=False)
73 | g.from_cudf_edgelist(d, edge_attr="weights", renumber=renumber)
74 | return g
75 |
--------------------------------------------------------------------------------
/smart_tree/model/helper.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from pathlib import Path
3 | from typing import List
4 |
5 | import torch
6 | from py_structs.torch import map_tensors
7 |
8 | from smart_tree.data_types.cloud import Cloud
9 |
10 | from .sparse import sparse_from_batch
11 |
12 |
13 | def get_batch(dataloader, device, fp_16=False):
14 | for (feats, target_feats), coords, mask, filenames in dataloader:
15 | if fp_16:
16 | feats = feats.half()
17 | target_feats = target_feats.half()
18 | coords = coords.half()
19 |
20 | sparse_input = sparse_from_batch(
21 | feats,
22 | coords,
23 | device=device,
24 | )
25 | targets = map_tensors(
26 | target_feats,
27 | partial(
28 | torch.Tensor.to,
29 | device=device,
30 | ),
31 | )
32 |
33 | yield sparse_input, targets, mask, filenames
34 |
35 |
36 | def model_output_to_labelled_clds(
37 | sparse_input,
38 | model_output,
39 | cmap,
40 | filenames,
41 | ) -> List[Cloud]:
42 | return to_labelled_clds(
43 | sparse_input.indices[:, 0],
44 | sparse_input.features[:, :3],
45 | sparse_input.features[:, 3:6],
46 | model_output,
47 | cmap,
48 | filenames,
49 | )
50 |
51 |
52 | def split_outputs(features, mask):
53 | radii = torch.exp(features["radius"][mask])
54 | direction = features["direction"][mask]
55 | class_l = torch.argmax(features["class_l"], dim=1)[mask]
56 |
57 | return radii, direction, class_l
58 |
59 |
60 | def to_labelled_clds(
61 | cloud_ids,
62 | coords,
63 | rgb,
64 | model_output,
65 | cmap,
66 | filenames,
67 | ) -> List[Cloud]:
68 | num_clouds = cloud_ids.max() + 1
69 | clouds = []
70 |
71 | # assert rgb.shape[1] > 0
72 |
73 | for i in range(num_clouds):
74 | mask = cloud_ids == i
75 | xyz = coords[mask]
76 | rgb = torch.rand(xyz.shape) # rgb[mask]
77 |
78 | radii, direction, class_l = split_outputs(model_output, mask)
79 |
80 | labelled_cloud = Cloud(
81 | xyz=xyz,
82 | rgb=rgb,
83 | medial_vector=radii * direction,
84 | class_l=class_l,
85 | filename=Path(filenames[i]),
86 | )
87 |
88 | clouds.append(labelled_cloud.to_device(torch.device("cpu")))
89 |
90 | return clouds
91 |
--------------------------------------------------------------------------------
/smart_tree/model/model.py:
--------------------------------------------------------------------------------
1 |
2 | import spconv.pytorch as spconv
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from smart_tree.model.model_blocks import MLP, SubMConvBlock, UBlock
7 |
8 |
9 |
10 | class Smart_Tree(nn.Module):
11 | def __init__(
12 | self,
13 | input_channels,
14 | unet_planes,
15 | radius_fc_planes,
16 | direction_fc_planes,
17 | class_fc_planes,
18 | bias=False,
19 | algo=spconv.ConvAlgo.Native,
20 | ):
21 | super().__init__()
22 |
23 | norm_fn = nn.BatchNorm1d
24 | # functools.partial(
25 | # nn.BatchNorm1d,
26 | # eps=1e-4,
27 | # momentum=0.1,
28 | # )
29 | activation_fn = nn.ReLU
30 |
31 | self.input_conv = SubMConvBlock(
32 | input_channels=input_channels,
33 | output_channels=unet_planes[0],
34 | kernel_size=1,
35 | padding=1,
36 | norm_fn=norm_fn,
37 | activation_fn=activation_fn,
38 | )
39 |
40 | self.UNet = UBlock(
41 | unet_planes,
42 | norm_fn,
43 | activation_fn,
44 | key_id=1,
45 | algo=algo,
46 | )
47 |
48 | # Three Heads...
49 | self.radius_head = MLP(
50 | radius_fc_planes,
51 | norm_fn,
52 | activation_fn,
53 | bias=True,
54 | )
55 | self.direction_head = MLP(
56 | direction_fc_planes,
57 | norm_fn,
58 | activation_fn,
59 | bias=True,
60 | )
61 | self.class_head = MLP(
62 | class_fc_planes,
63 | norm_fn,
64 | activation_fn,
65 | bias=True,
66 | )
67 |
68 | self.apply(self.set_bn_init)
69 |
70 | @staticmethod
71 | def set_bn_init(m):
72 | classname = m.__class__.__name__
73 | if classname.find("BatchNorm") != -1:
74 | m.weight.data.fill_(1.0)
75 | m.bias.data.fill_(0.0)
76 |
77 | def forward(self, input):
78 | predictions = {}
79 |
80 | x = self.input_conv(input)
81 | unet_out = self.UNet(x)
82 |
83 | predictions["radius"] = self.radius_head(unet_out).features
84 | predictions["direction"] = F.normalize(self.direction_head(unet_out).features)
85 | predictions["class_l"] = self.class_head(unet_out).features
86 |
87 | return predictions
88 |
--------------------------------------------------------------------------------
/smart_tree/data_types/tube.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 | import numpy as np
5 | import torch
6 | from torchtyping import TensorType
7 |
8 |
9 | @dataclass
10 | class Tube:
11 | a: TensorType["N", 3] # Start Point 3
12 | b: TensorType["N", 3] # End Point 3
13 | r1: float # Start Radius
14 | r2: float # End Radius
15 |
16 | def to_torch(self, device=torch.device("cuda:0")):
17 | self.a = torch.from_numpy(self.a).float().to(device)
18 | self.b = torch.from_numpy(self.b).float().to(device)
19 | self.r1 = torch.from_numpy(self.r1).float().to(device)
20 | self.r2 = torch.from_numpy(self.r2).float().to(device)
21 |
22 | def to_numpy(self):
23 | self.a = self.a.cpu().detach().numpy()
24 | self.b = self.b.cpu().detach().numpy()
25 | self.r1 = self.r1.cpu().detach().numpy()
26 | self.r2 = self.r2.cpu().detach().numpy()
27 |
28 |
29 | @dataclass
30 | class CollatedTube:
31 | a: TensorType["N", 3] # Nx3
32 | b: TensorType["N", 3] # Nx3
33 | r1: TensorType["N", 1] # N
34 | r2: TensorType["N", 1] # N
35 |
36 | def to_gpu(self, device=torch.device("cuda")):
37 | self.a = self.a.to(device)
38 | self.b = self.b.to(device)
39 | self.r1 = self.r1.to(device)
40 | self.r2 = self.r2.to(device)
41 |
42 |
43 | def collate_tubes(tubes: List[Tube]) -> CollatedTube:
44 | a = torch.cat([tube.a for tube in tubes]).reshape(-1, 3)
45 | b = torch.cat([tube.b for tube in tubes]).reshape(-1, 3)
46 |
47 | r1 = torch.cat([tube.r1 for tube in tubes]).reshape(1, -1)
48 | r2 = torch.cat([tube.r2 for tube in tubes]).reshape(1, -1)
49 |
50 | return CollatedTube(a, b, r1, r2)
51 |
52 |
53 | def sample_tubes(tubes: List[Tube], spacing):
54 | pts, radius = [], []
55 |
56 | for i, tube in enumerate(tubes):
57 | v = tube.b - tube.a
58 | length = np.linalg.norm(v)
59 |
60 | direction = v / length
61 | num_points = np.ceil(length / spacing)
62 |
63 | if int(num_points) > 0.0:
64 | spaced_points = np.arange(
65 | 0, float(length), step=float(length / num_points)
66 | ).reshape(-1, 1)
67 | lin_radius = np.linspace(
68 | tube.r1, tube.r2, spaced_points.shape[0], dtype=float
69 | )
70 |
71 | pts.append(tube.a + direction * spaced_points)
72 | radius.append(lin_radius)
73 |
74 | return np.concatenate(pts, axis=0), np.concatenate(radius, axis=0)
75 |
--------------------------------------------------------------------------------
/smart_tree/scripts/split-data.py:
--------------------------------------------------------------------------------
1 | """ Script to split dataset into train and test sets
2 |
3 | usage:
4 |
5 | python smart_tree/scripts/split-data.py --read_directory=/speed-tree/speed-tree-outputs/processed_vines/ --json_save_path=/smart-tree/smart_tree/conf/vine-split.json --sample_type=random
6 |
7 |
8 | """
9 |
10 | import json
11 | import os
12 | import random
13 | from pathlib import Path
14 |
15 | import click
16 |
17 |
18 | def flatten_list(l):
19 | return [item for sublist in l for item in sublist]
20 |
21 |
22 | def random_sample(read_dir, json_save_path):
23 | items = [str(path.name) for path in Path(read_dir).rglob("*.npz")]
24 | random.shuffle(items)
25 |
26 | data = {}
27 |
28 | data["train"] = sorted(items[: int(0.8 * len(items))])
29 | data["test"] = sorted(items[int(0.8 * len(items)) : int(0.9 * len(items))])
30 | data["validation"] = sorted(items[int(0.9 * len(items)) :])
31 |
32 | with open(json_save_path, "w") as outfile:
33 | json.dump(data, outfile, indent=4, sort_keys=False)
34 |
35 |
36 | def strattified_sample(read_dir, json_save_path):
37 | dirs = os.listdir(read_dir)
38 |
39 | train_paths = []
40 | test_paths = []
41 | val_paths = []
42 |
43 | for directory in dirs:
44 | items = [
45 | str(path.resolve())
46 | for path in Path(f"{read_dir}/{directory}").rglob("*.npz")
47 | ]
48 | random.shuffle(items)
49 |
50 | train_paths.append(items[: int(0.8 * len(items))])
51 | test_paths.append(
52 | items[int(0.8 * len(items)) : int(0.8 * len(items) + int(0.1 * len(items)))]
53 | )
54 | val_paths.append(items[int(0.8 * len(items)) + int(0.1 * len(items)) :])
55 |
56 | data = {}
57 |
58 | data["train"] = sorted(flatten_list(train_paths))
59 | data["test"] = sorted(flatten_list(test_paths))
60 | data["validation"] = sorted(flatten_list(val_paths))
61 |
62 | with open(json_save_path, "w") as outfile:
63 | json.dump(data, outfile, indent=4, sort_keys=False)
64 |
65 |
66 | @click.command()
67 | @click.option(
68 | "--read_directory", type=click.Path(exists=True), prompt="read directory?"
69 | )
70 | @click.option("--json_save_path", prompt="json path?")
71 | @click.option("--sample_type", type=str, default=False, required=False)
72 | def main(read_directory, json_save_path, sample_type):
73 | if sample_type == "random":
74 | random_sample(read_directory, json_save_path)
75 |
76 | if sample_type == "strattified":
77 | strattified_sample(read_directory, json_save_path)
78 |
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/smart_tree/util/misc.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Union
3 |
4 | import cmapy
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def flatten_list(l):
10 | return [item for sublist in l for item in sublist]
11 |
12 |
13 | def at_least_2d(tensors: Union[List[torch.tensor], torch.tensor], expand_dim=1):
14 | if type(tensors) is list:
15 | return [at_least_2d(tensor) for tensor in tensors]
16 | else:
17 | if len(tensors.shape) == 1:
18 | return tensors.unsqueeze(expand_dim)
19 | else:
20 | return tensors
21 |
22 |
23 | def to_torch(numpy_arrays: List[np.array], device=torch.device("cpu")):
24 | return [torch.from_numpy(np_arr).float().to(device) for np_arr in numpy_arrays]
25 |
26 |
27 | def to_numpy(_torch: Union[List[torch.tensor], torch.tensor]):
28 | if type(_torch) is list:
29 | return [torch_arr.cpu().detach().numpy() for torch_arr in _torch]
30 | else:
31 | return _torch.cpu().detach().numpy()
32 |
33 |
34 | def concate_dict_of_tensors(tensors: dict, device=torch.device("cpu")):
35 | for key, values in tensors.items():
36 | tensors[key] = torch.concatenate(values).to(device)
37 | return tensors
38 |
39 |
40 | def unique_n_colours(num_colours, cmap="hsv"):
41 | return (
42 | np.asarray(
43 | [cmapy.color(cmap, i) for i in range(0, 255, math.ceil(255 / num_colours))]
44 | ).reshape(-1, 3)
45 | / 255
46 | )
47 |
48 |
49 | def unique_n_random_colours(num_colours):
50 | return np.asarray([np.random.rand(3) for i in range(num_colours)]).reshape(-1, 3)
51 |
52 |
53 | def points_to_edges(points):
54 | points = points.reshape(-1, 3)
55 | parents = torch.arange(points.shape[0] - 1)
56 | children = torch.arange(1, points.shape[0])
57 |
58 | return torch.column_stack((parents, children))
59 |
60 |
61 | def voxel_downsample(xyz, voxel_size):
62 | xyz_quantized = (
63 | xyz // voxel_size
64 | ) # torch.div(xyz + (voxel_size / 2), voxel_size, rounding_mode="floor")
65 |
66 | unique, idx, counts = torch.unique(
67 | xyz_quantized,
68 | dim=0,
69 | sorted=True,
70 | return_counts=True,
71 | return_inverse=True,
72 | )
73 |
74 | _, ind_sorted = torch.sort(idx, stable=True)
75 | cum_sum = counts.cumsum(0)
76 | cum_sum = torch.cat((torch.tensor([0], device=cum_sum.device), cum_sum[:-1]))
77 | first_indicies = ind_sorted[cum_sum[1:]]
78 |
79 | return first_indicies
80 |
81 |
82 | def merge_dictionaries(dict1, dict2):
83 | merged_dict = {}
84 |
85 | for key, value in dict1.items():
86 | merged_dict[key] = value
87 |
88 | i = 1
89 | for key, value in dict2.items():
90 | new_key = key
91 | while new_key in merged_dict:
92 | new_key = i
93 | i += 1
94 | merged_dict[new_key] = value
95 |
96 | return merged_dict
97 |
--------------------------------------------------------------------------------
/smart_tree/scripts/view_npz.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from typing import List, Tuple
4 |
5 |
6 | from smart_tree.data_types.cloud import Cloud
7 | from smart_tree.data_types.tree import TreeSkeleton
8 | from smart_tree.o3d_abstractions.visualizer import ViewerItem, o3d_viewer
9 | from smart_tree.util.file import load_data_npz
10 |
11 |
12 | def view_synthetic_data(data: List[Tuple[Cloud, TreeSkeleton]], line_width=1):
13 | geometries = []
14 | for i, item in enumerate(data):
15 | (cloud, skeleton), path = item
16 |
17 | tree_name = path.stem
18 | visible = i == 0
19 |
20 | geometries = [
21 | ViewerItem(
22 | f"{tree_name}_cloud",
23 | cloud.to_o3d_cld(),
24 | is_visible=visible,
25 | ),
26 | ViewerItem(
27 | f"{tree_name}_labelled_cloud",
28 | cloud.to_o3d_seg_cld(),
29 | is_visible=visible,
30 | ),
31 | ViewerItem(
32 | f"{tree_name}_medial_vectors",
33 | cloud.to_o3d_medial_vectors(),
34 | is_visible=visible,
35 | ),
36 | # ViewerItem(
37 | # f"{tree_name}_skeleton",
38 | # skeleton.to_o3d_lineset(),
39 | # is_visible=visible,
40 | # ),
41 | # ViewerItem(
42 | # f"{tree_name}_skeleton_mesh",
43 | # skeleton.to_o3d_tubes(),
44 | # is_visible=visible,
45 | # ),
46 | ]
47 |
48 | o3d_viewer(geometries, line_width=line_width)
49 |
50 |
51 | def parse_args():
52 | parser = argparse.ArgumentParser(description="Visualizer Arguments")
53 |
54 | parser.add_argument(
55 | "file_path",
56 | help="File Path of tree.npz",
57 | default=None,
58 | type=Path,
59 | )
60 |
61 | parser.add_argument(
62 | "-lw",
63 | "--line_width",
64 | help="Width of visualizer lines",
65 | default=1,
66 | type=int,
67 | )
68 | return parser.parse_args()
69 |
70 |
71 | def paths_from_args(args, glob="*.npz"):
72 | if not args.file_path.exists():
73 | raise ValueError(f"File {args.file_path} does not exist")
74 |
75 | if args.file_path.is_file():
76 | print(f"Loading data from file: {args.file_path}")
77 | return [args.file_path]
78 |
79 | if args.file_path.is_dir():
80 | print(f"Loading data from directory: {args.file_path}")
81 | files = args.file_path.glob(glob)
82 | if files == []:
83 | raise ValueError(f"No npz files found in {args.file_path}")
84 | return files
85 |
86 |
87 | def main():
88 | args = parse_args()
89 |
90 | data = [(load_data_npz(filename), filename) for filename in paths_from_args(args)]
91 | view_synthetic_data(data, args.line_width)
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | .vscode/
133 |
134 | outputs/
135 | data/
136 | multirun/
137 | wandb/
138 | *.out
139 | *.sl
140 | speed-tree-outputs/
141 | smart_tree/conf/tree-split-test.json
142 | smart_tree/conf/apple-trellis-split-test.json
143 | smart_tree/conf/apple-trellis.yaml
144 | smart_tree/conf/apple-trellis-split.json
145 | FRNN/
146 |
147 |
148 | smart_tree/conf
149 | test_data/
150 |
151 | !smart_tree/conf/pipeline.yaml
152 | !smart_tree/conf/training.yaml
153 | !smart_tree/conf/training-split.json
154 |
155 | *.ply
156 | *.obj
--------------------------------------------------------------------------------
/smart_tree/model/fp16.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from collections import abc
3 | from inspect import getfullargspec
4 |
5 | import spconv.pytorch as spconv
6 | import torch
7 |
8 |
9 | def cast_tensor_type(inputs, src_type, dst_type):
10 | if isinstance(inputs, torch.Tensor):
11 | return inputs.to(dst_type) if inputs.dtype == src_type else inputs
12 | elif isinstance(inputs, spconv.SparseConvTensor):
13 | if inputs.features.dtype == src_type:
14 | features = inputs.features.to(dst_type)
15 | inputs = inputs.replace_feature(features)
16 | return inputs
17 | elif isinstance(inputs, abc.Mapping):
18 | return type(inputs)(
19 | {k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items()}
20 | )
21 | elif isinstance(inputs, abc.Iterable):
22 | return type(inputs)(
23 | cast_tensor_type(item, src_type, dst_type) for item in inputs
24 | )
25 | else:
26 | return inputs
27 |
28 |
29 | def force_fp32(apply_to=None, out_fp16=False):
30 | def force_fp32_wrapper(old_func):
31 | @functools.wraps(old_func)
32 | def new_func(*args, **kwargs):
33 | if not isinstance(args[0], torch.nn.Module):
34 | raise TypeError(
35 | "@force_fp32 can only be used to decorate the "
36 | "method of nn.Module"
37 | )
38 | # get the arg spec of the decorated method
39 | args_info = getfullargspec(old_func)
40 | # get the argument names to be casted
41 | args_to_cast = args_info.args if apply_to is None else apply_to
42 | # convert the args that need to be processed
43 | new_args = []
44 | if args:
45 | arg_names = args_info.args[: len(args)]
46 | for i, arg_name in enumerate(arg_names):
47 | if arg_name in args_to_cast:
48 | new_args.append(
49 | cast_tensor_type(args[i], torch.half, torch.float)
50 | )
51 | else:
52 | new_args.append(args[i])
53 | # convert the kwargs that need to be processed
54 | new_kwargs = dict()
55 | if kwargs:
56 | for arg_name, arg_value in kwargs.items():
57 | if arg_name in args_to_cast:
58 | new_kwargs[arg_name] = cast_tensor_type(
59 | arg_value, torch.half, torch.float
60 | )
61 | else:
62 | new_kwargs[arg_name] = arg_value
63 | with torch.cuda.amp.autocast(enabled=False):
64 | output = old_func(*new_args, **new_kwargs)
65 | # cast the results back to fp32 if necessary
66 | if out_fp16:
67 | output = cast_tensor_type(output, torch.float, torch.half)
68 | return output
69 |
70 | return new_func
71 |
72 | return force_fp32_wrapper
73 |
--------------------------------------------------------------------------------
/smart_tree/model/sparse.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | from typing import List, Tuple, Union
3 |
4 | import numpy as np
5 | import spconv.pytorch as spconv
6 | import torch
7 |
8 |
9 | def sparse_from_batch(features, coordinates, device):
10 | batch_size = features.shape[0]
11 |
12 | features = features.to(device)
13 | coordinates = coordinates.to(device)
14 |
15 | values, _ = torch.max(coordinates, 0) # BXYZ -> XYZ (Biggest Spatial Size)
16 |
17 | return spconv.SparseConvTensor(
18 | features, coordinates.int(), values[1:], batch_size=batch_size
19 | )
20 |
21 |
22 | def split_sparse_list(indices, features):
23 | cloud_ids = indices[:, 0]
24 |
25 | num_clouds = cloud_ids.max() + 1
26 | return [
27 | (indices[cloud_ids == i], features[cloud_ids == i]) for i in range(num_clouds)
28 | ]
29 |
30 |
31 | def split_sparse(sparse_tensor):
32 | cloud_ids = sparse_tensor.indices[:, 0]
33 | num_clouds = cloud_ids.max() + 1
34 | return [
35 | (sparse_tensor.indices[cloud_ids == i], sparse_tensor.features[cloud_ids == i])
36 | for i in range(num_clouds)
37 | ]
38 |
39 |
40 | def batch_collate(batch):
41 | """Custom Batch Collate Function for Sparse Data..."""
42 |
43 | batch_feats, batch_coords, batch_mask, fn = zip(*batch)
44 |
45 | for i, coords in enumerate(batch_coords):
46 | coords[:, 0] = torch.tensor([i], dtype=torch.float32)
47 |
48 | if isinstance(batch_feats[0], tuple):
49 | input_feats, target_feats = tuple(zip(*batch_feats))
50 |
51 | input_feats, target_feats, coords, mask = [
52 | torch.cat(x) for x in [input_feats, target_feats, batch_coords, batch_mask]
53 | ]
54 |
55 | return [(input_feats, target_feats), coords, mask, fn]
56 |
57 | feats, coords, mask = [
58 | torch.cat(x) for x in [batch_feats, batch_coords, batch_mask]
59 | ]
60 |
61 | return [feats, coords, mask, fn]
62 |
63 |
64 | def ravel_hash(x: np.ndarray) -> np.ndarray:
65 | assert x.ndim == 2, x.shape
66 |
67 | x = x - np.min(x, axis=0)
68 | x = x.astype(np.uint64, copy=False)
69 | xmax = np.max(x, axis=0).astype(np.uint64) + 1
70 |
71 | h = np.zeros(x.shape[0], dtype=np.uint64)
72 | for k in range(x.shape[1] - 1):
73 | h += x[:, k]
74 | h *= xmax[k + 1]
75 | h += x[:, -1]
76 | return h
77 |
78 |
79 | def sparse_quantize(
80 | coords,
81 | voxel_size: Union[float, Tuple[float, ...]] = 1,
82 | *,
83 | return_index: bool = False,
84 | return_inverse: bool = False,
85 | ) -> List[np.ndarray]:
86 | if isinstance(voxel_size, (float, int)):
87 | voxel_size = tuple(repeat(voxel_size, 3))
88 | assert isinstance(voxel_size, tuple) and len(voxel_size) == 3
89 |
90 | voxel_size = np.array(voxel_size)
91 | coords = np.floor(coords / voxel_size).astype(np.int32)
92 |
93 | _, indices, inverse_indices = np.unique(
94 | ravel_hash(coords), return_index=True, return_inverse=True
95 | )
96 | coords = coords[indices]
97 |
98 | outputs = [coords]
99 | if return_index:
100 | outputs += [indices]
101 | if return_inverse:
102 | outputs += [inverse_indices]
103 | return outputs[0] if len(outputs) == 1 else outputs
104 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
💡🧠🤔 Smart-Tree 🌳🌲🌴
2 |
3 | ## 📝 Description:
4 |
5 | This repository contains code from the paper "Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization".
6 | The code provided is a deep-learning-based skeletonization method for point clouds.
7 |
8 |
9 |
10 |  |
11 |  |
12 |  |
13 |
14 |
15 | | Input point cloud. |
16 | Mesh output. |
17 | Skeleton output. |
18 |
19 |
20 |
21 |
22 | ## 💾 Data:
23 |
24 | Please follow instructions to download data from this link.
25 |
26 | ## 🔧 Installation:
27 |
28 | First, make sure you have Conda installed, aswell as mamba.
29 | This will ensure the enviroment is created within a resonable timeframe.
30 |
31 | To install smart-tree please use
`bash create-env.sh`
32 | Then activate the environment using:
`conda activate smart-tree`
33 |
34 |
35 | ## 📈 Training:
36 |
37 | To train the model open smart_tree/conf/training.yaml.
38 |
39 | You will need to update (alternatively these can be overwritten with hydra):
40 |
41 | - training.dataset.json_path to the location of where your smart_tree/conf/tree-split.json is stored.
42 | - training.dataset.directory to the location of where you downloaded the data (you can choose whether to train on the data with foliage or without based on the directory you supply).
43 |
44 | You can experiment with/adjust hyper-parameter settings too.
45 |
46 | The model will then train using the following:
47 |
48 | `train-smart-tree`
49 |
50 | The best model weights and model will be stored in the generated outputs directory.
51 |
52 | ## ▶️ Inference / ☠️ Skeletonization:
53 |
54 | We supply two different models with weights:
55 | * `noble-elevator-58` contains branch/foliage segmentation.
56 | * `peach-forest-65` is only trained on points from the branching structure.
57 |
58 | If you wish to run smart-tree using your own weights you will need to update the model paths in the `tree-dataset.yaml`.
59 |
60 | To run smart-tree use:
61 | `run-smart-tree +path=cloud_path`
62 | where `cloud_path` is the path of the point cloud you want to skeletonize.
63 | Skeletonization-specific parameters can be adjusted within the `smart_tree/conf/tree-dataset.yaml` config.
64 |
65 | ## 📜 Citation:
66 | Please use the following BibTeX entry to cite our work:
67 |
68 | ```
69 | @inproceedings{dobbs2023smart,
70 | title={Smart-Tree: Neural Medial Axis Approximation of Point Clouds for 3D Tree Skeletonization},
71 | author={Dobbs, Harry and Batchelor, Oliver and Green, Richard and Atlas, James},
72 | booktitle={Iberian Conference on Pattern Recognition and Image Analysis},
73 | pages={351--362},
74 | year={2023},
75 | organization={Springer}
76 | }
77 | ```
78 | ## Star History
79 |
80 | [](https://star-history.com/#uc-vision/smart-tree&Date)
81 |
82 |
83 | ## 📥 Contact
84 |
85 | Should you have any questions, comments or suggestions please use the following contact details:
86 | harry.dobbs@pg.canterbury.ac.nz
87 |
--------------------------------------------------------------------------------
/smart_tree/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 | def compute_loss(
8 | preds,
9 | targets,
10 | mask=None,
11 | radius_loss_fn=None,
12 | direction_loss_fn=None,
13 | class_loss_fn=None,
14 | target_radius_log=True,
15 | vector_class=None,
16 | ):
17 | predicted_radius = preds["radius"] #
18 | predicted_direction = preds["direction"]
19 | predicted_class = preds["class_l"]
20 | target_class = targets[:, [-1]].long()
21 | target_direction = targets[:, 1:-1]
22 | target_radius = targets[:, [0]]
23 |
24 | if mask is not None:
25 | predicted_radius = predicted_radius[mask]
26 | predicted_direction = predicted_direction[mask]
27 | predicted_class = predicted_class[mask]
28 | target_radius = target_radius[mask]
29 | target_direction = target_direction[mask]
30 | target_class = target_class[mask]
31 |
32 | # Only compute vector loss on branch points...
33 | if vector_class is not None:
34 | vector_mask = target_class == vector_class
35 | vector_mask = vector_mask.view(-1)
36 | predicted_radius = predicted_radius[vector_mask]
37 | predicted_direction = predicted_direction[vector_mask]
38 | target_radius = target_radius[vector_mask]
39 | target_direction = target_direction[vector_mask]
40 |
41 | if target_radius_log:
42 | target_radius = torch.log(target_radius)
43 |
44 | losses = {}
45 |
46 | losses["radius"] = radius_loss_fn(predicted_radius.view(-1), target_radius.view(-1))
47 | losses["direction"] = direction_loss_fn(predicted_direction, target_direction)
48 | losses["class_l"] = class_loss_fn(predicted_class, target_class)
49 |
50 | return losses
51 |
52 |
53 | def L1Loss(outputs, targets):
54 | loss = nn.L1Loss()
55 | return loss(outputs, targets)
56 |
57 |
58 | def cosine_similarity_loss(outputs, targets):
59 | loss = nn.CosineSimilarity()
60 | return torch.mean(1 - loss(outputs, targets))
61 |
62 |
63 | def dice_loss(outputs, targets):
64 | # https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08
65 | smooth = 1
66 | outputs = F.softmax(outputs, dim=1)
67 | targets = F.one_hot(targets)
68 |
69 | outputs = outputs.view(-1)
70 | targets = targets.view(-1)
71 |
72 | intersection = (outputs * targets).sum()
73 |
74 | return 1 - (
75 | (2.0 * intersection + smooth) / (outputs.sum() + targets.sum() + smooth)
76 | )
77 |
78 |
79 | def focal_loss(outputs, targets):
80 | # https://github.com/torrvision/focal_calibration/blob/main/Losses/focal_loss.py
81 | gamma = 2
82 | input = outputs
83 |
84 | if input.dim() > 2:
85 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W
86 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
87 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
88 | targets = targets.view(-1, 1)
89 |
90 | logpt = F.log_softmax(input, dim=1)
91 | logpt = logpt.gather(1, targets)
92 | logpt = logpt.view(-1)
93 | pt = logpt.exp()
94 | loss = -1 * (1 - pt) ** gamma * logpt
95 | # return loss.sum()
96 | return loss.mean()
97 |
98 |
99 | def nll_loss(outputs, targets):
100 | return torch.tensor([0]).cuda()
101 | weights = targets.shape[0] / (torch.bincount(targets)) # Balance class weights
102 | return F.nll_loss(F.log_softmax(outputs, dim=1), targets, weight=weights)
103 |
--------------------------------------------------------------------------------
/smart_tree/model/model_inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from smart_tree.data_types.cloud import Cloud
7 | from smart_tree.dataset.dataset import load_dataloader
8 | from smart_tree.model.sparse import sparse_from_batch
9 |
10 |
11 | def load_model(model_path, weights_path, device=torch.device("cuda:0")):
12 | model = torch.load(f"{model_path}", map_location=device)
13 | model.load_state_dict(torch.load(f"{weights_path}"))
14 | model.eval()
15 |
16 | return model
17 |
18 |
19 | """ Loads model and model weights, then returns the input, outputs and mask """
20 |
21 |
22 | class ModelInference:
23 | def __init__(
24 | self,
25 | model_path: Path,
26 | weights_path: Path,
27 | voxel_size: float,
28 | block_size: float,
29 | buffer_size: float,
30 | num_workers=8,
31 | batch_size=4,
32 | device=torch.device("cuda:0"),
33 | verbose=False,
34 | ):
35 | self.device = device
36 | self.verbose = verbose
37 | self.voxel_size = voxel_size
38 | self.block_size = block_size
39 | self.buffer_size = buffer_size
40 |
41 | self.num_workers = num_workers
42 | self.batch_size = batch_size
43 |
44 | self.model = load_model(model_path, weights_path, self.device)
45 |
46 | if self.verbose:
47 | print("Model Loaded Succesfully")
48 |
49 | def forward(self, cloud: Cloud, return_masked=True):
50 | inputs, masks = [], []
51 | radius, direction, class_l = [], [], []
52 |
53 | dataloader = load_dataloader(
54 | cloud,
55 | self.voxel_size,
56 | self.block_size,
57 | self.buffer_size,
58 | self.num_workers,
59 | self.batch_size,
60 | )
61 |
62 | for features, coordinates, mask, filename in tqdm(
63 | dataloader, desc="Inferring", leave=False
64 | ):
65 | sparse_input = sparse_from_batch(
66 | features[:, :3],
67 | coordinates,
68 | device=self.device,
69 | )
70 |
71 | preds = self.model.forward(sparse_input)
72 |
73 | radius.append(preds["radius"].detach().cpu())
74 | direction.append(preds["direction"].detach().cpu())
75 | class_l.append(preds["class_l"].detach().cpu())
76 |
77 | inputs.append(features.detach().cpu())
78 | masks.append(mask.detach().cpu())
79 |
80 | radius = torch.cat(radius)
81 | direction = torch.cat(direction)
82 | class_l = torch.cat(class_l)
83 |
84 | inputs = torch.cat(inputs)
85 | masks = torch.cat(masks)
86 |
87 | medial_vector = torch.exp(radius) * direction
88 | class_l = torch.argmax(class_l, dim=1, keepdim=True)
89 |
90 | lc = Cloud(
91 | xyz=inputs[:, :3],
92 | rgb=inputs[:, 3:6],
93 | medial_vector=medial_vector,
94 | class_l=class_l,
95 | )
96 |
97 | if return_masked:
98 | return lc.filter(masks)
99 |
100 | return lc
101 |
102 | @staticmethod
103 | def from_cfg(cfg):
104 | return ModelInference(
105 | model_path=cfg.model_path,
106 | weights_path=cfg.weights_path,
107 | voxel_size=cfg.voxel_size,
108 | block_size=cfg.block_size,
109 | buffer_size=cfg.buffer_size,
110 | num_workers=cfg.num_workers,
111 | batch_size=cfg.batch_size,
112 | )
113 |
--------------------------------------------------------------------------------
/smart_tree/skeleton/skeletonize.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import cugraph
4 | import cupy
5 | import torch
6 | from cugraph import sssp
7 | from tqdm import tqdm
8 |
9 | from ..data_types.cloud import Cloud
10 | from ..data_types.graph import Graph
11 | from ..data_types.tree import DisjointTreeSkeleton, TreeSkeleton
12 | from .filter import outlier_removal
13 | from .graph import decompose_cuda_graph, nn_graph
14 | from .path import sample_tree
15 | from .shortest_path import pred_graph, shortest_paths
16 |
17 |
18 | class Skeletonizer:
19 | def __init__(
20 | self,
21 | K: int,
22 | min_connection_length: float,
23 | minimum_graph_vertices: int,
24 | device: torch.device = torch.device("cuda:0"),
25 | ):
26 | self.K = K
27 | self.min_connection_length = min_connection_length
28 | self.minimum_graph_vertices = minimum_graph_vertices
29 | self.device = device
30 |
31 | def forward(self, cloud: Cloud) -> DisjointTreeSkeleton:
32 | cloud.to_device(self.device)
33 |
34 | mask = outlier_removal(cloud.medial_pts, cloud.radius.unsqueeze(1), nb_points=8)
35 | cloud = cloud.filter(mask)
36 |
37 | graph: Graph = nn_graph(
38 | cloud.medial_pts,
39 | cloud.radius.clamp(min=self.min_connection_length),
40 | K=self.K,
41 | )
42 |
43 | subgraphs: List[cugraph.Graph] = graph.connected_cugraph_components(
44 | minimum_vertices=self.minimum_graph_vertices
45 | )
46 |
47 | skeletons = []
48 | for subgraph_id, subgraph in enumerate(
49 | tqdm(subgraphs, desc="Processing Connected Components", leave=False)
50 | ):
51 | skeletons.append(
52 | self.process_subgraph(cloud, subgraph, skeleton_id=subgraph_id)
53 | )
54 |
55 | return DisjointTreeSkeleton(skeletons)
56 |
57 | def process_subgraph(self, cloud, subgraph, skeleton_id=0) -> TreeSkeleton:
58 | """Extract skeleton for connected component"""
59 |
60 | subgraph_vertice_idx = torch.tensor(
61 | cupy.unique(subgraph.edges().values),
62 | device=self.device,
63 | )
64 |
65 | subgraph_cloud = cloud.filter(subgraph_vertice_idx)
66 |
67 | edges, edge_weights = decompose_cuda_graph(
68 | subgraph,
69 | renumber_edges=True,
70 | device=self.device,
71 | )
72 |
73 | verts, preds, distance = shortest_paths(
74 | subgraph_cloud.root_idx,
75 | edges,
76 | edge_weights,
77 | renumber=False,
78 | )
79 |
80 | predecessor_graph = pred_graph(verts, preds, subgraph_cloud.medial_pts)
81 |
82 | distances = torch.as_tensor(
83 | sssp(predecessor_graph, source=subgraph_cloud.root_idx)["distance"],
84 | device=self.device,
85 | )
86 |
87 | branches = sample_tree(
88 | subgraph_cloud.medial_pts,
89 | subgraph_cloud.radius.unsqueeze(1),
90 | preds,
91 | distances,
92 | subgraph_cloud.xyz,
93 | )
94 |
95 | return TreeSkeleton(skeleton_id, branches)
96 |
97 | @staticmethod
98 | def from_cfg(cfg):
99 | return Skeletonizer(
100 | K=cfg.K,
101 | min_connection_length=cfg.min_connection_length,
102 | minimum_graph_vertices=cfg.minimum_graph_vertices,
103 | edge_non_linear=cfg.edge_non_linear,
104 | )
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/smart_tree/dataset/augmentations.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Sequence
3 |
4 | import torch
5 | from beartype import beartype
6 |
7 | from smart_tree.data_types.cloud import Cloud
8 | from smart_tree.util.maths import euler_angles_to_rotation
9 |
10 |
11 | class Augmentation(ABC):
12 | @abstractmethod
13 | def __call__(self, cloud: Cloud) -> Cloud:
14 | pass
15 |
16 |
17 | class Scale(Augmentation):
18 | def __init__(self, min_scale=0.9, max_scale=1.1):
19 | self.min_scale = min_scale
20 | self.max_scale = max_scale
21 |
22 | def __call__(self, cloud: Cloud) -> Cloud:
23 | t = torch.rand(1, device=cloud.xyz.device) * (self.max_scale - self.min_scale)
24 | return cloud.scale(t + self.min_scale)
25 |
26 |
27 | class FixedRotate(Augmentation):
28 | def __init__(self, xyz):
29 | self.xyz = xyz
30 |
31 | def __call__(self, cloud: Cloud) -> Cloud:
32 | self.rot_mat = euler_angles_to_rotation(
33 | torch.tensor(self.xyz, device=cloud.xyz.device)
34 | ).float()
35 | return cloud.rotate(self.rot_mat)
36 |
37 |
38 | class CentreCloud(Augmentation):
39 | def __call__(self, cloud: Cloud) -> Cloud:
40 | centre, (x, y, z) = cloud.bbox
41 | return cloud.translate(-centre + torch.tensor([0, y, 0], device=centre.device))
42 |
43 |
44 | class VoxelDownsample(Augmentation):
45 | def __init__(self, voxel_size):
46 | self.voxel_size = voxel_size
47 |
48 | def __call__(self, cloud: Cloud) -> Cloud:
49 | return cloud.voxel_down_sample(self.voxel_size)
50 |
51 |
52 | class FixedTranslate(Augmentation):
53 | def __init__(self, xyz):
54 | self.xyz = torch.tensor(xyz)
55 |
56 | def __call__(self, cloud: Cloud) -> Cloud:
57 | return cloud.translate(self.xyz)
58 |
59 |
60 | class RandomCrop(Augmentation):
61 | def __init__(self, max_x, max_y, max_z):
62 | self.max_translation = torch.tensor([max_x, max_y, max_z])
63 |
64 | def __call__(self, cloud):
65 | offset = (
66 | torch.rand(3, device=cloud.xyz.device) - 0.5
67 | ) * self.max_translation.to(device=cloud.xyz.device)
68 |
69 | p = cloud.xyz + offset
70 | mask = torch.logical_and(p >= cloud.min_xyz, p <= cloud.max_xyz).all(dim=1)
71 |
72 | return cloud.filter(mask)
73 |
74 |
75 | class RandomCubicCrop(Augmentation):
76 | def __init__(self, size):
77 | self.size = size
78 |
79 | def __call__(self, cloud):
80 | random_pt = cloud.xyz[torch.randint(0, cloud.xyz.shape[0], (1,))]
81 | min_corner = random_pt - self.size / 2
82 | max_corner = random_pt + self.size / 2
83 |
84 | mask = torch.logical_and(
85 | cloud.xyz >= min_corner,
86 | cloud.xyz <= max_corner,
87 | ).all(dim=1)
88 |
89 | return cloud.filter(mask)
90 |
91 |
92 | class RandomDropout(Augmentation):
93 | def __init__(self, max_drop_out):
94 | self.max_drop_out = max_drop_out
95 |
96 | def __call__(self, cloud: Cloud) -> Cloud:
97 | num_indices = int(
98 | (1.0 - (self.max_drop_out * torch.rand(1, device=cloud.xyz.device)))
99 | * cloud.xyz.shape[0]
100 | )
101 |
102 | indices = torch.randint(
103 | high=cloud.xyz.shape[0], size=(num_indices, 1), device=cloud.xyz.device
104 | ).squeeze(1)
105 | return cloud.filter(indices)
106 |
107 |
108 | class AugmentationPipeline(Augmentation):
109 | @beartype
110 | def __init__(self, augmentations: Sequence[Augmentation]):
111 | self.augmentations = augmentations
112 |
113 | def __call__(self, cloud):
114 | for augmentation in self.augmentations:
115 | cloud = augmentation(cloud)
116 | return cloud
117 |
--------------------------------------------------------------------------------
/smart_tree/o3d_abstractions/camera.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 | import open3d.visualization.rendering as rendering
4 |
5 |
6 | def create_camera(width, height, fx=575, fy=575):
7 | camera_parameters = o3d.camera.PinholeCameraParameters()
8 | camera_parameters.extrinsic = np.eye(4)
9 | camera_parameters.intrinsic.set_intrinsics(
10 | width=width,
11 | height=height,
12 | fx=fx,
13 | fy=fy,
14 | cx=(width / 2.0 - 0.5),
15 | cy=(height / 2.0 - 0.5),
16 | )
17 | return camera_parameters
18 |
19 |
20 | def update_camera_position(
21 | camera, camera_position, camera_target=[0, 0, 0], up=np.asarray([0, 1, 0])
22 | ):
23 | camera_direction = (camera_target - camera_position) / np.linalg.norm(
24 | camera_target - camera_position
25 | )
26 | camera_right = np.cross(camera_direction, up)
27 | camera_right = camera_right / np.linalg.norm(camera_right)
28 | camera_up = np.cross(camera_direction, camera_right)
29 | position_matrix = np.eye(4)
30 | position_matrix[:3, 3] = -camera_position
31 | camera_look_at = np.eye(4)
32 | camera_look_at[:3, :3] = np.vstack((camera_right, camera_up, camera_direction))
33 | camera_look_at = np.matmul(camera_look_at, position_matrix)
34 | camera.extrinsic = camera_look_at
35 | return camera
36 |
37 |
38 | def o3d_headless_render(geoms, camera_position, camera_up):
39 | width, height = 1920, 1080
40 | camera = create_camera(width, height)
41 |
42 | # Setup a Offscreen Renderer
43 | render = rendering.OffscreenRenderer(width=width, height=height)
44 | render.scene.set_background([1.0, 1.0, 1.0, 1.0]) # RGBA
45 | render.setup_camera(camera.intrinsic, camera.extrinsic)
46 | render.scene.scene.set_sun_light(
47 | geoms[0].get_center() + np.asarray(camera_position), [1.0, 1.0, 1.0], 75000
48 | )
49 | render.scene.scene.enable_sun_light(True)
50 | render.scene.scene.enable_indirect_light(True)
51 | render.scene.scene.set_indirect_light_intensity(0.3)
52 | # render.scene.set_lighting(render.scene.LightingProfile.NO_SHADOWS, (0, 0, 0))
53 |
54 | mtl = o3d.visualization.rendering.MaterialRecord()
55 | mtl.shader = "defaultUnlit"
56 |
57 | for i, item in enumerate(geoms):
58 | render.scene.add_geometry(f"{i}", item, mtl)
59 |
60 | camera = update_camera_position(
61 | camera,
62 | geoms[0].get_center() + np.asarray(camera_position),
63 | camera_target=geoms[0].get_center(),
64 | up=np.asarray(camera_up),
65 | )
66 | render.setup_camera(camera.intrinsic, camera.extrinsic)
67 |
68 | return render.render_to_image()
69 |
70 |
71 | class Renderer:
72 | def __init__(self, width, height):
73 | self.camera = create_camera(width, height)
74 | self.render = rendering.OffscreenRenderer(width=width, height=height)
75 | self.render.scene.set_background([1.0, 1.0, 1.0, 1.0]) # RGBA
76 | self.render.setup_camera(self.camera.intrinsic, self.camera.extrinsic)
77 |
78 | self.render.scene.scene.enable_sun_light(True)
79 | self.render.scene.scene.enable_indirect_light(True)
80 | self.render.scene.scene.set_indirect_light_intensity(0.3)
81 | self.mtl = o3d.visualization.rendering.MaterialRecord()
82 | self.mtl.shader = "defaultUnlit"
83 |
84 | def capture(self, geoms, camera_position, camera_up):
85 | self.render.scene.scene.set_sun_light(
86 | geoms[0].get_center() + np.asarray(camera_position), [1.0, 1.0, 1.0], 75000
87 | )
88 | for i, item in enumerate(geoms):
89 | self.render.scene.add_geometry(f"{i}", item, self.mtl)
90 |
91 | camera = update_camera_position(
92 | self.camera,
93 | geoms[0].get_center() + np.asarray(camera_position),
94 | camera_target=geoms[0].get_center(),
95 | up=np.asarray(camera_up),
96 | )
97 | self.render.setup_camera(camera.intrinsic, camera.extrinsic)
98 | img = self.render.render_to_image()
99 | self.render.scene.clear_geometry()
100 |
101 | return img
102 |
--------------------------------------------------------------------------------
/smart_tree/conf/training.yaml:
--------------------------------------------------------------------------------
1 | # @package: _global_
2 |
3 | wandb:
4 | project: tree
5 | entity: harry1576
6 | mode: online
7 |
8 | fp16: True
9 | num_epoch: 1
10 | lr_decay: True
11 | lr: 0.1
12 | early_stop_epoch: 20
13 | early_stop: True
14 |
15 | batch_size: 8
16 | directory: /local/smart-tree/data/branches
17 | json_path: smart_tree/conf/training-split.json
18 | voxel_size: 0.01
19 |
20 | cmap:
21 | - [0.450, 0.325, 0.164] # Trunk
22 | - [0.541, 0.670, 0.164] # Foliage
23 |
24 |
25 | capture_output: 1
26 |
27 | input_features:
28 | - xyz
29 |
30 | target_features:
31 | - radius
32 | - direction
33 | - class_l
34 |
35 | train_dataset:
36 | _target_: smart_tree.dataset.dataset.TreeDataset
37 | mode: train
38 | voxel_size: ${voxel_size}
39 | directory: ${directory}
40 | json_path: ${json_path}
41 | input_features: ${input_features}
42 | target_features: ${target_features}
43 | augmentation:
44 | _target_: smart_tree.dataset.augmentations.AugmentationPipeline
45 | augmentations:
46 | - _target_: smart_tree.dataset.augmentations.RandomCubicCrop
47 | size: 4.0
48 |
49 | test_dataset:
50 | _target_: smart_tree.dataset.dataset.TreeDataset
51 | mode: test
52 | voxel_size: ${voxel_size}
53 | directory: ${directory}
54 | json_path: ${json_path}
55 | input_features: ${input_features}
56 | target_features: ${target_features}
57 | augmentation:
58 | _target_: smart_tree.dataset.augmentations.AugmentationPipeline
59 | augmentations:
60 | - _target_: smart_tree.dataset.augmentations.RandomCubicCrop
61 | size: 4.0
62 |
63 | validation_dataset:
64 | _target_: smart_tree.dataset.dataset.TreeDataset
65 | mode: validation
66 | voxel_size: ${voxel_size}
67 | directory: ${directory}
68 | json_path: ${json_path}
69 | input_features: ${input_features}
70 | target_features: ${target_features}
71 | cache: True
72 | augmentation:
73 | _target_: smart_tree.dataset.augmentations.AugmentationPipeline
74 | augmentations:
75 | - _target_: smart_tree.dataset.augmentations.RandomCubicCrop
76 | size: 4.0
77 |
78 |
79 | train_data_loader:
80 | _target_: torch.utils.data.DataLoader
81 | batch_size: ${batch_size}
82 | drop_last: False
83 | pin_memory: False
84 | num_workers: 0
85 | shuffle: False
86 | # sampler:
87 | # _target_: torch.utils.data.RandomSampler
88 | # replacement: True
89 | # num_samples: 8
90 | # data_source: ${train_dataset}
91 | collate_fn:
92 | _target_: smart_tree.model.sparse.batch_collate
93 | _partial_: True
94 | dataset: ${train_dataset}
95 |
96 | validation_data_loader:
97 | _target_: torch.utils.data.DataLoader
98 | batch_size: ${batch_size}
99 | drop_last: False
100 | pin_memory: False
101 | num_workers: 0
102 | shuffle: False
103 | collate_fn:
104 | _target_: smart_tree.model.sparse.batch_collate
105 | _partial_: True
106 | dataset: ${validation_dataset}
107 |
108 | test_data_loader:
109 | _target_: torch.utils.data.DataLoader
110 | batch_size: ${batch_size}
111 | drop_last: False
112 | pin_memory: False
113 | num_workers: 0
114 | shuffle: False
115 | collate_fn:
116 | _target_: smart_tree.model.sparse.batch_collate
117 | _partial_: True
118 | dataset: ${test_dataset}
119 |
120 |
121 | model:
122 | _target_: smart_tree.model.model.Smart_Tree
123 | input_channels: 3
124 | unet_planes: [8, 16, 32]
125 | radius_fc_planes: [8, 8, 4, 1]
126 | direction_fc_planes: [8, 8, 4, 3]
127 | class_fc_planes: [8, 8, 4, 2]
128 |
129 | optimizer:
130 | _target_: torch.optim.Adam
131 | lr: ${lr}
132 |
133 | scheduler:
134 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
135 | mode: "min"
136 |
137 | loss_fn:
138 | _target_: smart_tree.model.loss.compute_loss
139 | _partial_: True
140 | vector_class: 0
141 | target_radius_log: True
142 |
143 | radius_loss_fn:
144 | _target_ : smart_tree.model.loss.L1Loss
145 | _partial_: True
146 |
147 | direction_loss_fn:
148 | _target_: smart_tree.model.loss.cosine_similarity_loss
149 | _partial_: True
150 |
151 | class_loss_fn:
152 | _target_: smart_tree.model.loss.focal_loss
153 | _partial_: True
154 |
--------------------------------------------------------------------------------
/smart_tree/skeleton/graph.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import cugraph
4 | import frnn
5 | import numpy as np
6 | import torch
7 | from tqdm import tqdm
8 |
9 | from ..data_types.graph import Graph
10 |
11 |
12 | def knn(src, dest, K=50, r=1.0, grid=None):
13 | src_lengths = src.new_tensor([src.shape[0]], dtype=torch.long)
14 | dest_lengths = src.new_tensor([dest.shape[0]], dtype=torch.long)
15 | dists, idxs, grid, _ = frnn.frnn_grid_points(
16 | src.unsqueeze(0),
17 | dest.unsqueeze(0),
18 | src_lengths,
19 | dest_lengths,
20 | K,
21 | r,
22 | return_nn=False,
23 | return_sorted=True,
24 | )
25 |
26 | return idxs.squeeze(0), dists.sqrt().squeeze(0), grid
27 |
28 |
29 | def nn(src, dest, r=1.0, grid=None):
30 | idx, dist, grid = knn(src, dest, K=1, r=r, grid=grid)
31 | idx, dist = idx.squeeze(1), dist.squeeze(1)
32 |
33 | return idx, dist, grid
34 |
35 |
36 | def nn_graph(points: torch.Tensor, radii, K=40):
37 | idxs, dists, _ = knn(points, points, K=K, r=radii.max().item())
38 | idxs[dists > radii.unsqueeze(1)] = -1
39 | edges, edge_weights = make_edges(dists, idxs)
40 | return Graph(points, edges, edge_weights)
41 |
42 |
43 | def medial_nn_graph(points: torch.Tensor, radii, medial_dist, K=40):
44 | # edges weighted based on distance to medial axis
45 | idxs, dists, _ = knn(points, points, K=K, r=radii.max().item() / 4.0)
46 | dists_ = dists + medial_dist[idxs] # Add medial distance to distance graph...
47 | idxs[dists > radii.unsqueeze(1)] = -1
48 |
49 | return make_edges(dists_, idxs)
50 |
51 |
52 | def make_edges(dists, idxs):
53 | n = dists.shape[0]
54 | K = dists.shape[1]
55 |
56 | parent = torch.arange(n, device=dists.device).unsqueeze(1).expand(n, K)
57 | edges = torch.stack([parent, idxs], dim=2)
58 |
59 | valid = idxs.view(-1) > 0
60 | return edges.view(-1, 2)[valid], dists.view(-1)[valid]
61 |
62 |
63 | def nn_flat(points, K=50, r=1.0, device=torch.device("cuda")):
64 | idxs, dists = nn(points, K, r, device)
65 | return idxs.reshape(-1), dists.reshape(-1)
66 |
67 |
68 | def pcd_nn(points, radii, K=20):
69 | idxs, dists = nn_flat(np.asarray(points), K=K, r=float(radii.max() / 4.0))
70 |
71 | parent_vertex = np.repeat(np.arange(points.shape[0]), K)
72 | parent_vertex = parent_vertex.reshape(-1)
73 |
74 | edges = np.vstack((parent_vertex, idxs.reshape(-1))).T
75 |
76 | valid = idxs != -1
77 | return edges[valid], dists[valid]
78 |
79 |
80 | def decompose_cuda_graph(cuda_graph, renumber_edges=False, device=torch.device("cuda")):
81 | pdf = cugraph.to_pandas_edgelist(cuda_graph)
82 |
83 | edges = torch.stack((torch.tensor(pdf["src"]), torch.tensor(pdf["dst"])), dim=1)
84 | edge_weights = torch.tensor(pdf["weights"])
85 |
86 | edges, edge_weights = edges.long().to(device), edge_weights.to(device)
87 |
88 | if renumber_edges:
89 | edges = remap_edges(edges)
90 |
91 | return edges, edge_weights
92 |
93 |
94 | def remap_edges(edges):
95 | # Find unique node IDs and their corresponding indices
96 | unique_nodes, node_indices = torch.unique(edges, return_inverse=True)
97 |
98 | # Create a mapping from old node IDs to new node IDs
99 | mapping = torch.arange(unique_nodes.size(0), device=edges.device)
100 |
101 | # Map the old node IDs to new node IDs using the indices
102 | renumbered_edges = mapping[node_indices].reshape(edges.shape)
103 |
104 | return renumbered_edges
105 |
106 |
107 | def connected_components(edges, edge_weights, minimum_vertices=0, max_components=10):
108 | g = cuda_graph(edges, edge_weights)
109 |
110 | connected_components = cugraph.connected_components(g)
111 |
112 | num_labels = connected_components["labels"].to_pandas().value_counts()
113 | valid_labels = num_labels[num_labels > minimum_vertices].index
114 |
115 | graphs = []
116 |
117 | for label in tqdm(
118 | valid_labels[:max_components], desc="Getting Connected Componenets", leave=False
119 | ):
120 | graphs.append(
121 | cugraph.subgraph(
122 | g, connected_components.query(f"labels == {label}")["vertex"]
123 | )
124 | )
125 |
126 | return graphs
127 |
--------------------------------------------------------------------------------
/smart_tree/skeleton/path.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from tqdm import tqdm
4 |
5 | from ..data_types.branch import BranchSkeleton
6 | from .graph import nn
7 |
8 |
9 | def trace_route(preds, idx, termination_pts):
10 | path = []
11 |
12 | while idx >= 0 and idx not in termination_pts:
13 | path.append(idx)
14 | idx = preds[idx]
15 |
16 | return preds.new_tensor(path, dtype=torch.long).flip(0), idx
17 |
18 |
19 | def select_path_points(
20 | points: torch.tensor, path_verts: torch.tensor, radii: torch.tensor
21 | ):
22 | """
23 | Finds points nearest to a path (specified by points with radii).
24 | points: (N, 3) 3d points of point cloud
25 | path_verts: (M, 3) 3d points of path
26 | radii: (M, 1) radii of path
27 | returns: (X, 2) index tuples of (path, point) for X points falling within the path, ordered by path index
28 | """
29 |
30 | point_path, dists, _ = nn(
31 | points,
32 | path_verts,
33 | r=radii.max().item(),
34 | ) # nearest path idx for each point
35 | valid = dists[point_path >= 0] < radii[point_path[point_path >= 0]].squeeze(
36 | 1
37 | ) # where the path idx is less than the distance to the point
38 |
39 | on_path = point_path.new_zeros(point_path.shape, dtype=torch.bool)
40 | on_path[point_path >= 0] = valid # points that are on the path.
41 |
42 | idx_point = on_path.nonzero().squeeze(1)
43 | idx_path = point_path[idx_point]
44 |
45 | order = torch.argsort(idx_path)
46 | return idx_point[order], idx_path[order]
47 |
48 |
49 | def sample_tree(
50 | medial_pts,
51 | medial_radii,
52 | preds,
53 | distances,
54 | all_points,
55 | root_idx=0,
56 | visualize=False,
57 | pbar=None,
58 | ):
59 | """
60 | Medial Points: NN estimated medial points
61 | Medial Radii: NN estimated radii of points
62 | Preds: Predecessor of each medial point (on path to root node)
63 | Distance: Distance from root node to medial points
64 | Surface Points: The point the medial pts got projected from..
65 | """
66 |
67 | branch_id = 0
68 |
69 | branches = {}
70 |
71 | selection_mask = preds > 0
72 | distances[~selection_mask] = -1
73 |
74 | termination_pts = torch.tensor([], device=torch.device("cuda"))
75 | branch_ids = torch.full(
76 | (medial_pts.shape[0],),
77 | -1,
78 | device=torch.device("cuda"),
79 | dtype=int,
80 | )
81 |
82 | pbar = tqdm(
83 | total=distances.shape[0],
84 | leave=False,
85 | desc="Allocating Points",
86 | )
87 |
88 | while True:
89 | pbar.update(n=((distances < 0).sum().item() - pbar.n))
90 | pbar.refresh()
91 |
92 | farthest = distances.argmax().item()
93 |
94 | if distances[farthest] <= 0:
95 | break
96 |
97 | """ Traces the path of the futhrest point until it converges with allocated points """
98 | path_vertices_idx, termination_idx = trace_route(
99 | preds,
100 | farthest,
101 | termination_pts,
102 | )
103 |
104 | """ Gets the points around that path (and which path indexs they are close to) """
105 | idx_points, idx_path = select_path_points(
106 | medial_pts,
107 | medial_pts[path_vertices_idx],
108 | medial_radii[path_vertices_idx],
109 | )
110 |
111 | """ Mark this points as allocated and as termination points """
112 | distances[idx_points] = -1
113 | distances[path_vertices_idx] = -1
114 | termination_pts = torch.unique(
115 | torch.cat(
116 | (
117 | termination_pts,
118 | idx_points,
119 | path_vertices_idx,
120 | )
121 | )
122 | )
123 |
124 | """ If the path has at least two points, save it as a branch """
125 | if len(path_vertices_idx) < 2:
126 | continue
127 |
128 | branches[branch_id] = BranchSkeleton(
129 | branch_id,
130 | xyz=medial_pts[path_vertices_idx].cpu(),
131 | radii=medial_radii[path_vertices_idx].cpu(),
132 | parent_id=int(branch_ids[termination_idx].item()),
133 | )
134 |
135 | branch_ids[path_vertices_idx] = branch_id
136 | branch_ids[idx_points] = branch_id
137 |
138 | branch_id += 1
139 |
140 | return branches
141 |
--------------------------------------------------------------------------------
/smart_tree/pipeline.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from .data_types.cloud import Cloud
7 | from .data_types.tree import DisjointTreeSkeleton
8 | from .o3d_abstractions.visualizer import o3d_viewer
9 | from .util.file import (load_cloud, save_o3d_cloud,
10 | save_o3d_lineset, save_o3d_mesh)
11 |
12 |
13 | class Pipeline:
14 | def __init__(
15 | self,
16 | preprocessing,
17 | model_inference,
18 | skeletonizer,
19 | repair_skeletons=False,
20 | smooth_skeletons=False,
21 | smooth_kernel_size=0,
22 | prune_skeletons=False,
23 | min_skeleton_radius=0.0,
24 | min_skeleton_length=1000,
25 | view_model_output=False,
26 | view_skeletons=False,
27 | save_outputs=False,
28 | save_path="/",
29 | branch_classes=[0],
30 | cmap=[[1, 0, 0], [0, 1, 0]],
31 | device=torch.device("cuda:0"),
32 | ):
33 | self.preprocessing = preprocessing
34 | self.model_inference = model_inference
35 | self.skeletonizer = skeletonizer
36 |
37 | self.repair_skeletons = repair_skeletons
38 | self.smooth_skeletons = smooth_skeletons
39 | self.smooth_kernel_size = smooth_kernel_size
40 | self.prune_skeletons = prune_skeletons
41 |
42 | self.min_skeleton_radius = min_skeleton_radius
43 | self.min_skeleton_length = min_skeleton_length
44 |
45 | self.view_model_output = view_model_output
46 | self.view_skeletons = view_skeletons
47 |
48 | self.save_outputs = save_outputs
49 | self.save_path = save_path
50 |
51 | self.branch_classes = branch_classes
52 | self.cmap = np.asarray(cmap)
53 | self.device = device
54 |
55 | def process_cloud(self, path: Path =None, cloud: Cloud=None):
56 | # Load point cloud
57 | cloud: Cloud = load_cloud(path) if path != None else cloud
58 |
59 | cloud = cloud.to_device(self.device)
60 | cloud = self.preprocessing(cloud)
61 |
62 | # Run point cloud through model to predict class, radius, direction
63 | lc: Cloud = self.model_inference.forward(cloud).to_device(self.device)
64 | if self.view_model_output:
65 | lc.view(self.cmap)
66 |
67 | # Filter only the branch points for skeletonizaiton
68 | branch_cloud: Cloud = lc.filter_by_class(self.branch_classes)
69 |
70 | # Run the branch cloud through skeletonization algorithm, then post process
71 | skeleton: DisjointTreeSkeleton = self.skeletonizer.forward(branch_cloud)
72 |
73 | self.post_process(skeleton)
74 |
75 | # View skeletonization results
76 | if self.view_skeletons:
77 | o3d_viewer(
78 | [
79 | skeleton.to_o3d_tube(),
80 | skeleton.to_o3d_lineset(),
81 | skeleton.to_o3d_tube(colour=False),
82 | cloud.to_o3d_cld(),
83 | ],
84 | line_width=5,
85 | )
86 |
87 | if self.save_outputs:
88 | print("Saving Outputs")
89 | sp = self.save_path
90 | save_o3d_lineset(f"{sp}/skeleton.ply", skeleton.to_o3d_lineset())
91 | save_o3d_mesh(f"{sp}/mesh.ply", skeleton.to_o3d_tube())
92 | save_o3d_cloud(f"{sp}/cloud.ply", lc.to_o3d_cld())
93 | save_o3d_cloud(f"{sp}/seg_cld.ply", lc.to_o3d_seg_cld(self.cmap))
94 |
95 | def post_process(self, skeleton: DisjointTreeSkeleton):
96 | if self.prune_skeletons:
97 | skeleton.prune(
98 | min_length=self.min_skeleton_length,
99 | min_radius=self.min_skeleton_radius,
100 | )
101 |
102 | if self.repair_skeletons:
103 | skeleton.repair()
104 |
105 | if self.smooth_skeletons:
106 | skeleton.smooth(self.smooth_kernel_size)
107 |
108 | @staticmethod
109 | def from_cfg(inferer, skeletonizer, cfg):
110 | return Pipeline(
111 | inferer,
112 | skeletonizer,
113 | preprocessing_cfg=cfg.preprocessing,
114 | repair_skeletons=cfg.repair_skeletons,
115 | smooth_skeletons=cfg.smooth_skeletons,
116 | smooth_kernel_size=cfg.smooth_kernel_size,
117 | prune_skeletons=cfg.prune_skeletons,
118 | min_skeleton_radius=cfg.min_skeleton_radius,
119 | min_skeleton_length=cfg.min_skeleton_length,
120 | view_model_output=cfg.view_model_output,
121 | view_skeletons=cfg.view_skeletons,
122 | save_outputs=cfg.save_outputs,
123 | branch_classes=cfg.branch_classes,
124 | cmap=cfg.cmap,
125 | )
126 |
--------------------------------------------------------------------------------
/smart_tree/util/file.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import Tuple
4 |
5 | import numpy as np
6 | import open3d as o3d
7 | import yaml
8 |
9 | from smart_tree.data_types.branch import BranchSkeleton
10 | from smart_tree.data_types.cloud import Cloud
11 | from smart_tree.data_types.tree import TreeSkeleton
12 |
13 |
14 | def unpackage_data(data: dict) -> Tuple[Cloud, TreeSkeleton]:
15 | tree_id = data["tree_id"]
16 | branch_id = data["branch_id"]
17 | branch_parent_id = data["branch_parent_id"]
18 | skeleton_xyz = data["skeleton_xyz"]
19 | skeleton_radii = data["skeleton_radii"]
20 | sizes = data["branch_num_elements"]
21 |
22 | cld = Cloud.from_numpy(
23 | xyz=data["xyz"],
24 | rgb=data["rgb"],
25 | vector=data["vector"],
26 | class_l=data["class_l"],
27 | )
28 |
29 | return cld, None
30 |
31 | offsets = np.cumsum(np.append([0], sizes))
32 |
33 | branch_idx = [np.arange(size) + offset for size, offset in zip(sizes, offsets)]
34 | branches = {}
35 |
36 | for idx, _id, parent_id in zip(branch_idx, branch_id, branch_parent_id):
37 | branches[_id] = BranchSkeleton(
38 | _id, parent_id, skeleton_xyz[idx], skeleton_radii[idx]
39 | )
40 |
41 | return cld, TreeSkeleton(tree_id, branches)
42 |
43 |
44 | def package_data(skeleton: TreeSkeleton, pointcloud: Cloud) -> dict:
45 | data = {}
46 |
47 | data["tree_id"] = skeleton._id
48 |
49 | data["xyz"] = pointcloud.xyz
50 | data["rgb"] = pointcloud.rgb
51 | data["vector"] = pointcloud.vector
52 | data["class_l"] = pointcloud.class_l
53 |
54 | data["skeleton_xyz"] = np.concatenate(
55 | [branch.xyz for branch in skeleton.branches.values()]
56 | )
57 | data["skeleton_radii"] = np.concatenate(
58 | [branch.radii for branch in skeleton.branches.values()]
59 | )[..., np.newaxis]
60 | data["branch_id"] = np.asarray(
61 | [branch._id for branch in skeleton.branches.values()]
62 | )
63 | data["branch_parent_id"] = np.asarray(
64 | [branch.parent_id for branch in skeleton.branches.values()]
65 | )
66 | data["branch_num_elements"] = np.asarray(
67 | [len(branch) for branch in skeleton.branches.values()]
68 | )
69 |
70 | return data
71 |
72 |
73 | def save_skeleton(skeleton: TreeSkeleton, save_location):
74 | data = {}
75 | data["tree_id"] = skeleton._id
76 |
77 | data["skeleton_xyz"] = np.concatenate(
78 | [branch.xyz for branch in skeleton.branches.values()]
79 | )
80 | data["skeleton_radii"] = np.concatenate(
81 | [branch.radii for branch in skeleton.branches.values()]
82 | )[..., np.newaxis]
83 | data["branch_id"] = np.asarray(
84 | [branch._id for branch in skeleton.branches.values()]
85 | )
86 | data["branch_parent_id"] = np.asarray(
87 | [branch.parent_id for branch in skeleton.branches.values()]
88 | )
89 | data["branch_num_elements"] = np.asarray(
90 | [len(branch) for branch in skeleton.branches.values()]
91 | )
92 |
93 | np.savez(save_location, **data)
94 |
95 |
96 | def load_skeleton(path):
97 | data = np.load(path)
98 |
99 | # tree_id = data["tree_id"]
100 | branch_id = data["branch_id"]
101 | branch_parent_id = data["branch_parent_id"]
102 | skeleton_xyz = data["skeleton_xyz"]
103 | skeleton_radii = data["skeleton_radii"]
104 | sizes = data["branch_num_elements"]
105 |
106 | offsets = np.cumsum(np.append([0], sizes))
107 |
108 | branch_idx = [np.arange(size) + offset for size, offset in zip(sizes, offsets)]
109 | branches = {}
110 |
111 | for idx, _id, parent_id in zip(branch_idx, branch_id, branch_parent_id):
112 | branches[_id] = BranchSkeleton(
113 | _id, parent_id, skeleton_xyz[idx], skeleton_radii[idx]
114 | )
115 |
116 | return TreeSkeleton(0, branches)
117 |
118 |
119 | def save_data_npz(path: Path, skeleton: TreeSkeleton, pointcloud: Cloud):
120 | np.savez(path, **package_data(skeleton, pointcloud))
121 |
122 |
123 | def load_data_npz(path: Path) -> Tuple[Cloud, TreeSkeleton]:
124 | with np.load(path) as data:
125 | return unpackage_data(data)
126 |
127 |
128 | def load_json(json_path):
129 | return json.load(open(json_path))
130 |
131 |
132 | def save_o3d_cloud(filename: Path, cld: Cloud):
133 | o3d.io.write_point_cloud(str(filename), cld)
134 |
135 |
136 | def save_o3d_lineset(path: Path, ls):
137 | return o3d.io.write_line_set(path, ls)
138 |
139 |
140 | def save_o3d_mesh(path: Path, mesh):
141 | return o3d.io.write_triangle_mesh(path, mesh)
142 |
143 |
144 | def load_o3d_cloud(path: Path):
145 | return o3d.io.read_point_cloud(path)
146 |
147 |
148 | def load_o3d_lineset(path: Path):
149 | return o3d.io.read_line_set(path)
150 |
151 |
152 | def load_o3d_mesh(path: Path):
153 | return o3d.io.read_triangle_model(path)
154 |
155 |
156 | def load_cloud(path: Path):
157 | if path.suffix == ".npz":
158 | return Cloud.from_numpy(**np.load(path))
159 |
160 | data = o3d.io.read_point_cloud(str(path))
161 | xyz = np.asarray(data.points)
162 | rgb = (
163 | np.asarray(data.colors)
164 | if np.asarray(data.colors).shape[0] != 0
165 | else np.zeros_like(xyz)
166 | )
167 | return Cloud.from_numpy(xyz=xyz, rgb=rgb)
168 |
169 |
170 | def load_yaml(path: Path):
171 | with open(f"{path}") as f:
172 | config = yaml.load(f, Loader=yaml.FullLoader)
173 | return config
174 |
--------------------------------------------------------------------------------
/smart_tree/util/queries.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import List
4 |
5 | import numpy as np
6 | import torch
7 | from tqdm import tqdm
8 |
9 | from smart_tree.data_types.tube import CollatedTube, Tube, collate_tubes
10 |
11 | """
12 | For the following :
13 | N : number of pts
14 | M : number of tubes
15 | """
16 |
17 |
18 | def points_to_collated_tube_projections(
19 | pts: np.array, collated_tube: CollatedTube, eps=1e-12
20 | ): # N x 3, M x 2
21 | ab = collated_tube.b - collated_tube.a # M x 3
22 |
23 | ap = pts[:, np.newaxis] - collated_tube.a[np.newaxis, ...] # N x M x 3
24 |
25 | t = np.clip(
26 | np.einsum("nmd,md->nm", ap, ab) / (np.einsum("md,md->m", ab, ab) + eps),
27 | 0.0,
28 | 1.0,
29 | ) # N x M
30 | proj = collated_tube.a[np.newaxis, ...] + np.einsum(
31 | "nm,md->nmd", t, ab
32 | ) # N x M x 3
33 | return proj, t
34 |
35 |
36 | def projection_to_distance_matrix(projections, pts): # N x M x 3
37 | return np.sqrt(np.sum(np.square(projections - pts[:, np.newaxis, :]), 2)) # N x M
38 |
39 |
40 | def pts_to_nearest_tube(pts: np.array, tubes: List[Tube]):
41 | """Vectors from pt to the nearest tube"""
42 |
43 | collated_tube = collate_tubes(tubes)
44 | projections, t = points_to_collated_tube_projections(
45 | pts, collated_tube
46 | ) # N x M x 3
47 |
48 |
49 | def pts_to_nearest_tube_keops(pts: np.array, tubes: List[Tube]):
50 | """Vectors from pt to the nearest tube"""
51 |
52 | distances, idx, r = points_to_tube_distance_keops(pts, tubes) # N x M x 3
53 |
54 | return distances.reshape(-1), idx, rce_matrix(projections, pts) # N x M
55 |
56 | distances = distances - r
57 | idx = np.argmin(distances, 1) # N
58 |
59 | # assert idx.shape[0] == pts.shape[0]
60 |
61 | return (
62 | projections[np.arange(pts.shape[0]), idx] - pts,
63 | idx,
64 | r[np.arange(pts.shape[0]), idx],
65 | ) # vector, idx , radius
66 |
67 |
68 | def pairwise_pts_to_nearest_tube(pts: np.array, tubes: List[Tube]):
69 | collated_tube = collate_tubes(tubes)
70 |
71 | ab = collated_tube.b - collated_tube.a # M x 3
72 | ap = (
73 | pts - collated_tube.a
74 | ) # N def pts_to_nearest_tube_keops(pts: np.array, tubes: List[Tube]):
75 | """Vectors from pt to the nearest tube"""
76 |
77 | distances, idx, r = points_to_tube_distance_keops(pts, tubes) # N x M x 3
78 |
79 | return distances.reshape(-1), idx, r + t * collated_tube.r2
80 |
81 | distances = np.sqrt(np.sum(np.square(proj - pts), 1))
82 |
83 | distances = distances - r
84 |
85 | return distances, r # vector, idx , radius
86 |
87 |
88 | # GPU
89 | def points_to_collated_tube_projections_gpu(
90 | pts: np.array, collated_tube: CollatedTube, device=torch.device("cuda")
91 | ):
92 | ab = collated_tube.b - collated_tube.a # M x 3
93 |
94 | ap = pts.unsqueeze(1) - collated_tube.a.unsqueeze(0) # N x M x 3
95 |
96 | t = (torch.einsum("nmd,md->nm", ap, ab) / torch.einsum("md,md->m", ab, ab)).clip(
97 | 0.0, 1.0
98 | ) # N x M
99 | proj = collated_tube.a.unsqueeze(0) + torch.einsum("nm,md->nmd", t, ab) # N x M x 3
100 | return proj, t
101 |
102 |
103 | def projection_to_distance_matrix_gpu(projections, pts): # N x M x 3
104 | return (projections - pts.unsqueeze(1)).square().sum(2).sqrt()
105 |
106 |
107 | def pts_to_nearest_tube_gpu(
108 | pts: torch.tensor, tubes: List[Tube], device=torch.device("cuda")
109 | ):
110 | """Vectors from pt to the nearest tube"""
111 |
112 | collated_tube_gpu = collate_tubes(tubes)
113 | collated_tube_gpu.to_gpu()
114 |
115 | pts = pts.float().to(device)
116 |
117 | projections, t = points_to_collated_tube_projections_gpu(
118 | pts, collated_tube_gpu, device=torch.device("cuda")
119 | ) # N x M x 3
120 | r = (1 - t) * collated_tube_gpu.r1 + t * collated_tube_gpu.r2
121 |
122 | distances = projection_to_distance_matrix_gpu(projections, pts) # N x M
123 |
124 | distances = torch.abs(distances - r)
125 | idx = torch.argmin(distances, 1) # N
126 |
127 | assert idx.shape[0] == pts.shape[0]
128 |
129 | return (
130 | projections[torch.arange(pts.shape[0]), idx] - pts,
131 | idx,
132 | r[torch.arange(pts.shape[0]), idx],
133 | )
134 |
135 |
136 | def projection_to_distance_matrix_keops(projections, pts): # N x M x 3
137 | return np.sqrt(np.sum(np.square(projections - pts[:, np.newaxis, :]), 2)) # N x M
138 |
139 |
140 | def skeleton_to_points(pcd, skeleton, chunk_size=4096, device="gpu"):
141 | distances = []
142 | radii = []
143 | vectors_ = []
144 |
145 | tubes = skeleton.to_tubes()
146 | pts_chunks = np.array_split(pcd.xyz, np.ceil(pcd.xyz.shape[0] / chunk_size))
147 |
148 | for pts in tqdm(pts_chunks, desc="Labelling Chunks", leave=False):
149 | if device == "gpu":
150 | vectors, idxs, radiuses = pts_to_nearest_tube_gpu(
151 | pts, tubes
152 | ) # vector to nearest skeleton...
153 | else:
154 | vectors, idxs, radiuses = pts_to_nearest_tube(
155 | pts, tubes
156 | ) # vector to nearest skeleton...
157 |
158 | distances.append(
159 | np.sqrt(np.einsum("ij,ij->i", vectors, vectors))
160 | ) # could do on gpu but meh
161 | radii.append([radius for radius in radiuses])
162 | vectors_.append([v for v in vectors])
163 |
164 | distances = np.concatenate(distances)
165 | radii = np.concatenate(radii)
166 | vectors_ = np.concatenate(vectors_)
167 |
168 | return distances, radii, vectors_
169 |
--------------------------------------------------------------------------------
/smart_tree/util/maths.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def np_normalized(a: np.array, axis=-1, order=2) -> np.array:
9 | """Normalizes a numpy array of points"""
10 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
11 | l2[l2 == 0] = 1e-13
12 | return a / np.expand_dims(l2, axis), l2
13 |
14 |
15 | def torch_normalized(v):
16 | return F.normalize(v), v.pow(2).sum(1).sqrt().unsqueeze(1)
17 |
18 |
19 | def euler_angles_to_rotation(xyz: List) -> torch.tensor:
20 | x, y, z = xyz
21 |
22 | R_X = torch.tensor(
23 | [
24 | [1.0, 0.0, 0.0],
25 | [0.0, torch.cos(x), -torch.sin(x)],
26 | [0.0, torch.sin(x), torch.cos(x)],
27 | ]
28 | )
29 |
30 | R_Y = torch.tensor(
31 | [
32 | [torch.cos(y), 0.0, torch.sin(y)],
33 | [0.0, 1.0, 0.0],
34 | [-torch.sin(y), 0.0, torch.cos(y)],
35 | ]
36 | )
37 |
38 | R_Z = torch.tensor(
39 | [
40 | [torch.cos(z), -torch.sin(z), 0.0],
41 | [torch.sin(z), torch.cos(z), 0.0],
42 | [0.0, 0.0, 1.0],
43 | ]
44 | )
45 |
46 | return torch.mm(R_Z, torch.mm(R_Y, R_X))
47 |
48 |
49 | def rotation_matrix_from_vectors_np(vec1: np.array, vec2: np.array) -> np.array:
50 | """Find the rotation matrix that aligns vec1 to vec2
51 | :param vec1: A 3d "source" vector
52 | :param vec2: A 3d "destination" vector
53 | :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
54 | """
55 | a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (
56 | vec2 / np.linalg.norm(vec2)
57 | ).reshape(3)
58 | v = np.cross(a, b)
59 | c = np.dot(a, b)
60 | s = np.linalg.norm(v)
61 |
62 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
63 | rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2))
64 | return rotation_matrix
65 |
66 |
67 | def rotation_matrix_from_vectors_torch(vec1, vec2):
68 | a, b = F.normalize(vec1, dim=0), F.normalize(vec2, dim=0)
69 | v = torch.cross(a, b)
70 | c = torch.dot(a, b)
71 | s = torch.linalg.norm(v)
72 |
73 | kmat = torch.tensor([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
74 | rotation_matrix = torch.eye(3) + kmat + kmat.matmul(kmat) * ((1 - c) / (s**2))
75 |
76 | return rotation_matrix
77 |
78 |
79 | def make_transformation_matrix(rotation, translation):
80 | return torch.vstack(
81 | (torch.hstack((rotation, translation)), torch.tensor([0.0, 0.0, 0.0, 1.0]))
82 | )
83 |
84 |
85 | def bb_filter(
86 | points,
87 | min_x=-np.inf,
88 | max_x=np.inf,
89 | min_y=-np.inf,
90 | max_y=np.inf,
91 | min_z=-np.inf,
92 | max_z=np.inf,
93 | ):
94 | bound_x = np.logical_and(points[:, 0] >= min_x, points[:, 0] < max_x)
95 | bound_y = np.logical_and(points[:, 1] >= min_y, points[:, 1] < max_y)
96 | bound_z = np.logical_and(points[:, 2] >= min_z, points[:, 2] < max_z)
97 |
98 | bb_filter = np.logical_and(np.logical_and(bound_x, bound_y), bound_z)
99 |
100 | return bb_filter
101 |
102 |
103 | """
104 | def cube_filter(points, center, cube_size):
105 |
106 | min_x = center[0] - cube_size / 2
107 | max_x = center[0] + cube_size / 2
108 | min_y = center[1] - cube_size / 2
109 | max_y = center[1] + cube_size / 2
110 | min_z = center[2] - cube_size / 2
111 | max_z = center[2] + cube_size / 2
112 |
113 | return bb_filter(points, min_x, max_x, min_y, max_y, min_z, max_z)
114 | """
115 |
116 |
117 | def np_bb_filter(
118 | points,
119 | min_x=-np.inf,
120 | max_x=np.inf,
121 | min_y=-np.inf,
122 | max_y=np.inf,
123 | min_z=-np.inf,
124 | max_z=np.inf,
125 | ):
126 | bound_x = np.logical_and(points[:, 0] >= min_x, points[:, 0] < max_x)
127 | bound_y = np.logical_and(points[:, 1] >= min_y, points[:, 1] < max_y)
128 | bound_z = np.logical_and(points[:, 2] >= min_z, points[:, 2] < max_z)
129 |
130 | bb_filter = np.logical_and(np.logical_and(bound_x, bound_y), bound_z)
131 |
132 | return bb_filter
133 |
134 |
135 | def torch_bb_filter(points, min_x, max_x, min_y, max_y, min_z, max_z):
136 | bound_x = torch.logical_and(points[:, 0] >= min_x, points[:, 0] < max_x)
137 | bound_y = torch.logical_and(points[:, 1] >= min_y, points[:, 1] < max_y)
138 | bound_z = torch.logical_and(points[:, 2] >= min_z, points[:, 2] < max_z)
139 |
140 | bb_filter = torch.logical_and(torch.logical_and(bound_x, bound_y), bound_z)
141 |
142 | return bb_filter
143 |
144 |
145 | def cube_filter(points, center, cube_size):
146 | min = center - (cube_size / 2)
147 | max = center + (cube_size / 2)
148 |
149 | if type(center) == np.array:
150 | return np_bb_filter(points, min[0], max[0], min[1], max[1], min[2], max[2])
151 |
152 | max = max.to(points.device)
153 | min = min.to(points.device)
154 |
155 | return torch_bb_filter(points, min[0], max[0], min[1], max[1], min[2], max[2])
156 |
157 |
158 | def vertex_dirs(points):
159 | d = points[1:] - points[:-1]
160 | d = d / np.linalg.norm(d)
161 |
162 | smooth = (d[1:] + d[:-1]) * 0.5
163 | dirs = np.concatenate([np.array(d[0:1]), smooth, np.array(d[-2:-1])])
164 |
165 | return dirs / np.linalg.norm(dirs, axis=1, keepdims=True)
166 |
167 |
168 | def random_unit(dtype=np.float32):
169 | x = np.random.randn(3).astype(dtype)
170 | return x / np.linalg.norm(x)
171 |
172 |
173 | def make_tangent(d, n):
174 | t = np.cross(d, n)
175 | t /= np.linalg.norm(t, axis=-1, keepdims=True)
176 | return np.cross(t, d)
177 |
178 |
179 | def gen_tangents(dirs, t):
180 | tangents = []
181 |
182 | for dir in dirs:
183 | t = make_tangent(dir, t)
184 | tangents.append(t)
185 |
186 | return np.stack(tangents)
187 |
--------------------------------------------------------------------------------
/smart_tree/o3d_abstractions/geometries.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 | import torch
4 |
5 | # from o3d_abstractions.viewimport o3d_view_geometries
6 | from smart_tree.util.maths import (gen_tangents, random_unit,
7 | vertex_dirs)
8 |
9 |
10 | def o3d_mesh(verts, tris):
11 | return o3d.geometry.TriangleMesh(
12 | o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(tris)
13 | ).compute_triangle_normals()
14 |
15 |
16 | def o3d_merge_meshes(meshes, colourize=False):
17 | if colourize:
18 | for m in meshes:
19 | m.paint_uniform_color(np.random.rand(3))
20 |
21 | mesh = meshes[0]
22 | for m in meshes[1:]:
23 | mesh += m
24 | return mesh
25 |
26 |
27 | def o3d_merge_linesets(line_sets, colour=[0, 0, 0]):
28 | sizes = [np.asarray(ls.points).shape[0] for ls in line_sets]
29 | offsets = np.cumsum([0] + sizes)
30 |
31 | points = np.concatenate([ls.points for ls in line_sets])
32 | idxs = np.concatenate([ls.lines + offset for ls, offset in zip(line_sets, offsets)])
33 |
34 | return o3d.geometry.LineSet(
35 | o3d.utility.Vector3dVector(points), o3d.utility.Vector2iVector(idxs)
36 | ).paint_uniform_color(colour)
37 |
38 |
39 | def points_to_edge_idx(points):
40 | idx = torch.arange(points.shape[0] - 1)
41 | return torch.column_stack((idx, idx + 1))
42 |
43 |
44 | def o3d_sphere(xyz, radius, colour=(1, 0, 0)):
45 | return (
46 | o3d.geometry.TriangleMesh.create_sphere(radius)
47 | .translate(np.asarray(xyz))
48 | .paint_uniform_color(colour)
49 | )
50 |
51 |
52 | def o3d_spheres(xyzs, radii, colour=None, colours=None):
53 | spheres = [o3d_sphere(xyz, r) for xyz, r in zip(xyzs, radii)]
54 | return paint_o3d_geoms(spheres, colour, colours)
55 |
56 |
57 | def o3d_line_set(vertices, edges, colour=None):
58 | line_set = o3d.geometry.LineSet(
59 | o3d.utility.Vector3dVector(vertices), o3d.utility.Vector2iVector(edges)
60 | )
61 | if colour is not None:
62 | return line_set.paint_uniform_color(colour)
63 | return line_set
64 |
65 |
66 | def o3d_line_sets(vertices, edges):
67 | return [o3d_line_set(v, e) for v, e in zip(vertices, edges)]
68 |
69 |
70 | def o3d_path(vertices, colour=None):
71 | edge_idx = points_to_edge_idx(vertices)
72 | if colour is not None:
73 | return o3d_line_set(vertices, edge_idx).paint_uniform_color(colour)
74 | return o3d_line_set(vertices, edge_idx)
75 |
76 |
77 | def o3d_paths(vertices):
78 | return [o3d_path(v) for v in vertices]
79 |
80 |
81 | def o3d_merge_clouds(points_clds):
82 | points = np.concatenate([np.asarray(pcd.points) for pcd in points_clds])
83 | colors = np.concatenate([np.asarray(pcd.colors) for pcd in points_clds])
84 |
85 | return o3d_cloud(points, colours=colors)
86 |
87 |
88 | def o3d_cloud(points, colour=None, colours=None, normals=None):
89 | cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))
90 | if normals is not None:
91 | cloud.normals = o3d.utility.Vector3dVector(normals)
92 | if colour is not None:
93 | return cloud.paint_uniform_color(colour)
94 | elif colours is not None:
95 | cloud.colors = o3d.utility.Vector3dVector(colours)
96 | return cloud
97 |
98 | return cloud.paint_uniform_color([1, 0, 0])
99 |
100 |
101 | def class_label_o3d_cloud(cloud, class_labels, cmap=[]):
102 | cloud.colors = o3d.utility.Vector3dVector(np.asarray(cmap)[class_labels])
103 | return cloud
104 |
105 |
106 | def o3d_clouds(batch_points, colour=None, colours=None, p_colours=None):
107 | # KN#
108 | # colour -> paint all the same colour
109 | # colours -> colour for each point cloud
110 | # p_colours -> points for each point in each cloud
111 | if colour is not None:
112 | return [o3d_cloud(p, colour) for p in zip(batch_points)]
113 |
114 | if colours is not None:
115 | return [o3d_cloud(p, c) for p, c in zip(batch_points, colours)]
116 |
117 | if p_colours is not None:
118 | return [o3d_cloud(p, colours=c) for p, c in zip(batch_points, p_colours)]
119 |
120 | return [o3d_cloud(p, n) for p, n in zip(batch_points)]
121 |
122 |
123 | def o3d_voxel_grid(
124 | width: float,
125 | depth: float,
126 | height: float,
127 | voxel_size: float,
128 | origin=np.asarray([0, 0, 0]),
129 | colour=np.asarray([1, 1, 0]),
130 | ):
131 | return o3d.geometry.VoxelGrid.create_dense(
132 | origin, colour, voxel_size, width, depth, height
133 | )
134 |
135 |
136 | def o3d_cylinder(radius, length, colour=(1, 0, 0)):
137 | return o3d.geometry.TriangleMesh.create_cylinder(
138 | radius, length
139 | ).paint_uniform_color(colour)
140 |
141 |
142 | def o3d_cylinders(radii, length, colour=None, colours=None):
143 | cylinders = [o3d_cylinder(r, l, colour) for r, l in zip(radii, length)]
144 | return paint_o3d_geoms(cylinders, colour, colours)
145 |
146 |
147 | def paint_o3d_geoms(geometries, colour=None, colours=None):
148 | if colours is not None:
149 | return [
150 | geom.paint_uniform_color(colours[i]) for i, geom in enumerate(geometries)
151 | ]
152 | elif colour is not None:
153 | return [geom.paint_uniform_color(colour) for geom in geometries]
154 | return geometries
155 |
156 |
157 | def unit_circle(n):
158 | a = np.linspace(0, 2 * np.pi, n + 1)[:-1]
159 | return np.stack([np.sin(a), np.cos(a)], axis=1)
160 |
161 |
162 | def cylinder_triangles(m, n):
163 | tri1 = np.array([0, 1, 2])
164 | tri2 = np.array([2, 3, 0])
165 |
166 | v0 = np.arange(m)
167 | v1 = (v0 + 1) % m
168 | v2 = v1 + m
169 | v3 = v0 + m
170 |
171 | edges = np.stack([v0, v1, v2, v3], axis=1)
172 |
173 | segments = np.arange(n - 1) * m
174 | edges = edges.reshape(1, *edges.shape) + segments.reshape(n - 1, 1, 1)
175 |
176 | edges = edges.reshape(-1, 4)
177 | return np.concatenate([edges[:, tri1], edges[:, tri2]])
178 |
179 |
180 | def tube_vertices(points, radii, n=10):
181 | circle = unit_circle(n).astype(np.float32)
182 |
183 | dirs = vertex_dirs(points)
184 | t = gen_tangents(dirs, random_unit())
185 |
186 | b = np.stack([t, np.cross(t, dirs)], axis=1)
187 | b = b * radii.reshape(-1, 1, 1)
188 |
189 | return np.einsum("bdx,md->bmx", b, circle) + points.reshape(points.shape[0], 1, 3)
190 |
191 |
192 | def o3d_lines_between_clouds(cld1, cld2):
193 | pts1 = np.asarray(cld1.points)
194 | pts2 = np.asarray(cld2.points)
195 |
196 | interweaved = np.hstack((pts1, pts2)).reshape(-1, 3)
197 | return o3d_line_set(
198 | interweaved, np.arange(0, min(pts1.shape[0], pts2.shape[0]) * 2).reshape(-1, 2)
199 | )
200 |
201 |
202 | def o3d_tube_mesh(points, radii, colour=(1, 0, 0), n=10):
203 | points = tube_vertices(points, radii, n)
204 |
205 | n, m, _ = points.shape
206 | indexes = cylinder_triangles(m, n)
207 |
208 | mesh = o3d_mesh(points.reshape(-1, 3), indexes)
209 | mesh.compute_vertex_normals()
210 |
211 | return mesh.paint_uniform_color(colour)
212 |
213 |
214 | def sample_o3d_lineset(lineset, sample_rate):
215 | edges = np.asarray(lineset.lines)
216 | xyz = np.asarray(lineset.points)
217 |
218 | pts, radius = [], []
219 |
220 | for i, edge in enumerate(edges):
221 | start = xyz[edge[0]]
222 | end = xyz[edge[1]]
223 |
224 | v = end - start
225 | length = np.linalg.norm(v)
226 | direction = v / length
227 | num_points = np.ceil(length / sample_rate)
228 |
229 | if int(num_points) > 0.0:
230 | spaced_points = np.arange(
231 | 0, float(length), step=float(length / num_points)
232 | ).reshape(-1, 1)
233 | pts.append(start + direction * spaced_points)
234 |
235 | return np.concatenate(pts, axis=0)
236 |
--------------------------------------------------------------------------------
/smart_tree/model/train.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import logging
3 | import math
4 | import os
5 | from pathlib import Path
6 | from typing import List
7 |
8 | import hydra
9 | import numpy as np
10 | import torch
11 | from hydra.core.hydra_config import HydraConfig
12 | from hydra.utils import instantiate
13 | from omegaconf import DictConfig, OmegaConf
14 | from tqdm import tqdm
15 |
16 | import wandb
17 | from smart_tree.data_types.cloud import Cloud
18 | from smart_tree.o3d_abstractions.camera import Renderer
19 |
20 | from .helper import get_batch, model_output_to_labelled_clds
21 | from .tracker import Tracker
22 |
23 |
24 | def train_epoch(
25 | data_loader,
26 | model,
27 | optimizer,
28 | loss_fn,
29 | fp16=False,
30 | scaler=None,
31 | device=torch.device("cuda"),
32 | ):
33 | tracker = Tracker()
34 |
35 | for sp_input, targets, mask, fn in tqdm(
36 | get_batch(data_loader, device, fp16),
37 | desc="Batch",
38 | leave=False,
39 | total=math.ceil(len(data_loader) / data_loader.batch_size),
40 | ):
41 | preds = model.forward(sp_input)
42 |
43 | loss = loss_fn(preds, targets, mask)
44 |
45 | if fp16:
46 | assert sum(loss.values()).dtype is torch.float32
47 | scaler.scale(sum(loss.values())).backward()
48 | scaler.step(optimizer)
49 | scaler.update()
50 | scale = scaler.get_scale()
51 | else:
52 | (sum(loss.values())).backward()
53 | optimizer.step()
54 |
55 | optimizer.zero_grad()
56 | tracker.update(loss)
57 |
58 | return tracker, scaler
59 |
60 |
61 | @torch.no_grad()
62 | def eval_epoch(
63 | data_loader,
64 | model,
65 | loss_fn,
66 | fp16=False,
67 | device=torch.device("cuda"),
68 | ):
69 | tracker = Tracker()
70 | model.eval()
71 |
72 | for sp_input, targets, mask, fn in tqdm(
73 | get_batch(data_loader, device, fp16),
74 | desc="Evaluating",
75 | leave=False,
76 | total=math.ceil(len(data_loader) / data_loader.batch_size),
77 | ):
78 | preds = model.forward(sp_input)
79 | loss = loss_fn(preds, targets, mask)
80 | tracker.update(loss)
81 | model.train()
82 |
83 | return tracker
84 |
85 |
86 | @torch.no_grad()
87 | def capture_epoch(
88 | data_loader,
89 | model,
90 | cmap,
91 | fp16=False,
92 | device=torch.device("cuda"),
93 | ):
94 | model.eval()
95 | captures = []
96 |
97 | for sp_input, targets, mask, fn in tqdm(
98 | get_batch(data_loader, device, fp16),
99 | desc="Capturing Outputs",
100 | leave=False,
101 | total=math.ceil(len(data_loader) / data_loader.batch_size),
102 | ):
103 | model_output = model.forward(sp_input)
104 |
105 | labelled_clouds = model_output_to_labelled_clds(
106 | sp_input,
107 | model_output,
108 | cmap,
109 | fn,
110 | )
111 |
112 | model.train()
113 | return labelled_clouds
114 |
115 |
116 | @torch.no_grad()
117 | def capture_clouds(
118 | data_loader,
119 | model,
120 | cmap,
121 | fp16=False,
122 | device=torch.device("cuda"),
123 | ) -> List[Cloud]:
124 | model.eval()
125 | clouds = []
126 |
127 | for sp_input, targets, mask, filenames in tqdm(
128 | get_batch(data_loader, device, fp16),
129 | desc="Capturing Outputs",
130 | leave=False,
131 | total=math.ceil(len(data_loader) / data_loader.batch_size),
132 | ):
133 | model_output = model.forward(sp_input)
134 | clouds.extend(
135 | model_output_to_labelled_clds(
136 | sp_input,
137 | model_output,
138 | cmap,
139 | filenames,
140 | )
141 | )
142 |
143 | model.train()
144 | return clouds
145 |
146 |
147 | def capture_and_log(loader, model, epoch, wandb_run, cfg):
148 | clouds = capture_clouds(
149 | loader,
150 | model,
151 | cfg.cmap,
152 | fp16=cfg.fp16,
153 | )
154 |
155 | for cloud in tqdm(clouds, desc="Uploading Clouds", leave=False):
156 | seg_cloud = cloud.to_o3d_seg_cld(np.asarray(cfg.cmap))
157 | xyz_rgb = np.concatenate(
158 | (np.asarray(seg_cloud.points), np.asarray(seg_cloud.colors) * 255), -1
159 | )
160 | wandb_run.log(
161 | {f"{Path(cloud.filename).stem}": wandb.Object3D(xyz_rgb)},
162 | step=epoch,
163 | )
164 |
165 |
166 | @hydra.main(
167 | version_base=None,
168 | config_path="../conf",
169 | config_name="training.yaml",
170 | )
171 | def main(cfg: DictConfig):
172 | torch.manual_seed(42)
173 | torch.cuda.manual_seed_all(42)
174 | log = logging.getLogger(__name__)
175 |
176 | wandb.init(
177 | project=cfg.wandb.project,
178 | entity=cfg.wandb.entity,
179 | mode=cfg.wandb.mode,
180 | config=OmegaConf.to_container(cfg, resolve=[True | False]),
181 | )
182 | run_dir = HydraConfig.get().runtime.output_dir
183 | run_name = wandb.run.name
184 | log.info(f"Directory : {run_dir}")
185 | log.info(f"Machine: {os.uname()[1]}")
186 |
187 | renderer = Renderer(960, 540)
188 |
189 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
190 |
191 | train_loader = instantiate(cfg.train_data_loader)
192 | val_loader = instantiate(cfg.validation_data_loader)
193 | test_loader = instantiate(cfg.test_data_loader)
194 |
195 | log.info(f"Train Dataset Size: {len(train_loader.dataset)}")
196 | log.info(f"Validation Dataset Size: {len(val_loader.dataset)}")
197 | log.info(f"Test Dataset Size: {len(test_loader.dataset)}")
198 |
199 | # Model
200 | model = instantiate(cfg.model).to(device).train()
201 | torch.save(model, f"{run_dir}/{run_name}_model.pt")
202 |
203 | # Optimizer / Scheduler
204 | optimizer = instantiate(cfg.optimizer, params=model.parameters())
205 | scheduler = instantiate(cfg.scheduler, optimizer=optimizer)
206 | loss_fn = instantiate(cfg.loss_fn)
207 |
208 | # FP-16
209 | amp_ctx = torch.cuda.amp.autocast() if cfg.fp16 else contextlib.nullcontext()
210 | scaler = torch.cuda.amp.grad_scaler.GradScaler()
211 |
212 | epochs_no_improve = 0
213 | best_val_loss = torch.inf
214 |
215 | # Epochs
216 | for epoch in tqdm(range(0, cfg.num_epoch), leave=True, desc="Epoch"):
217 | with amp_ctx:
218 | training_tracker, scaler = train_epoch(
219 | train_loader,
220 | model,
221 | optimizer,
222 | loss_fn,
223 | scaler=scaler,
224 | fp16=cfg.fp16,
225 | )
226 |
227 | val_tracker = eval_epoch(
228 | val_loader,
229 | model,
230 | loss_fn,
231 | fp16=cfg.fp16,
232 | )
233 |
234 | test_tracker = eval_epoch(
235 | test_loader,
236 | model,
237 | loss_fn,
238 | fp16=cfg.fp16,
239 | )
240 |
241 | if (epoch + 1) % cfg.capture_output == 0:
242 | capture_and_log(test_loader, model, epoch, wandb.run, cfg)
243 | capture_and_log(val_loader, model, epoch, wandb.run, cfg)
244 |
245 | scheduler.step(val_tracker.total_loss) if cfg.lr_decay else None
246 |
247 | # Save Best Model
248 | if val_tracker.total_loss < best_val_loss:
249 | epochs_no_improve = 0
250 | best_val_loss = val_tracker.total_loss
251 | wandb.run.summary["Best Test Loss"] = best_val_loss
252 | torch.save(model.state_dict(), f"{run_dir}/{run_name}_model_weights.pt")
253 | log.info(f"Weights Saved at epoch: {epoch}")
254 | else:
255 | epochs_no_improve += 1
256 |
257 | if epochs_no_improve == cfg.early_stop_epoch and cfg.early_stop:
258 | log.info("Training Ended (Evaluation Test Score Not Improving)")
259 | break
260 |
261 | # log onto wandb...
262 | training_tracker.log("Training", epoch)
263 | val_tracker.log("Validation", epoch)
264 |
265 |
266 | if __name__ == "__main__":
267 | main()
268 |
--------------------------------------------------------------------------------
/smart_tree/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import List
4 |
5 | import torch
6 | import torch.utils.data
7 | from spconv.pytorch.utils import PointToVoxel
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 |
11 | from ..data_types.cloud import Cloud
12 | from ..model.sparse import batch_collate
13 | from ..util.file import load_cloud
14 | from ..util.maths import cube_filter
15 | from ..util.misc import at_least_2d
16 |
17 |
18 | class TreeDataset:
19 | def __init__(
20 | self,
21 | voxel_size: int,
22 | json_path: Path,
23 | directory: Path,
24 | mode: str,
25 | input_features: List[str],
26 | target_features: List[str],
27 | augmentation=None,
28 | cache: bool = False,
29 | device=torch.device("cuda:0"),
30 | ):
31 | self.voxel_size = voxel_size
32 | self.mode = mode
33 | self.augmentation = augmentation
34 | self.directory = directory
35 | self.device = device
36 |
37 | self.input_features = input_features
38 | self.target_features = target_features
39 |
40 | assert Path(
41 | json_path
42 | ).is_file(), f"json metadata does not exist at '{json_path}'"
43 | json_data = json.load(open(json_path))
44 |
45 | if self.mode == "train":
46 | self.tree_paths = json_data["train"]
47 | elif self.mode == "validation":
48 | self.tree_paths = json_data["validation"]
49 | elif self.mode == "test":
50 | self.tree_paths = json_data["test"]
51 |
52 | missing = [
53 | path
54 | for path in self.tree_paths
55 | if not Path(f"{self.directory}/{path}").is_file()
56 | ]
57 |
58 | assert len(missing) == 0, f"Missing {len(missing)} files: {missing}"
59 |
60 | self.cache = {} if cache else None
61 | self.load_cloud = load_cloud if self.mode != "test" else load_cloud
62 |
63 | def load(self, filename) -> Cloud:
64 | if self.cache is None:
65 | return self.load_cloud(filename)
66 |
67 | if filename not in self.cache:
68 | self.cache[filename] = self.load_cloud(filename).pin_memory()
69 |
70 | return self.cache[filename]
71 |
72 | def __getitem__(self, idx):
73 | filename = Path(f"{self.directory}/{self.tree_paths[idx]}")
74 | cld = self.load(filename)
75 |
76 | try:
77 | return self.process_cloud(cld, self.tree_paths[idx])
78 | except Exception:
79 | print(f"Exception processing {filename} with {len(cld)} points")
80 | raise
81 |
82 | def process_cloud(self, cld: Cloud, filename: str):
83 | cld = cld.to_device(self.device)
84 |
85 | if self.augmentation != None:
86 | cld = self.augmentation(cld)
87 |
88 | xyzmin, _ = torch.min(cld.xyz, axis=0)
89 | xyzmax, _ = torch.max(cld.xyz, axis=0)
90 |
91 | # data = cld.cat()
92 | input_features = torch.cat(
93 | [at_least_2d(getattr(cld, attr)) for attr in self.input_features], dim=1
94 | )
95 |
96 | target_features = torch.cat(
97 | [at_least_2d(getattr(cld, attr)) for attr in self.target_features], dim=1
98 | )
99 |
100 | data = torch.cat([input_features, target_features], dim=1)
101 |
102 | assert (
103 | data.shape[0] > 0
104 | ), f"Empty cloud after augmentation: {self.tree_paths[idx]}"
105 |
106 | surface_voxel_generator = PointToVoxel(
107 | vsize_xyz=[self.voxel_size] * 3,
108 | coors_range_xyz=[
109 | xyzmin[0],
110 | xyzmin[1],
111 | xyzmin[2],
112 | xyzmax[0],
113 | xyzmax[1],
114 | xyzmax[2],
115 | ],
116 | num_point_features=data.shape[1],
117 | max_num_voxels=data.shape[0],
118 | max_num_points_per_voxel=1,
119 | device=data.device,
120 | )
121 |
122 | feats, coords, _, _ = surface_voxel_generator.generate_voxel_with_id(data)
123 |
124 | indice = torch.zeros(
125 | (coords.shape[0], 1),
126 | dtype=coords.dtype,
127 | device=feats.device,
128 | )
129 | coords = torch.cat((indice, coords), dim=1)
130 |
131 | feats = feats.squeeze(1)
132 | coords = coords.squeeze(1)
133 | loss_mask = torch.ones(feats.shape[0], dtype=torch.bool, device=feats.device)
134 |
135 | input_feats = feats[:, : input_features.shape[1]]
136 | target_feats = feats[:, input_features.shape[1] :]
137 |
138 | return (input_feats, target_feats), coords, loss_mask, filename
139 |
140 | def __len__(self):
141 | return len(self.tree_paths)
142 |
143 |
144 | class SingleTreeInference:
145 | def __init__(
146 | self,
147 | cloud: Cloud,
148 | voxel_size: float,
149 | block_size: float = 4,
150 | buffer_size: float = 0.4,
151 | min_points=20,
152 | file_name=None,
153 | device=torch.device("cuda:0"),
154 | ):
155 | self.cloud = cloud
156 |
157 | self.voxel_size = voxel_size
158 | self.block_size = block_size
159 | self.buffer_size = buffer_size
160 | self.min_points = min_points
161 | self.device = device
162 | self.file_name = file_name
163 |
164 | self.compute_blocks()
165 |
166 | def compute_blocks(self):
167 | self.xyz_quantized = torch.div(
168 | self.cloud.xyz, self.block_size, rounding_mode="floor"
169 | )
170 | self.block_ids, pnt_counts = torch.unique(
171 | self.xyz_quantized, return_counts=True, dim=0
172 | )
173 |
174 | # Remove blocks that have less than specified amount of points...
175 | self.block_ids = self.block_ids[pnt_counts > self.min_points]
176 | self.block_centres = (self.block_ids * self.block_size) + (self.block_size / 2)
177 |
178 | self.clouds: List[Cloud] = []
179 |
180 | for centre in tqdm(self.block_centres, desc="Computing blocks...", leave=False):
181 | mask = cube_filter(
182 | self.cloud.xyz,
183 | centre,
184 | self.block_size + (self.buffer_size * 2),
185 | )
186 | block_cloud = self.cloud.filter(mask).to_device(torch.device("cpu"))
187 |
188 | self.clouds.append(block_cloud)
189 |
190 | self.block_centres = self.block_centres.to(torch.device("cpu"))
191 |
192 | def __getitem__(self, idx):
193 | block_centre = self.block_centres[idx]
194 | cloud: Cloud = self.clouds[idx]
195 |
196 | xyzmin, _ = torch.min(cloud.xyz, axis=0)
197 | xyzmax, _ = torch.max(cloud.xyz, axis=0)
198 |
199 | surface_voxel_generator = PointToVoxel(
200 | vsize_xyz=[self.voxel_size] * 3,
201 | coors_range_xyz=[
202 | xyzmin[0],
203 | xyzmin[1],
204 | xyzmin[2],
205 | xyzmax[0],
206 | xyzmax[1],
207 | xyzmax[2],
208 | ],
209 | num_point_features=6,
210 | max_num_voxels=len(cloud),
211 | max_num_points_per_voxel=1,
212 | )
213 |
214 | feats, coords, _, voxel_id_tv = surface_voxel_generator.generate_voxel_with_id(
215 | torch.cat((cloud.xyz, cloud.rgb), dim=1).contiguous()
216 | ) #
217 |
218 | indice = torch.zeros((coords.shape[0], 1), dtype=torch.int32)
219 | coords = torch.cat((indice, coords), dim=1)
220 |
221 | feats = feats.squeeze(1)
222 | coords = coords.squeeze(1)
223 |
224 | mask = cube_filter(feats[:, :3], block_centre, self.block_size)
225 |
226 | return feats, coords, mask, self.file_name
227 |
228 | def __len__(self):
229 | return len(self.clouds)
230 |
231 |
232 | def load_dataloader(
233 | cloud: Cloud,
234 | voxel_size: float,
235 | block_size: float,
236 | buffer_size: float,
237 | num_workers: float,
238 | batch_size: float,
239 | ):
240 | dataset = SingleTreeInference(cloud, voxel_size, block_size, buffer_size)
241 |
242 | return DataLoader(dataset, batch_size, num_workers, collate_fn=batch_collate)
243 |
--------------------------------------------------------------------------------
/smart_tree/data_types/tree.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from dataclasses import dataclass
3 | from typing import Dict, List
4 |
5 | import open3d as o3d
6 | import torch
7 | import torch.nn.functional as F
8 | from torchtyping import TensorType
9 | from tqdm import tqdm
10 |
11 | from ..o3d_abstractions.geometries import (o3d_merge_linesets,
12 | o3d_merge_meshes)
13 | from ..o3d_abstractions.visualizer import o3d_viewer
14 | from ..util.misc import flatten_list, merge_dictionaries
15 | from ..util.queries import pts_to_nearest_tube_gpu
16 | from .branch import BranchSkeleton
17 | from .tube import Tube
18 |
19 |
20 | @dataclass
21 | class TreeSkeleton:
22 | _id: int
23 | branches: Dict[int, BranchSkeleton]
24 |
25 | def __post_init__(self):
26 | self.colour = torch.rand(3)
27 |
28 | def __len__(self) -> int:
29 | return len(self.branches)
30 |
31 | def __str__(self) -> str:
32 | return f"Tree Skeleton ({self._id}) has {len(self)} branches..."
33 |
34 | def to_o3d_linesets(self) -> List[o3d.geometry.LineSet]:
35 | return [b.to_o3d_lineset() for b in self.branches.values()]
36 |
37 | def to_o3d_lineset(self, colour=(0, 0, 0)) -> o3d.geometry.LineSet:
38 | return o3d_merge_linesets(self.to_o3d_linesets(), colour=colour)
39 |
40 | def to_o3d_tubes(self) -> List[o3d.geometry.TriangleMesh]:
41 | return [b.to_o3d_tube() for b in self.branches.values()]
42 |
43 | def to_o3d_tube(self) -> o3d.geometry.TriangleMesh:
44 | return o3d_merge_meshes(self.to_o3d_tubes())
45 |
46 | def view(self):
47 | o3d_viewer([self.to_o3d_lineset(), self.to_o3d_tube()])
48 |
49 | def to_tubes(self) -> List[Tube]:
50 | return flatten_list([b.to_tubes() for b in self.branches.values()])
51 |
52 | def sample_skeleton(self, spacing):
53 | return sample_tubes(self.to_tubes(), spacing)
54 |
55 | def to_device(self, device):
56 | for branch_id, branch in self.branches():
57 | return BranchSkeleton(
58 | _id,
59 | parent_id,
60 | self.xyz.to(device),
61 | self.radii.to(device),
62 | child_id,
63 | )
64 |
65 | def point_to_skeleton(self):
66 | tubes = self.to_tubes()
67 |
68 | def point_to_tube(pt):
69 | return pts_to_nearest_tube_gpu(pt.reshape(-1, 3), tubes)
70 |
71 | return point_to_tube # v, idx, _
72 |
73 | def repair(self):
74 | """skeletons are not connected between branches.
75 | this function connects the branches to their parent branches by finding
76 | the nearest point on the parent branch."""
77 |
78 | branch_ids = [branch._id for branch in self.branches.values()]
79 |
80 | for branch in self.branches.values():
81 | if branch.parent_id not in branch_ids:
82 | continue
83 |
84 | parent_branch = self.branches[branch.parent_id]
85 | tubes = parent_branch.to_tubes()
86 |
87 | v, idx, _ = pts_to_nearest_tube_gpu(branch.xyz[0].reshape(-1, 3), tubes)
88 |
89 | connection_pt = branch.xyz[0].reshape(-1, 3).cpu() + v[0].cpu()
90 |
91 | branch.xyz = torch.cat((connection_pt, branch.xyz))
92 | branch.radii = torch.cat((branch.radii[[0]], branch.radii))
93 |
94 | def prune(self, min_radius: float, min_length: float, root_id=None):
95 | """
96 | If a branch doesn't meet the initial radius threshold or length threshold we want to remove it and all
97 | it's predecessors...
98 | minimum_radius: some point of the branch must be above this to not remove the branch
99 | length: the total lenght of the branch must be greater than this point
100 | """
101 | root_id = min(self.branches.keys()) if root_id == None else root_id
102 |
103 | keep = {root_id: self.branches[root_id]}
104 | remove = {}
105 |
106 | for branch_id, branch in tqdm(
107 | self.branches.items(),
108 | leave=False,
109 | desc="Pruning Branches",
110 | ):
111 | if branch.parent_id not in keep and branch._id != root_id:
112 | remove[branch_id] = branch
113 | elif branch.length < min_length:
114 | remove[branch_id] = branch
115 | elif branch.initial_radius < min_radius:
116 | remove[branch_id] = branch
117 | else:
118 | keep[branch_id] = branch
119 |
120 | self.branches = keep
121 | return TreeSkeleton(0, remove)
122 |
123 | def smooth(self, kernel_size=5):
124 | """
125 | Smooths the skeleton radius.
126 | """
127 | kernel = torch.ones(1, 1, kernel_size) / kernel_size
128 | for branch in self.branches.values():
129 | if branch.radii.shape[0] > kernel_size:
130 | branch.radii = F.conv1d(
131 | branch.radii.reshape(1, 1, -1),
132 | kernel,
133 | padding="same",
134 | ).reshape(-1)
135 |
136 | @property
137 | def length(self) -> TensorType[1]:
138 | return torch.sum(torch.tensor([b.length for b in self.branches.values()]))
139 |
140 | @property
141 | def biggest_radius_idx(self) -> TensorType[1]:
142 | return torch.argmax(self.radii)
143 |
144 | @property
145 | def key_branch_with_biggest_radius(self) -> TensorType[1]:
146 | """Returns the key of the branch with the biggest radius"""
147 | biggest_branch_radius = 0
148 | for key, branch in self.branches.items():
149 | if branch.biggest_radius > biggest_branch_radius:
150 | biggest_branch_radius = branch.biggest_radius
151 | biggest_branch_radius_key = key
152 |
153 | return biggest_branch_radius_key
154 |
155 | @property
156 | def max_branch_id(self):
157 | return max(self.branches.keys())
158 |
159 |
160 | @dataclass
161 | class DisjointTreeSkeleton:
162 | skeletons: List[TreeSkeleton]
163 |
164 | def prune(self, min_radius, min_length):
165 | self.skeletons[0].prune(
166 | min_radius=min_radius,
167 | min_length=min_length,
168 | ) # Can only prune the first skeleton as we don't know the root points for all the other skeletons...
169 |
170 | def repair(self):
171 | for skeleton in self.skeletons:
172 | skeleton.repair()
173 |
174 | def smooth(self, kernel_size=7):
175 | for skeleton in self.skeletons:
176 | skeleton.smooth(kernel_size=kernel_size)
177 |
178 | def to_o3d_lineset(self):
179 | return o3d_merge_linesets(
180 | [s.to_o3d_lineset().paint_uniform_color(s.colour) for s in self.skeletons]
181 | )
182 |
183 | def to_o3d_tube(self, colour=True):
184 | if colour:
185 | skeleton_tubes = [
186 | skel.to_o3d_tube().paint_uniform_color(skel.colour)
187 | for skel in self.skeletons
188 | ]
189 | else:
190 | skeleton_tubes = [s.to_o3d_tube() for s in self.skeletons]
191 |
192 | return o3d_merge_meshes(skeleton_tubes)
193 |
194 | def view(self):
195 | o3d_viewer([self.to_o3d_lineset(), self.to_o3d_tube()])
196 |
197 | def to_pickle(self, path):
198 | with open("disjoint_skeleton.pkl", "wb") as pickle_file:
199 | pickle.dump(self, pickle_file)
200 |
201 | @staticmethod
202 | def from_pickle(path):
203 | with open(f"{path}", "rb") as pickle_file:
204 | return pickle.load(pickle_file)
205 |
206 |
207 | def connect(
208 | skeleton_1: TreeSkeleton,
209 | skeleton_1_child_branch_key: int,
210 | skeleton_1_child_vert_idx: int,
211 | skeleton_2: TreeSkeleton,
212 | skeleton_2_child_branch_key: int,
213 | skeleton_2_child_vert_idx: int,
214 | ) -> TreeSkeleton:
215 | # This is bundy, only visually gives appearance of connection...
216 | # Need to do some more processing to actually connect the skeletons...
217 |
218 | parent_branch = skeleton_1.branches[skeleton_1_parent_branch_key]
219 | child_branch = skeleton_2.branches[skeleton_2_child_branch_key]
220 |
221 | child_branch.parent_id = skeleton_1_parent_branch_key
222 | connection_pt = parent_branch.xyz[skeleton_1_parent_vert_idx]
223 |
224 | child_branch.xyz = torch.cat((connection_pt.unsqueeze(0), child_branch.xyz))
225 | child_branch.radii = torch.cat((child_branch.radii[[0]], child_branch.radii))
226 |
227 | for key, branch in skeleton_2.branches.items():
228 | branch._id += skeleton_1.max_branch_id
229 |
230 | if branch.parent_id != -1:
231 | branch.parent_id += skeleton_1.max_branch_id
232 |
233 | return TreeSkeleton(0, merge_dictionaries(skeleton_1.branches, skeleton_2.branches))
234 |
--------------------------------------------------------------------------------
/smart_tree/model/model_blocks.py:
--------------------------------------------------------------------------------
1 | import spconv.pytorch as spconv
2 | import torch
3 | import torch.cuda.amp
4 | import torch.nn as nn
5 | from spconv.pytorch import SparseModule
6 |
7 |
8 | class SubMConvBlock(SparseModule):
9 | def __init__(
10 | self,
11 | input_channels,
12 | output_channels,
13 | kernel_size,
14 | norm_fn,
15 | activation_fn,
16 | stride=1,
17 | padding=1,
18 | algo=spconv.ConvAlgo.Native,
19 | bias=False,
20 | ):
21 | super().__init__()
22 |
23 | self.sequence = spconv.SparseSequential(
24 | spconv.SubMConv3d(
25 | in_channels=input_channels,
26 | out_channels=output_channels,
27 | kernel_size=kernel_size,
28 | stride=stride,
29 | padding=padding,
30 | bias=bias,
31 | algo=algo,
32 | ),
33 | norm_fn(output_channels),
34 | activation_fn(),
35 | )
36 |
37 | def forward(self, input):
38 | return self.sequence(input)
39 |
40 |
41 | class EncoderBlock(SparseModule):
42 | def __init__(
43 | self,
44 | input_channels,
45 | output_channels,
46 | kernel_size,
47 | norm_fn,
48 | activation_fn,
49 | stride=2,
50 | padding=1,
51 | key=None,
52 | algo=spconv.ConvAlgo.Native,
53 | bias=False,
54 | ):
55 | super().__init__()
56 |
57 | self.sequence = spconv.SparseSequential(
58 | spconv.SparseConv3d(
59 | input_channels,
60 | output_channels,
61 | kernel_size=kernel_size,
62 | stride=stride,
63 | indice_key=key,
64 | algo=algo,
65 | bias=bias,
66 | padding=padding,
67 | ),
68 | norm_fn(output_channels),
69 | activation_fn(),
70 | )
71 |
72 | def forward(self, input):
73 | return self.sequence(input)
74 |
75 |
76 | class DecoderBlock(SparseModule):
77 | def __init__(
78 | self,
79 | input_channels,
80 | output_channels,
81 | kernel_size,
82 | norm_fn,
83 | activation_fn,
84 | key=None,
85 | algo=spconv.ConvAlgo.Native,
86 | bias=False,
87 | ):
88 | super().__init__()
89 |
90 | self.sequence = spconv.SparseSequential(
91 | spconv.SparseInverseConv3d(
92 | input_channels,
93 | output_channels,
94 | kernel_size,
95 | indice_key=key,
96 | algo=algo,
97 | bias=bias,
98 | ),
99 | norm_fn(output_channels),
100 | activation_fn(),
101 | )
102 |
103 | def forward(self, input):
104 | return self.sequence(input)
105 |
106 |
107 | class ResBlock(nn.Module):
108 | def __init__(
109 | self,
110 | input_channels,
111 | output_channels,
112 | kernel_size,
113 | norm_fn,
114 | activation_fn,
115 | algo=spconv.ConvAlgo.Native,
116 | bias=False,
117 | ):
118 | super().__init__()
119 |
120 | if input_channels == output_channels:
121 | self.identity = spconv.SparseSequential(nn.Identity())
122 | else:
123 | self.identity = spconv.SparseSequential(
124 | spconv.SubMConv3d(
125 | input_channels,
126 | output_channels,
127 | kernel_size=1,
128 | padding=1,
129 | bias=False,
130 | algo=algo,
131 | )
132 | )
133 |
134 | self.sequence = spconv.SparseSequential(
135 | spconv.SubMConv3d(
136 | input_channels, output_channels, kernel_size, bias=False, algo=algo
137 | ),
138 | norm_fn(output_channels),
139 | activation_fn(),
140 | spconv.SubMConv3d(
141 | output_channels, output_channels, kernel_size, bias=False, algo=algo
142 | ),
143 | norm_fn(output_channels),
144 | )
145 |
146 | self.activation_fn = spconv.SparseSequential(activation_fn())
147 |
148 | def forward(self, input):
149 | identity = spconv.SparseConvTensor(
150 | input.features, input.indices, input.spatial_shape, input.batch_size
151 | )
152 | output = self.sequence(input)
153 | output = output.replace_feature(
154 | output.features + self.identity(identity).features
155 | )
156 | return self.activation_fn(output)
157 |
158 |
159 | class UBlock(nn.Module):
160 | def __init__(
161 | self,
162 | n_planes,
163 | norm_fn,
164 | activation_fn,
165 | kernel_size=3,
166 | key_id=1,
167 | algo=spconv.ConvAlgo.Native,
168 | bias=False,
169 | ):
170 | super().__init__()
171 |
172 | self.n_planes = n_planes
173 | self.Head = ResBlock(
174 | n_planes[0],
175 | n_planes[0],
176 | kernel_size,
177 | norm_fn,
178 | activation_fn,
179 | algo=algo,
180 | bias=bias,
181 | )
182 |
183 | if len(n_planes) > 1:
184 | self.Encode = EncoderBlock(
185 | n_planes[0],
186 | n_planes[1],
187 | kernel_size,
188 | norm_fn,
189 | activation_fn,
190 | stride=2,
191 | key=key_id,
192 | algo=algo,
193 | bias=bias,
194 | )
195 | self.U = UBlock(
196 | n_planes[1:],
197 | norm_fn,
198 | activation_fn,
199 | kernel_size,
200 | key_id + 1,
201 | algo,
202 | bias,
203 | )
204 | self.Decode = DecoderBlock(
205 | n_planes[1],
206 | n_planes[0],
207 | kernel_size,
208 | norm_fn,
209 | activation_fn,
210 | key=key_id,
211 | algo=algo,
212 | bias=bias,
213 | )
214 | self.Tail = ResBlock(
215 | n_planes[0] * 2,
216 | n_planes[0],
217 | kernel_size,
218 | norm_fn,
219 | activation_fn,
220 | algo=algo,
221 | bias=bias,
222 | )
223 |
224 | def forward(self, input):
225 | output = self.Head(input)
226 |
227 | identity = spconv.SparseConvTensor(
228 | output.features,
229 | output.indices,
230 | output.spatial_shape,
231 | output.batch_size,
232 | )
233 |
234 | if len(self.n_planes) > 1:
235 | output = self.Encode(output)
236 | output = self.U(output)
237 | output = self.Decode(output)
238 | output = output.replace_feature(
239 | torch.cat((identity.features, output.features), dim=1)
240 | )
241 | output = self.Tail(output)
242 |
243 | return output
244 |
245 |
246 | class SparseFC(nn.Module):
247 | def __init__(
248 | self,
249 | n_planes,
250 | norm_fn,
251 | activation_fn=None,
252 | kernel_size=1,
253 | algo=spconv.ConvAlgo.Native,
254 | bias=False,
255 | ):
256 | super().__init__()
257 |
258 | self.sequence = spconv.SparseSequential()
259 | for i in range(len(n_planes) - 2):
260 | self.sequence.add(
261 | spconv.SubMConv3d(
262 | n_planes[i],
263 | n_planes[i + 1],
264 | kernel_size=kernel_size,
265 | bias=False,
266 | algo=algo,
267 | padding=0,
268 | )
269 | )
270 | self.sequence.add(norm_fn(n_planes[i + 1]))
271 | self.sequence.add(activation_fn())
272 |
273 | self.sequence.add(
274 | spconv.SubMConv3d(
275 | n_planes[-2],
276 | n_planes[-1],
277 | kernel_size=kernel_size,
278 | bias=False,
279 | algo=algo,
280 | padding=0,
281 | )
282 | )
283 |
284 | def forward(self, input):
285 | return self.sequence(input)
286 |
287 |
288 | class MLP(nn.Module):
289 | def __init__(
290 | self,
291 | n_planes,
292 | norm_fn,
293 | activation_fn=None,
294 | bias=False,
295 | ):
296 | super().__init__()
297 |
298 | self.sequence = spconv.SparseSequential()
299 |
300 | for i in range(len(n_planes) - 2):
301 | self.sequence.add(
302 | nn.Linear(
303 | n_planes[i],
304 | n_planes[i + 1],
305 | bias=bias,
306 | )
307 | )
308 | self.sequence.add(norm_fn(n_planes[i + 1]))
309 | self.sequence.add(activation_fn())
310 |
311 | self.sequence.add(
312 | nn.Linear(
313 | n_planes[-2],
314 | n_planes[-1],
315 | bias=bias,
316 | )
317 | )
318 |
319 | def forward(self, input):
320 | return self.sequence(input)
321 |
--------------------------------------------------------------------------------
/smart_tree/data_types/cloud.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 | from typing import Optional
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from torchtyping import TensorType
11 | from typeguard import typechecked
12 |
13 | from ..o3d_abstractions.geometries import o3d_cloud, o3d_lines_between_clouds
14 | from ..o3d_abstractions.visualizer import o3d_viewer
15 | from ..util.misc import voxel_downsample
16 | from ..util.queries import skeleton_to_points
17 |
18 |
19 | @typechecked
20 | @dataclass
21 | class Cloud:
22 | xyz: TensorType["N", 3]
23 | rgb: Optional[TensorType["N", 3]] = None
24 | medial_vector: Optional[TensorType["N", 3]] = None
25 | branch_direction: Optional[TensorType["N", 3]] = None
26 | branch_ids: Optional[TensorType["N", 1]] = None
27 | class_l: Optional[TensorType["N", 1]] = None
28 | filename: Optional[Path] = None
29 |
30 | def __len__(self):
31 | return self.xyz.shape[0]
32 |
33 | def __str__(self):
34 | return f"{'*' * 80}\nCloud with {self.xyz.shape[0]} Points.\nMin: {torch.min(self.xyz, 0)[0]}\nMax: {torch.max(self.xyz, 0)[0]}\nDevice:{self.xyz.device}\n{'*' * 80}\n"
35 |
36 | def paint(self, colour=[1, 0, 0]):
37 | self.rgb = torch.tensor([colour]).expand(self.__len__, -1)
38 |
39 | def to_o3d_cld(self):
40 | cpu_cld = self.cpu()
41 | if not hasattr(cpu_cld, "rgb"):
42 | cpu_cld.paint()
43 | return o3d_cloud(cpu_cld.xyz, colours=cpu_cld.rgb)
44 |
45 | def to_o3d_seg_cld(self, cmap: np.ndarray = np.array([[1, 0, 0], [0, 1, 0]])):
46 | cpu_cld = self.cpu()
47 | colours = cmap[cpu_cld.class_l.view(-1).int()]
48 | return o3d_cloud(cpu_cld.xyz, colours=colours)
49 |
50 | def to_o3d_trunk_cld(self):
51 | cpu_cld = self.cpu()
52 | min_branch_id = cpu_cld.branch_ids[0]
53 | return cpu_cld.filter(cpu_cld.branch_ids == min_branch_id).to_o3d_cld()
54 |
55 | def to_o3d_branch_cld(self):
56 | cpu_cld = self.cpu()
57 | min_branch_id = cpu_cld.branch_ids[0]
58 | return cpu_cld.filter(cpu_cld.branch_ids != min_branch_id).to_o3d_cld()
59 |
60 | def to_o3d_medial_vectors(self, cmap=None):
61 | cpu_cld = self.cpu()
62 | medial_cloud = o3d_cloud(cpu_cld.xyz + cpu_cld.medial_vector)
63 | return o3d_lines_between_clouds(cpu_cld.to_o3d_cld(), medial_cloud)
64 |
65 | def to_o3d_branch_directions(self, scale=0.1, cmap=None):
66 | cpu_cld = self.cpu()
67 | branch_dir_cloud = o3d_cloud(cpu_cld.xyz + (cpu_cld.branch_direction * scale))
68 | return o3d_lines_between_clouds(cpu_cld.to_o3d_cld(), branch_dir_cloud)
69 |
70 | # def to_wandb_seg_cld(self, cmap: np.ndarray = np.array([[1, 0, 0], [0, 1, 0]])):
71 |
72 | def filter(cloud: Cloud, mask) -> Cloud:
73 | mask = mask.to(cloud.xyz.device)
74 | xyz = cloud.xyz[mask]
75 | rgb = cloud.rgb[mask] if cloud.rgb is not None else None
76 |
77 | medial_vector = (
78 | cloud.medial_vector[mask] if cloud.medial_vector is not None else None
79 | )
80 | branch_direction = (
81 | cloud.branch_direction[mask] if cloud.branch_direction is not None else None
82 | )
83 | class_l = cloud.class_l[mask] if cloud.class_l is not None else None
84 | branch_ids = cloud.branch_ids[mask] if cloud.branch_ids is not None else None
85 | filename = cloud.filename if cloud.filename is not None else None
86 |
87 | return Cloud(
88 | xyz=xyz,
89 | rgb=rgb,
90 | medial_vector=medial_vector,
91 | branch_direction=branch_direction,
92 | class_l=class_l,
93 | branch_ids=branch_ids,
94 | filename=filename,
95 | )
96 |
97 | def filter_by_class(self, classes):
98 | classes = torch.tensor(classes, device=self.class_l.device)
99 | mask = torch.isin(
100 | self.class_l,
101 | classes,
102 | )
103 | return self.filter(mask.view(-1))
104 |
105 | def filter_by_skeleton(skeleton: TreeSkeleton, cloud: Cloud, threshold=1.1):
106 | distances, radii, vectors_ = skeleton_to_points(cloud, skeleton, chunk_size=512)
107 | mask = distances < radii * threshold
108 | return filter(cloud, mask)
109 |
110 | def pin_memory(self):
111 | xyz = self.xyz.pin_memory()
112 | rgb = self.rgb.pin_memory() if self.rgb is not None else None
113 | medial_vector = (
114 | self.medial_vector.pin_memory() if self.medial_vector is not None else None
115 | )
116 | branch_direction = (
117 | self.branch_direction.pin_memory()
118 | if self.branch_direction is not None
119 | else None
120 | )
121 | class_l = self.class_l.pin_memory() if self.class_l is not None else None
122 | branch_ids = (
123 | self.branch_ids.pin_memory() if self.branch_ids is not None else None
124 | )
125 |
126 | return Cloud(
127 | xyz=xyz,
128 | rgb=rgb,
129 | medial_vector=medial_vector,
130 | branch_direction=branch_direction,
131 | class_l=class_l,
132 | branch_ids=branch_ids,
133 | )
134 |
135 | def cpu(self):
136 | return self.to_device(torch.device("cpu"))
137 |
138 | def to_device(self, device):
139 | xyz = self.xyz.to(device)
140 | rgb = self.rgb.to(device) if self.rgb is not None else None
141 | medial_vector = (
142 | self.medial_vector.to(device) if self.medial_vector is not None else None
143 | )
144 | branch_direction = (
145 | self.branch_direction.to(device)
146 | if self.branch_direction is not None
147 | else None
148 | )
149 | class_l = self.class_l.to(device) if self.class_l is not None else None
150 | branch_ids = self.branch_ids.to(device) if self.branch_ids is not None else None
151 | filename = self.filename if self.filename is not None else None
152 |
153 | return Cloud(
154 | xyz=xyz,
155 | rgb=rgb,
156 | medial_vector=medial_vector,
157 | branch_direction=branch_direction,
158 | class_l=class_l,
159 | branch_ids=branch_ids,
160 | filename=filename,
161 | )
162 |
163 | def cat(self):
164 | return torch.cat(
165 | (
166 | self.xyz,
167 | self.rgb,
168 | ),
169 | 1,
170 | )
171 |
172 | def view(self, cmap=[]):
173 | if cmap == []:
174 | cmap = np.random.rand(self.number_classes, 3)
175 |
176 | cpu_cld = self.cpu()
177 | geoms = []
178 |
179 | geoms.append(cpu_cld.to_o3d_cld())
180 | if cpu_cld.class_l != None:
181 | geoms.append(cpu_cld.to_o3d_seg_cld(cmap))
182 |
183 | if cpu_cld.medial_vector != None:
184 | projected = o3d_cloud(cpu_cld.xyz + cpu_cld.medial_vector, colour=(1, 0, 0))
185 | geoms.append(projected)
186 | geoms.append(o3d_lines_between_clouds(cpu_cld.to_o3d_cld(), projected))
187 |
188 | o3d_viewer(geoms)
189 |
190 | def voxel_down_sample(self, voxel_size):
191 | idx = voxel_downsample(self.xyz, voxel_size)
192 | return self.filter(idx)
193 |
194 | def scale(self, factor):
195 | return Cloud(self.xyz * factor, self.rgb)
196 |
197 | def translate(self, xyz):
198 | return Cloud(self.xyz + xyz.to(self.xyz.device), self.rgb)
199 |
200 | def rotate(self, rot_mat):
201 | rot_mat = rot_mat.to(self.xyz.dtype)
202 | return Cloud(torch.matmul(self.xyz, rot_mat.to(self.xyz.device)), self.rgb)
203 |
204 | @property
205 | def root_idx(self) -> int:
206 | return torch.argmin(self.xyz[:, 1]).item()
207 |
208 | @property
209 | def number_classes(self) -> int:
210 | if not hasattr(self, "class_l"):
211 | return 1
212 | return torch.max(self.class_l).item() + 1
213 |
214 | @property
215 | def max_xyz(self) -> torch.Tensor:
216 | return torch.max(self.xyz, 0)[0]
217 |
218 | @property
219 | def min_xyz(self) -> torch.Tensor:
220 | return torch.min(self.xyz, 0)[0]
221 |
222 | @property
223 | def bbox(self) -> tuple[torch.Tensor, torch.Tensor]:
224 | # defined by centre coordinate, x/2, y/2, z/2
225 | dimensions = (self.max_xyz - self.min_xyz) / 2
226 | centre = self.min_xyz + dimensions
227 | return centre, dimensions
228 |
229 | @property
230 | def medial_pts(self) -> torch.Tensor:
231 | return self.xyz + self.medial_vector
232 |
233 | @staticmethod
234 | def from_numpy(**kwargs) -> "Cloud":
235 | torch_kwargs = {}
236 |
237 | for key, value in kwargs.items():
238 | if key in [
239 | "xyz",
240 | "rgb",
241 | "medial_vector",
242 | "branch_direction",
243 | "branch_ids",
244 | "class_l",
245 | ]:
246 | torch_kwargs[key] = torch.tensor(value).float()
247 |
248 | """ SUPPORT LEGACY NPZ -> Remove in Future..."""
249 | if key in ["vector"]:
250 | torch_kwargs["medial_vector"] = torch.tensor(value)
251 |
252 | return Cloud(**torch_kwargs)
253 |
254 | @property
255 | def radius(self) -> torch.Tensor:
256 | return self.medial_vector.pow(2).sum(1).sqrt()
257 |
258 | @property
259 | def direction(self) -> torch.Tensor:
260 | return F.normalize(self.medial_vector)
261 |
262 | @staticmethod
263 | def from_o3d_cld(cld) -> Cloud:
264 | return Cloud.from_numpy(xyz=np.asarray(cld.points), rgb=np.asarray(cld.colors))
265 |
--------------------------------------------------------------------------------
/smart_tree/conf/training-split.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": [
3 | "/cherry/cherry_21.npz",
4 | "/cherry/cherry_41.npz",
5 | "/cherry/cherry_67.npz",
6 | "/cherry/cherry_12.npz",
7 | "/cherry/cherry_92.npz",
8 | "/cherry/cherry_33.npz",
9 | "/cherry/cherry_29.npz",
10 | "/cherry/cherry_11.npz",
11 | "/cherry/cherry_84.npz",
12 | "/cherry/cherry_24.npz",
13 | "/cherry/cherry_5.npz",
14 | "/cherry/cherry_10.npz",
15 | "/cherry/cherry_66.npz",
16 | "/cherry/cherry_38.npz",
17 | "/cherry/cherry_61.npz",
18 | "/cherry/cherry_48.npz",
19 | "/cherry/cherry_26.npz",
20 | "/cherry/cherry_1.npz",
21 | "/cherry/cherry_40.npz",
22 | "/cherry/cherry_51.npz",
23 | "/cherry/cherry_70.npz",
24 | "/cherry/cherry_31.npz",
25 | "/cherry/cherry_81.npz",
26 | "/cherry/cherry_97.npz",
27 | "/cherry/cherry_62.npz",
28 | "/cherry/cherry_72.npz",
29 | "/cherry/cherry_86.npz",
30 | "/cherry/cherry_39.npz",
31 | "/cherry/cherry_79.npz",
32 | "/cherry/cherry_100.npz",
33 | "/cherry/cherry_83.npz",
34 | "/cherry/cherry_35.npz",
35 | "/cherry/cherry_87.npz",
36 | "/cherry/cherry_17.npz",
37 | "/cherry/cherry_57.npz",
38 | "/cherry/cherry_68.npz",
39 | "/cherry/cherry_77.npz",
40 | "/cherry/cherry_60.npz",
41 | "/cherry/cherry_3.npz",
42 | "/cherry/cherry_22.npz",
43 | "/cherry/cherry_80.npz",
44 | "/cherry/cherry_14.npz",
45 | "/cherry/cherry_23.npz",
46 | "/cherry/cherry_2.npz",
47 | "/cherry/cherry_53.npz",
48 | "/cherry/cherry_8.npz",
49 | "/cherry/cherry_32.npz",
50 | "/cherry/cherry_54.npz",
51 | "/cherry/cherry_63.npz",
52 | "/cherry/cherry_52.npz",
53 | "/cherry/cherry_16.npz",
54 | "/cherry/cherry_71.npz",
55 | "/cherry/cherry_44.npz",
56 | "/cherry/cherry_43.npz",
57 | "/cherry/cherry_59.npz",
58 | "/cherry/cherry_47.npz",
59 | "/cherry/cherry_69.npz",
60 | "/cherry/cherry_9.npz",
61 | "/cherry/cherry_95.npz",
62 | "/cherry/cherry_76.npz",
63 | "/cherry/cherry_93.npz",
64 | "/cherry/cherry_78.npz",
65 | "/cherry/cherry_99.npz",
66 | "/cherry/cherry_36.npz",
67 | "/cherry/cherry_27.npz",
68 | "/cherry/cherry_55.npz",
69 | "/cherry/cherry_45.npz",
70 | "/cherry/cherry_91.npz",
71 | "/cherry/cherry_46.npz",
72 | "/cherry/cherry_65.npz",
73 | "/cherry/cherry_85.npz",
74 | "/cherry/cherry_88.npz",
75 | "/cherry/cherry_58.npz",
76 | "/cherry/cherry_30.npz",
77 | "/cherry/cherry_37.npz",
78 | "/cherry/cherry_50.npz",
79 | "/cherry/cherry_25.npz",
80 | "/cherry/cherry_20.npz",
81 | "/cherry/cherry_64.npz",
82 | "/cherry/cherry_74.npz",
83 | "/apple/apple_57.npz",
84 | "/apple/apple_82.npz",
85 | "/apple/apple_38.npz",
86 | "/apple/apple_94.npz",
87 | "/apple/apple_49.npz",
88 | "/apple/apple_70.npz",
89 | "/apple/apple_71.npz",
90 | "/apple/apple_90.npz",
91 | "/apple/apple_41.npz",
92 | "/apple/apple_8.npz",
93 | "/apple/apple_75.npz",
94 | "/apple/apple_55.npz",
95 | "/apple/apple_19.npz",
96 | "/apple/apple_30.npz",
97 | "/apple/apple_88.npz",
98 | "/apple/apple_15.npz",
99 | "/apple/apple_65.npz",
100 | "/apple/apple_45.npz",
101 | "/apple/apple_78.npz",
102 | "/apple/apple_18.npz",
103 | "/apple/apple_26.npz",
104 | "/apple/apple_25.npz",
105 | "/apple/apple_80.npz",
106 | "/apple/apple_54.npz",
107 | "/apple/apple_72.npz",
108 | "/apple/apple_50.npz",
109 | "/apple/apple_1.npz",
110 | "/apple/apple_93.npz",
111 | "/apple/apple_99.npz",
112 | "/apple/apple_74.npz",
113 | "/apple/apple_33.npz",
114 | "/apple/apple_13.npz",
115 | "/apple/apple_34.npz",
116 | "/apple/apple_21.npz",
117 | "/apple/apple_59.npz",
118 | "/apple/apple_7.npz",
119 | "/apple/apple_84.npz",
120 | "/apple/apple_81.npz",
121 | "/apple/apple_85.npz",
122 | "/apple/apple_79.npz",
123 | "/apple/apple_29.npz",
124 | "/apple/apple_43.npz",
125 | "/apple/apple_40.npz",
126 | "/apple/apple_61.npz",
127 | "/apple/apple_68.npz",
128 | "/apple/apple_3.npz",
129 | "/apple/apple_24.npz",
130 | "/apple/apple_67.npz",
131 | "/apple/apple_32.npz",
132 | "/apple/apple_95.npz",
133 | "/apple/apple_4.npz",
134 | "/apple/apple_91.npz",
135 | "/apple/apple_46.npz",
136 | "/apple/apple_20.npz",
137 | "/apple/apple_89.npz",
138 | "/apple/apple_51.npz",
139 | "/apple/apple_10.npz",
140 | "/apple/apple_98.npz",
141 | "/apple/apple_47.npz",
142 | "/apple/apple_62.npz",
143 | "/apple/apple_52.npz",
144 | "/apple/apple_2.npz",
145 | "/apple/apple_31.npz",
146 | "/apple/apple_58.npz",
147 | "/apple/apple_22.npz",
148 | "/apple/apple_37.npz",
149 | "/apple/apple_9.npz",
150 | "/apple/apple_28.npz",
151 | "/apple/apple_69.npz",
152 | "/apple/apple_76.npz",
153 | "/apple/apple_100.npz",
154 | "/apple/apple_14.npz",
155 | "/apple/apple_44.npz",
156 | "/apple/apple_6.npz",
157 | "/apple/apple_66.npz",
158 | "/apple/apple_35.npz",
159 | "/apple/apple_27.npz",
160 | "/apple/apple_87.npz",
161 | "/apple/apple_48.npz",
162 | "/apple/apple_60.npz",
163 | "/ginkgo/ginkgo_75.npz",
164 | "/ginkgo/ginkgo_44.npz",
165 | "/ginkgo/ginkgo_65.npz",
166 | "/ginkgo/ginkgo_32.npz",
167 | "/ginkgo/ginkgo_82.npz",
168 | "/ginkgo/ginkgo_39.npz",
169 | "/ginkgo/ginkgo_38.npz",
170 | "/ginkgo/ginkgo_95.npz",
171 | "/ginkgo/ginkgo_4.npz",
172 | "/ginkgo/ginkgo_70.npz",
173 | "/ginkgo/ginkgo_1.npz",
174 | "/ginkgo/ginkgo_29.npz",
175 | "/ginkgo/ginkgo_90.npz",
176 | "/ginkgo/ginkgo_13.npz",
177 | "/ginkgo/ginkgo_2.npz",
178 | "/ginkgo/ginkgo_55.npz",
179 | "/ginkgo/ginkgo_68.npz",
180 | "/ginkgo/ginkgo_49.npz",
181 | "/ginkgo/ginkgo_63.npz",
182 | "/ginkgo/ginkgo_47.npz",
183 | "/ginkgo/ginkgo_66.npz",
184 | "/ginkgo/ginkgo_40.npz",
185 | "/ginkgo/ginkgo_97.npz",
186 | "/ginkgo/ginkgo_81.npz",
187 | "/ginkgo/ginkgo_41.npz",
188 | "/ginkgo/ginkgo_26.npz",
189 | "/ginkgo/ginkgo_42.npz",
190 | "/ginkgo/ginkgo_35.npz",
191 | "/ginkgo/ginkgo_10.npz",
192 | "/ginkgo/ginkgo_53.npz",
193 | "/ginkgo/ginkgo_24.npz",
194 | "/ginkgo/ginkgo_31.npz",
195 | "/ginkgo/ginkgo_45.npz",
196 | "/ginkgo/ginkgo_8.npz",
197 | "/ginkgo/ginkgo_87.npz",
198 | "/ginkgo/ginkgo_12.npz",
199 | "/ginkgo/ginkgo_92.npz",
200 | "/ginkgo/ginkgo_59.npz",
201 | "/ginkgo/ginkgo_22.npz",
202 | "/ginkgo/ginkgo_48.npz",
203 | "/ginkgo/ginkgo_76.npz",
204 | "/ginkgo/ginkgo_50.npz",
205 | "/ginkgo/ginkgo_84.npz",
206 | "/ginkgo/ginkgo_30.npz",
207 | "/ginkgo/ginkgo_80.npz",
208 | "/ginkgo/ginkgo_5.npz",
209 | "/ginkgo/ginkgo_28.npz",
210 | "/ginkgo/ginkgo_57.npz",
211 | "/ginkgo/ginkgo_79.npz",
212 | "/ginkgo/ginkgo_89.npz",
213 | "/ginkgo/ginkgo_17.npz",
214 | "/ginkgo/ginkgo_19.npz",
215 | "/ginkgo/ginkgo_98.npz",
216 | "/ginkgo/ginkgo_67.npz",
217 | "/ginkgo/ginkgo_69.npz",
218 | "/ginkgo/ginkgo_58.npz",
219 | "/ginkgo/ginkgo_33.npz",
220 | "/ginkgo/ginkgo_85.npz",
221 | "/ginkgo/ginkgo_3.npz",
222 | "/ginkgo/ginkgo_14.npz",
223 | "/ginkgo/ginkgo_51.npz",
224 | "/ginkgo/ginkgo_15.npz",
225 | "/ginkgo/ginkgo_72.npz",
226 | "/ginkgo/ginkgo_46.npz",
227 | "/ginkgo/ginkgo_16.npz",
228 | "/ginkgo/ginkgo_52.npz",
229 | "/ginkgo/ginkgo_34.npz",
230 | "/ginkgo/ginkgo_100.npz",
231 | "/ginkgo/ginkgo_74.npz",
232 | "/ginkgo/ginkgo_91.npz",
233 | "/ginkgo/ginkgo_78.npz",
234 | "/ginkgo/ginkgo_83.npz",
235 | "/ginkgo/ginkgo_6.npz",
236 | "/ginkgo/ginkgo_73.npz",
237 | "/ginkgo/ginkgo_20.npz",
238 | "/ginkgo/ginkgo_27.npz",
239 | "/ginkgo/ginkgo_25.npz",
240 | "/ginkgo/ginkgo_96.npz",
241 | "/ginkgo/ginkgo_9.npz",
242 | "/ginkgo/ginkgo_60.npz",
243 | "/walnut/walnut_59.npz",
244 | "/walnut/walnut_32.npz",
245 | "/walnut/walnut_10.npz",
246 | "/walnut/walnut_16.npz",
247 | "/walnut/walnut_48.npz",
248 | "/walnut/walnut_60.npz",
249 | "/walnut/walnut_23.npz",
250 | "/walnut/walnut_55.npz",
251 | "/walnut/walnut_1.npz",
252 | "/walnut/walnut_19.npz",
253 | "/walnut/walnut_35.npz",
254 | "/walnut/walnut_6.npz",
255 | "/walnut/walnut_49.npz",
256 | "/walnut/walnut_42.npz",
257 | "/walnut/walnut_62.npz",
258 | "/walnut/walnut_76.npz",
259 | "/walnut/walnut_21.npz",
260 | "/walnut/walnut_73.npz",
261 | "/walnut/walnut_58.npz",
262 | "/walnut/walnut_63.npz",
263 | "/walnut/walnut_85.npz",
264 | "/walnut/walnut_72.npz",
265 | "/walnut/walnut_39.npz",
266 | "/walnut/walnut_69.npz",
267 | "/walnut/walnut_78.npz",
268 | "/walnut/walnut_41.npz",
269 | "/walnut/walnut_3.npz",
270 | "/walnut/walnut_18.npz",
271 | "/walnut/walnut_99.npz",
272 | "/walnut/walnut_29.npz",
273 | "/walnut/walnut_31.npz",
274 | "/walnut/walnut_33.npz",
275 | "/walnut/walnut_89.npz",
276 | "/walnut/walnut_90.npz",
277 | "/walnut/walnut_83.npz",
278 | "/walnut/walnut_4.npz",
279 | "/walnut/walnut_24.npz",
280 | "/walnut/walnut_52.npz",
281 | "/walnut/walnut_38.npz",
282 | "/walnut/walnut_28.npz",
283 | "/walnut/walnut_51.npz",
284 | "/walnut/walnut_71.npz",
285 | "/walnut/walnut_100.npz",
286 | "/walnut/walnut_30.npz",
287 | "/walnut/walnut_64.npz",
288 | "/walnut/walnut_57.npz",
289 | "/walnut/walnut_75.npz",
290 | "/walnut/walnut_77.npz",
291 | "/walnut/walnut_81.npz",
292 | "/walnut/walnut_22.npz",
293 | "/walnut/walnut_13.npz",
294 | "/walnut/walnut_53.npz",
295 | "/walnut/walnut_56.npz",
296 | "/walnut/walnut_11.npz",
297 | "/walnut/walnut_8.npz",
298 | "/walnut/walnut_20.npz",
299 | "/walnut/walnut_65.npz",
300 | "/walnut/walnut_88.npz",
301 | "/walnut/walnut_54.npz",
302 | "/walnut/walnut_40.npz",
303 | "/walnut/walnut_95.npz",
304 | "/walnut/walnut_70.npz",
305 | "/walnut/walnut_37.npz",
306 | "/walnut/walnut_17.npz",
307 | "/walnut/walnut_44.npz",
308 | "/walnut/walnut_5.npz",
309 | "/walnut/walnut_36.npz",
310 | "/walnut/walnut_9.npz",
311 | "/walnut/walnut_86.npz",
312 | "/walnut/walnut_50.npz",
313 | "/walnut/walnut_46.npz",
314 | "/walnut/walnut_34.npz",
315 | "/walnut/walnut_82.npz",
316 | "/walnut/walnut_61.npz",
317 | "/walnut/walnut_96.npz",
318 | "/walnut/walnut_25.npz",
319 | "/walnut/walnut_80.npz",
320 | "/walnut/walnut_45.npz",
321 | "/walnut/walnut_94.npz",
322 | "/walnut/walnut_87.npz",
323 | "/pine/pine_47.npz",
324 | "/pine/pine_88.npz",
325 | "/pine/pine_58.npz",
326 | "/pine/pine_63.npz",
327 | "/pine/pine_32.npz",
328 | "/pine/pine_61.npz",
329 | "/pine/pine_81.npz",
330 | "/pine/pine_16.npz",
331 | "/pine/pine_50.npz",
332 | "/pine/pine_48.npz",
333 | "/pine/pine_44.npz",
334 | "/pine/pine_98.npz",
335 | "/pine/pine_17.npz",
336 | "/pine/pine_1.npz",
337 | "/pine/pine_36.npz",
338 | "/pine/pine_31.npz",
339 | "/pine/pine_100.npz",
340 | "/pine/pine_69.npz",
341 | "/pine/pine_49.npz",
342 | "/pine/pine_24.npz",
343 | "/pine/pine_78.npz",
344 | "/pine/pine_90.npz",
345 | "/pine/pine_77.npz",
346 | "/pine/pine_57.npz",
347 | "/pine/pine_41.npz",
348 | "/pine/pine_19.npz",
349 | "/pine/pine_56.npz",
350 | "/pine/pine_53.npz",
351 | "/pine/pine_60.npz",
352 | "/pine/pine_38.npz",
353 | "/pine/pine_25.npz",
354 | "/pine/pine_20.npz",
355 | "/pine/pine_89.npz",
356 | "/pine/pine_6.npz",
357 | "/pine/pine_82.npz",
358 | "/pine/pine_3.npz",
359 | "/pine/pine_97.npz",
360 | "/pine/pine_34.npz",
361 | "/pine/pine_73.npz",
362 | "/pine/pine_26.npz",
363 | "/pine/pine_51.npz",
364 | "/pine/pine_5.npz",
365 | "/pine/pine_91.npz",
366 | "/pine/pine_75.npz",
367 | "/pine/pine_66.npz",
368 | "/pine/pine_12.npz",
369 | "/pine/pine_96.npz",
370 | "/pine/pine_85.npz",
371 | "/pine/pine_86.npz",
372 | "/pine/pine_23.npz",
373 | "/pine/pine_64.npz",
374 | "/pine/pine_79.npz",
375 | "/pine/pine_59.npz",
376 | "/pine/pine_39.npz",
377 | "/pine/pine_94.npz",
378 | "/pine/pine_42.npz",
379 | "/pine/pine_54.npz",
380 | "/pine/pine_15.npz",
381 | "/pine/pine_72.npz",
382 | "/pine/pine_4.npz",
383 | "/pine/pine_28.npz",
384 | "/pine/pine_93.npz",
385 | "/pine/pine_87.npz",
386 | "/pine/pine_74.npz",
387 | "/pine/pine_18.npz",
388 | "/pine/pine_71.npz",
389 | "/pine/pine_92.npz",
390 | "/pine/pine_55.npz",
391 | "/pine/pine_62.npz",
392 | "/pine/pine_80.npz",
393 | "/pine/pine_68.npz",
394 | "/pine/pine_70.npz",
395 | "/pine/pine_14.npz",
396 | "/pine/pine_22.npz",
397 | "/pine/pine_95.npz",
398 | "/pine/pine_40.npz",
399 | "/pine/pine_8.npz",
400 | "/pine/pine_65.npz",
401 | "/pine/pine_37.npz",
402 | "/pine/pine_29.npz",
403 | "/eucalyptus/eucalyptus_79.npz",
404 | "/eucalyptus/eucalyptus_90.npz",
405 | "/eucalyptus/eucalyptus_23.npz",
406 | "/eucalyptus/eucalyptus_2.npz",
407 | "/eucalyptus/eucalyptus_17.npz",
408 | "/eucalyptus/eucalyptus_74.npz",
409 | "/eucalyptus/eucalyptus_93.npz",
410 | "/eucalyptus/eucalyptus_85.npz",
411 | "/eucalyptus/eucalyptus_24.npz",
412 | "/eucalyptus/eucalyptus_86.npz",
413 | "/eucalyptus/eucalyptus_82.npz",
414 | "/eucalyptus/eucalyptus_15.npz",
415 | "/eucalyptus/eucalyptus_64.npz",
416 | "/eucalyptus/eucalyptus_97.npz",
417 | "/eucalyptus/eucalyptus_43.npz",
418 | "/eucalyptus/eucalyptus_6.npz",
419 | "/eucalyptus/eucalyptus_78.npz",
420 | "/eucalyptus/eucalyptus_31.npz",
421 | "/eucalyptus/eucalyptus_29.npz",
422 | "/eucalyptus/eucalyptus_75.npz",
423 | "/eucalyptus/eucalyptus_84.npz",
424 | "/eucalyptus/eucalyptus_37.npz",
425 | "/eucalyptus/eucalyptus_18.npz",
426 | "/eucalyptus/eucalyptus_100.npz",
427 | "/eucalyptus/eucalyptus_20.npz",
428 | "/eucalyptus/eucalyptus_99.npz",
429 | "/eucalyptus/eucalyptus_62.npz",
430 | "/eucalyptus/eucalyptus_4.npz",
431 | "/eucalyptus/eucalyptus_38.npz",
432 | "/eucalyptus/eucalyptus_42.npz",
433 | "/eucalyptus/eucalyptus_72.npz",
434 | "/eucalyptus/eucalyptus_1.npz",
435 | "/eucalyptus/eucalyptus_51.npz",
436 | "/eucalyptus/eucalyptus_87.npz",
437 | "/eucalyptus/eucalyptus_80.npz",
438 | "/eucalyptus/eucalyptus_98.npz",
439 | "/eucalyptus/eucalyptus_91.npz",
440 | "/eucalyptus/eucalyptus_27.npz",
441 | "/eucalyptus/eucalyptus_11.npz",
442 | "/eucalyptus/eucalyptus_66.npz",
443 | "/eucalyptus/eucalyptus_63.npz",
444 | "/eucalyptus/eucalyptus_69.npz",
445 | "/eucalyptus/eucalyptus_59.npz",
446 | "/eucalyptus/eucalyptus_61.npz",
447 | "/eucalyptus/eucalyptus_13.npz",
448 | "/eucalyptus/eucalyptus_7.npz",
449 | "/eucalyptus/eucalyptus_41.npz",
450 | "/eucalyptus/eucalyptus_67.npz",
451 | "/eucalyptus/eucalyptus_77.npz",
452 | "/eucalyptus/eucalyptus_92.npz",
453 | "/eucalyptus/eucalyptus_26.npz",
454 | "/eucalyptus/eucalyptus_83.npz",
455 | "/eucalyptus/eucalyptus_95.npz",
456 | "/eucalyptus/eucalyptus_45.npz",
457 | "/eucalyptus/eucalyptus_60.npz",
458 | "/eucalyptus/eucalyptus_19.npz",
459 | "/eucalyptus/eucalyptus_49.npz",
460 | "/eucalyptus/eucalyptus_9.npz",
461 | "/eucalyptus/eucalyptus_46.npz",
462 | "/eucalyptus/eucalyptus_3.npz",
463 | "/eucalyptus/eucalyptus_22.npz",
464 | "/eucalyptus/eucalyptus_10.npz",
465 | "/eucalyptus/eucalyptus_39.npz",
466 | "/eucalyptus/eucalyptus_56.npz",
467 | "/eucalyptus/eucalyptus_71.npz",
468 | "/eucalyptus/eucalyptus_48.npz",
469 | "/eucalyptus/eucalyptus_36.npz",
470 | "/eucalyptus/eucalyptus_68.npz",
471 | "/eucalyptus/eucalyptus_30.npz",
472 | "/eucalyptus/eucalyptus_44.npz",
473 | "/eucalyptus/eucalyptus_34.npz",
474 | "/eucalyptus/eucalyptus_55.npz",
475 | "/eucalyptus/eucalyptus_8.npz",
476 | "/eucalyptus/eucalyptus_25.npz",
477 | "/eucalyptus/eucalyptus_65.npz",
478 | "/eucalyptus/eucalyptus_81.npz",
479 | "/eucalyptus/eucalyptus_32.npz",
480 | "/eucalyptus/eucalyptus_16.npz",
481 | "/eucalyptus/eucalyptus_96.npz",
482 | "/eucalyptus/eucalyptus_40.npz"
483 | ],
484 | "test": [
485 | "/cherry/cherry_15.npz",
486 | "/cherry/cherry_90.npz",
487 | "/cherry/cherry_75.npz",
488 | "/cherry/cherry_82.npz",
489 | "/cherry/cherry_34.npz",
490 | "/cherry/cherry_4.npz",
491 | "/cherry/cherry_49.npz",
492 | "/cherry/cherry_28.npz",
493 | "/cherry/cherry_19.npz",
494 | "/cherry/cherry_56.npz",
495 | "/apple/apple_97.npz",
496 | "/apple/apple_53.npz",
497 | "/apple/apple_63.npz",
498 | "/apple/apple_23.npz",
499 | "/apple/apple_83.npz",
500 | "/apple/apple_92.npz",
501 | "/apple/apple_16.npz",
502 | "/apple/apple_86.npz",
503 | "/apple/apple_36.npz",
504 | "/apple/apple_12.npz",
505 | "/ginkgo/ginkgo_21.npz",
506 | "/ginkgo/ginkgo_99.npz",
507 | "/ginkgo/ginkgo_88.npz",
508 | "/ginkgo/ginkgo_77.npz",
509 | "/ginkgo/ginkgo_11.npz",
510 | "/ginkgo/ginkgo_37.npz",
511 | "/ginkgo/ginkgo_18.npz",
512 | "/ginkgo/ginkgo_7.npz",
513 | "/ginkgo/ginkgo_71.npz",
514 | "/ginkgo/ginkgo_93.npz",
515 | "/walnut/walnut_67.npz",
516 | "/walnut/walnut_7.npz",
517 | "/walnut/walnut_92.npz",
518 | "/walnut/walnut_2.npz",
519 | "/walnut/walnut_12.npz",
520 | "/walnut/walnut_27.npz",
521 | "/walnut/walnut_43.npz",
522 | "/walnut/walnut_93.npz",
523 | "/walnut/walnut_66.npz",
524 | "/walnut/walnut_26.npz",
525 | "/pine/pine_13.npz",
526 | "/pine/pine_52.npz",
527 | "/pine/pine_2.npz",
528 | "/pine/pine_27.npz",
529 | "/pine/pine_46.npz",
530 | "/pine/pine_76.npz",
531 | "/pine/pine_30.npz",
532 | "/pine/pine_67.npz",
533 | "/pine/pine_21.npz",
534 | "/pine/pine_11.npz",
535 | "/eucalyptus/eucalyptus_28.npz",
536 | "/eucalyptus/eucalyptus_5.npz",
537 | "/eucalyptus/eucalyptus_94.npz",
538 | "/eucalyptus/eucalyptus_89.npz",
539 | "/eucalyptus/eucalyptus_54.npz",
540 | "/eucalyptus/eucalyptus_53.npz",
541 | "/eucalyptus/eucalyptus_58.npz",
542 | "/eucalyptus/eucalyptus_73.npz",
543 | "/eucalyptus/eucalyptus_35.npz",
544 | "/eucalyptus/eucalyptus_70.npz"
545 | ],
546 | "validation": [
547 | "/cherry/cherry_42.npz",
548 | "/cherry/cherry_13.npz",
549 | "/cherry/cherry_89.npz",
550 | "/cherry/cherry_18.npz",
551 | "/cherry/cherry_7.npz",
552 | "/cherry/cherry_73.npz",
553 | "/cherry/cherry_98.npz",
554 | "/cherry/cherry_96.npz",
555 | "/cherry/cherry_6.npz",
556 | "/cherry/cherry_94.npz",
557 | "/apple/apple_39.npz",
558 | "/apple/apple_56.npz",
559 | "/apple/apple_96.npz",
560 | "/apple/apple_17.npz",
561 | "/apple/apple_11.npz",
562 | "/apple/apple_5.npz",
563 | "/apple/apple_64.npz",
564 | "/apple/apple_42.npz",
565 | "/apple/apple_77.npz",
566 | "/apple/apple_73.npz",
567 | "/ginkgo/ginkgo_64.npz",
568 | "/ginkgo/ginkgo_54.npz",
569 | "/ginkgo/ginkgo_23.npz",
570 | "/ginkgo/ginkgo_94.npz",
571 | "/ginkgo/ginkgo_36.npz",
572 | "/ginkgo/ginkgo_56.npz",
573 | "/ginkgo/ginkgo_86.npz",
574 | "/ginkgo/ginkgo_61.npz",
575 | "/ginkgo/ginkgo_43.npz",
576 | "/ginkgo/ginkgo_62.npz",
577 | "/walnut/walnut_14.npz",
578 | "/walnut/walnut_97.npz",
579 | "/walnut/walnut_74.npz",
580 | "/walnut/walnut_84.npz",
581 | "/walnut/walnut_47.npz",
582 | "/walnut/walnut_15.npz",
583 | "/walnut/walnut_79.npz",
584 | "/walnut/walnut_68.npz",
585 | "/walnut/walnut_98.npz",
586 | "/walnut/walnut_91.npz",
587 | "/pine/pine_84.npz",
588 | "/pine/pine_7.npz",
589 | "/pine/pine_83.npz",
590 | "/pine/pine_43.npz",
591 | "/pine/pine_45.npz",
592 | "/pine/pine_35.npz",
593 | "/pine/pine_99.npz",
594 | "/pine/pine_9.npz",
595 | "/pine/pine_33.npz",
596 | "/pine/pine_10.npz",
597 | "/eucalyptus/eucalyptus_12.npz",
598 | "/eucalyptus/eucalyptus_14.npz",
599 | "/eucalyptus/eucalyptus_76.npz",
600 | "/eucalyptus/eucalyptus_33.npz",
601 | "/eucalyptus/eucalyptus_50.npz",
602 | "/eucalyptus/eucalyptus_47.npz",
603 | "/eucalyptus/eucalyptus_57.npz",
604 | "/eucalyptus/eucalyptus_21.npz",
605 | "/eucalyptus/eucalyptus_88.npz",
606 | "/eucalyptus/eucalyptus_52.npz"
607 | ]
608 | }
--------------------------------------------------------------------------------