├── .gitignore ├── atom3d ├── grid │ ├── __init__.py │ ├── octree_indexer.py │ └── cube_grid.py ├── apps │ ├── __init__.py │ ├── mesh_intersector.py │ ├── udf_query.py │ ├── visibility_query.py │ ├── sdf_query.py │ ├── flood_fill.py │ └── voxelizer.py ├── core │ ├── __init__.py │ ├── data_structures.py │ └── mesh_bvh.py ├── __init__.py └── kernels │ ├── bvh.py │ ├── __init__.py │ └── bvh_kernels.cu ├── LICENSE ├── setup.py ├── examples ├── udf_gradient.py └── basic_voxelization.py ├── README.md └── demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | atom3d/kernels/build/* 3 | atom3d.egg-info/* 4 | 5 | -------------------------------------------------------------------------------- /atom3d/grid/__init__.py: -------------------------------------------------------------------------------- 1 | """Grid module exports""" 2 | 3 | from .octree_indexer import OctreeIndexer 4 | from .cube_grid import CubeGrid 5 | 6 | __all__ = ['OctreeIndexer', 'CubeGrid'] 7 | -------------------------------------------------------------------------------- /atom3d/apps/__init__.py: -------------------------------------------------------------------------------- 1 | """Apps module exports""" 2 | 3 | from .voxelizer import Voxelizer 4 | from .mesh_intersector import MeshIntersector 5 | from .visibility_query import VisibilityQuery 6 | from .udf_query import UDFQuery 7 | from .sdf_query import SDFQuery 8 | from .flood_fill import FloodFill 9 | 10 | __all__ = [ 11 | "Voxelizer", 12 | "MeshIntersector", 13 | "VisibilityQuery", 14 | "UDFQuery", 15 | "SDFQuery", 16 | "FloodFill", 17 | ] 18 | -------------------------------------------------------------------------------- /atom3d/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core module exports""" 2 | 3 | from .mesh_bvh import MeshBVH 4 | from .data_structures import ( 5 | AABBIntersectResult, 6 | RayIntersectResult, 7 | SegmentIntersectResult, 8 | ClosestPointResult, 9 | TriangleIntersectResult, 10 | VoxelFaceMapping, 11 | VoxelPolygonMapping, 12 | VisibilityResult, 13 | ) 14 | 15 | __all__ = [ 16 | "MeshBVH", 17 | "AABBIntersectResult", 18 | "RayIntersectResult", 19 | "SegmentIntersectResult", 20 | "ClosestPointResult", 21 | "TriangleIntersectResult", 22 | "VoxelFaceMapping", 23 | "VoxelPolygonMapping", 24 | "VisibilityResult", 25 | ] 26 | -------------------------------------------------------------------------------- /atom3d/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Atom3D: Atomize Your 3D Meshes 3 | 4 | A high-performance CUDA library for mesh processing and representation 5 | to support 3D deep learning. 6 | """ 7 | 8 | __version__ = "0.1.0" 9 | __author__ = "Atom3D Contributors" 10 | 11 | from .core.mesh_bvh import MeshBVH 12 | from .core.data_structures import ( 13 | AABBIntersectResult, 14 | RayIntersectResult, 15 | SegmentIntersectResult, 16 | ClosestPointResult, 17 | TriangleIntersectResult, 18 | VoxelFaceMapping, 19 | VoxelPolygonMapping, 20 | VisibilityResult, 21 | ) 22 | from .grid.octree_indexer import OctreeIndexer 23 | from .grid.cube_grid import CubeGrid 24 | 25 | __all__ = [ 26 | # Core 27 | "MeshBVH", 28 | # Data structures 29 | "AABBIntersectResult", 30 | "RayIntersectResult", 31 | "SegmentIntersectResult", 32 | "ClosestPointResult", 33 | "TriangleIntersectResult", 34 | "VoxelFaceMapping", 35 | "VoxelPolygonMapping", 36 | "VisibilityResult", 37 | # Grid 38 | "OctreeIndexer", 39 | "CubeGrid", 40 | ] 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Atom3D: Atomize Your 3D Meshes 3 | 4 | Installation: 5 | pip install -e . 6 | """ 7 | 8 | from setuptools import setup, find_packages 9 | 10 | setup( 11 | name="atom3d", 12 | version="0.1.0", 13 | author="Atom3D Contributors", 14 | description="Atomize your 3D meshes - High-performance CUDA mesh voxelization and distance field queries", 15 | long_description=open("README.md").read(), 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/your-org/Atom3D", 18 | packages=find_packages(), 19 | python_requires=">=3.8", 20 | install_requires=[ 21 | "torch>=2.0.0", 22 | "numpy>=1.20.0", 23 | ], 24 | extras_require={ 25 | "dev": [ 26 | "pytest>=6.0", 27 | "trimesh>=3.0", 28 | ], 29 | "full": [ 30 | "trimesh>=3.0", 31 | ], 32 | }, 33 | classifiers=[ 34 | "Development Status :: 4 - Beta", 35 | "Intended Audience :: Science/Research", 36 | "License :: OSI Approved :: MIT License", 37 | "Programming Language :: Python :: 3", 38 | "Programming Language :: Python :: 3.8", 39 | "Programming Language :: Python :: 3.9", 40 | "Programming Language :: Python :: 3.10", 41 | "Programming Language :: Python :: 3.11", 42 | "Topic :: Scientific/Engineering :: Visualization", 43 | "Topic :: Scientific/Engineering :: Image Processing", 44 | ], 45 | include_package_data=True, 46 | ) 47 | -------------------------------------------------------------------------------- /atom3d/apps/mesh_intersector.py: -------------------------------------------------------------------------------- 1 | """ 2 | MeshIntersector: Mesh-mesh collision detection application 3 | """ 4 | 5 | import torch 6 | 7 | from ..core.mesh_bvh import MeshBVH 8 | from ..core.data_structures import TriangleIntersectResult 9 | 10 | 11 | class MeshIntersector: 12 | """ 13 | 网格碰撞检测应用 14 | 15 | = MeshBVH.intersect_triangles 的封装 16 | 17 | 用于网格自相交检测、多网格碰撞检测 18 | 19 | Args: 20 | bvh: MeshBVH实例 21 | """ 22 | 23 | def __init__(self, bvh: MeshBVH): 24 | self.bvh = bvh 25 | 26 | def check_self_intersection( 27 | self, 28 | skip_adjacent: bool = True 29 | ) -> TriangleIntersectResult: 30 | """ 31 | 检测网格自相交 32 | 33 | Args: 34 | skip_adjacent: 是否跳过相邻面(共享顶点) 35 | 36 | Returns: 37 | result: TriangleIntersectResult 38 | """ 39 | # Use the mesh against itself 40 | result = self.bvh.intersect_triangles( 41 | self.bvh.vertices, 42 | self.bvh.faces 43 | ) 44 | 45 | if skip_adjacent: 46 | # Filter out adjacent face collisions 47 | # (faces that share vertices) 48 | result = self._filter_adjacent(result) 49 | 50 | return result 51 | 52 | def _filter_adjacent( 53 | self, 54 | result: TriangleIntersectResult 55 | ) -> TriangleIntersectResult: 56 | """Filter out collisions from adjacent faces""" 57 | if result.hit_points.shape[0] == 0: 58 | return result 59 | 60 | # Get edge face indices 61 | edge_face_ids = result.hit_edge_ids // 3 # Each face has 3 edges 62 | hit_face_ids = result.hit_face_ids 63 | 64 | # Check if faces share vertices 65 | valid_mask = torch.ones(result.hit_points.shape[0], dtype=torch.bool, device=self.bvh.device) 66 | 67 | for i in range(result.hit_points.shape[0]): 68 | face1 = self.bvh.faces[hit_face_ids[i]] 69 | face2 = self.bvh.faces[edge_face_ids[i]] 70 | 71 | # Check for shared vertices 72 | shared = (face1.unsqueeze(1) == face2.unsqueeze(0)).any() 73 | if shared: 74 | valid_mask[i] = False 75 | 76 | return TriangleIntersectResult( 77 | edge_hit=result.edge_hit, # Keep original 78 | hit_points=result.hit_points[valid_mask], 79 | hit_face_ids=result.hit_face_ids[valid_mask], 80 | hit_edge_ids=result.hit_edge_ids[valid_mask] 81 | ) 82 | 83 | def intersect_with_mesh( 84 | self, 85 | other_vertices: torch.Tensor, 86 | other_faces: torch.Tensor 87 | ) -> TriangleIntersectResult: 88 | """ 89 | 与另一个网格碰撞检测 90 | 91 | Args: 92 | other_vertices: [M, 3] 93 | other_faces: [K, 3] 94 | 95 | Returns: 96 | result: TriangleIntersectResult 97 | """ 98 | return self.bvh.intersect_triangles(other_vertices, other_faces) 99 | -------------------------------------------------------------------------------- /examples/udf_gradient.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | UDF/SDF Query Example 4 | 5 | Demonstrates distance field queries with gradient support. 6 | """ 7 | 8 | import torch 9 | import time 10 | 11 | import sys 12 | from pathlib import Path 13 | sys.path.insert(0, str(Path(__file__).parent.parent)) 14 | 15 | from atom3d import MeshBVH 16 | 17 | 18 | def create_sphere_mesh(): 19 | """Create icosphere mesh.""" 20 | try: 21 | import trimesh 22 | mesh = trimesh.creation.icosphere(subdivisions=3, radius=0.5) 23 | return ( 24 | torch.tensor(mesh.vertices, dtype=torch.float32), 25 | torch.tensor(mesh.faces, dtype=torch.int64) 26 | ) 27 | except ImportError: 28 | raise RuntimeError("trimesh required: pip install trimesh") 29 | 30 | 31 | def main(): 32 | print("=" * 50) 33 | print("Atom3D: UDF/SDF Query Example") 34 | print("=" * 50) 35 | 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | 38 | # Create sphere mesh 39 | vertices, faces = create_sphere_mesh() 40 | vertices = vertices.to(device) 41 | faces = faces.to(device) 42 | print(f"Mesh: sphere with {len(faces)} faces, radius=0.5") 43 | 44 | # Create BVH 45 | bvh = MeshBVH(vertices, faces, device=device) 46 | 47 | # Generate random query points 48 | num_points = 10000 49 | points = torch.randn(num_points, 3, device=device) 50 | points = points / points.norm(dim=1, keepdim=True) * torch.rand(num_points, 1, device=device) * 2 51 | 52 | print(f"\n--- UDF Query (no gradient) ---") 53 | t0 = time.time() 54 | result = bvh.udf(points, return_grad=False) 55 | print(f"Time: {time.time()-t0:.4f}s") 56 | print(f"Distance range: [{result.distances.min():.4f}, {result.distances.max():.4f}]") 57 | 58 | print(f"\n--- UDF Query (with gradient) ---") 59 | points_grad = points.clone().requires_grad_(True) 60 | 61 | t0 = time.time() 62 | result = bvh.udf(points_grad, return_grad=True) 63 | print(f"Time: {time.time()-t0:.4f}s") 64 | 65 | # Backprop 66 | loss = result.distances.mean() 67 | loss.backward() 68 | 69 | gradients = points_grad.grad 70 | print(f"Gradient shape: {gradients.shape}") 71 | print(f"Gradient norm mean: {gradients.norm(dim=1).mean():.4f}") 72 | 73 | # Verify gradient 74 | print(f"\n--- Gradient Verification ---") 75 | closest = result.closest_points 76 | expected_grad = (points_grad.detach() - closest) 77 | expected_grad = expected_grad / (expected_grad.norm(dim=1, keepdim=True) + 1e-8) 78 | 79 | dot = (gradients * expected_grad).sum(dim=1) 80 | print(f"Gradient · Expected: {dot.mean():.6f} (should be ~1.0)") 81 | 82 | print(f"\n--- SDF Query ---") 83 | sdf_result = bvh.sdf(points, return_grad=False) 84 | 85 | inside = (sdf_result.distances < 0).sum() 86 | outside = (sdf_result.distances >= 0).sum() 87 | print(f"Inside mesh: {inside}, Outside: {outside}") 88 | print(f"SDF range: [{sdf_result.distances.min():.4f}, {sdf_result.distances.max():.4f}]") 89 | 90 | # Application: move points towards surface 91 | print(f"\n--- Application: Move towards surface ---") 92 | 93 | with torch.no_grad(): 94 | step_size = 0.1 95 | direction = (points - result.closest_points) 96 | direction = direction / (direction.norm(dim=1, keepdim=True) + 1e-8) 97 | 98 | new_points = points - step_size * direction * result.distances[:, None].sign() 99 | new_result = bvh.udf(new_points) 100 | 101 | print(f"Before: mean distance = {result.distances.abs().mean():.4f}") 102 | print(f"After: mean distance = {new_result.distances.abs().mean():.4f}") 103 | 104 | print("\n" + "=" * 50) 105 | print("Done!") 106 | print("=" * 50) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /examples/basic_voxelization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Basic Mesh Voxelization Example 4 | 5 | Demonstrates octree-accelerated mesh voxelization with SAT intersection. 6 | """ 7 | 8 | import torch 9 | import time 10 | 11 | import sys 12 | from pathlib import Path 13 | sys.path.insert(0, str(Path(__file__).parent.parent)) 14 | 15 | from atom3d import MeshBVH 16 | from atom3d.grid import OctreeIndexer 17 | 18 | 19 | def create_test_mesh(): 20 | """Create a simple icosphere mesh.""" 21 | try: 22 | import trimesh 23 | mesh = trimesh.creation.icosphere(subdivisions=3, radius=0.8) 24 | vertices = torch.tensor(mesh.vertices, dtype=torch.float32) 25 | faces = torch.tensor(mesh.faces, dtype=torch.int64) 26 | return vertices, faces 27 | except ImportError: 28 | raise RuntimeError("trimesh required: pip install trimesh") 29 | 30 | 31 | def main(): 32 | print("=" * 50) 33 | print("Atom3D: Basic Voxelization Example") 34 | print("=" * 50) 35 | 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | print(f"Device: {device}") 38 | 39 | # Load mesh 40 | vertices, faces = create_test_mesh() 41 | vertices = vertices.to(device) 42 | faces = faces.to(device) 43 | print(f"Mesh: {len(vertices)} vertices, {len(faces)} faces") 44 | 45 | # Create BVH 46 | print("\n[1] Building BVH...") 47 | t0 = time.time() 48 | bvh = MeshBVH(vertices, faces, device=device) 49 | print(f" Time: {time.time()-t0:.3f}s") 50 | print(f" Bounds: {bvh.get_bounds()[0].tolist()}") 51 | 52 | # Create octree 53 | max_level = 7 # 128^3 resolution 54 | print(f"\n[2] Creating octree (level {max_level}, res {2**max_level})...") 55 | octree = OctreeIndexer(max_level=max_level, device=device) 56 | 57 | # Octree traversal using BVH-accelerated broadphase 58 | print("\n[3] Octree traversal (BVH-accelerated broadphase)...") 59 | t0 = time.time() 60 | 61 | min_level = 4 62 | candidates = octree.octree_traverse(bvh, min_level=min_level) 63 | 64 | print(f" Candidates at level {max_level}: {len(candidates)}") 65 | print(f" Broadphase time: {time.time()-t0:.3f}s") 66 | 67 | # SAT intersection (narrowphase) 68 | print("\n[4] SAT intersection (narrowphase)...") 69 | t0 = time.time() 70 | 71 | voxel_min, voxel_max = octree.cube_aabb_level(candidates, max_level) 72 | result = bvh.intersect_aabb(voxel_min, voxel_max, mode=1) 73 | 74 | surface_voxels = candidates[result.hit] 75 | print(f" Surface voxels: {len(surface_voxels)}") 76 | print(f" SAT time: {time.time()-t0:.3f}s") 77 | print(f" Reduction: {len(candidates)} -> {len(surface_voxels)} ({100*(1-len(surface_voxels)/len(candidates)):.1f}%)") 78 | 79 | # UDF at voxel corners 80 | print("\n[5] UDF at voxel corners...") 81 | t0 = time.time() 82 | 83 | corners = octree.cube_corner_coords_level(surface_voxels[:100], max_level) # First 100 84 | corner_points = corners.reshape(-1, 3) 85 | 86 | udf_distances = bvh.udf(corner_points) # Returns tensor directly when no extras requested 87 | print(f" UDF range: [{udf_distances.min():.4f}, {udf_distances.max():.4f}]") 88 | print(f" UDF time: {time.time()-t0:.3f}s") 89 | 90 | # Polygon clipping (mode 2) 91 | print("\n[6] Polygon clipping (mode=2)...") 92 | t0 = time.time() 93 | 94 | clip_result = bvh.intersect_aabb(voxel_min[:100], voxel_max[:100], mode=2) 95 | if hasattr(clip_result, 'centroids'): 96 | print(f" Centroids shape: {clip_result.centroids.shape}") 97 | print(f" Areas range: [{clip_result.areas.min():.6f}, {clip_result.areas.max():.6f}]") 98 | print(f" Clip time: {time.time()-t0:.3f}s") 99 | 100 | print("\n" + "=" * 50) 101 | print("Done!") 102 | print("=" * 50) 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /atom3d/apps/udf_query.py: -------------------------------------------------------------------------------- 1 | """ 2 | UDFQuery: Unsigned Distance Field query with gradient support. 3 | """ 4 | 5 | from typing import Optional 6 | import torch 7 | 8 | from ..core.mesh_bvh import MeshBVH 9 | from ..core.data_structures import ClosestPointResult 10 | 11 | 12 | class UDFQuery: 13 | """ 14 | Unsigned Distance Field query application. 15 | 16 | Wraps MeshBVH.query_closest_point with: 17 | - Gradient support via autograd 18 | - Batch processing to avoid OOM 19 | 20 | Args: 21 | bvh: MeshBVH instance 22 | """ 23 | 24 | def __init__(self, bvh: MeshBVH): 25 | self.bvh = bvh 26 | 27 | def query( 28 | self, 29 | points: torch.Tensor, 30 | compute_grad: bool = False, 31 | batch_size: Optional[int] = None 32 | ) -> ClosestPointResult: 33 | """ 34 | Query unsigned distance field. 35 | 36 | Args: 37 | points: [N, 3] query points 38 | If compute_grad=True, should have requires_grad=True 39 | compute_grad: Whether to enable gradient computation 40 | batch_size: Batch size (None = process all at once) 41 | 42 | Returns: 43 | result: ClosestPointResult 44 | - distances: [N] float32 45 | - face_ids: [N] int32 46 | - closest_points: [N, 3] 47 | - uvw: [N, 3] 48 | """ 49 | if batch_size is not None and points.shape[0] > batch_size: 50 | return self._query_batched(points, compute_grad, batch_size) 51 | 52 | if compute_grad: 53 | return self._query_with_grad(points) 54 | else: 55 | return self.bvh.query_closest_point(points, return_uvw=True) 56 | 57 | def _query_batched( 58 | self, 59 | points: torch.Tensor, 60 | compute_grad: bool, 61 | batch_size: int 62 | ) -> ClosestPointResult: 63 | """Batched query to avoid OOM.""" 64 | N = points.shape[0] 65 | 66 | all_distances = [] 67 | all_face_ids = [] 68 | all_closest_points = [] 69 | all_uvw = [] 70 | 71 | for i in range(0, N, batch_size): 72 | batch_points = points[i:i+batch_size] 73 | 74 | if compute_grad: 75 | result = self._query_with_grad(batch_points) 76 | else: 77 | result = self.bvh.query_closest_point(batch_points, return_uvw=True) 78 | 79 | all_distances.append(result.distances) 80 | all_face_ids.append(result.face_ids) 81 | all_closest_points.append(result.closest_points) 82 | if result.uvw is not None: 83 | all_uvw.append(result.uvw) 84 | 85 | return ClosestPointResult( 86 | distances=torch.cat(all_distances), 87 | face_ids=torch.cat(all_face_ids), 88 | closest_points=torch.cat(all_closest_points), 89 | uvw=torch.cat(all_uvw) if all_uvw else None 90 | ) 91 | 92 | def _query_with_grad(self, points: torch.Tensor) -> ClosestPointResult: 93 | """ 94 | Query with gradient computation. 95 | 96 | Gradient: d(distance)/d(point) = (point - closest_point) / distance 97 | """ 98 | # Get closest points (no grad) 99 | with torch.no_grad(): 100 | result = self.bvh.query_closest_point(points, return_uvw=True) 101 | 102 | # Compute distances with gradient 103 | closest_points = result.closest_points.detach() 104 | diff = points - closest_points 105 | distances = diff.norm(dim=1) 106 | 107 | # If input requires grad, the output distances will have grad_fn 108 | return ClosestPointResult( 109 | distances=distances, 110 | face_ids=result.face_ids, 111 | closest_points=closest_points, 112 | uvw=result.uvw 113 | ) 114 | -------------------------------------------------------------------------------- /atom3d/core/data_structures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data structures for cuMTV 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import Optional 7 | import torch 8 | 9 | 10 | @dataclass 11 | class AABBIntersectResult: 12 | """AABB intersection result""" 13 | hit: torch.Tensor # [N] bool whether each AABB has collision 14 | aabb_ids: Optional[torch.Tensor] = None # [total_hits] int32 colliding AABB indices 15 | face_ids: Optional[torch.Tensor] = None # [total_hits] int32 colliding face indices 16 | centroids: Optional[torch.Tensor] = None # [total_hits, 3] float32 clipped polygon centroids (mode>=2) 17 | areas: Optional[torch.Tensor] = None # [total_hits] float32 clipped polygon areas (mode>=2) 18 | poly_verts: Optional[torch.Tensor] = None # [total_hits, 8, 3] float32 clipped polygon vertices (mode==3) 19 | poly_counts: Optional[torch.Tensor] = None # [total_hits] int32 polygon vertex counts (mode>=2) 20 | 21 | 22 | @dataclass 23 | class RayIntersectResult: 24 | """Ray intersection result""" 25 | hit: torch.Tensor # [N] bool 26 | t: torch.Tensor # [N] float32 (miss=inf) 27 | face_ids: torch.Tensor # [N] int32 (miss=-1) 28 | hit_points: torch.Tensor # [N, 3] 29 | normals: torch.Tensor # [N, 3] 30 | bary_coords: torch.Tensor # [N, 3] 31 | 32 | 33 | @dataclass 34 | class SegmentIntersectResult: 35 | """Segment intersection result""" 36 | hit: torch.Tensor # [N] bool 37 | hit_points: torch.Tensor # [N, 3] or [total, 3] 38 | face_ids: torch.Tensor # [N] or [total] int32 39 | bary_coords: torch.Tensor # [N, 3] or [total, 3] 40 | segment_ids: Optional[torch.Tensor] = None # [total] (if return_all=True) 41 | 42 | 43 | @dataclass 44 | class ClosestPointResult: 45 | """Closest point query result (UDF)""" 46 | distances: torch.Tensor # [N] float32 unsigned distance 47 | face_ids: torch.Tensor # [N] int32 closest face 48 | closest_points: torch.Tensor # [N, 3] 49 | uvw: Optional[torch.Tensor] = None # [N, 3] barycentric coordinates 50 | 51 | 52 | @dataclass 53 | class TriangleIntersectResult: 54 | """Triangle-triangle intersection result""" 55 | edge_hit: torch.Tensor # [num_edges] bool whether each edge intersects 56 | hit_points: torch.Tensor # [num_hits, 3] intersection point coordinates 57 | hit_face_ids: torch.Tensor # [num_hits] int32 hit faces in this mesh 58 | hit_edge_ids: torch.Tensor # [num_hits] int32 hit edges in other mesh 59 | 60 | 61 | @dataclass 62 | class VoxelFaceMapping: 63 | """Voxel-face mapping (CSR sparse format)""" 64 | voxel_coords: torch.Tensor # [K, 3] int32 65 | face_indices: torch.Tensor # [total] int32 66 | face_start: torch.Tensor # [K] int32 67 | face_count: torch.Tensor # [K] int32 68 | 69 | def get_faces_for_voxel(self, voxel_idx: int) -> torch.Tensor: 70 | """Get all faces intersecting the specified voxel""" 71 | start = self.face_start[voxel_idx].item() 72 | count = self.face_count[voxel_idx].item() 73 | return self.face_indices[start:start+count] 74 | 75 | 76 | @dataclass 77 | class VoxelPolygonMapping: 78 | """Voxel-polygon mapping (exact intersection region)""" 79 | voxel_coords: torch.Tensor # [K, 3] int32 80 | polygons: torch.Tensor # [total, max_verts, 3] float32 81 | polygon_counts: torch.Tensor # [total] int32 vertex count per polygon 82 | face_indices: torch.Tensor # [total] int32 83 | voxel_ids: torch.Tensor # [total] int32 84 | 85 | def get_polygon(self, idx: int) -> torch.Tensor: 86 | """Get intersection polygon at specified index""" 87 | count = self.polygon_counts[idx].item() 88 | return self.polygons[idx, :count] 89 | 90 | 91 | @dataclass 92 | class VisibilityResult: 93 | """Visibility query result""" 94 | visibility: torch.Tensor # [N] float32 visibility probability [0, 1] 95 | visible_mask: Optional[torch.Tensor] = None # [N, M] bool visibility of each point from each viewpoint 96 | hit_distances: Optional[torch.Tensor] = None # [N, M] float32 occlusion distance 97 | -------------------------------------------------------------------------------- /atom3d/apps/visibility_query.py: -------------------------------------------------------------------------------- 1 | """ 2 | VisibilityQuery: Visibility query application with statistical probability 3 | """ 4 | 5 | from typing import Union, Optional 6 | import torch 7 | 8 | from ..core.mesh_bvh import MeshBVH 9 | from ..core.data_structures import VisibilityResult 10 | 11 | 12 | class VisibilityQuery: 13 | """ 14 | 可见性查询应用 15 | 16 | = MeshBVH.intersect_ray + 统计概率 17 | 18 | 对任意query point查询其可见性,用多视角射线的统计概率表达 19 | 20 | Args: 21 | bvh: MeshBVH实例 22 | """ 23 | 24 | def __init__(self, bvh: MeshBVH): 25 | self.bvh = bvh 26 | 27 | def query( 28 | self, 29 | points: torch.Tensor, 30 | view_directions: torch.Tensor, 31 | return_details: bool = False 32 | ) -> Union[torch.Tensor, VisibilityResult]: 33 | """ 34 | 查询点的可见性(统计概率) 35 | 36 | 对任意query point,从多个视角方向发射射线检测遮挡 37 | 38 | Args: 39 | points: [N, 3] 任意查询点 40 | view_directions: [M, 3] 多个视角方向(归一化) 41 | return_details: 是否返回详细信息 42 | 43 | Returns: 44 | 如果return_details=False: 45 | visibility: [N] float32 可见性概率 [0, 1] 46 | 如果return_details=True: 47 | result: VisibilityResult 48 | """ 49 | N = points.shape[0] 50 | M = view_directions.shape[0] 51 | device = points.device 52 | 53 | visible_mask = torch.zeros(N, M, dtype=torch.bool, device=device) 54 | hit_distances = torch.zeros(N, M, device=device) 55 | 56 | for j in range(M): 57 | direction = view_directions[j] 58 | 59 | # All points use the same direction 60 | rays_o = points 61 | rays_d = direction.expand_as(points) 62 | 63 | result = self.bvh.intersect_ray(rays_o, rays_d) 64 | 65 | # Not hit = visible 66 | visible_mask[:, j] = ~result.hit 67 | hit_distances[:, j] = result.t 68 | 69 | # Compute visibility probability 70 | visibility = visible_mask.float().mean(dim=1) 71 | 72 | if return_details: 73 | return VisibilityResult( 74 | visibility=visibility, 75 | visible_mask=visible_mask, 76 | hit_distances=hit_distances 77 | ) 78 | else: 79 | return visibility 80 | 81 | def query_from_cameras( 82 | self, 83 | points: torch.Tensor, 84 | camera_positions: torch.Tensor 85 | ) -> torch.Tensor: 86 | """ 87 | 从相机位置查询可见性 88 | 89 | Args: 90 | points: [N, 3] 查询点 91 | camera_positions: [M, 3] 相机位置 92 | 93 | Returns: 94 | visibility: [N] float32 可见性概率 95 | = 能看到该点的相机比例 96 | """ 97 | N = points.shape[0] 98 | M = camera_positions.shape[0] 99 | device = points.device 100 | 101 | visible_count = torch.zeros(N, device=device) 102 | 103 | for cam_pos in camera_positions: 104 | # 从点向相机发射射线 105 | rays_o = points 106 | rays_d = cam_pos - points 107 | dists = rays_d.norm(dim=1, keepdim=True) 108 | rays_d = rays_d / (dists + 1e-8) 109 | 110 | # 检测遮挡 111 | result = self.bvh.intersect_ray(rays_o, rays_d, max_t=dists.squeeze().max().item()) 112 | 113 | # 如果未击中或击中距离 >= 到相机距离,则可见 114 | visible = ~result.hit | (result.t >= dists.squeeze() - 1e-4) 115 | visible_count += visible.float() 116 | 117 | return visible_count / M 118 | 119 | def query_uniform_sphere( 120 | self, 121 | points: torch.Tensor, 122 | num_samples: int = 32, 123 | seed: Optional[int] = None 124 | ) -> torch.Tensor: 125 | """ 126 | 均匀球面采样查询可见性 127 | 128 | Args: 129 | points: [N, 3] 查询点 130 | num_samples: int 球面采样数量 131 | seed: 随机种子(用于可重复性) 132 | 133 | Returns: 134 | visibility: [N] float32 可见性概率 135 | = 均匀球面上未被遮挡的方向比例 136 | 137 | 用途: 138 | - 无特定视角时的通用可见性度量 139 | - 可用于SDF符号判断(内部点可见性低) 140 | """ 141 | device = points.device 142 | 143 | if seed is not None: 144 | torch.manual_seed(seed) 145 | 146 | # Generate uniform sphere directions using Fibonacci spiral 147 | directions = self._fibonacci_sphere(num_samples, device) 148 | 149 | return self.query(points, directions, return_details=False) 150 | 151 | def _fibonacci_sphere(self, n: int, device: str) -> torch.Tensor: 152 | """Generate n uniformly distributed points on sphere""" 153 | indices = torch.arange(n, dtype=torch.float32, device=device) 154 | 155 | phi = torch.acos(1 - 2 * (indices + 0.5) / n) 156 | theta = torch.pi * (1 + 5**0.5) * indices 157 | 158 | x = torch.sin(phi) * torch.cos(theta) 159 | y = torch.sin(phi) * torch.sin(theta) 160 | z = torch.cos(phi) 161 | 162 | return torch.stack([x, y, z], dim=1) 163 | -------------------------------------------------------------------------------- /atom3d/kernels/bvh.py: -------------------------------------------------------------------------------- 1 | """ 2 | BVH Accelerator for Atom3D 3 | 4 | Python interface to BVH CUDA kernels. Provides accelerated: 5 | - UDF queries (closest point to mesh) 6 | - Ray-mesh intersection 7 | - AABB-mesh intersection (with exact SAT) 8 | """ 9 | 10 | import torch 11 | from torch.utils.cpp_extension import load 12 | import os 13 | 14 | # Get the CUDA kernel source 15 | _KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | _BUILD_DIR = os.path.join(_KERNEL_DIR, 'build', 'bvh') 17 | 18 | # Global cache for JIT compiled module 19 | _bvh_cuda = None 20 | 21 | def get_bvh_kernels(): 22 | """Get or compile the BVH CUDA kernels.""" 23 | global _bvh_cuda 24 | 25 | if _bvh_cuda is not None: 26 | return _bvh_cuda 27 | 28 | os.makedirs(_BUILD_DIR, exist_ok=True) 29 | 30 | kernel_path = os.path.join(_KERNEL_DIR, 'bvh_kernels.cu') 31 | 32 | _bvh_cuda = load( 33 | name='bvh_cuda', 34 | sources=[kernel_path], 35 | build_directory=_BUILD_DIR, 36 | extra_cuda_cflags=['-O3'], 37 | verbose=False 38 | ) 39 | 40 | return _bvh_cuda 41 | 42 | 43 | class BVHAccelerator: 44 | """ 45 | BVH-accelerated mesh queries. 46 | 47 | Provides O(log M) queries instead of O(M) brute-force: 48 | - udf: closest point to mesh 49 | - ray_intersect: ray-mesh intersection 50 | - aabb_intersect: AABB-mesh intersection with exact SAT 51 | 52 | All returned face_ids are in ORIGINAL mesh order (not BVH reordered order). 53 | """ 54 | 55 | def __init__( 56 | self, 57 | vertices: torch.Tensor, 58 | faces: torch.Tensor, 59 | n_primitives_per_leaf: int = 8 60 | ): 61 | """ 62 | Build BVH from mesh. 63 | 64 | Args: 65 | vertices: [N, 3] float32 vertices 66 | faces: [M, 3] int32 face indices 67 | n_primitives_per_leaf: Max triangles per leaf node 68 | """ 69 | self.device = vertices.device 70 | self.vertices = vertices.contiguous().float() 71 | self.faces = faces.contiguous().int() 72 | self.num_faces = faces.shape[0] 73 | 74 | # Build BVH - returns (nodes, triangles with original_id) 75 | cuda = get_bvh_kernels() 76 | result = cuda.build_bvh( 77 | self.vertices, 78 | self.faces, 79 | n_primitives_per_leaf 80 | ) 81 | self.nodes = result[0] # [num_nodes, 9] 82 | self.triangles = result[1] # [num_faces, 10] - includes original_id 83 | 84 | def udf( 85 | self, 86 | points: torch.Tensor 87 | ): 88 | """ 89 | Unsigned distance field query. 90 | 91 | Args: 92 | points: [K, 3] query points 93 | 94 | Returns: 95 | distances: [K] unsigned distances 96 | face_ids: [K] closest face indices (ORIGINAL order) 97 | closest_points: [K, 3] closest points on mesh 98 | uvw: [K, 3] barycentric coordinates 99 | """ 100 | cuda = get_bvh_kernels() 101 | distances, face_ids, closest_points, uvw = cuda.bvh_udf( 102 | self.nodes, 103 | self.triangles, 104 | points.contiguous().float() 105 | ) 106 | return distances, face_ids, closest_points, uvw 107 | 108 | def ray_intersect( 109 | self, 110 | rays_o: torch.Tensor, 111 | rays_d: torch.Tensor, 112 | max_t: float = 1e10 113 | ): 114 | """ 115 | Ray-mesh intersection. 116 | 117 | Args: 118 | rays_o: [K, 3] ray origins 119 | rays_d: [K, 3] ray directions 120 | max_t: Maximum ray distance 121 | 122 | Returns: 123 | hit_mask: [K] bool - whether ray hit mesh 124 | hit_t: [K] hit distance (max_t if no hit) 125 | face_ids: [K] hit face indices (ORIGINAL order, -1 if no hit) 126 | hit_points: [K, 3] hit positions 127 | """ 128 | cuda = get_bvh_kernels() 129 | hit_mask, hit_t, face_ids, hit_points = cuda.bvh_ray_intersect( 130 | self.nodes, 131 | self.triangles, 132 | rays_o.contiguous().float(), 133 | rays_d.contiguous().float(), 134 | max_t 135 | ) 136 | return hit_mask, hit_t, face_ids, hit_points 137 | 138 | def aabb_intersect( 139 | self, 140 | query_min: torch.Tensor, 141 | query_max: torch.Tensor 142 | ): 143 | """ 144 | AABB-mesh intersection with exact SAT test. 145 | 146 | Args: 147 | query_min: [K, 3] query AABB mins 148 | query_max: [K, 3] query AABB maxs 149 | 150 | Returns: 151 | hit_mask: [K] bool - whether AABB intersects mesh 152 | aabb_ids: [N] query indices for each intersection pair 153 | face_ids: [N] face indices (ORIGINAL order) for each pair 154 | """ 155 | cuda = get_bvh_kernels() 156 | hit_mask, aabb_ids, face_ids = cuda.bvh_aabb_intersect( 157 | self.nodes, 158 | self.triangles, 159 | query_min.contiguous().float(), 160 | query_max.contiguous().float() 161 | ) 162 | return hit_mask, aabb_ids, face_ids 163 | 164 | 165 | # Check if BVH kernels are available 166 | def bvh_available(): 167 | """Check if BVH CUDA kernels can be compiled.""" 168 | try: 169 | get_bvh_kernels() 170 | return True 171 | except Exception: 172 | return False 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Atom3D 2 | 3 | **Atomize Your 3D Meshes** — High-performance CUDA mesh geometry with internal BVH acceleration for voxelization, distance fields, and intersection queries. 4 | 5 |

6 | CUDA 7 | BVH 8 | PyTorch 9 | Python 10 | License 11 |

12 | 13 | ## Highlights 14 | 15 | - **🚀 85x speedup** — BVH-accelerated octree traversal vs brute-force 16 | - **📦 Zero external BVH dependencies** — Fully self-contained, ported from [cubvh](https://github.com/ashawkey/cubvh) 17 | - **⚡ All operations BVH-accelerated** — UDF, ray, segment, AABB intersection 18 | 19 | ## Features 20 | 21 | ### Core Geometry (`MeshBVH`) 22 | 23 | | Operation | Complexity | BVH Accelerated | 24 | |-----------|------------|-----------------| 25 | | **Triangle-AABB intersection** | O(N log M) | ✅ Exact SAT | 26 | | **Polygon clipping** | O(N log M) | ✅ Broadphase | 27 | | **UDF/SDF queries** | O(N log M) | ✅ Closest point | 28 | | **Ray intersection** | O(N log M) | ✅ Möller-Trumbore | 29 | | **Segment intersection** | O(N log M) | ✅ Via ray | 30 | 31 | ### Spatial Indexing (`OctreeIndexer`) 32 | - **Octree-accelerated voxelization** — Hierarchical coarse-to-fine surface detection 33 | - **Multi-resolution traversal** — Efficient broadphase filtering from coarse to fine levels 34 | - **Cube topology** — Vertex, edge, and face indexing for primal/dual grids 35 | 36 | ### CUDA Kernels 37 | - **`bvh_kernels.cu`** — Internal BVH: build, UDF, ray, AABB intersection 38 | - **`cumtv_kernels.cu`** — SAT clip polygon, segment-triangle intersection 39 | 40 | ## Installation 41 | 42 | ```bash 43 | git clone https://github.com/your-org/Atom3D.git 44 | cd Atom3D 45 | pip install -e . --no-build-isolation 46 | ``` 47 | 48 | **Requirements:** Python ≥ 3.8, PyTorch ≥ 2.0, CUDA ≥ 11.0 49 | 50 | **Optional:** `pip install trimesh pyvista` 51 | 52 | ## Quick Start 53 | 54 | ### Surface Voxelization 55 | 56 | ```python 57 | import torch 58 | from atom3d import MeshBVH 59 | from atom3d.grid import OctreeIndexer 60 | 61 | # Load mesh 62 | bvh = MeshBVH(vertices.cuda(), faces.cuda(), device='cuda') 63 | 64 | # Create octree (256³ max resolution) 65 | octree = OctreeIndexer(max_level=8, device='cuda') 66 | 67 | # BVH-accelerated octree traversal (85x faster than brute-force!) 68 | candidates = octree.octree_traverse(bvh, min_level=4) 69 | 70 | # Narrowphase: precise SAT intersection 71 | voxel_min, voxel_max = octree.cube_aabb_level(candidates) 72 | result = bvh.intersect_aabb(voxel_min, voxel_max, mode=1) 73 | surface_voxels = candidates[result.hit] 74 | ``` 75 | 76 | ### Polygon Clipping 77 | 78 | ```python 79 | # Mode 2: Get clipped polygon centroids and areas 80 | result = bvh.intersect_aabb(voxel_min, voxel_max, mode=2) 81 | 82 | # result.aabb_ids: which voxel each hit belongs to 83 | # result.face_ids: which triangle was intersected 84 | # result.centroids: [N, 3] clipped polygon centroids 85 | # result.areas: [N] clipped polygon areas 86 | ``` 87 | 88 | ### UDF/SDF Query 89 | 90 | ```python 91 | points = torch.randn(1000, 3, device='cuda', requires_grad=True) 92 | 93 | # Unsigned distance with closest point (BVH-accelerated) 94 | result = bvh.udf(points, return_closest=True, return_uvw=True) 95 | # result.distances, result.closest_points, result.uvw 96 | 97 | # Signed distance (requires watertight mesh) 98 | distances = bvh.sdf(points) 99 | 100 | # Gradient support 101 | result = bvh.udf(points, return_grad=True) 102 | result.distances.mean().backward() 103 | ``` 104 | 105 | ### Ray Intersection 106 | 107 | ```python 108 | rays_o = torch.randn(1000, 3, device='cuda') 109 | rays_d = torch.randn(1000, 3, device='cuda') 110 | rays_d = rays_d / rays_d.norm(dim=1, keepdim=True) 111 | 112 | # BVH-accelerated ray-mesh intersection 113 | result = bvh.intersect_ray(rays_o, rays_d) 114 | # result.hit, result.t, result.face_ids, result.hit_points 115 | ``` 116 | 117 | ## Performance 118 | 119 | Benchmarks on robot.glb (687K faces) with FaithC_v2 encoding: 120 | 121 | | Resolution | Before BVH | After BVH | Speedup | 122 | |------------|-----------|-----------|---------| 123 | | 256 | 15.5s | 0.18s | **85x** | 124 | | 512 | 10.6s | 2.2s | **4.8x** | 125 | | 2048 | 210s | 32.5s | **6.5x** | 126 | 127 | ## API Reference 128 | 129 | ### MeshBVH 130 | 131 | | Method | Description | 132 | |--------|-------------| 133 | | `intersect_aabb(min, max, mode)` | Triangle-AABB SAT intersection. mode: 0=hit, 1=pairs, 2=clip | 134 | | `udf(points, ...)` | BVH-accelerated unsigned distance field query | 135 | | `sdf(points, ...)` | Signed distance field (watertight mesh required) | 136 | | `intersect_ray(o, d, max_t)` | BVH-accelerated ray-mesh intersection | 137 | | `intersect_segment(start, end)` | Segment-mesh intersection | 138 | | `get_bounds()` | Mesh AABB bounds | 139 | | `get_face_aabb()` | Per-triangle AABBs | 140 | 141 | ### OctreeIndexer 142 | 143 | | Method | Description | 144 | |--------|-------------| 145 | | `octree_traverse(bvh, min_level)` | BVH-accelerated hierarchical broadphase | 146 | | `cube_aabb_level(cubes, level)` | Get voxel AABB at level | 147 | | `ijk_to_cube(ijk)` | Grid coords to linear index | 148 | | `cube_to_ijk(idx)` | Linear index to grid coords | 149 | | `get_cell_size(level)` | Voxel size at level | 150 | 151 | ### Data Structures 152 | 153 | ```python 154 | # AABBIntersectResult 155 | result.hit # [N] bool 156 | result.aabb_ids # [H] int - which AABB 157 | result.face_ids # [H] int - which triangle 158 | result.centroids # [H, 3] float (mode >= 2) 159 | result.areas # [H] float (mode >= 2) 160 | 161 | # ClosestPointResult 162 | result.distances # [N] float 163 | result.closest_points # [N, 3] float 164 | result.face_ids # [N] int 165 | result.uvw # [N, 3] barycentric coords 166 | ``` 167 | 168 | ## Examples 169 | 170 | ```bash 171 | python examples/basic_voxelization.py 172 | python examples/udf_gradient.py 173 | ``` 174 | 175 | ## Acknowledgements 176 | 177 | This project builds upon excellent open-source work: 178 | 179 | - **[cubvh](https://github.com/ashawkey/cubvh)** — BVH implementation reference (ported and extended internally) 180 | - **[diso](https://github.com/SarahWeiii/diso)** — Differentiable isosurface extraction 181 | - **[FlexiCubes](https://github.com/nv-tlabs/FlexiCubes)** — NVIDIA's flexible isosurface extraction 182 | - **[instant-ngp](https://github.com/NVlabs/instant-ngp)** — NVIDIA's instant neural graphics primitives 183 | 184 | ## License 185 | 186 | [MIT](LICENSE) 187 | -------------------------------------------------------------------------------- /atom3d/apps/sdf_query.py: -------------------------------------------------------------------------------- 1 | """ 2 | SDFQuery: Signed Distance Field query with multiple methods 3 | 4 | Provides SDF computation using winding number, flood fill, or ray stabbing. 5 | """ 6 | 7 | from typing import Optional 8 | import torch 9 | 10 | from ..core.mesh_bvh import MeshBVH 11 | from ..grid.cube_grid import CubeGrid 12 | from .flood_fill import FloodFill 13 | 14 | 15 | class SDFQuery: 16 | """ 17 | SDF query application. 18 | 19 | Combines MeshBVH.query_closest_point with sign determination algorithms. 20 | 21 | Supported methods: 22 | - winding: Winding number (exact, requires watertight mesh) 23 | - flood: Flood fill (works with open meshes) 24 | - raystab: Ray stabbing (robust but slow) 25 | 26 | Args: 27 | bvh: MeshBVH instance 28 | """ 29 | 30 | def __init__(self, bvh: MeshBVH): 31 | self.bvh = bvh 32 | 33 | def query_winding(self, points: torch.Tensor) -> torch.Tensor: 34 | """ 35 | SDF via Winding Number. 36 | 37 | Args: 38 | points: [N, 3] 39 | 40 | Returns: 41 | sdf: [N] float32 (negative inside, positive outside) 42 | 43 | Suitable for: watertight or near-watertight meshes 44 | """ 45 | N = points.shape[0] 46 | device = points.device 47 | 48 | # Get UDF 49 | result = self.bvh.query_closest_point(points, return_uvw=False) 50 | distances = result.distances 51 | 52 | # Compute winding number 53 | winding = self._compute_winding_number(points) 54 | 55 | # Inside if winding > 0.5 56 | inside = winding > 0.5 57 | 58 | # SDF = distance * sign 59 | sdf = distances.clone() 60 | sdf[inside] = -sdf[inside] 61 | 62 | return sdf 63 | 64 | def _compute_winding_number(self, points: torch.Tensor) -> torch.Tensor: 65 | """ 66 | Compute generalized winding number. 67 | 68 | For each point, sum solid angles subtended by all triangles. 69 | """ 70 | N = points.shape[0] 71 | device = points.device 72 | 73 | # Get triangle vertices 74 | tri_verts = self.bvh.vertices[self.bvh.faces] # [M, 3, 3] 75 | v0, v1, v2 = tri_verts[:, 0], tri_verts[:, 1], tri_verts[:, 2] 76 | 77 | winding = torch.zeros(N, device=device) 78 | 79 | for i in range(N): 80 | p = points[i] 81 | 82 | # Vectors from point to vertices 83 | a = v0 - p 84 | b = v1 - p 85 | c = v2 - p 86 | 87 | # Normalize 88 | la = a.norm(dim=1, keepdim=True) + 1e-8 89 | lb = b.norm(dim=1, keepdim=True) + 1e-8 90 | lc = c.norm(dim=1, keepdim=True) + 1e-8 91 | 92 | a = a / la 93 | b = b / lb 94 | c = c / lc 95 | 96 | # Solid angle formula 97 | det = (a * torch.cross(b, c)).sum(dim=1) 98 | denom = 1 + (a * b).sum(dim=1) + (b * c).sum(dim=1) + (c * a).sum(dim=1) 99 | 100 | solid_angle = 2 * torch.atan2(det, denom) 101 | winding[i] = solid_angle.sum() / (4 * torch.pi) 102 | 103 | return winding 104 | 105 | def query_flood( 106 | self, 107 | points: torch.Tensor, 108 | voxel_coords: torch.Tensor, 109 | grid: CubeGrid 110 | ) -> torch.Tensor: 111 | """ 112 | SDF via Flood Fill. 113 | 114 | Args: 115 | points: [N, 3] 116 | voxel_coords: [K, 3] surface voxel coordinates 117 | grid: CubeGrid 118 | 119 | Returns: 120 | sdf: [N] float32 121 | 122 | Suitable for: open meshes 123 | """ 124 | # Get UDF 125 | result = self.bvh.query_closest_point(points, return_uvw=False) 126 | distances = result.distances 127 | 128 | # Perform flood fill to get inside/outside labels 129 | labels = FloodFill.fill(voxel_coords, grid) 130 | 131 | # Convert points to grid coords and check labels 132 | grid_coords = grid.world_to_grid(points).floor().int() 133 | grid_coords = grid_coords.clamp(0, grid.res - 1) 134 | 135 | # Create lookup 136 | exterior_label = 0 # Connected to seed (assumed exterior) 137 | 138 | # For each point, check if it's inside or outside 139 | # Simplified: interpolate from voxel labels 140 | sdf = distances.clone() 141 | 142 | # TODO: Proper interpolation from flood fill labels 143 | # For now, use visibility-based estimation 144 | 145 | return sdf 146 | 147 | def query_raystab( 148 | self, 149 | points: torch.Tensor, 150 | num_rays: int = 8, 151 | seed: Optional[int] = None 152 | ) -> torch.Tensor: 153 | """ 154 | SDF via Ray Stabbing. 155 | 156 | Shoots multiple rays from each point and counts intersection parity. 157 | 158 | Args: 159 | points: [N, 3] 160 | num_rays: Number of rays per point 161 | seed: Random seed 162 | 163 | Returns: 164 | sdf: [N] float32 165 | 166 | Suitable for: open, non-manifold meshes 167 | """ 168 | N = points.shape[0] 169 | device = points.device 170 | 171 | if seed is not None: 172 | torch.manual_seed(seed) 173 | 174 | # Get UDF first 175 | result = self.bvh.query_closest_point(points, return_uvw=False) 176 | distances = result.distances 177 | 178 | # Generate random directions 179 | directions = torch.randn(num_rays, 3, device=device) 180 | directions = directions / directions.norm(dim=1, keepdim=True) 181 | 182 | # Count intersections for each point 183 | inside_votes = torch.zeros(N, device=device) 184 | 185 | for direction in directions: 186 | rays_o = points 187 | rays_d = direction.expand_as(points) 188 | 189 | # Count intersections (simplified: just check if hits) 190 | result_ray = self.bvh.intersect_ray(rays_o, rays_d) 191 | 192 | # If hit, it's either entering or leaving 193 | # Odd number of hits = inside 194 | inside_votes += result_ray.hit.float() 195 | 196 | # Majority vote: if more than half rays hit, likely inside 197 | # (This is a simplification; proper impl would count all intersections) 198 | inside = (inside_votes / num_rays) > 0.5 199 | 200 | sdf = distances.clone() 201 | sdf[inside] = -sdf[inside] 202 | 203 | return sdf 204 | -------------------------------------------------------------------------------- /atom3d/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | cuMTV CUDA Kernels Python Interface 3 | 4 | Provides JIT compilation of CUDA kernels using torch.utils.cpp_extension 5 | """ 6 | 7 | import os 8 | import torch 9 | from torch.utils.cpp_extension import load 10 | 11 | _kernel_loaded = False 12 | _cumtv_cuda = None 13 | 14 | 15 | def get_cuda_kernels(): 16 | """ 17 | Load and compile CUDA kernels using JIT compilation 18 | 19 | Returns: 20 | Compiled CUDA module with functions: 21 | - triangle_aabb_intersect(vertices, faces, aabb_min, aabb_max) 22 | - ray_mesh_intersect(vertices, faces, rays_o, rays_d, max_t) 23 | - point_mesh_udf(vertices, faces, points) 24 | - segment_tri_intersect(seg_verts, tri_verts, tri_aabb_min, tri_aabb_max, eps) 25 | """ 26 | global _kernel_loaded, _cumtv_cuda 27 | 28 | if _kernel_loaded and _cumtv_cuda is not None: 29 | return _cumtv_cuda 30 | 31 | # Get kernel source directory 32 | kernel_dir = os.path.dirname(os.path.abspath(__file__)) 33 | kernel_file = os.path.join(kernel_dir, 'cumtv_kernels.cu') 34 | 35 | if not os.path.exists(kernel_file): 36 | raise RuntimeError(f"CUDA kernel file not found: {kernel_file}") 37 | 38 | # Build directory 39 | build_dir = os.path.join(kernel_dir, 'build') 40 | os.makedirs(build_dir, exist_ok=True) 41 | 42 | # JIT compile 43 | _cumtv_cuda = load( 44 | name='cumtv_cuda', 45 | sources=[kernel_file], 46 | build_directory=build_dir, 47 | extra_cuda_cflags=['-O3', '--use_fast_math'], 48 | verbose=False 49 | ) 50 | 51 | _kernel_loaded = True 52 | return _cumtv_cuda 53 | 54 | 55 | # Convenience functions 56 | 57 | def triangle_aabb_intersect(vertices, faces, aabb_min, aabb_max): 58 | """ 59 | CUDA Triangle-AABB intersection using SAT 60 | 61 | Args: 62 | vertices: [N, 3] float32 63 | faces: [M, 3] int32 64 | aabb_min: [K, 3] float32 65 | aabb_max: [K, 3] float32 66 | 67 | Returns: 68 | hit_mask: [K] bool 69 | aabb_ids: [num_hits] int32 70 | face_ids: [num_hits] int32 71 | """ 72 | cuda = get_cuda_kernels() 73 | return cuda.triangle_aabb_intersect( 74 | vertices.contiguous().float(), 75 | faces.contiguous().int(), 76 | aabb_min.contiguous().float(), 77 | aabb_max.contiguous().float() 78 | ) 79 | 80 | 81 | def ray_mesh_intersect(vertices, faces, rays_o, rays_d, max_t=1e10): 82 | """ 83 | CUDA Ray-Mesh intersection 84 | 85 | Args: 86 | vertices: [N, 3] float32 87 | faces: [M, 3] int32 88 | rays_o: [K, 3] float32 89 | rays_d: [K, 3] float32 90 | max_t: float 91 | 92 | Returns: 93 | hit_mask: [K] bool 94 | hit_t: [K] float32 95 | hit_face_ids: [K] int32 96 | hit_points: [K, 3] float32 97 | hit_uvs: [K, 2] float32 98 | """ 99 | cuda = get_cuda_kernels() 100 | return cuda.ray_mesh_intersect( 101 | vertices.contiguous().float(), 102 | faces.contiguous().int(), 103 | rays_o.contiguous().float(), 104 | rays_d.contiguous().float(), 105 | float(max_t) 106 | ) 107 | 108 | 109 | def point_mesh_udf(vertices, faces, points): 110 | """ 111 | CUDA Point-Mesh UDF query 112 | 113 | Args: 114 | vertices: [N, 3] float32 115 | faces: [M, 3] int32 116 | points: [K, 3] float32 117 | 118 | Returns: 119 | distances: [K] float32 120 | closest_face_ids: [K] int32 121 | closest_points: [K, 3] float32 122 | uvw: [K, 3] float32 123 | """ 124 | cuda = get_cuda_kernels() 125 | return cuda.point_mesh_udf( 126 | vertices.contiguous().float(), 127 | faces.contiguous().int(), 128 | points.contiguous().float() 129 | ) 130 | 131 | 132 | def segment_tri_intersect(seg_verts, tri_verts, tri_aabb_min, tri_aabb_max, eps=1e-8): 133 | """ 134 | CUDA Segment-Triangle intersection 135 | 136 | Args: 137 | seg_verts: [N_seg, 6] float32 (p0, p1) 138 | tri_verts: [N_tri, 9] float32 (v0, v1, v2) 139 | tri_aabb_min: [N_tri, 3] float32 140 | tri_aabb_max: [N_tri, 3] float32 141 | eps: float 142 | 143 | Returns: 144 | seg_ids: [num_hits] int64 145 | tri_ids: [num_hits] int64 146 | t: [num_hits] float32 147 | """ 148 | cuda = get_cuda_kernels() 149 | return cuda.segment_tri_intersect( 150 | seg_verts.contiguous().float(), 151 | tri_verts.contiguous().float(), 152 | tri_aabb_min.contiguous().float(), 153 | tri_aabb_max.contiguous().float(), 154 | float(eps) 155 | ) 156 | 157 | 158 | def sat_clip_polygon(aabbs_min, aabbs_max, tris_verts, cand_a, cand_t, mode=1, eps=1e-8): 159 | """ 160 | CUDA SAT Clip Polygon - compute intersection polygon, centroid, area 161 | 162 | For each (AABB, triangle) candidate pair, clips the triangle against the 163 | AABB using Sutherland-Hodgman algorithm and outputs the resulting polygon. 164 | 165 | Args: 166 | aabbs_min: [K, 3] float32 - AABB min bounds 167 | aabbs_max: [K, 3] float32 - AABB max bounds 168 | tris_verts: [M, 9] float32 - Triangle vertices (v0, v1, v2 flattened) 169 | cand_a: [N] int64 - Candidate AABB indices 170 | cand_t: [N] int64 - Candidate triangle indices 171 | mode: int - Output mode: 172 | 0 = hit mask only 173 | 1 = hit mask + centroid + area (default) 174 | 2 = hit mask + centroid + area + full polygon vertices 175 | eps: float - Tolerance 176 | 177 | Returns: 178 | hit_mask: [N] bool - True if intersection exists 179 | poly_counts: [N] int32 - Number of vertices in clipped polygon 180 | poly_verts: [N, 8, 3] float32 - Polygon vertices (only if mode=2) 181 | centroids: [N, 3] float32 - Centroid of clipped polygon (projected to triangle) 182 | areas: [N] float32 - Area of clipped polygon 183 | out_a_idx: [N] int64 - AABB indices 184 | out_t_idx: [N] int64 - Triangle indices 185 | """ 186 | cuda = get_cuda_kernels() 187 | return cuda.sat_clip_polygon( 188 | aabbs_min.contiguous().float(), 189 | aabbs_max.contiguous().float(), 190 | tris_verts.contiguous().float(), 191 | cand_a.contiguous().long(), 192 | cand_t.contiguous().long(), 193 | int(mode), 194 | float(eps) 195 | ) 196 | 197 | 198 | # Check if CUDA is available 199 | def cuda_available(): 200 | """Check if CUDA kernels can be compiled and used""" 201 | if not torch.cuda.is_available(): 202 | return False 203 | try: 204 | get_cuda_kernels() 205 | return True 206 | except Exception as e: 207 | print(f"Warning: CUDA kernels not available: {e}") 208 | return False 209 | -------------------------------------------------------------------------------- /atom3d/apps/flood_fill.py: -------------------------------------------------------------------------------- 1 | """ 2 | FloodFill: Flood fill for voxel connectivity analysis 3 | 4 | Used to distinguish interior/exterior voxels for solid voxelization. 5 | """ 6 | 7 | from typing import Optional 8 | import torch 9 | 10 | from ..grid.cube_grid import CubeGrid 11 | 12 | 13 | class FloodFill: 14 | """ 15 | Flood fill application for connectivity analysis. 16 | 17 | Uses CubeGrid for boundary checking and coordinate conversion. 18 | """ 19 | 20 | @staticmethod 21 | def fill( 22 | voxel_coords: torch.Tensor, 23 | grid: CubeGrid, 24 | seed: Optional[torch.Tensor] = None 25 | ) -> torch.Tensor: 26 | """ 27 | Connected component labeling via flood fill. 28 | 29 | Args: 30 | voxel_coords: [N, 3] int32 occupied voxel coordinates 31 | grid: CubeGrid (for boundary checking) 32 | seed: [3] seed point (default [0,0,0] as exterior) 33 | 34 | Returns: 35 | labels: [N] int32 connected component labels 36 | label=0 means connected to seed (exterior region) 37 | label>0 means interior connected components (may have multiple) 38 | """ 39 | N = voxel_coords.shape[0] 40 | device = voxel_coords.device 41 | resolution = grid.res 42 | 43 | if seed is None: 44 | seed = torch.tensor([0, 0, 0], device=device) 45 | 46 | # Create 3D occupancy grid 47 | occupied = torch.zeros(resolution, resolution, resolution, dtype=torch.bool, device=device) 48 | 49 | # Mark occupied voxels 50 | valid_mask = ( 51 | (voxel_coords[:, 0] >= 0) & (voxel_coords[:, 0] < resolution) & 52 | (voxel_coords[:, 1] >= 0) & (voxel_coords[:, 1] < resolution) & 53 | (voxel_coords[:, 2] >= 0) & (voxel_coords[:, 2] < resolution) 54 | ) 55 | valid_coords = voxel_coords[valid_mask] 56 | occupied[valid_coords[:, 0], valid_coords[:, 1], valid_coords[:, 2]] = True 57 | 58 | # Initialize labels 59 | labels_3d = torch.full((resolution, resolution, resolution), -1, dtype=torch.int32, device=device) 60 | labels_3d[occupied] = 0 # Occupied voxels start with label 0 61 | 62 | # BFS from seed 63 | visited = torch.zeros_like(occupied) 64 | current_label = 0 65 | 66 | # 6-connectivity offsets 67 | offsets = torch.tensor([ 68 | [1, 0, 0], [-1, 0, 0], 69 | [0, 1, 0], [0, -1, 0], 70 | [0, 0, 1], [0, 0, -1] 71 | ], device=device) 72 | 73 | # Start flood fill from seed (marking exterior) 74 | if not occupied[seed[0], seed[1], seed[2]]: 75 | queue = [seed.clone()] 76 | visited[seed[0], seed[1], seed[2]] = True 77 | labels_3d[seed[0], seed[1], seed[2]] = 0 # Exterior 78 | 79 | while queue: 80 | current = queue.pop(0) 81 | 82 | for offset in offsets: 83 | neighbor = current + offset 84 | 85 | # Check bounds 86 | if (neighbor >= 0).all() and (neighbor < resolution).all(): 87 | nx, ny, nz = neighbor[0].item(), neighbor[1].item(), neighbor[2].item() 88 | 89 | if not visited[nx, ny, nz] and not occupied[nx, ny, nz]: 90 | visited[nx, ny, nz] = True 91 | labels_3d[nx, ny, nz] = 0 # Exterior 92 | queue.append(neighbor) 93 | 94 | # Extract labels for input coordinates 95 | labels = torch.zeros(N, dtype=torch.int32, device=device) 96 | labels[valid_mask] = labels_3d[valid_coords[:, 0], valid_coords[:, 1], valid_coords[:, 2]] 97 | 98 | return labels 99 | 100 | @staticmethod 101 | def get_interior_voxels( 102 | voxel_coords: torch.Tensor, 103 | grid: CubeGrid, 104 | seed: Optional[torch.Tensor] = None 105 | ) -> torch.Tensor: 106 | """ 107 | Get interior voxel coordinates. 108 | 109 | Args: 110 | voxel_coords: [N, 3] surface voxel coordinates 111 | grid: CubeGrid 112 | seed: [3] exterior seed point 113 | 114 | Returns: 115 | interior_coords: [K, 3] interior voxel coordinates 116 | (empty voxels not connected to exterior) 117 | """ 118 | device = voxel_coords.device 119 | resolution = grid.res 120 | 121 | if seed is None: 122 | seed = torch.tensor([0, 0, 0], device=device) 123 | 124 | # Create occupancy grid from surface voxels 125 | occupied = torch.zeros(resolution, resolution, resolution, dtype=torch.bool, device=device) 126 | 127 | valid_mask = ( 128 | (voxel_coords[:, 0] >= 0) & (voxel_coords[:, 0] < resolution) & 129 | (voxel_coords[:, 1] >= 0) & (voxel_coords[:, 1] < resolution) & 130 | (voxel_coords[:, 2] >= 0) & (voxel_coords[:, 2] < resolution) 131 | ) 132 | valid_coords = voxel_coords[valid_mask] 133 | if valid_coords.shape[0] > 0: 134 | occupied[valid_coords[:, 0], valid_coords[:, 1], valid_coords[:, 2]] = True 135 | 136 | # Mark exterior via flood fill 137 | exterior = torch.zeros_like(occupied) 138 | 139 | offsets = torch.tensor([ 140 | [1, 0, 0], [-1, 0, 0], 141 | [0, 1, 0], [0, -1, 0], 142 | [0, 0, 1], [0, 0, -1] 143 | ], device=device) 144 | 145 | if not occupied[seed[0], seed[1], seed[2]]: 146 | queue = [seed.clone()] 147 | exterior[seed[0], seed[1], seed[2]] = True 148 | 149 | while queue: 150 | current = queue.pop(0) 151 | 152 | for offset in offsets: 153 | neighbor = current + offset 154 | 155 | if (neighbor >= 0).all() and (neighbor < resolution).all(): 156 | nx, ny, nz = neighbor[0].item(), neighbor[1].item(), neighbor[2].item() 157 | 158 | if not exterior[nx, ny, nz] and not occupied[nx, ny, nz]: 159 | exterior[nx, ny, nz] = True 160 | queue.append(neighbor) 161 | 162 | # Interior = not occupied and not exterior 163 | interior = ~occupied & ~exterior 164 | 165 | # Convert to coordinates 166 | interior_indices = torch.where(interior) 167 | interior_coords = torch.stack([ 168 | interior_indices[0], 169 | interior_indices[1], 170 | interior_indices[2] 171 | ], dim=1).int() 172 | 173 | return interior_coords 174 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "d3252283", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/opt/conda/lib/python3.11/site-packages/pyvista/plotting/utilities/xvfb.py:48: PyVistaDeprecationWarning: This function is deprecated and will be removed in future version of PyVista. Use vtk-osmesa instead.\n", 14 | " warnings.warn(\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import torch\n", 20 | "from atom3d import MeshBVH\n", 21 | "from atom3d.grid import CubeGrid, OctreeIndexer\n", 22 | "\n", 23 | "import trimesh\n", 24 | "import pyvista as pv\n", 25 | "pv.start_xvfb()\n", 26 | "pv.set_jupyter_backend('html')\n", 27 | "\n", 28 | "import numpy as np\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "2f068874", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "meshtem = trimesh.creation.icosphere(subdivisions=1)\n", 39 | "V = torch.tensor(meshtem.vertices, dtype=torch.float32)\n", 40 | "F = torch.tensor(meshtem.faces, dtype=torch.int32)\n", 41 | "\n", 42 | "bvh = MeshBVH(V, F, device='cuda')\n", 43 | "res = 512\n", 44 | "grid_indexer = OctreeIndexer(max_level=int(np.log2(res)), bounds=bvh.get_bounds(), device='cuda')\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "id": "22be49e8", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0')\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "candidates_ijk = grid_indexer.octree_traverse(bvh)\n", 63 | "candidates_idx = grid_indexer.ijk_to_cube(candidates_ijk)\n", 64 | "\n", 65 | "\n", 66 | "vertex_unique_idx, unique_coords, mapping = grid_indexer.voxel_unique_vertices(candidates_idx)\n", 67 | "\n", 68 | "unique_coords = unique_coords.cuda()\n", 69 | "unique_coords.requires_grad = True\n", 70 | "\n", 71 | "udfs = bvh.udf(unique_coords, return_grad=True)\n", 72 | "\n", 73 | "udfs_grad = torch.autograd.grad(udfs.sum(), unique_coords)[0]\n", 74 | "\n", 75 | "print(udfs_grad.norm(dim=1))\n", 76 | "\n" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "id": "7bed4931", 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "data": { 87 | "application/vnd.jupyter.widget-view+json": { 88 | "model_id": "06297a0cbb42428ab318464d0a0a2e6d", 89 | "version_major": 2, 90 | "version_minor": 0 91 | }, 92 | "text/plain": [ 93 | "EmbeddableWidget(value='