├── .gitignore ├── .gitmodules ├── Cargo.toml ├── LICENSE.md ├── README.md ├── bench ├── Cargo.toml ├── scripts │ ├── generate_pointcloud_data.py │ ├── plot_affordance_time.py │ ├── plot_error_hist.py │ ├── plot_filter_strength.py │ ├── plot_forest_error.py │ └── plot_times_for_blog.py └── src │ ├── bin │ ├── correctness.rs │ ├── error.rs │ ├── filter_strength.rs │ ├── forest_error.rs │ └── perf_plots.rs │ ├── forest.rs │ ├── kdt.rs │ └── lib.rs ├── captree ├── Cargo.toml └── src │ └── lib.rs ├── morton_filter ├── Cargo.toml └── src │ └── lib.rs ├── rust-toolchain.toml └── rustfmt.toml /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | flamegraph.svg 4 | perf.* 5 | *.csv 6 | .vscode 7 | /hilbert_test/build/ 8 | *.idx 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "hilbert_test/hilbert-sort"] 2 | path = hilbert_test/hilbert-sort 3 | url = git@github.com:rkowalewski/hilbert-sort.git 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["captree", "bench", "morton_filter"] 3 | resolver = "2" 4 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # PolyForm Noncommercial License 1.0.0 2 | 3 | 4 | 5 | ## Acceptance 6 | 7 | In order to get any license under these terms, you must agree 8 | to them as both strict obligations and conditions to all 9 | your licenses. 10 | 11 | ## Copyright License 12 | 13 | The licensor grants you a copyright license for the 14 | software to do everything you might do with the software 15 | that would otherwise infringe the licensor's copyright 16 | in it for any permitted purpose. However, you may 17 | only distribute the software according to [Distribution 18 | License](#distribution-license) and make changes or new works 19 | based on the software according to [Changes and New Works 20 | License](#changes-and-new-works-license). 21 | 22 | ## Distribution License 23 | 24 | The licensor grants you an additional copyright license 25 | to distribute copies of the software. Your license 26 | to distribute covers distributing the software with 27 | changes and new works permitted by [Changes and New Works 28 | License](#changes-and-new-works-license). 29 | 30 | ## Notices 31 | 32 | You must ensure that anyone who gets a copy of any part of 33 | the software from you also gets a copy of these terms or the 34 | URL for them above, as well as copies of any plain-text lines 35 | beginning with `Required Notice:` that the licensor provided 36 | with the software. For example: 37 | 38 | > Required Notice: Copyright Rice University 39 | 40 | ## Changes and New Works License 41 | 42 | The licensor grants you an additional copyright license to 43 | make changes and new works based on the software for any 44 | permitted purpose. 45 | 46 | ## Patent License 47 | 48 | The licensor grants you a patent license for the software that 49 | covers patent claims the licensor can license, or becomes able 50 | to license, that you would infringe by using the software. 51 | 52 | ## Noncommercial Purposes 53 | 54 | Any noncommercial purpose is a permitted purpose. 55 | 56 | ## Personal Uses 57 | 58 | Personal use for research, experiment, and testing for 59 | the benefit of public knowledge, personal study, private 60 | entertainment, hobby projects, amateur pursuits, or religious 61 | observance, without any anticipated commercial application, 62 | is use for a permitted purpose. 63 | 64 | ## Noncommercial Organizations 65 | 66 | Use by any charitable organization, educational institution, 67 | public research organization, public safety or health 68 | organization, environmental protection organization, 69 | or government institution is use for a permitted purpose 70 | regardless of the source of funding or obligations resulting 71 | from the funding. 72 | 73 | ## Fair Use 74 | 75 | You may have "fair use" rights for the software under the 76 | law. These terms do not limit them. 77 | 78 | ## No Other Rights 79 | 80 | These terms do not allow you to sublicense or transfer any of 81 | your licenses to anyone else, or prevent the licensor from 82 | granting licenses to anyone else. These terms do not imply 83 | any other licenses. 84 | 85 | ## Patent Defense 86 | 87 | If you make any written claim that the software infringes or 88 | contributes to infringement of any patent, your patent license 89 | for the software granted under these terms ends immediately. If 90 | your company makes such a claim, your patent license ends 91 | immediately for work on behalf of your company. 92 | 93 | ## Violations 94 | 95 | The first time you are notified in writing that you have 96 | violated any of these terms, or done anything with the software 97 | not covered by your licenses, your licenses can nonetheless 98 | continue if you come into full compliance with these terms, 99 | and take practical steps to correct past violations, within 100 | 32 days of receiving notice. Otherwise, all your licenses 101 | end immediately. 102 | 103 | ## No Liability 104 | 105 | ***As far as the law allows, the software comes as is, without 106 | any warranty or condition, and the licensor will not be liable 107 | to you for any damages arising out of these terms or the use 108 | or nature of the software, under any kind of legal claim.*** 109 | 110 | ## Definitions 111 | 112 | The **licensor** is the individual or entity offering these 113 | terms, and the **software** is the software the licensor makes 114 | available under these terms. 115 | 116 | **You** refers to the individual or entity agreeing to these 117 | terms. 118 | 119 | **Your company** is any legal entity, sole proprietorship, 120 | or other kind of organization that you work for, plus all 121 | organizations that have control over, are under the control of, 122 | or are under common control with that organization. **Control** 123 | means ownership of substantially all the assets of an entity, 124 | or the power to direct its management and policies by vote, 125 | contract, or otherwise. Control can be direct or indirect. 126 | 127 | **Your licenses** are all the licenses granted to you for the 128 | software under these terms. 129 | 130 | **Use** means anything you do with the software requiring one 131 | of your licenses. 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Collision-Affording Point Trees: SIMD-Amenable Nearest Neighbors for Fast Collision Checking 2 | 3 | This is a Rust implementation of the _collision-affording point tree_ (CAPT), a data structure for 4 | SIMD-parallel collision-checking against point clouds. 5 | 6 | You may also want to look at the following other sources: 7 | 8 | - [The paper](https://arxiv.org/abs/2406.02807) 9 | - [Demo video](https://www.youtube.com/watch?v=BzDKdrU1VpM) 10 | - [C++ implementation](https://github.com/KavrakiLab/vamp) 11 | - [Blog post about it](https://www.claytonwramsey.com/blog/captree) 12 | 13 | If you use this in an academic work, please cite it as follows: 14 | 15 | ```bibtex 16 | @InProceedings{capt, 17 | title = {Collision-Affording Point Trees: {SIMD}-Amenable Nearest Neighbors for Fast Collision Checking}, 18 | author = {Ramsey, Clayton W. and Kingston, Zachary and Thomason, Wil and Kavraki, Lydia E.}, 19 | booktitle = {Robotics: Science and Systems}, 20 | date = {2024}, 21 | url = {http://arxiv.org/abs/2406.02807}, 22 | note = {To Appear.} 23 | } 24 | ``` 25 | 26 | ## Usage 27 | 28 | The core data structure in this library is the `Capt`, which is a search tree used for collision checking. 29 | 30 | ```rust 31 | use captree::Capt; 32 | 33 | // list of points in tree 34 | let points = [[1.0, 1.0], [2.0, 1.0], [3.0, -1.0]]; 35 | 36 | // range of legal radii for collision-checking 37 | let radius_range = (0.0, 100.0); 38 | 39 | let captree = Capt::new(&points, radius_range); 40 | 41 | // sphere centered at (1.5, 1.5) with radius 0.01 does not collide 42 | assert!(!captree.collides(&[1.5, 1.5], 0.01)); 43 | 44 | // sphere centered at (1.5, 1.5) with radius 1.0 does collide 45 | assert!(captree.collides(&[1.5, 1.5], 0.01)); 46 | ``` 47 | 48 | ## License 49 | 50 | This work is licensed to you under the Polyform Non-Commercial License. 51 | For further details, refer to [LICENSE.md](/LICENSE.md). 52 | -------------------------------------------------------------------------------- /bench/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "bench" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | captree = {version = "0.1.0", path = "../captree", features = ["simd"]} 10 | morton_filter = {verstion = "0.1.0", path = "../morton_filter"} 11 | kiddo = {version = "4.0.0", features = ["simd"]} 12 | rand = "0.8.5" 13 | rand_chacha = "0.3.1" 14 | rand_distr = "0.4.3" 15 | 16 | [profile.release] 17 | lto = true 18 | panic = "abort" 19 | strip = true 20 | -------------------------------------------------------------------------------- /bench/scripts/generate_pointcloud_data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import InitVar, dataclass, field 2 | from functools import partial 3 | from itertools import product 4 | from pathlib import Path 5 | from time import time 6 | 7 | import numpy as np 8 | from fire import Fire 9 | from grapeshot.assets import ROBOTS 10 | from grapeshot.model.camera import Camera, process_camera 11 | from grapeshot.model.group import GroupABC 12 | from grapeshot.model.robot import process_srdf 13 | from grapeshot.model.world import Skeleton, World 14 | from grapeshot.simulators.pybullet import PyBulletSimulator 15 | from grapeshot.util.mesh import mesh_to_sampled_pointcloud, pointcloud_to_mesh 16 | from grapeshot.util.worldpool import WorldPool 17 | from h5py import File 18 | from numpy.typing import NDArray 19 | 20 | 21 | @dataclass(slots=True) 22 | class RobotParams: 23 | name: str 24 | sensor_link: str 25 | sensor_group: str 26 | 27 | 28 | @dataclass(slots=True) 29 | class PointcloudParams: 30 | problem_name: str 31 | resolution: tuple[float, float] 32 | num_points: int 33 | 34 | 35 | @dataclass(slots=True) 36 | class SampleWorld(World): 37 | camera: Camera = field(init=False) 38 | sensor_group: GroupABC = field(init=False) 39 | robot_params: InitVar[RobotParams | None] = None 40 | pointcloud_params: InitVar[PointcloudParams | None] = None 41 | 42 | def __post_init__(self, robot_params, sample_params): 43 | World.__post_init__(self) 44 | if robot_params is not None and sample_params is not None: 45 | # Initial setup 46 | robot = ROBOTS[robot_params.name] 47 | robot_skel = self.add_skeleton(robot.urdf, name="robot") 48 | groups = process_srdf(robot_skel, robot.srdf) 49 | self.add_skeleton(robot.problems[sample_params.problem_name].environment, name="environment") 50 | # NOTE: I assume we do not need to load the plane here, since we don't care about 51 | # visualization and aren't running the sim (so gravity doesn't matter) 52 | self.setup_collision_filter() 53 | self.camera = process_camera(robot.cameras[robot_params.sensor_link], robot_skel) 54 | self.sensor_group = groups[robot_params.sensor_group] 55 | 56 | @property 57 | def robot(self) -> Skeleton: 58 | return self.get_skeleton("robot") 59 | 60 | @property 61 | def environment(self) -> Skeleton: 62 | return self.get_skeleton("environment") 63 | 64 | 65 | def sample_pointcloud(q: NDArray, *, world: SampleWorld) -> NDArray: 66 | world.set_group_positions(world.sensor_group, q) 67 | # The ignore here is because we know the sim will be a PyBulletSim, which supports skeleton 68 | # filtering 69 | return world.sim.take_image(world.camera, [world.robot]).point_cloud # type: ignore 70 | 71 | 72 | def gather_pointcloud(num_points: int, *pointcloud_samples) -> NDArray: 73 | aggregate_cloud = np.concatenate(pointcloud_samples) 74 | # TODO: Parameterize the resolution 75 | cloud_mesh = pointcloud_to_mesh(aggregate_cloud, 0.005) 76 | return mesh_to_sampled_pointcloud(cloud_mesh, num_points) 77 | 78 | 79 | def main( 80 | problem_name: str, 81 | robot_name: str = "fetch", 82 | sensor_link: str = "head", 83 | sensor_group: str = "head_with_torso", 84 | head_resolution: float = 0.15, 85 | torso_resolution: float = 1.0, 86 | num_points: int = 400_000, 87 | noise: float = 0.01, 88 | output_path: Path | None = None, 89 | show_progress: bool = True, 90 | ): 91 | print("Loading parameters...") 92 | robot_params = RobotParams(robot_name, sensor_link, sensor_group) 93 | pointcloud_params = PointcloudParams( 94 | problem_name, (torso_resolution, head_resolution), num_points 95 | ) 96 | print("Setting up robot and scene...") 97 | world = SampleWorld(PyBulletSimulator(), {}, robot_params, pointcloud_params) 98 | 99 | def setup_clone(w: SampleWorld) -> SampleWorld: 100 | w.camera = world.camera 101 | w.sensor_group = world.sensor_group 102 | return w 103 | 104 | with WorldPool(world, world_setup_fn=setup_clone, show_progress=show_progress) as wp: 105 | print( 106 | f"Building pointcloud for {pointcloud_params.problem_name} with {pointcloud_params.num_points} points..." 107 | ) 108 | start = time() 109 | torso_resolution, head_resolution = pointcloud_params.resolution 110 | mins = world.sensor_group.mins 111 | maxes = world.sensor_group.maxes 112 | pointcloud = wp(sample_pointcloud)( 113 | map_over=[[ 114 | np.array([x, y, z]) for x, 115 | y, 116 | z in product( 117 | np.arange(mins[0], maxes[0], torso_resolution), 118 | np.arange(mins[1], maxes[1], head_resolution), 119 | np.arange(mins[2], maxes[2], head_resolution) 120 | ) 121 | ]], 122 | gather=partial(gather_pointcloud, pointcloud_params.num_points) 123 | ).result() 124 | pointcloud += np.random.default_rng().uniform(low=-noise, high=noise, size=pointcloud.shape) 125 | end = time() 126 | print(f"Pointcloud time: {end-start}s", f"Pointcloud shape: {pointcloud.shape}") 127 | if output_path is None: 128 | output_path = Path( 129 | "_".join([ 130 | f"{robot_params.name}", 131 | f"{pointcloud_params.problem_name}", 132 | f"r{pointcloud_params.resolution[0]}x{pointcloud_params.resolution[1]}", 133 | f"e{noise}", 134 | f"n{pointcloud_params.num_points}.h5", 135 | ]) 136 | ) 137 | print(f"Saving pointcloud to {output_path}") 138 | output_path.parent.mkdir(parents=True, exist_ok=True) 139 | with File(output_path, "w") as output_file: 140 | pointcloud_group = output_file.create_group("pointcloud") 141 | pointcloud_dset = pointcloud_group.create_dataset( 142 | "points", (pointcloud_params.num_points, 3), dtype=np.float32 143 | ) 144 | pointcloud_dset[...] = pointcloud 145 | 146 | 147 | if __name__ == "__main__": 148 | Fire(main) 149 | -------------------------------------------------------------------------------- /bench/scripts/plot_affordance_time.py: -------------------------------------------------------------------------------- 1 | """ 2 | USAGE: 3 | ``` 4 | python ./plot_affordance_time.py path/to/perf.csv 5 | ``` 6 | 7 | Generate a plot of the runtime of the affordance tree construction and querying. 8 | The first argument must point to a CSV file with the following columns: 9 | 1. Number of points in the tree 10 | 2. Time to construct the tree (in seconds) 11 | 3. Time for sequential queries (in seconds) 12 | 4. Time per point for batch queries (in seconds) 13 | 5. Time to build a conventional KD tree (in seconds) 14 | 6. Time to query a range in a conventional KD tree (in seconds) 15 | """ 16 | 17 | import csv 18 | import matplotlib.pyplot as plt 19 | import sys 20 | 21 | csv_fname = sys.argv[1] 22 | n_points = [] 23 | build_times = [] 24 | seq_times = [] 25 | simd_times = [] 26 | kdt_build_times = [] 27 | kdt_query_times = [] 28 | with open(csv_fname) as f: 29 | for row in csv.reader(f, delimiter=","): 30 | n_points.append(int(row[0])) 31 | build_times.append(float(row[1])) 32 | seq_times.append(float(row[2])) 33 | simd_times.append(float(row[3])) 34 | kdt_build_times.append(float(row[4])) 35 | kdt_query_times.append(float(row[5])) 36 | 37 | plt.plot(n_points, build_times, label="Affordance tree") 38 | plt.plot(n_points, kdt_build_times, label="KD tree") 39 | plt.title("Construction time for trees") 40 | plt.xlabel("Number of points in the tree") 41 | plt.ylabel("Time for construction (s)") 42 | plt.legend() 43 | plt.show() 44 | 45 | 46 | plt.plot( 47 | n_points, [p * 1e9 for p in seq_times], label="Sequential affordance tree queries" 48 | ) 49 | plt.plot(n_points, [p * 1e9 for p in simd_times], label="SIMD affordance tree queries") 50 | plt.plot(n_points, [p * 1e9 for p in kdt_query_times], label="KD tree queries") 51 | plt.title("Nearest neighbors + collision check time") 52 | plt.xlabel("Number of points in the tree") 53 | plt.ylabel("Time for query and collision checking (ns)") 54 | plt.legend() 55 | plt.show() 56 | -------------------------------------------------------------------------------- /bench/scripts/plot_error_hist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ 4 | USAGE: 5 | ``` 6 | python ./plot_error_hist.py path/to/errors.csv 7 | ``` 8 | 9 | Columns must be separated by tabs (`\\t`). 10 | errors.csv must have anything in column 0, true distance in column 1, estimated distance in column 11 | 2, and relative error in column 3. 12 | """ 13 | 14 | import csv 15 | import matplotlib.pyplot as plt 16 | import sys 17 | import numpy as np 18 | 19 | csv_fname = sys.argv[1] 20 | rel_errs = [] 21 | true_dists = [] 22 | abs_errs = [] 23 | approx_dists = [] 24 | with open(csv_fname) as f: 25 | for row in csv.reader(f, delimiter="\t"): 26 | rel_errs.append(float(row[3])) 27 | true_dists.append(float(row[1])) 28 | abs_errs.append(float(row[2]) - float(row[1])) 29 | approx_dists.append(float(row[2])) 30 | 31 | h, edges = np.histogram(abs_errs, bins=400) 32 | cy = np.cumsum(h / np.sum(h)) 33 | 34 | plt.plot(edges[:-1], cy) 35 | plt.fill_between( edges[:-1], cy, step="pre", alpha=0.4) 36 | plt.xlabel("Absolute distance error (m)") 37 | plt.ylabel("Frequency") 38 | plt.title(f"CDF of forward tree absolute error distribution") 39 | plt.show() 40 | -------------------------------------------------------------------------------- /bench/scripts/plot_filter_strength.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def main(): 7 | radii = np.genfromtxt(sys.argv[1], usecols=0, delimiter=',') 8 | ns = np.genfromtxt(sys.argv[1], usecols=(1, 2, 3, 4, 5, 6), delimiter=',') 9 | for i, col in enumerate(ns.T): 10 | plt.plot(radii * 100, col, label=f"{i + 1} permutations") 11 | plt.xlabel("Filter radius (cm)") 12 | plt.ylabel("Number of points after filtering") 13 | plt.semilogy() 14 | plt.title("Effectiveness of space-filling curve filter by radius") 15 | plt.legend() 16 | plt.show() 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /bench/scripts/plot_forest_error.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | USAGE: 4 | ``` 5 | python ./plot_error_hist.py path/to/errors.csv 6 | ``` 7 | 8 | Columns must be separated by tabs (`\\t`). 9 | errors.csv must have the number of trees in the forest in column 0, the absolute error in column 1, 10 | the relative error in column 2, and the exact distance in column 3. 11 | """ 12 | 13 | import csv 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import sys 17 | import collections 18 | 19 | csv_fname = sys.argv[1] 20 | abs_errs = collections.defaultdict(lambda: []) 21 | rel_errs = collections.defaultdict(lambda: []) 22 | exact_dists = collections.defaultdict(lambda: []) 23 | with open(csv_fname) as f: 24 | for row in csv.reader(f, delimiter="\t"): 25 | n_trees = int(row[0]) 26 | abs_err = float(row[1]) 27 | rel_err = float(row[2]) 28 | exact_dist = float(row[3]) 29 | abs_errs[n_trees].append(abs_err) 30 | rel_errs[n_trees].append(rel_err) 31 | exact_dists[n_trees].append(exact_dist) 32 | 33 | # CDFs 34 | 35 | for n_trees, abs_err in abs_errs.items(): 36 | h, edges = np.histogram(abs_err, bins=400) 37 | cy = np.cumsum(h / np.sum(h)) 38 | 39 | plt.plot(edges[:-1], cy, label=f"T={n_trees}") 40 | # plt.fill_between(edges[:-1], cy, step="pre", alpha=0.4) 41 | plt.xlabel("Absolute distance error (m)") 42 | plt.ylabel("Frequency") 43 | plt.title(f"CDF of forest absolute error distribution") 44 | plt.legend() 45 | plt.show() 46 | -------------------------------------------------------------------------------- /bench/scripts/plot_times_for_blog.py: -------------------------------------------------------------------------------- 1 | """ 2 | USAGE: 3 | ```sh 4 | ./plot_search_throughput path/to/hdf5 > results_pc.csv 5 | ./plot_search_throughput > results_unif.csv 6 | python plot_search_throughput.py results_pc.csv results_unif.csv 7 | ``` 8 | """ 9 | 10 | import sys 11 | import numpy as np 12 | import csv 13 | import matplotlib.pyplot as plt 14 | 15 | N_FORESTS = 10 16 | 17 | 18 | def plot_build_times(fname: str): 19 | n_points = np.genfromtxt(fname, usecols=0, delimiter=",") 20 | kdt_build_times = np.genfromtxt(fname, usecols=1, delimiter=",") 21 | forward_times = np.genfromtxt(fname, usecols=2, delimiter=",") 22 | captree_build_times = np.genfromtxt(fname, usecols=3, delimiter=",") 23 | forest_build_times = np.genfromtxt(fname, usecols=range(4, 4 + N_FORESTS), delimiter=",") 24 | 25 | plt.plot(n_points, kdt_build_times * 1e3, label="kiddo (k-d tree)") 26 | plt.plot(n_points, forward_times * 1e3, label="forward tree") 27 | 28 | for (i, btime) in enumerate(forest_build_times.T): 29 | plt.plot(n_points, btime * 1e3, label=f"forest (T={i + 1})") 30 | 31 | plt.legend() 32 | plt.xlabel("Number of points in cloud") 33 | plt.ylabel("Construction time (ms)") 34 | plt.title(f"Scaling of CC structure construction time") 35 | plt.show() 36 | 37 | plt.plot(n_points, kdt_build_times * 1e3, label="kiddo (k-d tree)") 38 | plt.plot(n_points, forward_times * 1e3, label="forward tree") 39 | plt.plot(n_points, captree_build_times * 1e3, label="CAPT") 40 | 41 | plt.legend() 42 | plt.xlabel("Number of points in cloud") 43 | plt.ylabel("Construction time (ms)") 44 | plt.title(f"Scaling of CC structure construction time") 45 | plt.show() 46 | 47 | 48 | 49 | def plot_query_times(fname: str, title: str): 50 | n_points = [] 51 | n_points = np.genfromtxt(fname, usecols=0, delimiter=",") 52 | n_tests = np.genfromtxt(fname, usecols=1, delimiter=",") 53 | kdt_times = np.genfromtxt(fname, usecols=2, delimiter=",") 54 | forward_seq_times = np.genfromtxt(fname, usecols=3, delimiter=",") 55 | forward_simd_times = np.genfromtxt(fname, usecols=4, delimiter=",") 56 | captree_seq_times = np.genfromtxt(fname, usecols=5, delimiter=",") 57 | captree_simd_times = np.genfromtxt(fname, usecols=6, delimiter=",") 58 | forest_times = np.genfromtxt(fname, usecols=range(7, 7 + N_FORESTS), delimiter=",") 59 | 60 | plt.plot(n_points, kdt_times * 1e9, label="kiddo (k-d tree)") 61 | plt.plot(n_points, forward_seq_times * 1e9, label="forward tree (sequential)") 62 | plt.plot(n_points, forward_simd_times * 1e9, label="forward tree (SIMD)") 63 | plt.legend(loc="upper left") 64 | plt.semilogy() 65 | plt.xlabel("Number of points in cloud") 66 | plt.ylabel("Query time (ns)") 67 | plt.title(title) 68 | plt.show() 69 | 70 | 71 | plt.plot(n_points, kdt_times * 1e9, label="kiddo (k-d tree)") 72 | plt.plot(n_points, forward_simd_times * 1e9, label="forward tree (SIMD)") 73 | 74 | plt.plot(n_points, forest_times[:, 0] * 1e9, label=f"forest (SIMD, T={1})") 75 | plt.plot(n_points, forest_times[:, 4] * 1e9, label=f"forest (SIMD, T={5})") 76 | plt.plot(n_points, forest_times[:, 9] * 1e9, label=f"forest (SIMD, T={10})") 77 | 78 | plt.legend(loc="upper left") 79 | plt.semilogy() 80 | plt.xlabel("Number of points in cloud") 81 | plt.ylabel("Query time (ns)") 82 | plt.title(title) 83 | plt.plot(n_points, forward_seq_times * 1e9, label="forward tree (sequential)") 84 | plt.plot(n_points, forward_simd_times * 1e9, label="forward tree (SIMD)") 85 | plt.plot(n_points, captree_seq_times * 1e9, label="CAPT (sequential)") 86 | plt.plot(n_points, captree_simd_times * 1e9, label="CAPT (SIMD)") 87 | 88 | plt.legend(loc="upper left") 89 | plt.semilogy() 90 | plt.xlabel("Number of points in cloud") 91 | plt.ylabel("Query time (ns)") 92 | plt.title(title) 93 | plt.show() 94 | 95 | 96 | def plot_mem(fname: str): 97 | n_points = [] 98 | forward_mem = [] 99 | captree_mem = [] 100 | 101 | with open(fname) as f: 102 | reader = csv.reader(f, delimiter=",") 103 | next(reader, None) 104 | for row in reader: 105 | n_points.append(int(row[0])) 106 | forward_mem.append(int(row[1])) 107 | captree_mem.append(int(row[2])) 108 | 109 | plt.plot(n_points, forward_mem, label="Forward tree") 110 | plt.plot(n_points, captree_mem, label="CAPT") 111 | plt.semilogy() 112 | plt.legend() 113 | plt.xlabel("Number of points in cloud") 114 | plt.ylabel("Memory used (bytes)") 115 | plt.title("Memory consumption") 116 | plt.show() 117 | 118 | 119 | plot_build_times(sys.argv[1]) 120 | plot_query_times(sys.argv[2], "Scaling of CC on mixed queries") 121 | plot_query_times(sys.argv[3], "Scaling of CC on all-colliding queries") 122 | plot_query_times(sys.argv[4], "Scaling of CC on non-colliding queries") 123 | plot_mem(sys.argv[5]) 124 | -------------------------------------------------------------------------------- /bench/src/bin/correctness.rs: -------------------------------------------------------------------------------- 1 | #![feature(portable_simd)] 2 | 3 | use std::simd::Simd; 4 | 5 | use bench::{dist, kdt::PkdTree, parse_pointcloud_csv, parse_trace_csv, trace_r_range}; 6 | use captree::Capt; 7 | use kiddo::SquaredEuclidean; 8 | use rand::{seq::SliceRandom, Rng, SeedableRng}; 9 | 10 | const N: usize = 1 << 12; 11 | const R: f32 = 0.02; 12 | 13 | fn main() -> Result<(), Box> { 14 | let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(2707); 15 | let args = std::env::args().collect::>(); 16 | let points: Box<[[f32; 3]]> = if args.len() < 2 { 17 | (0..1 << 16) 18 | .map(|_| { 19 | [ 20 | rng.gen_range::(0.0..1.0), 21 | rng.gen_range::(0.0..1.0), 22 | rng.gen_range::(0.0..1.0), 23 | ] 24 | }) 25 | .collect() 26 | } else { 27 | let mut p = parse_pointcloud_csv(&args[1])?.to_vec(); 28 | p.shuffle(&mut rng); 29 | p.truncate(1 << 16); 30 | p.into_boxed_slice() 31 | }; 32 | 33 | let trace: Box<[([f32; 3], f32)]> = if args.len() < 3 { 34 | (0..N) 35 | .map(|_| { 36 | ( 37 | [ 38 | rng.gen_range(0.0..1.0), 39 | rng.gen_range(0.0..1.0), 40 | rng.gen_range(0.0..1.0), 41 | ], 42 | rng.gen_range(0.0..=R), 43 | ) 44 | }) 45 | .collect() 46 | } else { 47 | parse_trace_csv(&args[2])? 48 | }; 49 | 50 | let r_range = trace_r_range(&trace); 51 | 52 | let kdt = PkdTree::new(&points); 53 | let mut kiddo_kdt = kiddo::KdTree::new(); 54 | for pt in points.iter() { 55 | kiddo_kdt.add(pt, 0); 56 | } 57 | 58 | let aff_tree = Capt::<3>::new(&points, r_range); 59 | 60 | for (i, (center, r)) in trace.iter().enumerate() { 61 | let exact_kiddo_dist = kiddo_kdt 62 | .nearest_one::(center) 63 | .distance 64 | .sqrt(); 65 | let exact_dist = dist(kdt.get_point(kdt.query1_exact(*center)), *center); 66 | assert_eq!(exact_dist, exact_kiddo_dist); 67 | 68 | let simd_center: [Simd; 3] = [ 69 | Simd::splat(center[0]), 70 | Simd::splat(center[1]), 71 | Simd::splat(center[2]), 72 | ]; 73 | if exact_dist <= *r { 74 | println!("iter {i}: {:?} (collides)", (center, r)); 75 | assert!(aff_tree.collides(center, *r)); 76 | assert!(aff_tree.collides_simd(&simd_center, Simd::splat(*r))) 77 | } else { 78 | println!("iter {i}: {:?} (no collides)", (center, r)); 79 | assert!(!aff_tree.collides(center, *r)); 80 | assert!(!aff_tree.collides_simd(&simd_center, Simd::splat(*r))) 81 | } 82 | } 83 | 84 | Ok(()) 85 | } 86 | -------------------------------------------------------------------------------- /bench/src/bin/error.rs: -------------------------------------------------------------------------------- 1 | #![feature(portable_simd)] 2 | 3 | use std::simd::{LaneCount, SupportedLaneCount}; 4 | 5 | use bench::{dist, fuzz_pointcloud, get_points, kdt::PkdTree, make_needles}; 6 | use kiddo::SquaredEuclidean; 7 | use rand::{Rng, SeedableRng}; 8 | use rand_chacha::ChaCha20Rng; 9 | 10 | const N: usize = 1 << 16; 11 | const L: usize = 16; 12 | const D: usize = 3; 13 | 14 | fn main() { 15 | let mut rng = ChaCha20Rng::seed_from_u64(2707); 16 | let mut starting_points = get_points(N); 17 | fuzz_pointcloud(&mut starting_points, 0.001, &mut rng); 18 | measure_error::(&starting_points, &mut rng, 1 << 16) 19 | } 20 | 21 | pub fn measure_error( 22 | points: &[[f32; D]], 23 | rng: &mut impl Rng, 24 | n_trials: usize, 25 | ) where 26 | LaneCount: SupportedLaneCount, 27 | { 28 | let kdt = PkdTree::new(points); 29 | let mut kiddo_kdt = kiddo::KdTree::new(); 30 | for pt in points.iter() { 31 | kiddo_kdt.add(pt, 0); 32 | } 33 | 34 | let (seq_needles, _) = make_needles(rng, n_trials); 35 | 36 | for seq_needle in seq_needles { 37 | let exact_kiddo_dist = kiddo_kdt 38 | .nearest_one::(&seq_needle) 39 | .distance 40 | .sqrt(); 41 | let approx_dist = dist(seq_needle, kdt.approx_nearest(seq_needle)); 42 | let rel_error = approx_dist / exact_kiddo_dist - 1.0; 43 | println!("{seq_needle:?}\t{exact_kiddo_dist}\t{approx_dist}\t{rel_error}"); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /bench/src/bin/filter_strength.rs: -------------------------------------------------------------------------------- 1 | use std::{fs::File, io::Write}; 2 | 3 | use bench::parse_pointcloud_csv; 4 | use morton_filter::filter_permutation; 5 | 6 | const PERMUTATIONS_3D: [[u8; 3]; 6] = [ 7 | [0, 1, 2], 8 | [0, 2, 1], 9 | [1, 0, 2], 10 | [1, 2, 0], 11 | [2, 0, 1], 12 | [2, 1, 0], 13 | ]; 14 | 15 | fn main() -> Result<(), Box> { 16 | let mut f_results = File::create("filter_strength.csv")?; 17 | let args: Vec<_> = std::env::args().collect(); 18 | let p = parse_pointcloud_csv(&args[1])?.to_vec(); 19 | 20 | for i in 0..1000 { 21 | let r_filter = i as f32 / 10_000.0; 22 | write!(&mut f_results, "{r_filter}")?; 23 | let mut new_points = p.clone(); 24 | for perm in PERMUTATIONS_3D { 25 | filter_permutation(&mut new_points, r_filter, perm); 26 | write!(&mut f_results, ",{}", new_points.len())?; 27 | } 28 | writeln!(&mut f_results)?; 29 | } 30 | 31 | Ok(()) 32 | } 33 | -------------------------------------------------------------------------------- /bench/src/bin/forest_error.rs: -------------------------------------------------------------------------------- 1 | use bench::{forest::PkdForest, fuzz_pointcloud, get_points, make_needles}; 2 | use kiddo::SquaredEuclidean; 3 | use rand::{Rng, SeedableRng}; 4 | use rand_chacha::ChaCha20Rng; 5 | 6 | const N: usize = 1 << 12; 7 | 8 | fn main() { 9 | let mut rng = ChaCha20Rng::seed_from_u64(2707); 10 | let mut starting_points = get_points(N); 11 | fuzz_pointcloud(&mut starting_points, 0.001, &mut rng); 12 | 13 | err_forest::<1>(&starting_points, &mut rng); 14 | err_forest::<2>(&starting_points, &mut rng); 15 | err_forest::<3>(&starting_points, &mut rng); 16 | err_forest::<4>(&starting_points, &mut rng); 17 | err_forest::<5>(&starting_points, &mut rng); 18 | err_forest::<6>(&starting_points, &mut rng); 19 | err_forest::<7>(&starting_points, &mut rng); 20 | err_forest::<8>(&starting_points, &mut rng); 21 | err_forest::<9>(&starting_points, &mut rng); 22 | err_forest::<10>(&starting_points, &mut rng); 23 | } 24 | 25 | fn err_forest(points: &[[f32; 3]], rng: &mut impl Rng) { 26 | let forest = PkdForest::<3, T>::new(points); 27 | 28 | let mut kiddo_kdt = kiddo::KdTree::new(); 29 | for pt in points { 30 | kiddo_kdt.add(pt, 0); 31 | } 32 | 33 | let (seq_needles, _) = make_needles::<3, 1>(rng, 10_000); 34 | 35 | let mut total_err = 0.0; 36 | for &needle in &seq_needles { 37 | let (_, forest_distsq) = forest.approx_nearest(needle); 38 | let exact_distsq = kiddo_kdt.nearest_one::(&needle).distance; 39 | 40 | let exact_dist = exact_distsq.sqrt(); 41 | let err = forest_distsq.sqrt() - exact_dist; 42 | total_err += err; 43 | let rel_err = err / exact_distsq.sqrt(); 44 | println!("{T}\t{err}\t{rel_err}\t{exact_dist}"); 45 | } 46 | 47 | eprintln!("T={T}: mean error {}", total_err / seq_needles.len() as f32); 48 | } 49 | -------------------------------------------------------------------------------- /bench/src/bin/perf_plots.rs: -------------------------------------------------------------------------------- 1 | #![feature(portable_simd)] 2 | 3 | use std::{ 4 | cmp::min, env::args, error::Error, fs::File, hint::black_box, io::Write, simd::f32x8, 5 | time::Duration, 6 | }; 7 | 8 | use bench::{ 9 | forest::PkdForest, fuzz_pointcloud, kdt::PkdTree, parse_pointcloud_csv, parse_trace_csv, 10 | simd_trace_new, stopwatch, SimdTrace, Trace, 11 | }; 12 | use captree::Capt; 13 | #[allow(unused_imports)] 14 | use kiddo::SquaredEuclidean; 15 | use morton_filter::morton_filter; 16 | use rand::{seq::SliceRandom, Rng}; 17 | use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; 18 | 19 | const N_TRIALS: usize = 100_000; 20 | const L: usize = 8; 21 | 22 | const QUERY_RADIUS: f32 = 0.05; 23 | 24 | struct Benchmark<'a> { 25 | seq: &'a Trace, 26 | simd: &'a SimdTrace, 27 | f_query: File, 28 | } 29 | 30 | fn main() -> Result<(), Box> { 31 | let mut f_construct = File::create("construct_time.csv")?; 32 | let mut f_mem = File::create("mem.csv")?; 33 | 34 | let args: Vec = args().collect(); 35 | 36 | let mut rng = ChaCha20Rng::seed_from_u64(2707); 37 | let points: Box<[[f32; 3]]> = if args.len() < 2 { 38 | (0..1 << 16) 39 | .map(|_| { 40 | [ 41 | rng.gen_range::(0.0..1.0), 42 | rng.gen_range::(0.0..1.0), 43 | rng.gen_range::(0.0..1.0), 44 | ] 45 | }) 46 | .collect() 47 | } else { 48 | let mut p = parse_pointcloud_csv(&args[1])?.to_vec(); 49 | fuzz_pointcloud(&mut p, 0.001, &mut rng); 50 | p.shuffle(&mut rng); 51 | p.truncate(1 << 16); 52 | p.into_boxed_slice() 53 | }; 54 | 55 | let all_trace: Box<[([f32; 3], f32)]> = if args.len() < 3 { 56 | (0..N_TRIALS) 57 | .map(|_| { 58 | ( 59 | [ 60 | rng.gen_range(0.0..1.0), 61 | rng.gen_range(0.0..1.0), 62 | rng.gen_range(0.0..1.0), 63 | ], 64 | rng.gen_range(0.0..=QUERY_RADIUS), 65 | ) 66 | }) 67 | .collect() 68 | } else { 69 | parse_trace_csv(&args[2])? 70 | }; 71 | 72 | // let rsq_range = ( 73 | // all_trace 74 | // .iter() 75 | // .map(|x| x.1) 76 | // .min_by(|a, b| a.partial_cmp(b).unwrap()) 77 | // .ok_or("no points")? 78 | // .powi(2), 79 | // all_trace 80 | // .iter() 81 | // .map(|x| x.1) 82 | // .max_by(|a, b| a.partial_cmp(b).unwrap()) 83 | // .ok_or("no points")? 84 | // .powi(2), 85 | // ); 86 | let r_range = (0.01, 0.08); 87 | 88 | println!("number of points: {}", points.len()); 89 | println!("number of tests: {}", all_trace.len()); 90 | println!("radius range: {r_range:?}"); 91 | 92 | let captree = Capt::<3>::new(&points, r_range); 93 | 94 | let collide_trace: Box = all_trace 95 | .iter() 96 | .filter(|(center, r)| captree.collides(center, *r)) 97 | .copied() 98 | .collect(); 99 | 100 | let no_collide_trace: Box = all_trace 101 | .iter() 102 | .filter(|(center, r)| !captree.collides(center, *r)) 103 | .copied() 104 | .collect(); 105 | 106 | let all_simd_trace = simd_trace_new(&all_trace); 107 | let collide_simd_trace = simd_trace_new(&collide_trace); 108 | let no_collide_simd_trace = simd_trace_new(&no_collide_trace); 109 | 110 | let mut benchmarks = [ 111 | Benchmark { 112 | seq: &all_trace, 113 | simd: &all_simd_trace, 114 | f_query: File::create("mixed.csv")?, 115 | }, 116 | Benchmark { 117 | seq: &collide_trace, 118 | simd: &collide_simd_trace, 119 | f_query: File::create("collides.csv")?, 120 | }, 121 | Benchmark { 122 | seq: &no_collide_trace, 123 | simd: &no_collide_simd_trace, 124 | f_query: File::create("no_collides.csv")?, 125 | }, 126 | ]; 127 | 128 | let mut r_filter = 0.001; 129 | loop { 130 | let mut new_points = points.to_vec(); 131 | morton_filter(&mut new_points, r_filter); 132 | do_row( 133 | &new_points, 134 | &mut benchmarks, 135 | r_range, 136 | &mut f_construct, 137 | &mut f_mem, 138 | )?; 139 | r_filter *= 1.02; 140 | if new_points.len() < 500 { 141 | break; 142 | } 143 | } 144 | 145 | Ok(()) 146 | } 147 | 148 | fn do_row( 149 | points: &[[f32; 3]], 150 | benchmarks: &mut [Benchmark], 151 | r_range: (f32, f32), 152 | f_construct: &mut File, 153 | f_mem: &mut File, 154 | ) -> Result<(), Box> { 155 | let (kdt, kdt_time) = stopwatch(|| kiddo::ImmutableKdTree::new_from_slice(points)); 156 | 157 | let (pkdt, pkdt_time) = stopwatch(|| PkdTree::new(points)); 158 | 159 | let (captree, captree_time) = stopwatch(|| Capt::<3, L, f32, u32>::new(points, r_range)); 160 | 161 | let (f1, f1_time) = stopwatch(|| PkdForest::<3, 1>::new(points)); 162 | let (f2, f2_time) = stopwatch(|| PkdForest::<3, 2>::new(points)); 163 | let (f3, f3_time) = stopwatch(|| PkdForest::<3, 3>::new(points)); 164 | let (f4, f4_time) = stopwatch(|| PkdForest::<3, 4>::new(points)); 165 | let (f5, f5_time) = stopwatch(|| PkdForest::<3, 5>::new(points)); 166 | let (f6, f6_time) = stopwatch(|| PkdForest::<3, 6>::new(points)); 167 | let (f7, f7_time) = stopwatch(|| PkdForest::<3, 7>::new(points)); 168 | let (f8, f8_time) = stopwatch(|| PkdForest::<3, 8>::new(points)); 169 | let (f9, f9_time) = stopwatch(|| PkdForest::<3, 9>::new(points)); 170 | let (f10, f10_time) = stopwatch(|| PkdForest::<3, 10>::new(points)); 171 | 172 | writeln!( 173 | f_construct, 174 | "{},{},{},{},{},{},{},{},{},{},{},{},{},{}", 175 | points.len(), 176 | kdt_time.as_secs_f64(), 177 | pkdt_time.as_secs_f64(), 178 | captree_time.as_secs_f64(), 179 | f1_time.as_secs_f64(), 180 | f2_time.as_secs_f64(), 181 | f3_time.as_secs_f64(), 182 | f4_time.as_secs_f64(), 183 | f5_time.as_secs_f64(), 184 | f6_time.as_secs_f64(), 185 | f7_time.as_secs_f64(), 186 | f8_time.as_secs_f64(), 187 | f9_time.as_secs_f64(), 188 | f10_time.as_secs_f64(), 189 | )?; 190 | 191 | writeln!( 192 | f_mem, 193 | "{},{},{}", 194 | points.len(), 195 | pkdt.memory_used(), 196 | captree.memory_used() 197 | )?; 198 | 199 | for Benchmark { 200 | seq: trace, 201 | simd: simd_trace, 202 | f_query, 203 | } in benchmarks 204 | { 205 | let (_, kdt_within_q_time) = stopwatch(|| { 206 | for (center, radius) in trace.iter() { 207 | black_box( 208 | kdt.within_unsorted::(center, radius.powi(2)) 209 | .is_empty(), 210 | ); 211 | } 212 | }); 213 | let (_, kdt_nearest_q_time) = stopwatch(|| { 214 | for (center, radius) in trace.iter() { 215 | black_box(kdt.nearest_one::(center).distance <= radius.powi(2)); 216 | } 217 | }); 218 | let kdt_total_q_time = min(kdt_within_q_time, kdt_nearest_q_time); 219 | 220 | let (_, pkdt_total_seq_q_time) = stopwatch(|| { 221 | for (center, radius) in trace.iter() { 222 | black_box(pkdt.might_collide(*center, radius.powi(2))); 223 | } 224 | }); 225 | let (_, pkdt_total_simd_q_time) = stopwatch(|| { 226 | for (centers, radii) in simd_trace.iter() { 227 | black_box(pkdt.might_collide_simd(centers, radii * radii)); 228 | } 229 | }); 230 | let (_, captree_total_seq_q_time) = stopwatch(|| { 231 | for (center, radius) in trace.iter() { 232 | black_box(captree.collides(center, radius.powi(2))); 233 | } 234 | }); 235 | let (_, captree_total_simd_q_time) = stopwatch(|| { 236 | for (centers, radii) in simd_trace.iter() { 237 | black_box(captree.collides_simd(centers, radii * radii)); 238 | } 239 | }); 240 | 241 | let trace_len = trace.len() as f64; 242 | write!( 243 | f_query, 244 | "{},{},{},{},{},{},{}", 245 | points.len(), 246 | trace.len(), 247 | kdt_total_q_time.as_secs_f64() / trace_len, 248 | pkdt_total_seq_q_time.as_secs_f64() / trace_len, 249 | pkdt_total_simd_q_time.as_secs_f64() / trace_len, 250 | captree_total_seq_q_time.as_secs_f64() / trace_len, 251 | captree_total_simd_q_time.as_secs_f64() / trace_len, 252 | )?; 253 | 254 | let forest_results = [ 255 | bench_forest(&f1, simd_trace), 256 | bench_forest(&f2, simd_trace), 257 | bench_forest(&f3, simd_trace), 258 | bench_forest(&f4, simd_trace), 259 | bench_forest(&f5, simd_trace), 260 | bench_forest(&f6, simd_trace), 261 | bench_forest(&f7, simd_trace), 262 | bench_forest(&f8, simd_trace), 263 | bench_forest(&f9, simd_trace), 264 | bench_forest(&f10, simd_trace), 265 | ]; 266 | 267 | for query_time in forest_results { 268 | write!(f_query, ",{}", query_time.as_secs_f64() / trace_len)?; 269 | } 270 | 271 | writeln!(f_query)?; 272 | } 273 | 274 | Ok(()) 275 | } 276 | 277 | fn bench_forest( 278 | forest: &PkdForest<3, T>, 279 | simd_trace: &[([f32x8; 3], f32x8)], 280 | ) -> Duration { 281 | stopwatch(|| { 282 | for (centers, radii) in simd_trace { 283 | black_box(forest.might_collide_simd(centers, radii * radii)); 284 | } 285 | }) 286 | .1 287 | } 288 | -------------------------------------------------------------------------------- /bench/src/forest.rs: -------------------------------------------------------------------------------- 1 | //! Power-of-two k-d forests. 2 | 3 | use std::simd::{ 4 | cmp::SimdPartialOrd, ptr::SimdConstPtr, LaneCount, Mask, Simd, SupportedLaneCount, 5 | }; 6 | 7 | use crate::{distsq, median_partition}; 8 | 9 | #[derive(Clone, Debug)] 10 | struct RandomizedTree { 11 | tests: Box<[f32]>, 12 | seed: u32, 13 | points: Box<[[f32; K]]>, 14 | } 15 | 16 | #[derive(Clone, Debug)] 17 | #[allow(clippy::module_name_repetitions)] 18 | pub struct PkdForest { 19 | test_seqs: [RandomizedTree; T], 20 | } 21 | 22 | impl PkdForest { 23 | const T_NONE: Option> = None; 24 | #[allow(clippy::cast_possible_truncation)] 25 | #[must_use] 26 | pub fn new(points: &[[f32; K]]) -> Self { 27 | let mut trees = [Self::T_NONE; T]; 28 | trees 29 | .iter_mut() 30 | .enumerate() 31 | .for_each(|(t, opt)| *opt = Some(RandomizedTree::new(points, t as u32))); 32 | Self { 33 | test_seqs: trees.map(Option::unwrap), 34 | } 35 | } 36 | 37 | #[must_use] 38 | /// # Panics 39 | /// 40 | /// This function will panic if `T` is 0. 41 | pub fn approx_nearest(&self, needle: [f32; K]) -> ([f32; K], f32) { 42 | self.test_seqs 43 | .iter() 44 | .map(|t| t.points[t.forward_pass(&needle)]) 45 | .map(|point| (point, distsq(needle, point))) 46 | .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()) 47 | .unwrap() 48 | } 49 | 50 | #[must_use] 51 | pub fn might_collide(&self, needle: [f32; K], r_squared: f32) -> bool { 52 | self.test_seqs 53 | .iter() 54 | .any(|t| distsq(t.points[t.forward_pass(&needle)], needle) < r_squared) 55 | } 56 | 57 | #[must_use] 58 | pub fn might_collide_simd( 59 | &self, 60 | needles: &[Simd; K], 61 | radii_squared: Simd, 62 | ) -> bool 63 | where 64 | LaneCount: SupportedLaneCount, 65 | { 66 | let mut not_yet_collided = Mask::splat(true); 67 | 68 | for tree in &self.test_seqs { 69 | let indices = tree.mask_query(needles, not_yet_collided); 70 | let mut dists_sq = Simd::splat(0.0); 71 | let mut ptrs = Simd::splat(tree.points.as_ptr().cast()).wrapping_offset(indices); 72 | for needle_set in needles { 73 | let diffs = 74 | unsafe { Simd::gather_select_ptr(ptrs, not_yet_collided, Simd::splat(0.0)) } 75 | - needle_set; 76 | dists_sq += diffs * diffs; 77 | ptrs = ptrs.wrapping_add(Simd::splat(1)); 78 | } 79 | 80 | not_yet_collided &= radii_squared.simd_lt(dists_sq).cast(); 81 | 82 | if !not_yet_collided.all() { 83 | // at least one has collided - can return quickly 84 | return true; 85 | } 86 | } 87 | 88 | false 89 | } 90 | } 91 | 92 | impl RandomizedTree { 93 | pub fn new(points: &[[f32; K]], seed: u32) -> Self { 94 | /// Recursive helper function to sort the points for the KD tree and generate the tests. 95 | fn recur_sort_points( 96 | points: &mut [[f32; K]], 97 | tests: &mut [f32], 98 | test_dims: &mut [u8], 99 | i: usize, 100 | state: u32, 101 | ) { 102 | if points.len() > 1 { 103 | let d = state as usize % K; 104 | tests[i] = median_partition(points, d); 105 | test_dims[i] = u8::try_from(d).unwrap(); 106 | let (lhs, rhs) = points.split_at_mut(points.len() / 2); 107 | recur_sort_points(lhs, tests, test_dims, 2 * i + 1, xorshift(state)); 108 | recur_sort_points(rhs, tests, test_dims, 2 * i + 2, xorshift(state)); 109 | } 110 | } 111 | 112 | assert!(K < u8::MAX as usize); 113 | 114 | let n2 = points.len().next_power_of_two(); 115 | 116 | let mut tests = vec![f32::INFINITY; n2 - 1].into_boxed_slice(); 117 | 118 | // hack: just pad with infinity to make it a power of 2 119 | let mut new_points = vec![[f32::INFINITY; K]; n2]; 120 | new_points[..points.len()].copy_from_slice(points); 121 | let mut test_dims = vec![0; n2 - 1].into_boxed_slice(); 122 | recur_sort_points( 123 | new_points.as_mut(), 124 | tests.as_mut(), 125 | test_dims.as_mut(), 126 | 0, 127 | seed, 128 | ); 129 | 130 | Self { 131 | tests, 132 | points: new_points.into_boxed_slice(), 133 | seed, 134 | } 135 | } 136 | 137 | fn forward_pass(&self, point: &[f32; K]) -> usize { 138 | let mut test_idx = 0; 139 | let mut k = 0; 140 | let mut state = self.seed; 141 | for _ in 0..self.tests.len().trailing_ones() { 142 | test_idx = 2 * test_idx 143 | + 1 144 | + usize::from(unsafe { *self.tests.get_unchecked(test_idx) } <= point[k]); 145 | state = xorshift(state); 146 | k = state as usize % K; 147 | } 148 | 149 | // retrieve affordance buffer location 150 | test_idx - self.tests.len() 151 | } 152 | 153 | #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] 154 | /// Perform a masked SIMD query of this tree, only determining the location of the nearest 155 | /// neighbors for points in `mask`. 156 | fn mask_query( 157 | &self, 158 | needles: &[Simd; K], 159 | mask: Mask, 160 | ) -> Simd 161 | where 162 | LaneCount: SupportedLaneCount, 163 | { 164 | let mut test_idxs: Simd = Simd::splat(0); 165 | let mut state = self.seed; 166 | 167 | // Advance the tests forward 168 | for _ in 0..self.tests.len().trailing_ones() { 169 | let relevant_tests: Simd = unsafe { 170 | Simd::gather_select_ptr( 171 | Simd::splat(self.tests.as_ptr().cast()).wrapping_offset(test_idxs), 172 | mask, 173 | Simd::splat(f32::NAN), 174 | ) 175 | }; 176 | let d = state as usize % K; 177 | let cmp_results: Mask = (needles[d].simd_ge(relevant_tests)).into(); 178 | 179 | // TODO is there a faster way than using a conditional select? 180 | test_idxs <<= Simd::splat(1); 181 | test_idxs += Simd::splat(1); 182 | test_idxs += cmp_results.to_int() & Simd::splat(1); 183 | state = xorshift(state); 184 | } 185 | 186 | test_idxs - Simd::splat(self.tests.len() as isize) 187 | } 188 | } 189 | 190 | #[inline] 191 | /// Compute the next value in the xorshift sequence given the most recent value. 192 | const fn xorshift(mut x: u32) -> u32 { 193 | x ^= x << 13; 194 | x ^= x >> 17; 195 | x ^= x << 5; 196 | x 197 | } 198 | 199 | #[cfg(test)] 200 | mod tests { 201 | use super::*; 202 | 203 | #[test] 204 | fn build_a_forest() { 205 | let points = [[0.0, 0.0], [0.2, 1.0], [-1.0, 0.4]]; 206 | 207 | let forest = PkdForest::<2, 2>::new(&points); 208 | println!("{forest:#?}"); 209 | } 210 | 211 | #[test] 212 | #[allow(clippy::float_cmp)] 213 | fn find_the_closest() { 214 | let points = [[0.0, 0.0], [0.2, 1.0], [-1.0, 0.4]]; 215 | 216 | let forest = PkdForest::<2, 2>::new(&points); 217 | // assert_eq!(forest.query1([0.01, 0.02]), ([])) 218 | let (nearest, ndsq) = forest.approx_nearest([0.01, 0.02]); 219 | assert_eq!(nearest, [0.0, 0.0]); 220 | assert!((ndsq - 0.0005) < 1e-6); 221 | println!("{:?}", forest.approx_nearest([0.01, 0.02])); 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /bench/src/kdt.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | mem::size_of, 3 | simd::{num::SimdInt, Simd, SupportedLaneCount}, 4 | }; 5 | 6 | use captree::{Aabb, Axis, AxisSimd}; 7 | 8 | use std::simd::{ 9 | cmp::{SimdPartialEq, SimdPartialOrd}, 10 | ptr::SimdConstPtr, 11 | LaneCount, Mask, 12 | }; 13 | 14 | use crate::{distsq, forward_pass, median_partition}; 15 | 16 | #[derive(Clone, Debug, PartialEq)] 17 | /// A power-of-two KD-tree. 18 | /// 19 | /// # Generic parameters 20 | /// 21 | /// - `D`: The dimension of the space. 22 | pub struct PkdTree { 23 | /// The test values for determining which part of the tree to enter. 24 | /// 25 | /// The first element of `tests` should be the first value to test against. 26 | /// If we are less than `tests[0]`, we move on to `tests[1]`; if not, we move on to `tests[2]`. 27 | /// At the `i`-th test performed in sequence of the traversal, if we are less than 28 | /// `tests[idx]`, we advance to `2 * idx + 1`; otherwise, we go to `2 * idx + 2`. 29 | /// 30 | /// The length of `tests` must be `N`, rounded up to the next power of 2, minus one. 31 | tests: Box<[f32]>, 32 | /// The relevant points at the center of each volume divided by `tests`. 33 | points: Box<[[f32; K]]>, 34 | } 35 | 36 | impl PkdTree { 37 | #[must_use] 38 | #[allow(clippy::cast_possible_truncation)] 39 | /// Construct a new `PkdTree` containing all the points in `points`. 40 | /// For performance, this function changes the ordering of `points`, but does not affect the 41 | /// set of points inside it. 42 | /// 43 | /// # Panics 44 | /// 45 | /// This function will panic if `D` is greater than or equal to 255. 46 | /// 47 | /// TODO: do all our sorting on the allocation that we return? 48 | pub fn new(points: &[[f32; K]]) -> Self { 49 | /// Recursive helper function to sort the points for the KD tree and generate the tests. 50 | /// Runs in O(n log n) 51 | fn build_tree(points: &mut [[f32; K]], tests: &mut [f32], k: u8, i: usize) { 52 | if points.len() > 1 { 53 | tests[i] = median_partition(points, k as usize); 54 | let next_k = (k + 1) % K as u8; 55 | let (lhs, rhs) = points.split_at_mut(points.len() / 2); 56 | build_tree(lhs, tests, next_k, 2 * i + 1); 57 | build_tree(rhs, tests, next_k, 2 * i + 2); 58 | } 59 | } 60 | 61 | assert!(K < u8::MAX as usize); 62 | 63 | let n2 = points.len().next_power_of_two(); 64 | 65 | let mut tests = vec![f32::INFINITY; n2 - 1].into_boxed_slice(); 66 | 67 | // hack: just pad with infinity to make it a power of 2 68 | let mut new_points = vec![[f32::INFINITY; K]; n2].into_boxed_slice(); 69 | new_points[..points.len()].copy_from_slice(points); 70 | build_tree(new_points.as_mut(), tests.as_mut(), 0, 0); 71 | 72 | Self { 73 | tests, 74 | points: new_points, 75 | } 76 | } 77 | 78 | #[must_use] 79 | pub fn approx_nearest(&self, needle: [f32; K]) -> [f32; K] { 80 | self.get_point(forward_pass(&self.tests, &needle)) 81 | } 82 | 83 | #[must_use] 84 | /// Determine whether a ball centered at `needle` with radius `r_squared` could collide with a 85 | /// point in this tree. 86 | pub fn might_collide(&self, needle: [f32; K], r_squared: f32) -> bool { 87 | distsq(self.approx_nearest(needle), needle) <= r_squared 88 | } 89 | 90 | #[must_use] 91 | #[allow(clippy::cast_possible_wrap)] 92 | pub fn might_collide_simd( 93 | &self, 94 | needles: &[Simd; K], 95 | radii_squared: Simd, 96 | ) -> bool 97 | where 98 | LaneCount: SupportedLaneCount, 99 | { 100 | let indices = forward_pass_simd(&self.tests, needles); 101 | let mut dists_squared = Simd::splat(0.0); 102 | let mut ptrs = 103 | Simd::splat(self.points.as_ptr().cast()).wrapping_add(indices * Simd::splat(K)); 104 | for needle_values in needles { 105 | let deltas = unsafe { Simd::gather_ptr(ptrs) } - needle_values; 106 | dists_squared += deltas * deltas; 107 | ptrs = ptrs.wrapping_add(Simd::splat(1)); 108 | } 109 | dists_squared.simd_lt(radii_squared).any() 110 | } 111 | 112 | #[must_use] 113 | #[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] 114 | /// Query for one point in this tree, returning an exact answer. 115 | pub fn query1_exact(&self, needle: [f32; K]) -> usize { 116 | let mut id = usize::MAX; 117 | let mut best_distsq = f32::INFINITY; 118 | self.exact_help( 119 | 0, 120 | 0, 121 | &Aabb { 122 | lo: [-f32::INFINITY; K], 123 | hi: [f32::INFINITY; K], 124 | }, 125 | needle, 126 | &mut id, 127 | &mut best_distsq, 128 | ); 129 | id 130 | } 131 | 132 | #[allow(clippy::cast_possible_truncation)] 133 | fn exact_help( 134 | &self, 135 | test_idx: usize, 136 | k: u8, 137 | bounding_box: &Aabb, 138 | point: [f32; K], 139 | best_id: &mut usize, 140 | best_distsq: &mut f32, 141 | ) { 142 | if bounding_box.closest_distsq_to(&point) > *best_distsq { 143 | return; 144 | } 145 | 146 | if self.tests.len() <= test_idx { 147 | let id = test_idx - self.tests.len(); 148 | let new_distsq = distsq(point, self.get_point(id)); 149 | if new_distsq < *best_distsq { 150 | *best_id = id; 151 | *best_distsq = new_distsq; 152 | } 153 | 154 | return; 155 | } 156 | 157 | let test = self.tests[test_idx]; 158 | 159 | let mut bb_below = *bounding_box; 160 | bb_below.hi[k as usize] = test; 161 | let mut bb_above = *bounding_box; 162 | bb_above.lo[k as usize] = test; 163 | 164 | let next_k = (k + 1) % K as u8; 165 | if point[k as usize] < test { 166 | self.exact_help( 167 | 2 * test_idx + 1, 168 | next_k, 169 | &bb_below, 170 | point, 171 | best_id, 172 | best_distsq, 173 | ); 174 | self.exact_help( 175 | 2 * test_idx + 2, 176 | next_k, 177 | &bb_above, 178 | point, 179 | best_id, 180 | best_distsq, 181 | ); 182 | } else { 183 | self.exact_help( 184 | 2 * test_idx + 2, 185 | next_k, 186 | &bb_above, 187 | point, 188 | best_id, 189 | best_distsq, 190 | ); 191 | self.exact_help( 192 | 2 * test_idx + 1, 193 | next_k, 194 | &bb_below, 195 | point, 196 | best_id, 197 | best_distsq, 198 | ); 199 | } 200 | } 201 | 202 | #[must_use] 203 | #[allow(clippy::missing_panics_doc)] 204 | pub const fn get_point(&self, id: usize) -> [f32; K] { 205 | self.points[id] 206 | } 207 | 208 | #[must_use] 209 | /// Return the total memory used (stack + heap) by this structure. 210 | pub const fn memory_used(&self) -> usize { 211 | size_of::() + (self.points.len() * K + self.tests.len()) * size_of::() 212 | } 213 | } 214 | 215 | #[inline] 216 | #[allow(clippy::cast_possible_wrap)] 217 | fn forward_pass_simd( 218 | tests: &[A], 219 | centers: &[Simd; K], 220 | ) -> Simd 221 | where 222 | Simd: SimdPartialOrd, 223 | Mask: From< as SimdPartialEq>::Mask>, 224 | A: Axis + AxisSimd< as SimdPartialEq>::Mask>, 225 | LaneCount: SupportedLaneCount, 226 | { 227 | let mut i: Simd = Simd::splat(0); 228 | let mut k = 0; 229 | for _ in 0..tests.len().trailing_ones() { 230 | let test_ptrs = Simd::splat(tests.as_ptr()).wrapping_add(i); 231 | let relevant_tests = unsafe { Simd::gather_ptr(test_ptrs) }; 232 | let cmp: Mask = centers[k].simd_ge(relevant_tests).into(); 233 | 234 | let one = Simd::splat(1); 235 | i = (i << one) + one + (cmp.to_int().cast() & one); 236 | k = (k + 1) % K; 237 | } 238 | 239 | i - Simd::splat(tests.len()) 240 | } 241 | 242 | #[cfg(test)] 243 | mod tests { 244 | 245 | use crate::forward_pass; 246 | 247 | use super::*; 248 | 249 | #[test] 250 | fn single_query() { 251 | let points = vec![ 252 | [0.1, 0.1], 253 | [0.1, 0.2], 254 | [0.5, 0.0], 255 | [0.3, 0.9], 256 | [1.0, 1.0], 257 | [0.35, 0.75], 258 | [0.6, 0.2], 259 | [0.7, 0.8], 260 | ]; 261 | let kdt = PkdTree::new(&points); 262 | 263 | println!("testing for correctness..."); 264 | 265 | let neg1 = [-1.0, -1.0]; 266 | let neg1_idx = forward_pass(&kdt.tests, &neg1); 267 | assert_eq!(neg1_idx, 0); 268 | 269 | let pos1 = [1.0, 1.0]; 270 | let pos1_idx = forward_pass(&kdt.tests, &pos1); 271 | assert_eq!(pos1_idx, points.len() - 1); 272 | } 273 | 274 | #[test] 275 | #[allow(clippy::cast_possible_wrap)] 276 | fn multi_query() { 277 | let points = vec![ 278 | [0.1, 0.1], 279 | [0.1, 0.2], 280 | [0.5, 0.0], 281 | [0.3, 0.9], 282 | [1.0, 1.0], 283 | [0.35, 0.75], 284 | [0.6, 0.2], 285 | [0.7, 0.8], 286 | ]; 287 | let kdt = PkdTree::new(&points); 288 | 289 | let needles = [Simd::from_array([-1.0, 2.0]), Simd::from_array([-1.0, 2.0])]; 290 | assert_eq!( 291 | forward_pass_simd(&kdt.tests, &needles), 292 | Simd::from_array([0, points.len() - 1]) 293 | ); 294 | } 295 | 296 | #[test] 297 | fn not_a_power_of_two() { 298 | let points = vec![[0.0], [2.0], [4.0]]; 299 | let kdt = PkdTree::new(&points); 300 | 301 | println!("{kdt:?}"); 302 | 303 | assert_eq!(forward_pass(&kdt.tests, &[-0.1]), 0); 304 | assert_eq!(forward_pass(&kdt.tests, &[0.5]), 0); 305 | assert_eq!(forward_pass(&kdt.tests, &[1.5]), 1); 306 | assert_eq!(forward_pass(&kdt.tests, &[2.5]), 1); 307 | assert_eq!(forward_pass(&kdt.tests, &[3.5]), 2); 308 | assert_eq!(forward_pass(&kdt.tests, &[4.5]), 2); 309 | } 310 | 311 | #[test] 312 | fn a_power_of_two() { 313 | let points = vec![[0.0], [2.0], [4.0], [6.0]]; 314 | let kdt = PkdTree::new(&points); 315 | 316 | println!("{kdt:?}"); 317 | 318 | assert_eq!(forward_pass(&kdt.tests, &[-0.1]), 0); 319 | assert_eq!(forward_pass(&kdt.tests, &[0.5]), 0); 320 | assert_eq!(forward_pass(&kdt.tests, &[1.5]), 1); 321 | assert_eq!(forward_pass(&kdt.tests, &[2.5]), 1); 322 | assert_eq!(forward_pass(&kdt.tests, &[3.5]), 2); 323 | assert_eq!(forward_pass(&kdt.tests, &[4.5]), 2); 324 | } 325 | } 326 | -------------------------------------------------------------------------------- /bench/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(portable_simd)] 2 | 3 | use std::{ 4 | env, 5 | error::Error, 6 | path::Path, 7 | simd::{LaneCount, Simd, SupportedLaneCount}, 8 | time::{Duration, Instant}, 9 | }; 10 | 11 | use captree::Axis; 12 | use rand::{Rng, SeedableRng}; 13 | use rand_chacha::ChaCha20Rng; 14 | 15 | use rand_distr::{Distribution, Normal}; 16 | 17 | pub mod forest; 18 | pub mod kdt; 19 | 20 | pub fn get_points(n_points_if_no_cloud: usize) -> Box<[[f32; 3]]> { 21 | let args: Vec = env::args().collect(); 22 | let mut rng = ChaCha20Rng::seed_from_u64(2707); 23 | 24 | if args.len() > 1 { 25 | eprintln!("Loading pointcloud from {}", &args[1]); 26 | parse_pointcloud_csv(&args[1]).unwrap() 27 | } else { 28 | eprintln!("No pointcloud file! Using N={n_points_if_no_cloud}"); 29 | eprintln!("generating random points..."); 30 | (0..n_points_if_no_cloud) 31 | .map(|_| { 32 | [ 33 | rng.gen_range::(0.0..1.0), 34 | rng.gen_range::(0.0..1.0), 35 | rng.gen_range::(0.0..1.0), 36 | ] 37 | }) 38 | .collect() 39 | } 40 | } 41 | 42 | /// Generate some randomized numbers for us to benchmark against. 43 | /// 44 | /// # Generic parameters 45 | /// 46 | /// - `D`: the dimension of the space 47 | /// - `L`: the number of SIMD lanes 48 | /// 49 | /// # Returns 50 | /// 51 | /// Returns a pair `(seq_needles, simd_needles)`, where `seq_needles` is correctly shaped for 52 | /// sequential querying and `simd_needles` is correctly shaped for SIMD querying. 53 | pub fn make_needles( 54 | rng: &mut impl Rng, 55 | n_trials: usize, 56 | ) -> (Vec<[f32; D]>, Vec<[Simd; D]>) 57 | where 58 | LaneCount: SupportedLaneCount, 59 | { 60 | let mut seq_needles = Vec::new(); 61 | let mut simd_needles = Vec::new(); 62 | 63 | for _ in 0..n_trials / L { 64 | let mut simd_pts = [Simd::splat(0.0); D]; 65 | for l in 0..L { 66 | let mut seq_needle = [0.0; D]; 67 | for d in 0..3 { 68 | let value = rng.gen_range::(0.0..1.0); 69 | seq_needle[d] = value; 70 | simd_pts[d].as_mut_array()[l] = value; 71 | } 72 | seq_needles.push(seq_needle); 73 | } 74 | simd_needles.push(simd_pts); 75 | } 76 | 77 | assert_eq!(seq_needles.len(), simd_needles.len() * L); 78 | 79 | (seq_needles, simd_needles) 80 | } 81 | 82 | /// Generate some randomized numbers for us to benchmark against which are correlated within a SIMD 83 | /// batch. 84 | /// 85 | /// # Generic parameters 86 | /// 87 | /// - `D`: the dimension of the space 88 | /// - `L`: the number of SIMD lanes 89 | /// 90 | /// # Returns 91 | /// 92 | /// Returns a pair `(seq_needles, simd_needles)`, where `seq_needles` is correctly shaped for 93 | /// sequential querying and `simd_needles` is correctly shaped for SIMD querying. 94 | /// Additionally, each element of each element of `simd_needles` will be relatively close in space. 95 | pub fn make_correlated_needles( 96 | rng: &mut impl Rng, 97 | n_trials: usize, 98 | ) -> (Vec<[f32; D]>, Vec<[Simd; D]>) 99 | where 100 | LaneCount: SupportedLaneCount, 101 | { 102 | let mut seq_needles = Vec::new(); 103 | let mut simd_needles = Vec::new(); 104 | 105 | for _ in 0..n_trials / L { 106 | let mut start_pt = [0.0; D]; 107 | for v in start_pt.iter_mut() { 108 | *v = rng.gen_range::(0.0..1.0); 109 | } 110 | let mut simd_pts = [Simd::splat(0.0); D]; 111 | for l in 0..L { 112 | let mut seq_needle = [0.0; D]; 113 | for d in 0..D { 114 | let value = start_pt[d] + rng.gen_range::(-0.02..0.02); 115 | seq_needle[d] = value; 116 | simd_pts[d].as_mut_array()[l] = value; 117 | } 118 | seq_needles.push(seq_needle); 119 | } 120 | simd_needles.push(simd_pts); 121 | } 122 | 123 | assert_eq!(seq_needles.len(), simd_needles.len() * L); 124 | 125 | (seq_needles, simd_needles) 126 | } 127 | 128 | pub fn dist(a: [f32; D], b: [f32; D]) -> f32 { 129 | a.into_iter() 130 | .zip(b) 131 | .map(|(x1, x2)| (x1 - x2).powi(2)) 132 | .sum::() 133 | .sqrt() 134 | } 135 | 136 | pub fn stopwatch R, R>(f: F) -> (R, Duration) { 137 | let tic = Instant::now(); 138 | let r = f(); 139 | (r, Instant::now().duration_since(tic)) 140 | } 141 | 142 | pub fn parse_pointcloud_csv( 143 | p: impl AsRef, 144 | ) -> Result, Box> { 145 | std::str::from_utf8(&std::fs::read(&p)?)? 146 | .lines() 147 | .map(|l| { 148 | let mut split = l.split(',').flat_map(|s| s.parse::().ok()); 149 | Ok::<_, Box>([ 150 | split.next().ok_or("trace missing x")?, 151 | split.next().ok_or("trace missing y")?, 152 | split.next().ok_or("trace missing z")?, 153 | ]) 154 | }) 155 | .collect() 156 | } 157 | 158 | pub type Trace = [([f32; 3], f32)]; 159 | 160 | pub fn parse_trace_csv(p: impl AsRef) -> Result, Box> { 161 | std::str::from_utf8(&std::fs::read(&p)?)? 162 | .lines() 163 | .map(|l| { 164 | let mut split = l.split(',').flat_map(|s| s.parse::().ok()); 165 | Ok::<_, Box>(( 166 | [ 167 | split.next().ok_or("trace missing x")?, 168 | split.next().ok_or("trace missing y")?, 169 | split.next().ok_or("trace missing z")?, 170 | ], 171 | split.next().ok_or("trace missing r")?, 172 | )) 173 | }) 174 | .collect() 175 | } 176 | 177 | pub type SimdTrace = [([Simd; 3], Simd)]; 178 | 179 | pub fn simd_trace_new(trace: &Trace) -> Box> 180 | where 181 | LaneCount: SupportedLaneCount, 182 | { 183 | trace 184 | .chunks(L) 185 | .map(|w| { 186 | let mut centers = [[0.0; L]; 3]; 187 | let mut radii = [0.0; L]; 188 | for (l, ([x, y, z], r)) in w.iter().copied().enumerate() { 189 | centers[0][l] = x; 190 | centers[1][l] = y; 191 | centers[2][l] = z; 192 | radii[l] = r; 193 | } 194 | (centers.map(Simd::from_array), Simd::from_array(radii)) 195 | }) 196 | .collect() 197 | } 198 | 199 | pub fn trace_r_range(t: &Trace) -> (f32, f32) { 200 | ( 201 | t.iter() 202 | .map(|x| x.1) 203 | .min_by(|a, b| a.partial_cmp(b).unwrap()) 204 | .unwrap_or(0.0), 205 | t.iter() 206 | .map(|x| x.1) 207 | .max_by(|a, b| a.partial_cmp(b).unwrap()) 208 | .unwrap_or(f32::INFINITY), 209 | ) 210 | } 211 | 212 | pub fn fuzz_pointcloud(t: &mut [[f32; 3]], stddev: f32, rng: &mut impl Rng) { 213 | let normal = Normal::new(0.0, stddev).unwrap(); 214 | t.iter_mut() 215 | .for_each(|p| p.iter_mut().for_each(|x| *x += normal.sample(rng))) 216 | } 217 | 218 | #[inline] 219 | /// Calculate the "true" median (halfway between two midpoints) and partition `points` about said 220 | /// median along axis `d`. 221 | fn median_partition(points: &mut [[A; K]], k: usize) -> A { 222 | let (lh, med_hi, _) = 223 | points.select_nth_unstable_by(points.len() / 2, |a, b| a[k].partial_cmp(&b[k]).unwrap()); 224 | let med_lo = lh 225 | .iter_mut() 226 | .map(|p| p[k]) 227 | .max_by(|a, b| a.partial_cmp(b).unwrap()) 228 | .unwrap(); 229 | A::in_between(med_lo, med_hi[k]) 230 | } 231 | 232 | #[inline] 233 | fn forward_pass(tests: &[A], point: &[A; K]) -> usize { 234 | // forward pass through the tree 235 | let mut test_idx = 0; 236 | let mut k = 0; 237 | for _ in 0..tests.len().trailing_ones() { 238 | test_idx = 239 | 2 * test_idx + 1 + usize::from(unsafe { *tests.get_unchecked(test_idx) } <= point[k]); 240 | k = (k + 1) % K; 241 | } 242 | 243 | // retrieve affordance buffer location 244 | test_idx - tests.len() 245 | } 246 | 247 | fn distsq(a: [f32; K], b: [f32; K]) -> f32 { 248 | let mut total = 0.0f32; 249 | for i in 0..K { 250 | total += (a[i] - b[i]).powi(2); 251 | } 252 | total 253 | } 254 | -------------------------------------------------------------------------------- /captree/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "captree" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dev-dependencies] 7 | rand = "0.8.5" 8 | 9 | [features] 10 | simd = [] 11 | 12 | [dependencies] 13 | elain = "0.3.0" 14 | 15 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 16 | -------------------------------------------------------------------------------- /captree/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Collision-Affording Point Trees: SIMD-Amenable Nearest Neighbors for Fast Collision Checking 2 | //! 3 | //! This is a Rust implementation of the _collision-affording point tree_ (CAPT), a data structure 4 | //! for SIMD-parallel collision-checking between spheres and point clouds. 5 | //! 6 | //! You may also want to look at the following other sources: 7 | //! 8 | //! - [The paper](https://arxiv.org/abs/2406.02807) 9 | //! - [C++ implementation](https://github.com/KavrakiLab/vamp) 10 | //! - [Blog post about it](https://www.claytonwramsey.com/blog/captree) 11 | //! - [Demo video](https://youtu.be/BzDKdrU1VpM) 12 | //! 13 | //! If you use this in an academic work, please cite it as follows: 14 | //! 15 | //! ```bibtex 16 | //! @InProceedings{capt, 17 | //! title = {Collision-Affording Point Trees: {SIMD}-Amenable Nearest Neighbors for Fast Collision Checking}, 18 | //! author = {Ramsey, Clayton W. and Kingston, Zachary and Thomason, Wil and Kavraki, Lydia E.}, 19 | //! booktitle = {Robotics: Science and Systems}, 20 | //! date = {2024}, 21 | //! url = {http://arxiv.org/abs/2406.02807}, 22 | //! note = {To Appear.} 23 | //! } 24 | //! ``` 25 | //! 26 | //! ## Usage 27 | //! 28 | //! The core data structure in this library is the [`Capt`], which is a search tree used for 29 | //! collision checking. [`Capt`]s are polymorphic over dimension and data type. On construction, 30 | //! they take in a list of points in a point cloud and a _radius range_: a tuple of the minimum and 31 | //! maximum radius used for querying. 32 | //! 33 | //! ```rust 34 | //! use captree::Capt; 35 | //! 36 | //! // list of points in cloud 37 | //! let points = [[0.0, 1.1], [0.2, 3.1]]; 38 | //! let r_min = 0.05; 39 | //! let r_max = 2.0; 40 | //! 41 | //! let capt = Capt::<2>::new(&points, (r_min, r_max)); 42 | //! ``` 43 | //! 44 | //! Once you have a `Capt`, you can use it for collision-checking against spheres. 45 | //! Correct answers are only guaranteed if you collision-check against spheres with a radius inside 46 | //! the radius range. 47 | //! 48 | //! ```rust 49 | //! # use captree::Capt; 50 | //! # let points = [[0.0, 1.1], [0.2, 3.1]]; 51 | //! # let capt = Capt::<2>::new(&points, (0.05, 2.0)); 52 | //! let center = [0.0, 0.0]; // center of sphere 53 | //! let radius0 = 1.0; // radius of sphere 54 | //! assert!(!capt.collides(¢er, radius0)); 55 | //! 56 | //! let radius1 = 1.5; 57 | //! assert!(capt.collides(¢er, radius1)); 58 | //! ``` 59 | //! 60 | //! ## License 61 | //! 62 | //! This work is licensed to you under the Polyform Non-Commercial License. 63 | #![cfg_attr(feature = "simd", feature(portable_simd))] 64 | #![warn(clippy::pedantic)] 65 | #![warn(clippy::nursery)] 66 | 67 | use std::{ 68 | array, 69 | fmt::Debug, 70 | mem::size_of, 71 | ops::{Add, Sub}, 72 | }; 73 | 74 | #[cfg(feature = "simd")] 75 | use std::{ 76 | ops::{AddAssign, Mul}, 77 | ptr, 78 | simd::{ 79 | cmp::{SimdPartialEq, SimdPartialOrd}, 80 | ptr::SimdConstPtr, 81 | LaneCount, Mask, Simd, SimdElement, SupportedLaneCount, 82 | }, 83 | }; 84 | 85 | use elain::{Align, Alignment}; 86 | 87 | /// A generic trait representing values which may be used as an "axis;" that is, elements of a 88 | /// vector representing a point. 89 | /// 90 | /// An array of `Axis` values is a point which can be stored in a [`Capt`]. 91 | /// Accordingly, this trait specifies nearly all the requirements for points that [`Capt`]s require. 92 | /// The only exception is that [`Axis`] values really ought to be [`Ord`] instead of [`PartialOrd`]; 93 | /// however, due to the disaster that is IEE 754 floating point numbers, `f32` and `f64` are not 94 | /// totally ordered. As a compromise, we relax the `Ord` requirement so that you can use floats in a 95 | /// `Capt`. 96 | /// 97 | /// # Examples 98 | /// 99 | /// ``` 100 | /// #[derive(Clone, Copy, PartialOrd, PartialEq)] 101 | /// enum HyperInt { 102 | /// MinusInf, 103 | /// Real(i32), 104 | /// PlusInf, 105 | /// } 106 | /// 107 | /// impl std::ops::Add for HyperInt { 108 | /// // ... 109 | /// # type Output = Self; 110 | /// # 111 | /// # fn add(self, rhs: Self) -> Self { 112 | /// # match (self, rhs) { 113 | /// # (Self::MinusInf, Self::PlusInf) => Self::Real(0), // evil, but who cares? 114 | /// # (Self::MinusInf, _) | (_, Self::MinusInf) => Self::MinusInf, 115 | /// # (Self::PlusInf, _) | (_, Self::PlusInf) => Self::PlusInf, 116 | /// # (Self::Real(x), Self::Real(y)) => Self::Real(x + y), 117 | /// # } 118 | /// # } 119 | /// } 120 | /// 121 | /// 122 | /// impl std::ops::Sub for HyperInt { 123 | /// // ... 124 | /// # type Output = Self; 125 | /// # 126 | /// # fn sub(self, rhs: Self) -> Self { 127 | /// # match (self, rhs) { 128 | /// # (Self::MinusInf, Self::MinusInf) | (Self::PlusInf, Self::PlusInf) => Self::Real(0), // evil, but who cares? 129 | /// # (Self::MinusInf, _) | (_, Self::PlusInf) => Self::MinusInf, 130 | /// # (Self::PlusInf, _) | (_, Self::MinusInf) => Self::PlusInf, 131 | /// # (Self::Real(x), Self::Real(y)) => Self::Real(x - y), 132 | /// # } 133 | /// # } 134 | /// } 135 | /// 136 | /// impl captree::Axis for HyperInt { 137 | /// const ZERO: Self = Self::Real(0); 138 | /// const INFINITY: Self = Self::PlusInf; 139 | /// const NEG_INFINITY: Self = Self::MinusInf; 140 | /// 141 | /// fn is_finite(self) -> bool { 142 | /// matches!(self, Self::Real(_)) 143 | /// } 144 | /// 145 | /// fn in_between(self, rhs: Self) -> Self { 146 | /// match (self, rhs) { 147 | /// (Self::PlusInf, Self::MinusInf) | (Self::MinusInf, Self::PlusInf) => Self::Real(0), 148 | /// (Self::MinusInf, _) | (_, Self::MinusInf) => Self::MinusInf, 149 | /// (Self::PlusInf, _) | (_, Self::PlusInf) => Self::PlusInf, 150 | /// (Self::Real(a), Self::Real(b)) => Self::Real((a + b) / 2) 151 | /// } 152 | /// } 153 | /// 154 | /// fn square(self) -> Self { 155 | /// match self { 156 | /// Self::PlusInf | Self::MinusInf => Self::PlusInf, 157 | /// Self::Real(a) => Self::Real(a * a), 158 | /// } 159 | /// } 160 | /// } 161 | /// ``` 162 | pub trait Axis: PartialOrd + Copy + Sub + Add { 163 | /// A zero value. 164 | const ZERO: Self; 165 | /// A value which is larger than any finite value. 166 | const INFINITY: Self; 167 | /// A value which is smaller than any finite value. 168 | const NEG_INFINITY: Self; 169 | 170 | #[must_use] 171 | /// Determine whether this value is finite or infinite. 172 | fn is_finite(self) -> bool; 173 | 174 | #[must_use] 175 | /// Compute a value of `Self` which is halfway between `self` and `rhs`. 176 | /// If there are no legal values between `self` and `rhs`, it is acceptable to return `self` 177 | /// instead. 178 | fn in_between(self, rhs: Self) -> Self; 179 | 180 | #[must_use] 181 | /// Compute the square of this value. 182 | fn square(self) -> Self; 183 | } 184 | 185 | #[cfg(feature = "simd")] 186 | /// A trait used for masks over SIMD vectors, used for parallel querying on [`Capt`]s. 187 | /// 188 | /// The interface for this trait should be considered unstable since the standard SIMD API may 189 | /// change with Rust versions. 190 | pub trait AxisSimd: SimdElement + Default { 191 | #[must_use] 192 | /// Determine whether any element of this mask is set to `true`. 193 | fn any(mask: M) -> bool; 194 | } 195 | 196 | /// An index type used for lookups into and out of arrays. 197 | /// 198 | /// This is implemented so that [`Capt`]s can use smaller index sizes (such as [`u32`] or [`u16`]) 199 | /// for improved memory performance. 200 | pub trait Index: TryFrom + TryInto + Copy { 201 | /// The zero index. This must be equal to `(0usize).try_into().unwrap()`. 202 | const ZERO: Self; 203 | } 204 | 205 | #[cfg(feature = "simd")] 206 | /// A SIMD parallel version of [`Index`]. 207 | /// 208 | /// This is used for implementing SIMD lookups in a [`Capt`]. 209 | /// The interface for this trait should be considered unstable since the standard SIMD API may 210 | /// change with Rust versions. 211 | pub trait IndexSimd: SimdElement + Default { 212 | #[must_use] 213 | /// Convert a SIMD array of `Self` to a SIMD array of `usize`, without checking that each 214 | /// element is valid. 215 | /// 216 | /// # Safety 217 | /// 218 | /// This function is only safe if all values of `x` are valid when converted to a `usize`. 219 | unsafe fn to_simd_usize_unchecked(x: Simd) -> Simd 220 | where 221 | LaneCount: SupportedLaneCount; 222 | } 223 | 224 | macro_rules! impl_axis { 225 | ($t: ty, $tm: ty) => { 226 | impl Axis for $t { 227 | const ZERO: Self = 0.0; 228 | const INFINITY: Self = <$t>::INFINITY; 229 | const NEG_INFINITY: Self = <$t>::NEG_INFINITY; 230 | fn is_finite(self) -> bool { 231 | <$t>::is_finite(self) 232 | } 233 | 234 | fn in_between(self, rhs: Self) -> Self { 235 | (self + rhs) / 2.0 236 | } 237 | 238 | fn square(self) -> Self { 239 | self * self 240 | } 241 | } 242 | 243 | #[cfg(feature = "simd")] 244 | impl AxisSimd> for $t 245 | where 246 | LaneCount: SupportedLaneCount, 247 | { 248 | fn any(mask: Mask<$tm, L>) -> bool { 249 | Mask::<$tm, L>::any(mask) 250 | } 251 | } 252 | }; 253 | } 254 | 255 | macro_rules! impl_idx { 256 | ($t: ty) => { 257 | impl Index for $t { 258 | const ZERO: Self = 0; 259 | } 260 | 261 | #[cfg(feature = "simd")] 262 | impl IndexSimd for $t { 263 | #[must_use] 264 | unsafe fn to_simd_usize_unchecked(x: Simd) -> Simd 265 | where 266 | LaneCount: SupportedLaneCount, 267 | { 268 | x.to_array().map(|a| a.try_into().unwrap_unchecked()).into() 269 | } 270 | } 271 | }; 272 | } 273 | 274 | impl_axis!(f32, i32); 275 | impl_axis!(f64, i64); 276 | 277 | impl_idx!(u8); 278 | impl_idx!(u16); 279 | impl_idx!(u32); 280 | impl_idx!(u64); 281 | impl_idx!(usize); 282 | 283 | /// Clamp a floating-point number. 284 | fn clamp(x: A, min: A, max: A) -> A { 285 | if x < min { 286 | min 287 | } else if x > max { 288 | max 289 | } else { 290 | x 291 | } 292 | } 293 | 294 | #[inline] 295 | #[allow(clippy::cast_possible_wrap)] 296 | #[cfg(feature = "simd")] 297 | fn forward_pass_simd( 298 | tests: &[A], 299 | centers: &[Simd; K], 300 | ) -> Simd 301 | where 302 | Simd: SimdPartialOrd, 303 | Mask: From< as SimdPartialEq>::Mask>, 304 | A: Axis + AxisSimd< as SimdPartialEq>::Mask>, 305 | LaneCount: SupportedLaneCount, 306 | { 307 | let mut test_idxs: Simd = Simd::splat(0); 308 | let mut k = 0; 309 | for _ in 0..tests.len().trailing_ones() { 310 | let test_ptrs = Simd::splat(tests.as_ptr()).wrapping_offset(test_idxs); 311 | let relevant_tests: Simd = unsafe { Simd::gather_ptr(test_ptrs) }; 312 | let cmp_results: Mask = centers[k % K].simd_ge(relevant_tests).into(); 313 | 314 | let one = Simd::splat(1); 315 | test_idxs = (test_idxs << one) + one + (cmp_results.to_int() & Simd::splat(1)); 316 | k = (k + 1) % K; 317 | } 318 | 319 | test_idxs - Simd::splat(tests.len() as isize) 320 | } 321 | 322 | #[repr(C)] 323 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] 324 | /// A stable-safe wrapper for `[A; L]` which is aligned to `L`. 325 | /// Equivalent to a `Simd`, but easier to work with. 326 | struct MySimd 327 | where 328 | Align: Alignment, 329 | { 330 | data: [A; L], 331 | _align: Align, 332 | } 333 | 334 | #[derive(Clone, Debug, PartialEq, Eq)] 335 | #[allow(clippy::module_name_repetitions)] 336 | /// A collision-affording point tree (CAPT), which allows for efficient collision-checking in a 337 | /// SIMD-parallel manner between spheres and point clouds. 338 | /// 339 | /// # Generic parameters 340 | /// 341 | /// - `K`: The dimension of the space. 342 | /// - `L`: The lane size of this tree. Internally, this is the upper bound on the width of a SIMD 343 | /// lane that can be used in this data structure. The alignment of this structure must be a power 344 | /// of two. 345 | /// - `A`: The value of the axes of each point. This should typically be `f32` or `f64`. This should 346 | /// implement [`Axis`]. 347 | /// - `I`: The index integer. This should generally be an unsigned integer, such as `usize` or 348 | /// `u32`. This should implement [`Index`]. 349 | /// 350 | /// # Examples 351 | /// 352 | /// ``` 353 | /// // list of points in cloud 354 | /// let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]]; 355 | /// 356 | /// // query radii must be between 0.0 and 0.2 357 | /// let t = captree::Capt::<2>::new(&points, (0.0, 0.2)); 358 | /// 359 | /// assert!(!t.collides(&[0.0, 0.3], 0.1)); 360 | /// assert!(t.collides(&[0.0, 0.2], 0.15)); 361 | /// ``` 362 | pub struct Capt 363 | where 364 | Align: Alignment, 365 | { 366 | /// The test values for determining which part of the tree to enter. 367 | /// 368 | /// The first element of `tests` should be the first value to test against. 369 | /// If we are less than `tests[0]`, we move on to `tests[1]`; if not, we move on to `tests[2]`. 370 | /// At the `i`-th test performed in sequence of the traversal, if we are less than 371 | /// `tests[idx]`, we advance to `2 * idx + 1`; otherwise, we go to `2 * idx + 2`. 372 | /// 373 | /// The length of `tests` must be `N`, rounded up to the next power of 2, minus one. 374 | tests: Box<[A]>, 375 | /// Axis-aligned bounding boxes containing the set of afforded points for each cell. 376 | aabbs: Box<[Aabb]>, 377 | /// Indexes for the starts of the affordance buffer subsequence of `points` corresponding to 378 | /// each leaf cell in the tree. 379 | /// This buffer is padded with one extra `usize` at the end with the maximum length of `points` 380 | /// for the sake of branchless computation. 381 | starts: Box<[I]>, 382 | /// The sets of afforded points for each cell. 383 | afforded: [Box<[MySimd]>; K], 384 | } 385 | 386 | #[repr(C)] 387 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] 388 | #[doc(hidden)] 389 | /// A prismatic bounding volume. 390 | pub struct Aabb { 391 | /// The lower bound on the volume. 392 | pub lo: [A; K], 393 | /// The upper bound on the volume. 394 | pub hi: [A; K], 395 | } 396 | 397 | #[non_exhaustive] 398 | #[derive(Clone, Debug, PartialEq, Eq)] 399 | /// The errors which can occur when calling [`Capt::try_new`]. 400 | pub enum NewCaptError { 401 | /// There were too many points in the provided cloud to be represented without integer 402 | /// overflow. 403 | TooManyPoints, 404 | /// At least one of the points had a non-finite value. 405 | NonFinite, 406 | } 407 | 408 | impl Capt 409 | where 410 | A: Axis, 411 | I: Index, 412 | Align: Alignment, 413 | { 414 | /// Construct a new CAPT containing all the points in `points`. 415 | /// 416 | /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the 417 | /// radius of the balls which will be queried against the tree. 418 | /// 419 | /// # Panics 420 | /// 421 | /// This function will panic if there are too many points in the tree to be addressed by `I`, or 422 | /// if any points contain non-finite non-real value. This can even be the case if there are 423 | /// fewer points in `points` than can be addressed by `I` as the CAPT may duplicate points 424 | /// for efficiency. 425 | /// 426 | /// # Examples 427 | /// 428 | /// ``` 429 | /// let points = [[0.0]]; 430 | /// 431 | /// let capt = captree::Capt::<1>::new(&points, (0.0, f32::INFINITY)); 432 | /// 433 | /// assert!(capt.collides(&[1.0], 1.5)); 434 | /// assert!(!capt.collides(&[1.0], 0.5)); 435 | /// ``` 436 | /// 437 | /// If there are too many points in `points`, this could cause a panic! 438 | /// 439 | /// ```rust,should_panic 440 | /// let points = [[0.0]; 256]; 441 | /// 442 | /// // note that we are using `u8` as our index type 443 | /// let capt = captree::Capt::<1, 8, f32, u8>::new(&points, (0.0, f32::INFINITY)); 444 | /// ``` 445 | pub fn new(points: &[[A; K]], r_range: (A, A)) -> Self { 446 | Self::try_new(points, r_range) 447 | .expect("index type I must be able to support all points in CAPT during construction") 448 | } 449 | 450 | /// Construct a new CAPT containing all the points in `points`, checking for index overflow. 451 | /// 452 | /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the 453 | /// radius of the balls which will be queried against the tree. 454 | /// `rng` is a random number generator. 455 | /// 456 | /// # Errors 457 | /// 458 | /// This function will return `Err(NewCaptError::TooManyPoints)` if there are too many points to 459 | /// be indexed by `I`. It will return `Err(NewCaptError::NonFinite)` if any element of 460 | /// `points` is non-finite. 461 | /// 462 | /// # Examples 463 | /// 464 | /// Unwrapping the output from this function is equivalent to calling [`Capt::new`]. 465 | /// 466 | /// ``` 467 | /// let points = [[0.0]]; 468 | /// 469 | /// let capt = captree::Capt::<1>::try_new(&points, (0.0, f32::INFINITY)).unwrap(); 470 | /// ``` 471 | /// 472 | /// In failure, we get an `Err`. 473 | /// 474 | /// ``` 475 | /// let points = [[0.0]; 256]; 476 | /// 477 | /// // note that we are using `u8` as our index type 478 | /// let opt = captree::Capt::<1, 8, f32, u8>::try_new(&points, (0.0, f32::INFINITY)); 479 | /// 480 | /// assert!(opt.is_err()); 481 | /// ``` 482 | pub fn try_new(points: &[[A; K]], r_range: (A, A)) -> Result { 483 | let n2 = points.len().next_power_of_two(); 484 | 485 | if points.iter().any(|p| p.iter().any(|x| !x.is_finite())) { 486 | return Err(NewCaptError::NonFinite); 487 | } 488 | 489 | let mut tests = vec![A::INFINITY; n2 - 1].into_boxed_slice(); 490 | 491 | // hack: just pad with infinity to make it a power of 2 492 | let mut points2 = vec![[A::INFINITY; K]; n2].into_boxed_slice(); 493 | points2[..points.len()].copy_from_slice(points); 494 | // hack - reduce number of reallocations by allocating a lot of points from the start 495 | let mut afforded = array::from_fn(|_| Vec::with_capacity(n2 * 100)); 496 | let mut starts = vec![I::ZERO; n2 + 1].into_boxed_slice(); 497 | 498 | let mut aabbs = vec![ 499 | Aabb { 500 | lo: [A::NEG_INFINITY; K], 501 | hi: [A::INFINITY; K], 502 | }; 503 | n2 504 | ] 505 | .into_boxed_slice(); 506 | 507 | unsafe { 508 | // SAFETY: We tested that `points` contains no `NaN` values. 509 | Self::new_help( 510 | &mut points2, 511 | &mut tests, 512 | &mut aabbs, 513 | &mut afforded, 514 | &mut starts, 515 | 0, 516 | 0, 517 | r_range, 518 | Vec::new(), 519 | Aabb::ALL, 520 | )?; 521 | } 522 | 523 | Ok(Self { 524 | tests, 525 | starts, 526 | afforded: afforded.map(Vec::into_boxed_slice), 527 | aabbs, 528 | }) 529 | } 530 | 531 | #[allow(clippy::too_many_arguments, clippy::too_many_lines)] 532 | /// # Safety 533 | /// 534 | /// This function will contain undefined behavior if `points` contains any `NaN` values. 535 | unsafe fn new_help( 536 | points: &mut [[A; K]], 537 | tests: &mut [A], 538 | aabbs: &mut [Aabb], 539 | afforded: &mut [Vec>; K], 540 | starts: &mut [I], 541 | k: usize, 542 | i: usize, 543 | r_range: (A, A), 544 | in_range: Vec<[A; K]>, 545 | cell: Aabb, 546 | ) -> Result<(), NewCaptError> { 547 | let rsq_min = r_range.0.square(); 548 | if let [rep] = *points { 549 | let z = i - tests.len(); 550 | let aabb = &mut aabbs[z]; 551 | *aabb = Aabb { lo: rep, hi: rep }; 552 | if rep[0].is_finite() { 553 | // lanes for afforded points 554 | let mut news = [[A::INFINITY; L]; K]; 555 | for k in 0..K { 556 | news[k][0] = rep[k]; 557 | } 558 | 559 | // index into the current lane 560 | let mut j = 1; 561 | 562 | // populate affordance buffer if the representative doesn't cover everything 563 | if !cell.contained_by_ball(&rep, rsq_min) { 564 | for ak in afforded.iter_mut() { 565 | ak.reserve(ak.len() + in_range.len() / L); 566 | } 567 | for p in in_range { 568 | aabb.insert(&p); 569 | 570 | // start a new lane if it's full 571 | if j == L { 572 | for k in 0..K { 573 | afforded[k].push(MySimd { 574 | data: news[k], 575 | _align: Align::NEW, 576 | }); 577 | } 578 | j = 0; 579 | } 580 | 581 | // add this point to the lane 582 | for k in 0..K { 583 | news[k][j] = p[k]; 584 | } 585 | 586 | j += 1; 587 | } 588 | } 589 | 590 | // fill out the last lane with infinities 591 | for k in 0..K { 592 | afforded[k].push(MySimd { 593 | data: news[k], 594 | _align: Align::NEW, 595 | }); 596 | } 597 | } 598 | 599 | starts[z + 1] = afforded[0] 600 | .len() 601 | .try_into() 602 | .map_err(|_| NewCaptError::TooManyPoints)?; 603 | return Ok(()); 604 | } 605 | 606 | let test = median_partition(points, k); 607 | tests[i] = test; 608 | 609 | let (lhs, rhs) = points.split_at_mut(points.len() / 2); 610 | let (lo_vol, hi_vol) = cell.split(test, k); 611 | 612 | let lo_too_small = distsq(lo_vol.lo, lo_vol.hi) <= rsq_min; 613 | let hi_too_small = distsq(hi_vol.lo, hi_vol.hi) <= rsq_min; 614 | 615 | // retain only points which might be in the affordance buffer for the split-out cells 616 | let (lo_afford, hi_afford) = match (lo_too_small, hi_too_small) { 617 | (false, false) => { 618 | let mut lo_afford = in_range; 619 | let mut hi_afford = lo_afford.clone(); 620 | lo_afford.retain(|pt| pt[k] <= test + r_range.1); 621 | lo_afford.extend(rhs.iter().filter(|pt| pt[k] <= test + r_range.1)); 622 | hi_afford.retain(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]); 623 | hi_afford.extend( 624 | lhs.iter() 625 | .filter(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]), 626 | ); 627 | 628 | (lo_afford, hi_afford) 629 | } 630 | (false, true) => { 631 | let mut lo_afford = in_range; 632 | lo_afford.retain(|pt| pt[k] <= test + r_range.1); 633 | lo_afford.extend(rhs.iter().filter(|pt| pt[k] <= test + r_range.1)); 634 | 635 | (lo_afford, Vec::new()) 636 | } 637 | (true, false) => { 638 | let mut hi_afford = in_range; 639 | hi_afford.retain(|pt| test - r_range.1 <= pt[k]); 640 | hi_afford.extend( 641 | lhs.iter() 642 | .filter(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]), 643 | ); 644 | 645 | (Vec::new(), hi_afford) 646 | } 647 | (true, true) => (Vec::new(), Vec::new()), 648 | }; 649 | 650 | let next_k = (k + 1) % K; 651 | Self::new_help( 652 | lhs, 653 | tests, 654 | aabbs, 655 | afforded, 656 | starts, 657 | next_k, 658 | 2 * i + 1, 659 | r_range, 660 | lo_afford, 661 | lo_vol, 662 | )?; 663 | Self::new_help( 664 | rhs, 665 | tests, 666 | aabbs, 667 | afforded, 668 | starts, 669 | next_k, 670 | 2 * i + 2, 671 | r_range, 672 | hi_afford, 673 | hi_vol, 674 | )?; 675 | 676 | Ok(()) 677 | } 678 | 679 | #[must_use] 680 | /// Determine whether a point in this tree is within a distance of `radius` to `center`. 681 | /// 682 | /// Note that this function will accept query radii outside of the range `r_range` passed to the 683 | /// construction for this CAPT in [`Capt::new`] or [`Capt::try_new`]. However, if the query 684 | /// radius is outside this range, the tree may erroneously return `false` (that is, erroneously 685 | /// report non-collision). 686 | /// 687 | /// # Examples 688 | /// 689 | /// ``` 690 | /// let points = [[0.0; 3], [1.0; 3], [0.1, 0.5, 1.0]]; 691 | /// let capt = captree::Capt::<3>::new(&points, (0.0, 1.0)); 692 | /// 693 | /// assert!(capt.collides(&[1.1; 3], 0.2)); 694 | /// assert!(!capt.collides(&[2.0; 3], 1.0)); 695 | /// 696 | /// // no guarantees about what this is, since the radius is greater than the construction range 697 | /// println!( 698 | /// "collision check result is {:?}", 699 | /// capt.collides(&[100.0; 3], 100.0) 700 | /// ); 701 | /// ``` 702 | pub fn collides(&self, center: &[A; K], radius: A) -> bool { 703 | // forward pass through the tree 704 | let mut test_idx = 0; 705 | let mut k = 0; 706 | for _ in 0..self.tests.len().trailing_ones() { 707 | test_idx = 2 * test_idx 708 | + 1 709 | + usize::from(unsafe { *self.tests.get_unchecked(test_idx) } <= center[k]); 710 | k = (k + 1) % K; 711 | } 712 | 713 | // retrieve affordance buffer location 714 | let rsq = radius.square(); 715 | let i = test_idx - self.tests.len(); 716 | let aabb = unsafe { self.aabbs.get_unchecked(i) }; 717 | if aabb.closest_distsq_to(center) > rsq { 718 | return false; 719 | } 720 | 721 | let mut range = unsafe { 722 | // SAFETY: The conversion worked the first way. 723 | self.starts[i].try_into().unwrap_unchecked() 724 | ..self.starts[i + 1].try_into().unwrap_unchecked() 725 | }; 726 | 727 | // check affordance buffer 728 | range.any(|i| { 729 | (0..L).any(|j| { 730 | let mut aff_pt = [A::INFINITY; K]; 731 | for (ak, sk) in aff_pt.iter_mut().zip(&self.afforded) { 732 | *ak = sk[i].data[j]; 733 | } 734 | distsq(aff_pt, *center) <= rsq 735 | }) 736 | }) 737 | } 738 | 739 | #[must_use] 740 | #[doc(hidden)] 741 | /// Get the total memory used (stack + heap) by this structure, measured in bytes. 742 | /// This function should not be considered stable; it is only used internally for benchmarks. 743 | pub const fn memory_used(&self) -> usize { 744 | size_of::() 745 | + K * self.afforded[0].len() * size_of::() 746 | + self.starts.len() * size_of::() 747 | + self.tests.len() * size_of::() 748 | + self.aabbs.len() * size_of::>() 749 | } 750 | 751 | #[must_use] 752 | #[doc(hidden)] 753 | #[allow(clippy::cast_precision_loss)] 754 | /// Get the average number of affordances per point. 755 | /// This function should not be considered stable; it is only used internally for benchmarks. 756 | pub fn affordance_size(&self) -> f64 { 757 | self.afforded.len() as f64 / (self.tests.len() + 1) as f64 758 | } 759 | } 760 | 761 | #[allow(clippy::mismatching_type_param_order)] 762 | #[cfg(feature = "simd")] 763 | impl Capt 764 | where 765 | I: IndexSimd, 766 | A: Mul, 767 | Align: Alignment, 768 | { 769 | #[must_use] 770 | /// Determine whether any sphere in the list of provided spheres intersects a point in this 771 | /// tree. 772 | /// 773 | /// # Examples 774 | /// 775 | /// ``` 776 | /// #![feature(portable_simd)] 777 | /// use std::simd::Simd; 778 | /// 779 | /// let points = [[1.0, 2.0], [1.1, 1.1]]; 780 | /// 781 | /// let centers = [ 782 | /// Simd::from_array([1.0, 1.1, 1.2, 1.3]), // x-positions 783 | /// Simd::from_array([1.0, 1.1, 1.2, 1.3]), // y-positions 784 | /// ]; 785 | /// let radii = Simd::splat(0.05); 786 | /// 787 | /// let tree = captree::Capt::<2, 4, f32, u32>::new(&points, (0.0, 0.1)); 788 | /// 789 | /// println!("{tree:?}"); 790 | /// 791 | /// assert!(tree.collides_simd(¢ers, radii)); 792 | /// ``` 793 | pub fn collides_simd(&self, centers: &[Simd; K], radii: Simd) -> bool 794 | where 795 | LaneCount: SupportedLaneCount, 796 | Simd: 797 | SimdPartialOrd + Sub> + Mul> + AddAssign, 798 | Mask: From< as SimdPartialEq>::Mask>, 799 | A: Axis + AxisSimd< as SimdPartialEq>::Mask>, 800 | { 801 | let zs = forward_pass_simd(&self.tests, centers); 802 | 803 | let mut inbounds = Mask::splat(true); 804 | 805 | let mut aabb_ptrs = Simd::splat(self.aabbs.as_ptr()).wrapping_offset(zs).cast(); 806 | 807 | unsafe { 808 | for center in centers { 809 | inbounds &= Mask::::from( 810 | (Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY)) 811 | - radii) 812 | .simd_le(*center), 813 | ); 814 | aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1)); 815 | } 816 | for center in centers { 817 | inbounds &= Mask::::from( 818 | Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY)) 819 | .simd_ge(*center - radii), 820 | ); 821 | aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1)); 822 | } 823 | } 824 | if !inbounds.any() { 825 | return false; 826 | } 827 | 828 | // retrieve start/end pointers for the affordance buffer 829 | let start_ptrs = Simd::splat(self.starts.as_ptr()).wrapping_offset(zs); 830 | let starts = unsafe { I::to_simd_usize_unchecked(Simd::gather_ptr(start_ptrs)) }.to_array(); 831 | let ends = unsafe { 832 | I::to_simd_usize_unchecked(Simd::gather_ptr(start_ptrs.wrapping_add(Simd::splat(1)))) 833 | } 834 | .to_array(); 835 | 836 | starts 837 | .into_iter() 838 | .zip(ends) 839 | .zip(inbounds.to_array()) 840 | .enumerate() 841 | .filter_map(|(j, r)| r.1.then_some((j, r.0))) 842 | .any(|(j, (start, end))| { 843 | let mut n_center = [Simd::splat(A::ZERO); K]; 844 | for k in 0..K { 845 | n_center[k] = Simd::splat(centers[k][j]); 846 | } 847 | let rs = Simd::splat(radii[j]); 848 | let rs_sq = rs * rs; 849 | (start..end).any(|i| { 850 | let mut dists_sq = Simd::splat(A::ZERO); 851 | #[allow(clippy::needless_range_loop)] 852 | for k in 0..K { 853 | let vals: Simd = unsafe { 854 | *ptr::from_ref(&self.afforded[k].get_unchecked(i).data).cast() 855 | }; 856 | let diff = vals - n_center[k]; 857 | dists_sq += diff * diff; 858 | } 859 | A::any(dists_sq.simd_le(rs_sq)) 860 | }) 861 | }) 862 | } 863 | } 864 | 865 | fn distsq(a: [A; K], b: [A; K]) -> A { 866 | let mut total = A::ZERO; 867 | for i in 0..K { 868 | total = total + (a[i] - b[i]).square(); 869 | } 870 | total 871 | } 872 | 873 | impl Aabb 874 | where 875 | A: Axis, 876 | { 877 | const ALL: Self = Self { 878 | lo: [A::NEG_INFINITY; K], 879 | hi: [A::INFINITY; K], 880 | }; 881 | 882 | /// Split this volume by a test plane with value `test` along `dim`. 883 | const fn split(mut self, test: A, dim: usize) -> (Self, Self) { 884 | let mut rhs = self; 885 | self.hi[dim] = test; 886 | rhs.lo[dim] = test; 887 | 888 | (self, rhs) 889 | } 890 | 891 | fn contained_by_ball(&self, center: &[A; K], rsq: A) -> bool { 892 | let mut dist = A::ZERO; 893 | 894 | #[allow(clippy::needless_range_loop)] 895 | for k in 0..K { 896 | let lo_diff = (self.lo[k] - center[k]).square(); 897 | let hi_diff = (self.hi[k] - center[k]).square(); 898 | 899 | dist = dist + if lo_diff < hi_diff { hi_diff } else { lo_diff }; 900 | } 901 | 902 | dist <= rsq 903 | } 904 | 905 | #[doc(hidden)] 906 | pub fn closest_distsq_to(&self, pt: &[A; K]) -> A { 907 | let mut dist = A::ZERO; 908 | 909 | #[allow(clippy::needless_range_loop)] 910 | for d in 0..K { 911 | let clamped = clamp(pt[d], self.lo[d], self.hi[d]); 912 | dist = dist + (pt[d] - clamped).square(); 913 | } 914 | 915 | dist 916 | } 917 | 918 | fn insert(&mut self, point: &[A; K]) { 919 | self.lo 920 | .iter_mut() 921 | .zip(&mut self.hi) 922 | .zip(point) 923 | .for_each(|((l, h), &x)| { 924 | if *l > x { 925 | *l = x; 926 | } 927 | if x > *h { 928 | *h = x; 929 | } 930 | }); 931 | } 932 | } 933 | 934 | #[inline] 935 | /// Calculate the "true" median (halfway between two midpoints) and partition `points` about said 936 | /// median along axis `d`. 937 | /// 938 | /// # Safety 939 | /// 940 | /// This function will result in undefined behavior if `points` contains any `NaN` values. 941 | unsafe fn median_partition(points: &mut [[A; K]], k: usize) -> A { 942 | let (lh, med_hi, _) = points.select_nth_unstable_by(points.len() / 2, |a, b| { 943 | a[k].partial_cmp(&b[k]).unwrap_unchecked() 944 | }); 945 | let med_lo = lh 946 | .iter_mut() 947 | .map(|p| p[k]) 948 | .max_by(|a, b| a.partial_cmp(b).unwrap_unchecked()) 949 | .unwrap(); 950 | A::in_between(med_lo, med_hi[k]) 951 | } 952 | 953 | #[cfg(test)] 954 | mod tests { 955 | use rand::{thread_rng, Rng}; 956 | 957 | use super::*; 958 | 959 | #[test] 960 | fn build_simple() { 961 | let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]]; 962 | let t = Capt::<2>::new(&points, (0.0, 0.2)); 963 | println!("{t:?}"); 964 | } 965 | 966 | #[test] 967 | fn exact_query_single() { 968 | let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]]; 969 | let t = Capt::<2>::new(&points, (0.0, 0.2)); 970 | 971 | println!("{t:?}"); 972 | 973 | let q0 = [0.0, -0.01]; 974 | assert!(t.collides(&q0, 0.12)); 975 | } 976 | 977 | #[test] 978 | fn another_one() { 979 | let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]]; 980 | let t = Capt::<2>::new(&points, (0.0, 0.2)); 981 | 982 | println!("{t:?}"); 983 | 984 | let q0 = [0.003_265_380_9, 0.106_527_805]; 985 | assert!(t.collides(&q0, 0.02)); 986 | } 987 | 988 | #[test] 989 | fn three_d() { 990 | let points = [ 991 | [0.0; 3], 992 | [0.1, -1.1, 0.5], 993 | [-0.2, -0.3, 0.25], 994 | [0.1, -1.1, 0.5], 995 | ]; 996 | 997 | let t = Capt::<3>::new(&points, (0.0, 0.2)); 998 | 999 | println!("{t:?}"); 1000 | assert!(t.collides(&[0.0, 0.1, 0.0], 0.11)); 1001 | assert!(!t.collides(&[0.0, 0.1, 0.0], 0.05)); 1002 | } 1003 | 1004 | #[test] 1005 | fn fuzz() { 1006 | const R: f32 = 0.04; 1007 | let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]]; 1008 | let mut rng = thread_rng(); 1009 | let t = Capt::<2>::new(&points, (0.0, R)); 1010 | 1011 | for _ in 0..10_000 { 1012 | let p = [rng.gen_range(-1.0..1.0), rng.gen_range(-1.0..1.0)]; 1013 | let collides = points.iter().any(|a| distsq(*a, p) <= R * R); 1014 | println!("{p:?}; {collides}"); 1015 | assert_eq!(collides, t.collides(&p, R)); 1016 | } 1017 | } 1018 | 1019 | #[test] 1020 | /// This test _should_ fail, but it doesn't somehow? 1021 | fn weird_bounds() { 1022 | const R_SQ: f32 = 1.0; 1023 | let points = [ 1024 | [-1.0, 0.0], 1025 | [0.001, 0.0], 1026 | [0.0, 0.5], 1027 | [-1.0, 10.0], 1028 | [-2.0, 10.0], 1029 | [-3.0, 10.0], 1030 | [-0.5, 0.0], 1031 | [-11.0, 1.0], 1032 | [-1.0, -0.5], 1033 | [1.0, 1.0], 1034 | [2.0, 2.0], 1035 | [3.0, 3.0], 1036 | [4.0, 4.0], 1037 | [5.0, 5.0], 1038 | [6.0, 6.0], 1039 | [7.0, 7.0], 1040 | ]; 1041 | let rsq_range = (R_SQ - f32::EPSILON, R_SQ + f32::EPSILON); 1042 | let t = Capt::<2>::new(&points, rsq_range); 1043 | println!("{t:?}"); 1044 | 1045 | assert!(t.collides(&[-0.001, -0.2], 1.0)); 1046 | } 1047 | 1048 | #[test] 1049 | #[allow(clippy::float_cmp)] 1050 | fn does_it_partition() { 1051 | let mut points = vec![[1.0], [2.0], [1.5], [2.1], [-0.5]]; 1052 | let median = unsafe { median_partition(&mut points, 0) }; 1053 | assert_eq!(median, 1.25); 1054 | for p0 in &points[..points.len() / 2] { 1055 | assert!(p0[0] <= median); 1056 | } 1057 | 1058 | for p0 in &points[points.len() / 2..] { 1059 | assert!(p0[0] >= median); 1060 | } 1061 | } 1062 | } 1063 | -------------------------------------------------------------------------------- /morton_filter/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "morton_filter" 3 | version = "0.1.0" 4 | edition = "2021" 5 | -------------------------------------------------------------------------------- /morton_filter/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![warn(clippy::pedantic)] 2 | 3 | //! A filtering algorithm for 3D point clouds. 4 | 5 | use std::ops::BitOr; 6 | 7 | /// Filter out `points` such that points within `min_sep` of each other may be removed. 8 | pub fn morton_filter(points: &mut Vec<[f32; 3]>, min_sep: f32) { 9 | const PERMUTATIONS_3D: [[u8; 3]; 6] = [ 10 | [0, 1, 2], 11 | [0, 2, 1], 12 | [1, 0, 2], 13 | [1, 2, 0], 14 | [2, 0, 1], 15 | [2, 1, 0], 16 | ]; 17 | for permutation in PERMUTATIONS_3D { 18 | filter_permutation(points, min_sep, permutation); 19 | } 20 | } 21 | 22 | pub fn filter_permutation(points: &mut Vec<[f32; 3]>, min_sep: f32, perm: [u8; 3]) { 23 | let mut aabb_min = [f32::INFINITY; 3]; 24 | let mut aabb_max = [f32::NEG_INFINITY; 3]; 25 | let rsq = min_sep * min_sep; 26 | 27 | for point in points.iter() { 28 | for k in 0..3 { 29 | if point[k] < aabb_min[k] { 30 | aabb_min[k] = point[k]; 31 | } 32 | if point[k] > aabb_max[k] { 33 | aabb_max[k] = point[k]; 34 | } 35 | } 36 | } 37 | 38 | points.sort_by_cached_key(|point| morton_index(point, &aabb_min, &aabb_max, perm)); 39 | let mut i = 0; 40 | let mut j = 1; 41 | while j < points.len() { 42 | if distsq(&points[i], &points[j]) > rsq { 43 | i += 1; 44 | points[i] = points[j]; 45 | } 46 | j += 1; 47 | } 48 | points.truncate(i + 1); 49 | } 50 | 51 | fn distsq(a: &[f32; 3], b: &[f32; 3]) -> f32 { 52 | a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum() 53 | } 54 | #[allow( 55 | clippy::cast_possible_truncation, 56 | clippy::cast_sign_loss, 57 | clippy::cast_precision_loss 58 | )] 59 | fn morton_index( 60 | point: &[f32; 3], 61 | aabb_min: &[f32; 3], 62 | aabb_max: &[f32; 3], 63 | permutation: [u8; 3], 64 | ) -> u32 { 65 | const WIDTH: u32 = u32::BITS / 3; 66 | const MASK: u32 = 0b001_001_001_001_001_001_001_001_001_001; 67 | 68 | permutation 69 | .map(usize::from) 70 | .into_iter() 71 | .enumerate() 72 | .map(|(i, k)| { 73 | pdep( 74 | (((point[k] - aabb_min[k]) / (aabb_max[k] - aabb_min[k])) * (1 << WIDTH) as f32) 75 | as u32, 76 | MASK << i, 77 | ) 78 | }) 79 | .fold(0, BitOr::bitor) 80 | } 81 | 82 | fn pdep(a: u32, mut mask: u32) -> u32 { 83 | #[cfg(target_feature = "bmi2")] 84 | { 85 | unsafe { 86 | return core::arch::x86_64::_pdep_u32(a, mask); 87 | } 88 | } 89 | #[cfg(not(target_feature = "bmi2"))] 90 | { 91 | let mut out = 0; 92 | for i in 0..mask.count_ones() { 93 | let bit = mask & !(mask - 1); 94 | if a & (1 << i) != 0 { 95 | out |= bit; 96 | } 97 | mask ^= bit; 98 | } 99 | out 100 | } 101 | } 102 | 103 | #[cfg(test)] 104 | mod tests { 105 | use crate::morton_filter; 106 | 107 | #[test] 108 | fn one_point() { 109 | let mut points = vec![[0.0; 3]]; 110 | morton_filter(&mut points, 0.01); 111 | assert_eq!(points, vec![[0.0; 3]]); 112 | } 113 | 114 | #[test] 115 | fn duplicate() { 116 | let mut points = vec![[0.0; 3]; 2]; 117 | morton_filter(&mut points, 0.01); 118 | assert_eq!(points, vec![[0.0; 3]]); 119 | } 120 | 121 | #[test] 122 | fn too_close() { 123 | let mut points = vec![[0.0; 3], [0.001; 3]]; 124 | morton_filter(&mut points, 0.01); 125 | assert_eq!(points, vec![[0.0; 3]]); 126 | } 127 | 128 | #[test] 129 | fn too_far() { 130 | let mut points = vec![[0.0; 3], [0.01; 3]]; 131 | morton_filter(&mut points, 0.01); 132 | assert_eq!(points, vec![[0.0; 3], [0.01; 3]]); 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly" -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | newline_style = "Unix" 2 | wrap_comments = true 3 | comment_width = 100 4 | format_code_in_doc_comments = true 5 | imports_granularity = "crate" --------------------------------------------------------------------------------