├── .gitignore ├── atom3d ├── grid │ ├── __init__.py │ ├── octree_indexer.py │ └── cube_grid.py ├── apps │ ├── __init__.py │ ├── mesh_intersector.py │ ├── udf_query.py │ ├── visibility_query.py │ ├── sdf_query.py │ ├── flood_fill.py │ └── voxelizer.py ├── core │ ├── __init__.py │ ├── data_structures.py │ └── mesh_bvh.py ├── __init__.py └── kernels │ ├── bvh.py │ ├── __init__.py │ └── bvh_kernels.cu ├── LICENSE ├── setup.py ├── examples ├── udf_gradient.py └── basic_voxelization.py ├── README.md └── demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | atom3d/kernels/build/* 3 | atom3d.egg-info/* 4 | 5 | -------------------------------------------------------------------------------- /atom3d/grid/__init__.py: -------------------------------------------------------------------------------- 1 | """Grid module exports""" 2 | 3 | from .octree_indexer import OctreeIndexer 4 | from .cube_grid import CubeGrid 5 | 6 | __all__ = ['OctreeIndexer', 'CubeGrid'] 7 | -------------------------------------------------------------------------------- /atom3d/apps/__init__.py: -------------------------------------------------------------------------------- 1 | """Apps module exports""" 2 | 3 | from .voxelizer import Voxelizer 4 | from .mesh_intersector import MeshIntersector 5 | from .visibility_query import VisibilityQuery 6 | from .udf_query import UDFQuery 7 | from .sdf_query import SDFQuery 8 | from .flood_fill import FloodFill 9 | 10 | __all__ = [ 11 | "Voxelizer", 12 | "MeshIntersector", 13 | "VisibilityQuery", 14 | "UDFQuery", 15 | "SDFQuery", 16 | "FloodFill", 17 | ] 18 | -------------------------------------------------------------------------------- /atom3d/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core module exports""" 2 | 3 | from .mesh_bvh import MeshBVH 4 | from .data_structures import ( 5 | AABBIntersectResult, 6 | RayIntersectResult, 7 | SegmentIntersectResult, 8 | ClosestPointResult, 9 | TriangleIntersectResult, 10 | VoxelFaceMapping, 11 | VoxelPolygonMapping, 12 | VisibilityResult, 13 | ) 14 | 15 | __all__ = [ 16 | "MeshBVH", 17 | "AABBIntersectResult", 18 | "RayIntersectResult", 19 | "SegmentIntersectResult", 20 | "ClosestPointResult", 21 | "TriangleIntersectResult", 22 | "VoxelFaceMapping", 23 | "VoxelPolygonMapping", 24 | "VisibilityResult", 25 | ] 26 | -------------------------------------------------------------------------------- /atom3d/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Atom3D: Atomize Your 3D Meshes 3 | 4 | A high-performance CUDA library for mesh processing and representation 5 | to support 3D deep learning. 6 | """ 7 | 8 | __version__ = "0.1.0" 9 | __author__ = "Atom3D Contributors" 10 | 11 | from .core.mesh_bvh import MeshBVH 12 | from .core.data_structures import ( 13 | AABBIntersectResult, 14 | RayIntersectResult, 15 | SegmentIntersectResult, 16 | ClosestPointResult, 17 | TriangleIntersectResult, 18 | VoxelFaceMapping, 19 | VoxelPolygonMapping, 20 | VisibilityResult, 21 | ) 22 | from .grid.octree_indexer import OctreeIndexer 23 | from .grid.cube_grid import CubeGrid 24 | 25 | __all__ = [ 26 | # Core 27 | "MeshBVH", 28 | # Data structures 29 | "AABBIntersectResult", 30 | "RayIntersectResult", 31 | "SegmentIntersectResult", 32 | "ClosestPointResult", 33 | "TriangleIntersectResult", 34 | "VoxelFaceMapping", 35 | "VoxelPolygonMapping", 36 | "VisibilityResult", 37 | # Grid 38 | "OctreeIndexer", 39 | "CubeGrid", 40 | ] 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Atom3D: Atomize Your 3D Meshes 3 | 4 | Installation: 5 | pip install -e . 6 | """ 7 | 8 | from setuptools import setup, find_packages 9 | 10 | setup( 11 | name="atom3d", 12 | version="0.1.0", 13 | author="Atom3D Contributors", 14 | description="Atomize your 3D meshes - High-performance CUDA mesh voxelization and distance field queries", 15 | long_description=open("README.md").read(), 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/your-org/Atom3D", 18 | packages=find_packages(), 19 | python_requires=">=3.8", 20 | install_requires=[ 21 | "torch>=2.0.0", 22 | "numpy>=1.20.0", 23 | ], 24 | extras_require={ 25 | "dev": [ 26 | "pytest>=6.0", 27 | "trimesh>=3.0", 28 | ], 29 | "full": [ 30 | "trimesh>=3.0", 31 | ], 32 | }, 33 | classifiers=[ 34 | "Development Status :: 4 - Beta", 35 | "Intended Audience :: Science/Research", 36 | "License :: OSI Approved :: MIT License", 37 | "Programming Language :: Python :: 3", 38 | "Programming Language :: Python :: 3.8", 39 | "Programming Language :: Python :: 3.9", 40 | "Programming Language :: Python :: 3.10", 41 | "Programming Language :: Python :: 3.11", 42 | "Topic :: Scientific/Engineering :: Visualization", 43 | "Topic :: Scientific/Engineering :: Image Processing", 44 | ], 45 | include_package_data=True, 46 | ) 47 | -------------------------------------------------------------------------------- /atom3d/apps/mesh_intersector.py: -------------------------------------------------------------------------------- 1 | """ 2 | MeshIntersector: Mesh-mesh collision detection application 3 | """ 4 | 5 | import torch 6 | 7 | from ..core.mesh_bvh import MeshBVH 8 | from ..core.data_structures import TriangleIntersectResult 9 | 10 | 11 | class MeshIntersector: 12 | """ 13 | 网格碰撞检测应用 14 | 15 | = MeshBVH.intersect_triangles 的封装 16 | 17 | 用于网格自相交检测、多网格碰撞检测 18 | 19 | Args: 20 | bvh: MeshBVH实例 21 | """ 22 | 23 | def __init__(self, bvh: MeshBVH): 24 | self.bvh = bvh 25 | 26 | def check_self_intersection( 27 | self, 28 | skip_adjacent: bool = True 29 | ) -> TriangleIntersectResult: 30 | """ 31 | 检测网格自相交 32 | 33 | Args: 34 | skip_adjacent: 是否跳过相邻面(共享顶点) 35 | 36 | Returns: 37 | result: TriangleIntersectResult 38 | """ 39 | # Use the mesh against itself 40 | result = self.bvh.intersect_triangles( 41 | self.bvh.vertices, 42 | self.bvh.faces 43 | ) 44 | 45 | if skip_adjacent: 46 | # Filter out adjacent face collisions 47 | # (faces that share vertices) 48 | result = self._filter_adjacent(result) 49 | 50 | return result 51 | 52 | def _filter_adjacent( 53 | self, 54 | result: TriangleIntersectResult 55 | ) -> TriangleIntersectResult: 56 | """Filter out collisions from adjacent faces""" 57 | if result.hit_points.shape[0] == 0: 58 | return result 59 | 60 | # Get edge face indices 61 | edge_face_ids = result.hit_edge_ids // 3 # Each face has 3 edges 62 | hit_face_ids = result.hit_face_ids 63 | 64 | # Check if faces share vertices 65 | valid_mask = torch.ones(result.hit_points.shape[0], dtype=torch.bool, device=self.bvh.device) 66 | 67 | for i in range(result.hit_points.shape[0]): 68 | face1 = self.bvh.faces[hit_face_ids[i]] 69 | face2 = self.bvh.faces[edge_face_ids[i]] 70 | 71 | # Check for shared vertices 72 | shared = (face1.unsqueeze(1) == face2.unsqueeze(0)).any() 73 | if shared: 74 | valid_mask[i] = False 75 | 76 | return TriangleIntersectResult( 77 | edge_hit=result.edge_hit, # Keep original 78 | hit_points=result.hit_points[valid_mask], 79 | hit_face_ids=result.hit_face_ids[valid_mask], 80 | hit_edge_ids=result.hit_edge_ids[valid_mask] 81 | ) 82 | 83 | def intersect_with_mesh( 84 | self, 85 | other_vertices: torch.Tensor, 86 | other_faces: torch.Tensor 87 | ) -> TriangleIntersectResult: 88 | """ 89 | 与另一个网格碰撞检测 90 | 91 | Args: 92 | other_vertices: [M, 3] 93 | other_faces: [K, 3] 94 | 95 | Returns: 96 | result: TriangleIntersectResult 97 | """ 98 | return self.bvh.intersect_triangles(other_vertices, other_faces) 99 | -------------------------------------------------------------------------------- /examples/udf_gradient.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | UDF/SDF Query Example 4 | 5 | Demonstrates distance field queries with gradient support. 6 | """ 7 | 8 | import torch 9 | import time 10 | 11 | import sys 12 | from pathlib import Path 13 | sys.path.insert(0, str(Path(__file__).parent.parent)) 14 | 15 | from atom3d import MeshBVH 16 | 17 | 18 | def create_sphere_mesh(): 19 | """Create icosphere mesh.""" 20 | try: 21 | import trimesh 22 | mesh = trimesh.creation.icosphere(subdivisions=3, radius=0.5) 23 | return ( 24 | torch.tensor(mesh.vertices, dtype=torch.float32), 25 | torch.tensor(mesh.faces, dtype=torch.int64) 26 | ) 27 | except ImportError: 28 | raise RuntimeError("trimesh required: pip install trimesh") 29 | 30 | 31 | def main(): 32 | print("=" * 50) 33 | print("Atom3D: UDF/SDF Query Example") 34 | print("=" * 50) 35 | 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | 38 | # Create sphere mesh 39 | vertices, faces = create_sphere_mesh() 40 | vertices = vertices.to(device) 41 | faces = faces.to(device) 42 | print(f"Mesh: sphere with {len(faces)} faces, radius=0.5") 43 | 44 | # Create BVH 45 | bvh = MeshBVH(vertices, faces, device=device) 46 | 47 | # Generate random query points 48 | num_points = 10000 49 | points = torch.randn(num_points, 3, device=device) 50 | points = points / points.norm(dim=1, keepdim=True) * torch.rand(num_points, 1, device=device) * 2 51 | 52 | print(f"\n--- UDF Query (no gradient) ---") 53 | t0 = time.time() 54 | result = bvh.udf(points, return_grad=False) 55 | print(f"Time: {time.time()-t0:.4f}s") 56 | print(f"Distance range: [{result.distances.min():.4f}, {result.distances.max():.4f}]") 57 | 58 | print(f"\n--- UDF Query (with gradient) ---") 59 | points_grad = points.clone().requires_grad_(True) 60 | 61 | t0 = time.time() 62 | result = bvh.udf(points_grad, return_grad=True) 63 | print(f"Time: {time.time()-t0:.4f}s") 64 | 65 | # Backprop 66 | loss = result.distances.mean() 67 | loss.backward() 68 | 69 | gradients = points_grad.grad 70 | print(f"Gradient shape: {gradients.shape}") 71 | print(f"Gradient norm mean: {gradients.norm(dim=1).mean():.4f}") 72 | 73 | # Verify gradient 74 | print(f"\n--- Gradient Verification ---") 75 | closest = result.closest_points 76 | expected_grad = (points_grad.detach() - closest) 77 | expected_grad = expected_grad / (expected_grad.norm(dim=1, keepdim=True) + 1e-8) 78 | 79 | dot = (gradients * expected_grad).sum(dim=1) 80 | print(f"Gradient · Expected: {dot.mean():.6f} (should be ~1.0)") 81 | 82 | print(f"\n--- SDF Query ---") 83 | sdf_result = bvh.sdf(points, return_grad=False) 84 | 85 | inside = (sdf_result.distances < 0).sum() 86 | outside = (sdf_result.distances >= 0).sum() 87 | print(f"Inside mesh: {inside}, Outside: {outside}") 88 | print(f"SDF range: [{sdf_result.distances.min():.4f}, {sdf_result.distances.max():.4f}]") 89 | 90 | # Application: move points towards surface 91 | print(f"\n--- Application: Move towards surface ---") 92 | 93 | with torch.no_grad(): 94 | step_size = 0.1 95 | direction = (points - result.closest_points) 96 | direction = direction / (direction.norm(dim=1, keepdim=True) + 1e-8) 97 | 98 | new_points = points - step_size * direction * result.distances[:, None].sign() 99 | new_result = bvh.udf(new_points) 100 | 101 | print(f"Before: mean distance = {result.distances.abs().mean():.4f}") 102 | print(f"After: mean distance = {new_result.distances.abs().mean():.4f}") 103 | 104 | print("\n" + "=" * 50) 105 | print("Done!") 106 | print("=" * 50) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /examples/basic_voxelization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Basic Mesh Voxelization Example 4 | 5 | Demonstrates octree-accelerated mesh voxelization with SAT intersection. 6 | """ 7 | 8 | import torch 9 | import time 10 | 11 | import sys 12 | from pathlib import Path 13 | sys.path.insert(0, str(Path(__file__).parent.parent)) 14 | 15 | from atom3d import MeshBVH 16 | from atom3d.grid import OctreeIndexer 17 | 18 | 19 | def create_test_mesh(): 20 | """Create a simple icosphere mesh.""" 21 | try: 22 | import trimesh 23 | mesh = trimesh.creation.icosphere(subdivisions=3, radius=0.8) 24 | vertices = torch.tensor(mesh.vertices, dtype=torch.float32) 25 | faces = torch.tensor(mesh.faces, dtype=torch.int64) 26 | return vertices, faces 27 | except ImportError: 28 | raise RuntimeError("trimesh required: pip install trimesh") 29 | 30 | 31 | def main(): 32 | print("=" * 50) 33 | print("Atom3D: Basic Voxelization Example") 34 | print("=" * 50) 35 | 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | print(f"Device: {device}") 38 | 39 | # Load mesh 40 | vertices, faces = create_test_mesh() 41 | vertices = vertices.to(device) 42 | faces = faces.to(device) 43 | print(f"Mesh: {len(vertices)} vertices, {len(faces)} faces") 44 | 45 | # Create BVH 46 | print("\n[1] Building BVH...") 47 | t0 = time.time() 48 | bvh = MeshBVH(vertices, faces, device=device) 49 | print(f" Time: {time.time()-t0:.3f}s") 50 | print(f" Bounds: {bvh.get_bounds()[0].tolist()}") 51 | 52 | # Create octree 53 | max_level = 7 # 128^3 resolution 54 | print(f"\n[2] Creating octree (level {max_level}, res {2**max_level})...") 55 | octree = OctreeIndexer(max_level=max_level, device=device) 56 | 57 | # Octree traversal using BVH-accelerated broadphase 58 | print("\n[3] Octree traversal (BVH-accelerated broadphase)...") 59 | t0 = time.time() 60 | 61 | min_level = 4 62 | candidates = octree.octree_traverse(bvh, min_level=min_level) 63 | 64 | print(f" Candidates at level {max_level}: {len(candidates)}") 65 | print(f" Broadphase time: {time.time()-t0:.3f}s") 66 | 67 | # SAT intersection (narrowphase) 68 | print("\n[4] SAT intersection (narrowphase)...") 69 | t0 = time.time() 70 | 71 | voxel_min, voxel_max = octree.cube_aabb_level(candidates, max_level) 72 | result = bvh.intersect_aabb(voxel_min, voxel_max, mode=1) 73 | 74 | surface_voxels = candidates[result.hit] 75 | print(f" Surface voxels: {len(surface_voxels)}") 76 | print(f" SAT time: {time.time()-t0:.3f}s") 77 | print(f" Reduction: {len(candidates)} -> {len(surface_voxels)} ({100*(1-len(surface_voxels)/len(candidates)):.1f}%)") 78 | 79 | # UDF at voxel corners 80 | print("\n[5] UDF at voxel corners...") 81 | t0 = time.time() 82 | 83 | corners = octree.cube_corner_coords_level(surface_voxels[:100], max_level) # First 100 84 | corner_points = corners.reshape(-1, 3) 85 | 86 | udf_distances = bvh.udf(corner_points) # Returns tensor directly when no extras requested 87 | print(f" UDF range: [{udf_distances.min():.4f}, {udf_distances.max():.4f}]") 88 | print(f" UDF time: {time.time()-t0:.3f}s") 89 | 90 | # Polygon clipping (mode 2) 91 | print("\n[6] Polygon clipping (mode=2)...") 92 | t0 = time.time() 93 | 94 | clip_result = bvh.intersect_aabb(voxel_min[:100], voxel_max[:100], mode=2) 95 | if hasattr(clip_result, 'centroids'): 96 | print(f" Centroids shape: {clip_result.centroids.shape}") 97 | print(f" Areas range: [{clip_result.areas.min():.6f}, {clip_result.areas.max():.6f}]") 98 | print(f" Clip time: {time.time()-t0:.3f}s") 99 | 100 | print("\n" + "=" * 50) 101 | print("Done!") 102 | print("=" * 50) 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /atom3d/apps/udf_query.py: -------------------------------------------------------------------------------- 1 | """ 2 | UDFQuery: Unsigned Distance Field query with gradient support. 3 | """ 4 | 5 | from typing import Optional 6 | import torch 7 | 8 | from ..core.mesh_bvh import MeshBVH 9 | from ..core.data_structures import ClosestPointResult 10 | 11 | 12 | class UDFQuery: 13 | """ 14 | Unsigned Distance Field query application. 15 | 16 | Wraps MeshBVH.query_closest_point with: 17 | - Gradient support via autograd 18 | - Batch processing to avoid OOM 19 | 20 | Args: 21 | bvh: MeshBVH instance 22 | """ 23 | 24 | def __init__(self, bvh: MeshBVH): 25 | self.bvh = bvh 26 | 27 | def query( 28 | self, 29 | points: torch.Tensor, 30 | compute_grad: bool = False, 31 | batch_size: Optional[int] = None 32 | ) -> ClosestPointResult: 33 | """ 34 | Query unsigned distance field. 35 | 36 | Args: 37 | points: [N, 3] query points 38 | If compute_grad=True, should have requires_grad=True 39 | compute_grad: Whether to enable gradient computation 40 | batch_size: Batch size (None = process all at once) 41 | 42 | Returns: 43 | result: ClosestPointResult 44 | - distances: [N] float32 45 | - face_ids: [N] int32 46 | - closest_points: [N, 3] 47 | - uvw: [N, 3] 48 | """ 49 | if batch_size is not None and points.shape[0] > batch_size: 50 | return self._query_batched(points, compute_grad, batch_size) 51 | 52 | if compute_grad: 53 | return self._query_with_grad(points) 54 | else: 55 | return self.bvh.query_closest_point(points, return_uvw=True) 56 | 57 | def _query_batched( 58 | self, 59 | points: torch.Tensor, 60 | compute_grad: bool, 61 | batch_size: int 62 | ) -> ClosestPointResult: 63 | """Batched query to avoid OOM.""" 64 | N = points.shape[0] 65 | 66 | all_distances = [] 67 | all_face_ids = [] 68 | all_closest_points = [] 69 | all_uvw = [] 70 | 71 | for i in range(0, N, batch_size): 72 | batch_points = points[i:i+batch_size] 73 | 74 | if compute_grad: 75 | result = self._query_with_grad(batch_points) 76 | else: 77 | result = self.bvh.query_closest_point(batch_points, return_uvw=True) 78 | 79 | all_distances.append(result.distances) 80 | all_face_ids.append(result.face_ids) 81 | all_closest_points.append(result.closest_points) 82 | if result.uvw is not None: 83 | all_uvw.append(result.uvw) 84 | 85 | return ClosestPointResult( 86 | distances=torch.cat(all_distances), 87 | face_ids=torch.cat(all_face_ids), 88 | closest_points=torch.cat(all_closest_points), 89 | uvw=torch.cat(all_uvw) if all_uvw else None 90 | ) 91 | 92 | def _query_with_grad(self, points: torch.Tensor) -> ClosestPointResult: 93 | """ 94 | Query with gradient computation. 95 | 96 | Gradient: d(distance)/d(point) = (point - closest_point) / distance 97 | """ 98 | # Get closest points (no grad) 99 | with torch.no_grad(): 100 | result = self.bvh.query_closest_point(points, return_uvw=True) 101 | 102 | # Compute distances with gradient 103 | closest_points = result.closest_points.detach() 104 | diff = points - closest_points 105 | distances = diff.norm(dim=1) 106 | 107 | # If input requires grad, the output distances will have grad_fn 108 | return ClosestPointResult( 109 | distances=distances, 110 | face_ids=result.face_ids, 111 | closest_points=closest_points, 112 | uvw=result.uvw 113 | ) 114 | -------------------------------------------------------------------------------- /atom3d/core/data_structures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data structures for cuMTV 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import Optional 7 | import torch 8 | 9 | 10 | @dataclass 11 | class AABBIntersectResult: 12 | """AABB intersection result""" 13 | hit: torch.Tensor # [N] bool whether each AABB has collision 14 | aabb_ids: Optional[torch.Tensor] = None # [total_hits] int32 colliding AABB indices 15 | face_ids: Optional[torch.Tensor] = None # [total_hits] int32 colliding face indices 16 | centroids: Optional[torch.Tensor] = None # [total_hits, 3] float32 clipped polygon centroids (mode>=2) 17 | areas: Optional[torch.Tensor] = None # [total_hits] float32 clipped polygon areas (mode>=2) 18 | poly_verts: Optional[torch.Tensor] = None # [total_hits, 8, 3] float32 clipped polygon vertices (mode==3) 19 | poly_counts: Optional[torch.Tensor] = None # [total_hits] int32 polygon vertex counts (mode>=2) 20 | 21 | 22 | @dataclass 23 | class RayIntersectResult: 24 | """Ray intersection result""" 25 | hit: torch.Tensor # [N] bool 26 | t: torch.Tensor # [N] float32 (miss=inf) 27 | face_ids: torch.Tensor # [N] int32 (miss=-1) 28 | hit_points: torch.Tensor # [N, 3] 29 | normals: torch.Tensor # [N, 3] 30 | bary_coords: torch.Tensor # [N, 3] 31 | 32 | 33 | @dataclass 34 | class SegmentIntersectResult: 35 | """Segment intersection result""" 36 | hit: torch.Tensor # [N] bool 37 | hit_points: torch.Tensor # [N, 3] or [total, 3] 38 | face_ids: torch.Tensor # [N] or [total] int32 39 | bary_coords: torch.Tensor # [N, 3] or [total, 3] 40 | segment_ids: Optional[torch.Tensor] = None # [total] (if return_all=True) 41 | 42 | 43 | @dataclass 44 | class ClosestPointResult: 45 | """Closest point query result (UDF)""" 46 | distances: torch.Tensor # [N] float32 unsigned distance 47 | face_ids: torch.Tensor # [N] int32 closest face 48 | closest_points: torch.Tensor # [N, 3] 49 | uvw: Optional[torch.Tensor] = None # [N, 3] barycentric coordinates 50 | 51 | 52 | @dataclass 53 | class TriangleIntersectResult: 54 | """Triangle-triangle intersection result""" 55 | edge_hit: torch.Tensor # [num_edges] bool whether each edge intersects 56 | hit_points: torch.Tensor # [num_hits, 3] intersection point coordinates 57 | hit_face_ids: torch.Tensor # [num_hits] int32 hit faces in this mesh 58 | hit_edge_ids: torch.Tensor # [num_hits] int32 hit edges in other mesh 59 | 60 | 61 | @dataclass 62 | class VoxelFaceMapping: 63 | """Voxel-face mapping (CSR sparse format)""" 64 | voxel_coords: torch.Tensor # [K, 3] int32 65 | face_indices: torch.Tensor # [total] int32 66 | face_start: torch.Tensor # [K] int32 67 | face_count: torch.Tensor # [K] int32 68 | 69 | def get_faces_for_voxel(self, voxel_idx: int) -> torch.Tensor: 70 | """Get all faces intersecting the specified voxel""" 71 | start = self.face_start[voxel_idx].item() 72 | count = self.face_count[voxel_idx].item() 73 | return self.face_indices[start:start+count] 74 | 75 | 76 | @dataclass 77 | class VoxelPolygonMapping: 78 | """Voxel-polygon mapping (exact intersection region)""" 79 | voxel_coords: torch.Tensor # [K, 3] int32 80 | polygons: torch.Tensor # [total, max_verts, 3] float32 81 | polygon_counts: torch.Tensor # [total] int32 vertex count per polygon 82 | face_indices: torch.Tensor # [total] int32 83 | voxel_ids: torch.Tensor # [total] int32 84 | 85 | def get_polygon(self, idx: int) -> torch.Tensor: 86 | """Get intersection polygon at specified index""" 87 | count = self.polygon_counts[idx].item() 88 | return self.polygons[idx, :count] 89 | 90 | 91 | @dataclass 92 | class VisibilityResult: 93 | """Visibility query result""" 94 | visibility: torch.Tensor # [N] float32 visibility probability [0, 1] 95 | visible_mask: Optional[torch.Tensor] = None # [N, M] bool visibility of each point from each viewpoint 96 | hit_distances: Optional[torch.Tensor] = None # [N, M] float32 occlusion distance 97 | -------------------------------------------------------------------------------- /atom3d/apps/visibility_query.py: -------------------------------------------------------------------------------- 1 | """ 2 | VisibilityQuery: Visibility query application with statistical probability 3 | """ 4 | 5 | from typing import Union, Optional 6 | import torch 7 | 8 | from ..core.mesh_bvh import MeshBVH 9 | from ..core.data_structures import VisibilityResult 10 | 11 | 12 | class VisibilityQuery: 13 | """ 14 | 可见性查询应用 15 | 16 | = MeshBVH.intersect_ray + 统计概率 17 | 18 | 对任意query point查询其可见性,用多视角射线的统计概率表达 19 | 20 | Args: 21 | bvh: MeshBVH实例 22 | """ 23 | 24 | def __init__(self, bvh: MeshBVH): 25 | self.bvh = bvh 26 | 27 | def query( 28 | self, 29 | points: torch.Tensor, 30 | view_directions: torch.Tensor, 31 | return_details: bool = False 32 | ) -> Union[torch.Tensor, VisibilityResult]: 33 | """ 34 | 查询点的可见性(统计概率) 35 | 36 | 对任意query point,从多个视角方向发射射线检测遮挡 37 | 38 | Args: 39 | points: [N, 3] 任意查询点 40 | view_directions: [M, 3] 多个视角方向(归一化) 41 | return_details: 是否返回详细信息 42 | 43 | Returns: 44 | 如果return_details=False: 45 | visibility: [N] float32 可见性概率 [0, 1] 46 | 如果return_details=True: 47 | result: VisibilityResult 48 | """ 49 | N = points.shape[0] 50 | M = view_directions.shape[0] 51 | device = points.device 52 | 53 | visible_mask = torch.zeros(N, M, dtype=torch.bool, device=device) 54 | hit_distances = torch.zeros(N, M, device=device) 55 | 56 | for j in range(M): 57 | direction = view_directions[j] 58 | 59 | # All points use the same direction 60 | rays_o = points 61 | rays_d = direction.expand_as(points) 62 | 63 | result = self.bvh.intersect_ray(rays_o, rays_d) 64 | 65 | # Not hit = visible 66 | visible_mask[:, j] = ~result.hit 67 | hit_distances[:, j] = result.t 68 | 69 | # Compute visibility probability 70 | visibility = visible_mask.float().mean(dim=1) 71 | 72 | if return_details: 73 | return VisibilityResult( 74 | visibility=visibility, 75 | visible_mask=visible_mask, 76 | hit_distances=hit_distances 77 | ) 78 | else: 79 | return visibility 80 | 81 | def query_from_cameras( 82 | self, 83 | points: torch.Tensor, 84 | camera_positions: torch.Tensor 85 | ) -> torch.Tensor: 86 | """ 87 | 从相机位置查询可见性 88 | 89 | Args: 90 | points: [N, 3] 查询点 91 | camera_positions: [M, 3] 相机位置 92 | 93 | Returns: 94 | visibility: [N] float32 可见性概率 95 | = 能看到该点的相机比例 96 | """ 97 | N = points.shape[0] 98 | M = camera_positions.shape[0] 99 | device = points.device 100 | 101 | visible_count = torch.zeros(N, device=device) 102 | 103 | for cam_pos in camera_positions: 104 | # 从点向相机发射射线 105 | rays_o = points 106 | rays_d = cam_pos - points 107 | dists = rays_d.norm(dim=1, keepdim=True) 108 | rays_d = rays_d / (dists + 1e-8) 109 | 110 | # 检测遮挡 111 | result = self.bvh.intersect_ray(rays_o, rays_d, max_t=dists.squeeze().max().item()) 112 | 113 | # 如果未击中或击中距离 >= 到相机距离,则可见 114 | visible = ~result.hit | (result.t >= dists.squeeze() - 1e-4) 115 | visible_count += visible.float() 116 | 117 | return visible_count / M 118 | 119 | def query_uniform_sphere( 120 | self, 121 | points: torch.Tensor, 122 | num_samples: int = 32, 123 | seed: Optional[int] = None 124 | ) -> torch.Tensor: 125 | """ 126 | 均匀球面采样查询可见性 127 | 128 | Args: 129 | points: [N, 3] 查询点 130 | num_samples: int 球面采样数量 131 | seed: 随机种子(用于可重复性) 132 | 133 | Returns: 134 | visibility: [N] float32 可见性概率 135 | = 均匀球面上未被遮挡的方向比例 136 | 137 | 用途: 138 | - 无特定视角时的通用可见性度量 139 | - 可用于SDF符号判断(内部点可见性低) 140 | """ 141 | device = points.device 142 | 143 | if seed is not None: 144 | torch.manual_seed(seed) 145 | 146 | # Generate uniform sphere directions using Fibonacci spiral 147 | directions = self._fibonacci_sphere(num_samples, device) 148 | 149 | return self.query(points, directions, return_details=False) 150 | 151 | def _fibonacci_sphere(self, n: int, device: str) -> torch.Tensor: 152 | """Generate n uniformly distributed points on sphere""" 153 | indices = torch.arange(n, dtype=torch.float32, device=device) 154 | 155 | phi = torch.acos(1 - 2 * (indices + 0.5) / n) 156 | theta = torch.pi * (1 + 5**0.5) * indices 157 | 158 | x = torch.sin(phi) * torch.cos(theta) 159 | y = torch.sin(phi) * torch.sin(theta) 160 | z = torch.cos(phi) 161 | 162 | return torch.stack([x, y, z], dim=1) 163 | -------------------------------------------------------------------------------- /atom3d/kernels/bvh.py: -------------------------------------------------------------------------------- 1 | """ 2 | BVH Accelerator for Atom3D 3 | 4 | Python interface to BVH CUDA kernels. Provides accelerated: 5 | - UDF queries (closest point to mesh) 6 | - Ray-mesh intersection 7 | - AABB-mesh intersection (with exact SAT) 8 | """ 9 | 10 | import torch 11 | from torch.utils.cpp_extension import load 12 | import os 13 | 14 | # Get the CUDA kernel source 15 | _KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | _BUILD_DIR = os.path.join(_KERNEL_DIR, 'build', 'bvh') 17 | 18 | # Global cache for JIT compiled module 19 | _bvh_cuda = None 20 | 21 | def get_bvh_kernels(): 22 | """Get or compile the BVH CUDA kernels.""" 23 | global _bvh_cuda 24 | 25 | if _bvh_cuda is not None: 26 | return _bvh_cuda 27 | 28 | os.makedirs(_BUILD_DIR, exist_ok=True) 29 | 30 | kernel_path = os.path.join(_KERNEL_DIR, 'bvh_kernels.cu') 31 | 32 | _bvh_cuda = load( 33 | name='bvh_cuda', 34 | sources=[kernel_path], 35 | build_directory=_BUILD_DIR, 36 | extra_cuda_cflags=['-O3'], 37 | verbose=False 38 | ) 39 | 40 | return _bvh_cuda 41 | 42 | 43 | class BVHAccelerator: 44 | """ 45 | BVH-accelerated mesh queries. 46 | 47 | Provides O(log M) queries instead of O(M) brute-force: 48 | - udf: closest point to mesh 49 | - ray_intersect: ray-mesh intersection 50 | - aabb_intersect: AABB-mesh intersection with exact SAT 51 | 52 | All returned face_ids are in ORIGINAL mesh order (not BVH reordered order). 53 | """ 54 | 55 | def __init__( 56 | self, 57 | vertices: torch.Tensor, 58 | faces: torch.Tensor, 59 | n_primitives_per_leaf: int = 8 60 | ): 61 | """ 62 | Build BVH from mesh. 63 | 64 | Args: 65 | vertices: [N, 3] float32 vertices 66 | faces: [M, 3] int32 face indices 67 | n_primitives_per_leaf: Max triangles per leaf node 68 | """ 69 | self.device = vertices.device 70 | self.vertices = vertices.contiguous().float() 71 | self.faces = faces.contiguous().int() 72 | self.num_faces = faces.shape[0] 73 | 74 | # Build BVH - returns (nodes, triangles with original_id) 75 | cuda = get_bvh_kernels() 76 | result = cuda.build_bvh( 77 | self.vertices, 78 | self.faces, 79 | n_primitives_per_leaf 80 | ) 81 | self.nodes = result[0] # [num_nodes, 9] 82 | self.triangles = result[1] # [num_faces, 10] - includes original_id 83 | 84 | def udf( 85 | self, 86 | points: torch.Tensor 87 | ): 88 | """ 89 | Unsigned distance field query. 90 | 91 | Args: 92 | points: [K, 3] query points 93 | 94 | Returns: 95 | distances: [K] unsigned distances 96 | face_ids: [K] closest face indices (ORIGINAL order) 97 | closest_points: [K, 3] closest points on mesh 98 | uvw: [K, 3] barycentric coordinates 99 | """ 100 | cuda = get_bvh_kernels() 101 | distances, face_ids, closest_points, uvw = cuda.bvh_udf( 102 | self.nodes, 103 | self.triangles, 104 | points.contiguous().float() 105 | ) 106 | return distances, face_ids, closest_points, uvw 107 | 108 | def ray_intersect( 109 | self, 110 | rays_o: torch.Tensor, 111 | rays_d: torch.Tensor, 112 | max_t: float = 1e10 113 | ): 114 | """ 115 | Ray-mesh intersection. 116 | 117 | Args: 118 | rays_o: [K, 3] ray origins 119 | rays_d: [K, 3] ray directions 120 | max_t: Maximum ray distance 121 | 122 | Returns: 123 | hit_mask: [K] bool - whether ray hit mesh 124 | hit_t: [K] hit distance (max_t if no hit) 125 | face_ids: [K] hit face indices (ORIGINAL order, -1 if no hit) 126 | hit_points: [K, 3] hit positions 127 | """ 128 | cuda = get_bvh_kernels() 129 | hit_mask, hit_t, face_ids, hit_points = cuda.bvh_ray_intersect( 130 | self.nodes, 131 | self.triangles, 132 | rays_o.contiguous().float(), 133 | rays_d.contiguous().float(), 134 | max_t 135 | ) 136 | return hit_mask, hit_t, face_ids, hit_points 137 | 138 | def aabb_intersect( 139 | self, 140 | query_min: torch.Tensor, 141 | query_max: torch.Tensor 142 | ): 143 | """ 144 | AABB-mesh intersection with exact SAT test. 145 | 146 | Args: 147 | query_min: [K, 3] query AABB mins 148 | query_max: [K, 3] query AABB maxs 149 | 150 | Returns: 151 | hit_mask: [K] bool - whether AABB intersects mesh 152 | aabb_ids: [N] query indices for each intersection pair 153 | face_ids: [N] face indices (ORIGINAL order) for each pair 154 | """ 155 | cuda = get_bvh_kernels() 156 | hit_mask, aabb_ids, face_ids = cuda.bvh_aabb_intersect( 157 | self.nodes, 158 | self.triangles, 159 | query_min.contiguous().float(), 160 | query_max.contiguous().float() 161 | ) 162 | return hit_mask, aabb_ids, face_ids 163 | 164 | 165 | # Check if BVH kernels are available 166 | def bvh_available(): 167 | """Check if BVH CUDA kernels can be compiled.""" 168 | try: 169 | get_bvh_kernels() 170 | return True 171 | except Exception: 172 | return False 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Atom3D 2 | 3 | **Atomize Your 3D Meshes** — High-performance CUDA mesh geometry with internal BVH acceleration for voxelization, distance fields, and intersection queries. 4 | 5 |
6 |
7 |
8 |
9 |
10 |
11 |