├── pyproject.toml ├── pytest.ini ├── nebula ├── __init__.py ├── evaluators │ ├── __init__.py │ └── bspline.py ├── topology │ ├── __init__.py │ ├── topology.py │ ├── shells.py │ ├── solids.py │ ├── faces.py │ ├── edges.py │ └── wires.py ├── prim │ ├── __init__.py │ ├── bspline_surfaces.py │ ├── sparse.py │ ├── lines.py │ ├── axes.py │ └── bspline_curves.py ├── helpers │ ├── types.py │ ├── vector.py │ ├── intersection.py │ ├── wire.py │ └── clipper.py ├── tools │ ├── edge.py │ └── solid.py ├── render │ ├── visualization.py │ └── tesselation.py ├── cases │ └── airfoil.py └── workplane.py ├── assets ├── example1_width.png └── example1_arc_height.png ├── requirements_dev.txt ├── .pylintrc ├── .github └── dependabot.yml ├── .vscode ├── launch.json └── settings.json ├── .gitpod.yml ├── setup.py ├── .devcontainer └── devcontainer.json ├── tests ├── test_intersection.py └── test_bspline.py ├── LICENSE ├── TODO ├── README.md └── .gitignore /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 90 -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | pythonpath = . nebula -------------------------------------------------------------------------------- /nebula/__init__.py: -------------------------------------------------------------------------------- 1 | from .workplane import Workplane -------------------------------------------------------------------------------- /nebula/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .bspline import BSplineEvaluator -------------------------------------------------------------------------------- /assets/example1_width.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenOrion/nebula/HEAD/assets/example1_width.png -------------------------------------------------------------------------------- /assets/example1_arc_height.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenOrion/nebula/HEAD/assets/example1_arc_height.png -------------------------------------------------------------------------------- /nebula/topology/__init__.py: -------------------------------------------------------------------------------- 1 | from .edges import Edges 2 | from .wires import Wires 3 | from .faces import Faces 4 | from .shells import Shells 5 | from .solids import Solids -------------------------------------------------------------------------------- /nebula/prim/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparse import SparseArray, SparseIndexable 2 | from .bspline_curves import BSplineCurves 3 | from .lines import Lines 4 | from .axes import Axes 5 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | plotly==5.11.0 2 | ipywidgets==7.6 3 | jupyterlab 4 | matplotlib 5 | pytest 6 | 7 | numpy 8 | ipython_genutils 9 | jupyter_cadquery 10 | jax[cpu] 11 | jax_dataclasses 12 | kaleido -------------------------------------------------------------------------------- /nebula/topology/topology.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, Union 2 | from nebula.prim.sparse import SparseIndexable 3 | 4 | class Topology(SparseIndexable, Protocol): 5 | 6 | def add(self, topology: "Topology", reorder_index: bool = False): ... 7 | -------------------------------------------------------------------------------- /nebula/helpers/types.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import jax.numpy as jnp 3 | 4 | Number = Union[float, jnp.ndarray] 5 | CoordLike = Union[ 6 | list[tuple[Number, Number, Number]], list[tuple[Number, Number]], jnp.ndarray 7 | ] 8 | ArrayLike = Union[list[float], jnp.ndarray] 9 | -------------------------------------------------------------------------------- /nebula/helpers/vector.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | class VectorHelper: 4 | @staticmethod 5 | def normalize(arr: jnp.ndarray) -> jnp.ndarray: 6 | arr_min = jnp.min(arr) 7 | arr_max = jnp.max(arr) 8 | normalized = (arr - arr_min) / (arr_max - arr_min) 9 | return normalized 10 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | 2 | [MESSAGES CONTROL] 3 | disable=import-error, 4 | invalid-name, 5 | non-ascii-name, 6 | missing-function-docstring, 7 | missing-module-docstring, 8 | missing-class-docstring, 9 | line-too-long, 10 | dangerous-default-value, 11 | line-too-long, 12 | unnecessary-lambda -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for more information: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | # https://containers.dev/guide/dependabot 6 | 7 | version: 2 8 | updates: 9 | - package-ecosystem: "devcontainers" 10 | directory: "/" 11 | schedule: 12 | interval: weekly 13 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.defaultInterpreterPath": "/opt/conda/bin/python", 3 | "jupyter.notebookFileRoot": "${workspaceFolder}", 4 | "python.analysis.typeCheckingMode": "basic", 5 | "python.analysis.inlayHints.variableTypes": true, 6 | "python.analysis.inlayHints.functionReturnTypes": true, 7 | "python.analysis.inlayHints.pytestParameters": true, 8 | "python.analysis.inlayHints.callArgumentNames": "partial", 9 | "python.testing.pytestArgs": [ 10 | "tests" 11 | ], 12 | "python.testing.unittestEnabled": false, 13 | "python.testing.pytestEnabled": true, 14 | } -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | # This configuration file was automatically generated by Gitpod. 2 | # Please adjust to your needs (see https://www.gitpod.io/docs/config-gitpod-file) 3 | # and commit this file to your remote git repository to share the goodness with others. 4 | image: 5 | file: .devcontainer/Dockerfile 6 | 7 | github: 8 | prebuilds: 9 | # enable for the master/default branch (defaults to true) 10 | master: true 11 | # enable for pull requests coming from this repo (defaults to true) 12 | pullRequests: false 13 | # add a "Review in Gitpod" button as a comment to pull requests (defaults to true) 14 | addComment: false 15 | 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="nebula", 5 | version="1.0.0", 6 | description="auto differentiable CAD library in JAX", 7 | author="Afshawn Lotfi", 8 | author_email="", 9 | packages=[ 10 | "nebula", 11 | "nebula.cases", 12 | "nebula.prim", 13 | "nebula.evaluators", 14 | "nebula.helpers", 15 | "nebula.render", 16 | "nebula.tools", 17 | "nebula.topology", 18 | ], 19 | install_requires=[ 20 | "plotly==5.11.0", 21 | "ipywidgets==7.6", 22 | "jupyterlab", 23 | "matplotlib", 24 | "numpy", 25 | "ipython_genutils", 26 | "jupyter_cadquery", 27 | "jax[cpu]", 28 | "jax_dataclasses", 29 | "kaleido" 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/python 3 | { 4 | "name": "Python 3", 5 | // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile 6 | "image": "mcr.microsoft.com/devcontainers/python:1-3.12-bullseye" 7 | 8 | // Features to add to the dev container. More info: https://containers.dev/features. 9 | // "features": {}, 10 | 11 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 12 | // "forwardPorts": [], 13 | 14 | // Use 'postCreateCommand' to run commands after the container is created. 15 | // "postCreateCommand": "pip3 install --user -r requirements.txt", 16 | 17 | // Configure tool-specific properties. 18 | // "customizations": {}, 19 | 20 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 21 | // "remoteUser": "root" 22 | } 23 | -------------------------------------------------------------------------------- /tests/test_intersection.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from pytest import mark 3 | from nebula.helpers.intersection import Intersection 4 | 5 | 6 | @mark.parametrize( 7 | "line1, line2, expected", 8 | [ 9 | ( 10 | jnp.array([[50, 50], [163, 215]]), 11 | jnp.array([[50, 250], [170, 170]]), 12 | jnp.array([144.036, 187.309, 0]), 13 | ), 14 | ], 15 | ) 16 | def test_intersection(line1: jnp.ndarray, line2: jnp.ndarray, expected: jnp.ndarray): 17 | intersection = Intersection.line_segment_intersection(line1, line2) 18 | assert jnp.allclose(intersection.intersected_vertices, expected, atol=1e-3) 19 | 20 | 21 | @mark.parametrize( 22 | "line1, line2", 23 | [ 24 | ( 25 | jnp.array([[50, 50], [205, 172]]), 26 | jnp.array([[50, 250], [170, 170]]), 27 | ), 28 | ], 29 | ) 30 | def test_no_intersect(line1: jnp.ndarray, line2: jnp.ndarray): 31 | intersection = Intersection.line_segment_intersection(line1, line2) 32 | assert jnp.isnan(intersection.intersected_vertices).all() 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Open Orion, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nebula/topology/shells.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import jax_dataclasses as jdc 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import numpy.typing as npt 6 | from nebula.topology.topology import Topology 7 | from nebula.topology.faces import Faces 8 | 9 | 10 | @jdc.pytree_dataclass 11 | class Shells(Topology): 12 | faces: Faces 13 | index: jdc.Static[npt.NDArray[np.int32]] 14 | 15 | def __repr__(self) -> str: 16 | return f"Shells(count={self.count}, faces={self.faces})" 17 | 18 | def add(self, shells: "Shells", reorder_index: bool = False): 19 | faces = self.faces.add(shells.faces) 20 | index = self.add_indices(shells.index, reorder_index) 21 | return Shells(faces, index) 22 | 23 | @staticmethod 24 | def empty(): 25 | return Shells(faces=Faces.empty(), index=np.empty((0,), dtype=np.int32)) 26 | 27 | @staticmethod 28 | def from_faces(faces: Faces): 29 | return Shells(faces=faces, index=np.full(len(faces.index), 0, dtype=np.int32)) 30 | 31 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]): 32 | """Get shells from mask 33 | 34 | :param mask: Mask for shells 35 | :type mask: Union[jnp.ndarray, np.ndarray] 36 | :return: Shells 37 | :rtype: Shells 38 | """ 39 | return Shells( 40 | self.faces.mask(mask), 41 | Topology.reorder_index(self.index[mask]), 42 | ) 43 | -------------------------------------------------------------------------------- /nebula/topology/solids.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import jax_dataclasses as jdc 3 | 4 | import numpy as np 5 | import numpy.typing as npt 6 | from nebula.topology.topology import Topology 7 | from nebula.topology.shells import Shells 8 | import jax.numpy as jnp 9 | 10 | 11 | @jdc.pytree_dataclass 12 | class Solids(Topology): 13 | shells: Shells 14 | index: jdc.Static[npt.NDArray[np.int32]] 15 | 16 | @staticmethod 17 | def empty(): 18 | return Solids(shells=Shells.empty(), index=np.empty((0,), dtype=np.int32)) 19 | 20 | @staticmethod 21 | def from_shells(shells: Shells): 22 | return Solids(shells=shells, index=np.full(len(shells.index), 0, dtype=np.int32)) 23 | 24 | def add(self, solids: "Solids", reorder_index: bool = False): 25 | shells = self.shells.add(solids.shells) 26 | index = self.add_indices(solids.index, reorder_index) 27 | return Solids(shells, index) 28 | 29 | def __repr__(self) -> str: 30 | return f"Solids(count={self.count}, shells={self.shells})" 31 | 32 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]): 33 | """Get solids from mask 34 | 35 | :param mask: Mask for solids 36 | :type mask: Union[jnp.ndarray, np.ndarray] 37 | :return: Masked solids 38 | :rtype: Shells 39 | """ 40 | return Solids( 41 | self.shells.mask(mask), 42 | Topology.reorder_index(self.index[mask]), 43 | ) 44 | -------------------------------------------------------------------------------- /nebula/prim/bspline_surfaces.py: -------------------------------------------------------------------------------- 1 | import jax_dataclasses as jdc 2 | import jax.numpy as jnp 3 | from nebula.evaluators.bspline import BSplineEvaluator 4 | from nebula.helpers.types import Number 5 | from nebula.topology.wires import Wires 6 | 7 | @jdc.pytree_dataclass 8 | class BSplineSurfaces: 9 | all_wires: Wires 10 | 11 | def evaluate(self, u: jnp.ndarray, v: jnp.ndarray, degree: Number = 1): 12 | bspline_wires = self.wires 13 | bspline_curves = bspline_wires.edges.curves 14 | num_edges = bspline_wires.get_num_edges() 15 | 16 | vertices = jnp.empty((0, 3)) 17 | for i in range(bspline_wires.count): 18 | curve = bspline_curves.mask(bspline_wires.index == i) 19 | # TODO: this is only if all curves in group are the same length 20 | ctrl_pnts = curve.ctrl_pnts.val.reshape( 21 | curve.count,-1, 3 22 | ) 23 | v_degree = curve.degree[0] 24 | v_knots = curve.knots.val[curve.knots.index == 0] 25 | u_degree = degree 26 | u_knots = BSplineEvaluator.generate_clamped_knots(u_degree, num_edges[i]) 27 | surface_pnts = BSplineEvaluator.eval_surface( 28 | u_degree, v_degree, ctrl_pnts, u_knots, v_knots, u, v 29 | ) 30 | new_vertices = surface_pnts.reshape(-1, 3) 31 | vertices = jnp.concatenate([vertices, new_vertices]) 32 | return vertices 33 | 34 | def __repr__(self) -> str: 35 | return f"BSplineSurfaces(count={self.count})" 36 | 37 | @property 38 | def wires(self): 39 | return self.all_wires.mask(~self.all_wires.is_planar) 40 | 41 | @property 42 | def count(self) -> int: 43 | return self.wires.count 44 | 45 | 46 | # TODO: switch to this later 47 | # @property 48 | # def count(self) -> int: 49 | # index = SparseIndexable.reorder_index(self.all_wires.index[~self.all_wires.is_planar]) 50 | # return SparseIndexable.get_count(index) 51 | -------------------------------------------------------------------------------- /nebula/topology/faces.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union 3 | import jax_dataclasses as jdc 4 | import numpy as np 5 | import numpy.typing as npt 6 | import jax.numpy as jnp 7 | from nebula.prim.bspline_curves import BSplineCurves 8 | from nebula.prim.bspline_surfaces import BSplineSurfaces 9 | from nebula.prim.axes import Axes 10 | from nebula.topology.topology import Topology 11 | from nebula.topology.wires import WireLike, Wires 12 | 13 | 14 | @jdc.pytree_dataclass 15 | class Faces(WireLike): 16 | wires: Wires 17 | index: jdc.Static[npt.NDArray[np.int32]] 18 | 19 | @property 20 | def bspline_surfaces(self): 21 | return BSplineSurfaces(all_wires=self.wires) 22 | 23 | @property 24 | def edges(self): 25 | return self.wires.edges 26 | 27 | def add(self, faces: "Faces", reorder_index: bool = False): 28 | wires = self.wires.add(faces.wires) 29 | index = self.add_indices(faces.index, reorder_index) 30 | return Faces(wires, index) 31 | 32 | @staticmethod 33 | def empty(): 34 | return Faces(wires=Wires.empty(), index=np.empty((0,), dtype=np.int32)) 35 | 36 | @staticmethod 37 | def from_wires(wires: Wires, is_exterior: bool = False): 38 | """ 39 | Translates a Wires into Faces. 40 | 41 | :param wires: Wires to convert to faces 42 | :type wires: Wires 43 | :param is_exterior: transfers the index of the wires to the faces otherwise treat the wires all as one face 44 | :type is_exterior: bool, optional 45 | """ 46 | if is_exterior: 47 | return Faces(wires=wires, index=wires.index) 48 | return Faces(wires=wires, index=np.full(len(wires.index), 0, dtype=np.int32)) 49 | 50 | def translate(self, translation: jnp.ndarray): 51 | return Faces(wires=self.wires.translate(translation), index=self.index) 52 | 53 | def flip_winding(self): 54 | return Faces(wires=self.wires.flip_winding(), index=self.index[::-1]) 55 | 56 | def project(self, new_axes: Axes): 57 | return Faces(wires=self.wires.project(new_axes), index=self.index) 58 | 59 | def __repr__(self) -> str: 60 | return f"Faces(count={self.count}, wires={self.wires})" 61 | 62 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]): 63 | """Get wires from mask 64 | 65 | :param mask: Mask for faces 66 | :type mask: Union[jnp.ndarray, np.ndarray] 67 | :return: Masked faces 68 | :rtype: Faces 69 | """ 70 | return Faces( 71 | self.wires.mask(mask), 72 | Topology.reorder_index(self.index[mask]), 73 | ) 74 | 75 | -------------------------------------------------------------------------------- /nebula/helpers/intersection.py: -------------------------------------------------------------------------------- 1 | import jax_dataclasses as jdc 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | 6 | @jdc.pytree_dataclass 7 | class IntersectionResult: 8 | intersected_vertices: jnp.ndarray 9 | is_intersection: jnp.ndarray 10 | num_intersections: jnp.ndarray 11 | 12 | 13 | class Intersection: 14 | @staticmethod 15 | @jax.jit 16 | def on_line(point: jnp.ndarray, line: jnp.ndarray): 17 | """ 18 | Determines if a point is on a line. 19 | 20 | :param point: point to check (3,) 21 | :type point: jnp.ndarray 22 | :param line: start and endpoint of line (2, 3) 23 | :type line: jnp.ndarray 24 | :return: True if point is on line, False otherwise 25 | :rtype: bool 26 | """ 27 | line_x1, line_y1, line_x2, line_y2 = line[:, :2].flatten() 28 | 29 | return jnp.all( 30 | jnp.array( 31 | [ 32 | jnp.minimum(line_x1, line_x2) <= point[0], 33 | jnp.minimum(line_y1, line_y2) <= point[1], 34 | jnp.maximum(line_x1, line_x2) >= point[0], 35 | jnp.maximum(line_y1, line_y2) >= point[1], 36 | ] 37 | ) 38 | ) 39 | 40 | @staticmethod 41 | @jax.jit 42 | def line_segment_intersection( 43 | line1: jnp.ndarray, line2: jnp.ndarray, decimal_accuracy=5 44 | ): 45 | """ 46 | Finds the intersection of two 3D line segment given endpoints of each line. 47 | https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection 48 | 49 | :param line1: start and endpoint of line 1 (2, 3) 50 | :type line1: jnp.ndarray 51 | :param line2: start and endpoint of line 2 (2, 3) 52 | :type line2: jnp.ndarray 53 | :return: intersection point, nan if lines are parallel 54 | :rtype: jnp.ndarray 55 | """ 56 | x1, y1, x2, y2 = line1[:, :2].flatten() 57 | x3, y3, x4, y4 = line2[:, :2].flatten() 58 | 59 | nx = (x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4) 60 | ny = (x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4) 61 | denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) 62 | 63 | unconstrained_intersection_vertices = jnp.array([nx / denom, ny / denom, 0]) 64 | rounded_unconstrained_intersection_vertices = jnp.round( 65 | unconstrained_intersection_vertices, decimal_accuracy 66 | ) 67 | 68 | intersection_on_line1 = Intersection.on_line( 69 | rounded_unconstrained_intersection_vertices, 70 | jnp.round(line1, decimal_accuracy), 71 | ) 72 | intersection_on_line2 = Intersection.on_line( 73 | rounded_unconstrained_intersection_vertices, 74 | jnp.round(line2, decimal_accuracy), 75 | ) 76 | is_intersection = (denom != 0) & intersection_on_line1 & intersection_on_line2 77 | 78 | nan_array = jnp.full((3,), jnp.nan) 79 | 80 | intersection_vertices = jnp.where( 81 | is_intersection, unconstrained_intersection_vertices, nan_array 82 | ) 83 | return IntersectionResult( 84 | intersected_vertices=intersection_vertices, 85 | is_intersection=is_intersection, 86 | num_intersections=jnp.where(is_intersection, 1, 0), 87 | ) 88 | -------------------------------------------------------------------------------- /nebula/prim/sparse.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, Union 2 | import jax_dataclasses as jdc 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import numpy.typing as npt 6 | 7 | 8 | class SparseIndexable(Protocol): 9 | index: jdc.Static[npt.NDArray[np.int32]] 10 | 11 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]) -> "SparseIndexable": ... 12 | 13 | @property 14 | def last_index(self): 15 | return np.max(self.index).astype(np.int32) if len(self.index) else 0 16 | 17 | @property 18 | def count(self): 19 | return SparseIndexable.get_count(self.index) 20 | 21 | def add_indices(self, indices: np.ndarray, reorder_index: bool = False): 22 | if reorder_index: 23 | indices = SparseIndexable.reorder_index(indices) 24 | 25 | new_indices = indices + self.count 26 | return np.concatenate([self.index, new_indices]) 27 | 28 | @staticmethod 29 | def get_count(index: np.ndarray): 30 | return np.max(index).astype(np.int32) + 1 if len(index) else 0 31 | 32 | @staticmethod 33 | def reorder_index(index: np.ndarray): 34 | return np.unique(index, return_inverse=True)[1] 35 | 36 | @staticmethod 37 | def expanded_index(index: np.ndarray, repeats: int): 38 | count = SparseIndexable.get_count(index) 39 | index_increment = (np.arange(0, repeats) * count).repeat(len(index)) 40 | return np.repeat(index[None, :], repeats, axis=0).flatten() + index_increment 41 | 42 | 43 | @jdc.pytree_dataclass 44 | class SparseArray(SparseIndexable): 45 | val: jnp.ndarray 46 | index: jdc.Static[npt.NDArray[np.int32]] 47 | 48 | @staticmethod 49 | def empty(shape=(0,)): 50 | return SparseArray(jnp.empty(shape), np.empty((0,), dtype=np.int32)) 51 | 52 | @staticmethod 53 | def from_array(array: jnp.ndarray): 54 | return SparseArray(array, np.full(len(array), 0, dtype=np.int32)) 55 | 56 | def add(self, item: "SparseArray"): 57 | val = jnp.concatenate([self.val, item.val]) 58 | index = self.add_indices(item.index) 59 | return SparseArray(val, index) 60 | 61 | def flip_winding(self, mask: Optional[jnp.ndarray] = None): 62 | if mask is None: 63 | return SparseArray(self.val[::-1], self.index[::-1]) 64 | 65 | new_index = self.index.copy() 66 | new_index[mask] = np.flip(new_index[mask], axis=0) 67 | 68 | return SparseArray(self.val, new_index) 69 | 70 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]) -> "SparseArray": 71 | return SparseArray( 72 | val=self.val[mask], 73 | index=SparseIndexable.reorder_index(self.index[mask]), 74 | ) 75 | 76 | # TODO: make this faster 77 | def reorder(self, index: np.ndarray): 78 | """ 79 | Reorder the array according to the given index. 80 | 81 | :param index: The new order of the array. 82 | :type index: np.ndarray 83 | :return: The reordered array. 84 | :rtype: SparseArray 85 | """ 86 | 87 | new_val = jnp.empty((0, *self.val.shape[1:])) 88 | new_index = np.empty((0,), dtype=np.int32) 89 | for i in index: 90 | index_mask = self.index == i 91 | new_val = jnp.concatenate([new_val, self.val[index_mask]]) 92 | new_index = np.concatenate([new_index, self.index[index_mask]]) 93 | 94 | return SparseArray(new_val, new_index) 95 | 96 | def __repr__(self) -> str: 97 | return f"SparseArray(count={self.count})" 98 | -------------------------------------------------------------------------------- /nebula/prim/lines.py: -------------------------------------------------------------------------------- 1 | import jax_dataclasses as jdc 2 | import numpy as np 3 | import jax.numpy as jnp 4 | from typing import Optional, Union 5 | from nebula.prim.axes import Axes 6 | import numpy.typing as npt 7 | 8 | @jdc.pytree_dataclass 9 | class Lines: 10 | vertices: jnp.ndarray 11 | index: jdc.Static[npt.NDArray[np.int32]] 12 | 13 | def add( 14 | self, 15 | lines: "Lines", 16 | ): 17 | vertices = lines.vertices 18 | # Allocate new indices for the new edges ahead of the current indices 19 | new_line_indices = lines.index + len(self.vertices) 20 | 21 | vertices = jnp.concatenate([self.vertices, vertices], axis=0) 22 | index = np.concatenate([self.index, new_line_indices], axis=0) 23 | return Lines(vertices, index) 24 | 25 | @staticmethod 26 | def empty(): 27 | return Lines(vertices=jnp.empty((0, 3)), index=np.empty((0, 2), dtype=np.int32)) 28 | 29 | @staticmethod 30 | def from_segments(segments: jnp.ndarray): 31 | vertices = segments.reshape(-1, 3) 32 | index = np.arange(vertices.shape[0], dtype=np.int32).reshape(-1, 2) 33 | return Lines(vertices=vertices, index=index) 34 | 35 | def get_segments(self): 36 | return self.vertices[self.index] 37 | 38 | def clone( 39 | self, 40 | vertices: Optional[jnp.ndarray] = None, 41 | index: Optional[np.ndarray] = None, 42 | ): 43 | return Lines( 44 | vertices=vertices if vertices is not None else self.vertices, 45 | index=index if index is not None else self.index, 46 | ) 47 | 48 | def evaluate_at(self, u: jnp.ndarray, index: jnp.ndarray): 49 | vertices = self.vertices[self.index[index]] 50 | return (1 - u) * vertices[0] + u * vertices[1] 51 | 52 | def project(self, curr_axes: Axes, new_axes: Axes): 53 | local_coords = curr_axes.to_local_coords(self.vertices) 54 | 55 | new_vertices = new_axes.to_world_coords(local_coords).reshape(-1, 3) 56 | 57 | max_index = np.max(self.index) if len(self.index) else 0 58 | index_increment = (np.arange(0, new_axes.count) * max_index).repeat(len(self.index)) 59 | new_index = np.repeat(self.index[None, :], new_axes.count, axis=0).reshape( 60 | -1, 2 61 | ) + np.expand_dims(index_increment, axis=1) 62 | 63 | return Lines(vertices=new_vertices, index=new_index) 64 | 65 | 66 | def translate(self, translation: jnp.ndarray): 67 | if len(translation.shape) == 1: 68 | new_vertices = self.vertices + translation 69 | else: 70 | new_vertices = self.vertices.at[self.index].add( 71 | jnp.expand_dims(translation, axis=1) 72 | ) 73 | 74 | return self.clone(vertices=new_vertices) 75 | 76 | 77 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]): 78 | return Lines.from_segments(self.vertices[self.index[mask]]) 79 | 80 | def flip_winding(self, mask: Optional[jnp.ndarray] = None): 81 | if mask is None: 82 | return Lines(vertices=self.vertices, index=np.flip(self.index, (0, 1))) 83 | new_index = self.index.copy() 84 | new_index[mask] = np.flip(new_index[mask], (0, 1)) 85 | return Lines(vertices=self.vertices, index=new_index) 86 | 87 | def reorder(self, index: np.ndarray): 88 | new_index = self.index[index] 89 | return self.clone(index=new_index) 90 | 91 | def __repr__(self) -> str: 92 | return f"Lines(count={len(self.index)})" 93 | 94 | @property 95 | def count(self) -> int: 96 | return len(self.index) 97 | -------------------------------------------------------------------------------- /nebula/helpers/wire.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from nebula.helpers.types import CoordLike 5 | from nebula.prim.axes import Axes 6 | 7 | 8 | @jax.jit 9 | def is_inside(polygon_segment: jnp.ndarray, point: jnp.ndarray): 10 | x1, y1, x2, y2 = polygon_segment[:, :2].flatten() 11 | xp, yp = point[:2] 12 | 13 | return ((yp < y1) != (yp < y2)) & (xp < (x1 + ((yp - y1) / (y2 - y1)) * (x2 - x1))) 14 | 15 | 16 | class WireHelper: 17 | @staticmethod 18 | def contains_vertex( 19 | polygon_segments: jnp.ndarray, 20 | vertex: jnp.ndarray, 21 | ): 22 | """ 23 | Checks if polygon contains vertex. 24 | 25 | :param polygon_segments: polygon segments (num_segments, 2, 3) 26 | :type polygon_segments: jnp.ndarray 27 | :param vertex: vertex to check (3,) 28 | :type vertex: jnp.ndarray 29 | :return: True if polygon contains vertex 30 | :rtype: bool 31 | 32 | """ 33 | ray_intersection = jax.vmap(is_inside, in_axes=(0, None))( 34 | polygon_segments, vertex 35 | ).astype(jnp.int32) 36 | 37 | # if odd number of intersections, then segment is inside polygon 38 | total_intersections = jnp.sum(ray_intersection) 39 | return jnp.mod(total_intersections, 2) == 1 40 | 41 | @staticmethod 42 | def contains_segment( 43 | polygon_segments: jnp.ndarray, 44 | segment: jnp.ndarray, 45 | ): 46 | """ 47 | Checks if segment is inside polygon. 48 | 49 | :param polygon_segments: polygon segments (num_segments, 2, 3) 50 | :type polygon_segments: jnp.ndarray 51 | :param segment: segment to check (2, 3) 52 | :type segment: jnp.ndarray 53 | :return: True if segment is inside polygon 54 | :rtype: bool 55 | 56 | """ 57 | mean_vertex = jnp.mean(segment, axis=0) 58 | return WireHelper.contains_vertex(polygon_segments, mean_vertex) 59 | 60 | @staticmethod 61 | def fully_contains_segment( 62 | polygon_segments: jnp.ndarray, 63 | segment: jnp.ndarray, 64 | ): 65 | """ 66 | Checks if segment is inside polygon. 67 | 68 | :param polygon_segments: polygon segments (num_segments, 2, 3) 69 | :type polygon_segments: jnp.ndarray 70 | :param segment: segment to check (2, 3) 71 | :type segment: jnp.ndarray 72 | :return: True if segment is inside polygon 73 | :rtype: bool 74 | 75 | """ 76 | return jax.vmap(WireHelper.contains_vertex, in_axes=(None, 0))( 77 | polygon_segments, segment 78 | ).all() 79 | 80 | @staticmethod 81 | def to_3d_vertices(vertices: CoordLike): 82 | vertices = jnp.asarray(vertices) 83 | if vertices.shape[-1] == 2: 84 | return jnp.concatenate([vertices, jnp.zeros((len(vertices), 1))], axis=-1) 85 | return vertices 86 | 87 | @staticmethod 88 | @jax.jit 89 | def is_clockwise(first_segment: jnp.ndarray, last_segment: jnp.ndarray): 90 | """ 91 | Finds if the segments are in clockwise order. IMPORTANT: This assumes that segments are on the XY plane 92 | 93 | """ 94 | normal = WireHelper.get_normal(first_segment, last_segment) 95 | return normal.sum() < 0 96 | 97 | @staticmethod 98 | def get_normal(first_segment: jnp.ndarray, last_segment: jnp.ndarray): 99 | start_vecs = first_segment[1] - first_segment[0] 100 | end_vecs = last_segment[1] - last_segment[0] 101 | normals = jnp.cross(end_vecs, start_vecs) 102 | return normals / jnp.linalg.norm(normals, axis=-1, keepdims=True) 103 | 104 | 105 | -------------------------------------------------------------------------------- /TODO: -------------------------------------------------------------------------------- 1 | Baseline MVP 2 | 3 | - Top Priority 4 | [X] Parafoil demo 5 | 6 | [ ] Tesselation 7 | - [ ] Add normals for bspline surface triangles (normals issue) 8 | - [X] L issue with triangulaion 9 | - [X] Correct Orientation 10 | 11 | - [ ] Sweep 12 | - [ ] Same Wire along path 13 | - [ ] Different Wire along path 14 | 15 | - [ ] Revolve 16 | - [ ] Add circle representation of spline 17 | - [ ] Sweep curve/line around that circle 18 | 19 | [ ] Tech Debt 20 | - Peformance 21 | [ ] Optimizations 22 | - [X] Use different dataclasses that work better with jit 23 | - [ ] New masking/padding approach 24 | - [ ] Better profiling for exactly what takes the most time 25 | - [ ] Project 26 | - [ ] Extrude 27 | - [ ] Clipper 28 | - [ ] Sort 29 | - [ ] Tesselation 30 | - [X] Reorder by index 31 | - [X] Improve/change unique operations 32 | - [X] Don't need to sort on all clip operations? 33 | - [ ] Jax types 34 | - [ ] Unit tests 35 | - [ ] Auto evaluation point count 36 | 37 | 38 | # Next Release 39 | [ ] Querying (next release) 40 | - [ ] Query centers by X,Y,Z 41 | - [ ] Parallel 42 | - [ ] Most/Least (+/-) 43 | - [X] Query index 44 | - [ ] Display selected entities 45 | 46 | 47 | [ ] OCC Compatibility (next release) 48 | - [ ] BSpline Surfaces 49 | - [ ] Polygons 50 | - [ ] CQ detach 51 | 52 | [X] Add tesselation to work with Jupyter Cadquery 53 | - [X] BSpline surfaces 54 | - [X] Add normals for bspline surface triangles (normals issue) 55 | - [X] Planar Surfaces 56 | - [X] Fix ordering for clipping 57 | - [X] Fix normals 58 | - [X] BSpline Curves 59 | - [X] Fix non showing planars in bspline case 60 | - [X] Trimmed Planar Surfaces 61 | - [X] Solve bug with specific case 62 | - [ ] Trimmed BSpline Surfaces (next release) 63 | 64 | 65 | [ ] Polar Array (next release) 66 | 67 | [ ] Add Edge Type 68 | - [X] Add Lines 69 | - [X] Add Bsplines 70 | - [ ] Efficient Bsplines (next release) 71 | 72 | [ ] Projection 73 | - [ ] Plane (next release) 74 | - [ ] Bspline Surface (next release) 75 | 76 | [ ] 2D Clip 77 | - [X] Line Line Intersection 78 | - [X] 2D Clip 79 | [X] 1x1 Intersections 80 | [ ] Curve intersections (next release) 81 | [ ] nx2 intersections (next release) 82 | [ ] Split into two wires (next release) 83 | 84 | [X] Auto Diff working for bspline example 85 | - [X] BSpline surface -> jacobian 86 | - [X] Plane -> BSpline surface -> jacobian 87 | - [ ] Polygon (next release) 88 | 89 | [ ] Add trims/holes 90 | - [ ] BSpline Surfaces (next release) 91 | - [X] Planar Surfaces 92 | 93 | [ ] 3D Contains (next release) 94 | (tesselate and ray test) 95 | - [ ] Planes 96 | - [ ] BSpline 97 | 98 | [] 3D Clip (next release) 99 | - [] Line Plane Intersection 100 | - [] 3D Clip 101 | 102 | 103 | [ ] Query topology and get all face points 104 | - [X] Inefficient method 105 | - [X] Eventually scatter 106 | 107 | [ ] Try to plot 108 | - [X] Fix normals 109 | - [X] Make runtime faster - might downgrade later 110 | 111 | Workplane 112 | - [X] Project to Axes 113 | - [X] Basic Workplane 114 | 115 | - [X] Extrude 116 | - [X] Plane 117 | - [X] BSpline 118 | 119 | 120 | # Tommorows next steps 121 | [X] Fix extrude indexes for normal cases 122 | - [X] Add quick extrude unit tests 123 | [X] Finally get tesselation to work properly 124 | [X] More unit test for existing edge cases 125 | [ ] Research into auto diff for polygons and in Jax in general 126 | 127 | [X] Start Bsplines 128 | 129 | [ ] Later: Account for multi line splits (not important right now) 130 | [ ] Later: Fix jax jit for interior intersection cases -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Nebula

2 | 3 |

open source differentiable geometry engine

4 | 5 | 6 | # About 7 | Nebula was an attempt to create simple geometries that are auto-differentiable with respect to their design variables for design optimization and machine learning tasks. 8 | 9 | In the past the jacobians of geometry surface meshes with respect to the design variables were computed via finite differences, a task of altering each design variable by small epsilons and taking the position difference of their corresponding meshes. This process was very slow depending on the geometry and also restricted to just structured meshes because triangulation is very non-deterministic to small changes. 10 | 11 | The structure of Nebula is single-to-many for performance reasons (i.e Solids with verticies and indexes rather than a list of class Solid). This allows us to construct simulataneuosly mutliple items at the same time. 12 | 13 | The Workplane class of Nebula is modeled after CadQuery syntax for familiarity and although it has not been tested should migrate for some cases. 14 | 15 | Nebula still has some performance issues mainly because the Jax JIT is very slow for operations that aren't uniform sized. Unfortunatly, operations like clipping and the amount of topology items are not uniform in nature so the only way to run fast would be to pad to a uniform size. This was not implemented mainly because time-deficits and the fact that the goal of this project mainly was to get CAD sensitivities. 16 | 17 | Nebula has a lot of deficits in features and most likely has a lot of bugs so this project is for experimental use only. We will add more features as we need them as we currently don't have the resources to invest all our time to features that don't help our current roadmap. 18 | 19 | 20 | # Install 21 | ``` 22 | pip install git+https://github.com/OpenOrion/nebula.git#egg=nebula 23 | ``` 24 | 25 | 26 | 27 | ## Example 28 | See more examples in [examples](/examples) directory 29 | 30 | ```python 31 | import jax 32 | import jax.numpy as jnp 33 | from nebula import Workplane 34 | from nebula.render.visualization import show 35 | from nebula.render.visualization import Tesselator 36 | 37 | def make_mesh(height: jnp.ndarray, arc_height: jnp.ndarray): 38 | profile = ( 39 | Workplane.init("XY") 40 | .lineTo(0, height) 41 | .bspline([ 42 | (0.5, height+arc_height), 43 | (1.5, height+arc_height), 44 | (2, height) 45 | ], includeCurrent=True) 46 | .lineTo(2, 0) 47 | .close() 48 | .extrude(2) 49 | ) 50 | return Tesselator.get_differentiable_mesh(profile) 51 | 52 | # design varibles 53 | height = jnp.array(2.0) 54 | arc_height = jnp.array(0.5) 55 | 56 | # the jacobian with respect to height and arc height i.e (0,1) 57 | mesh_jacobian = jax.jacobian( 58 | lambda height, arc_height: make_mesh(height, arc_height).vertices, 59 | (0,1) 60 | )(height, arc_height) 61 | arc_height_jac_magnitude = jnp.linalg.norm(mesh_jacobian[1], axis=-1) 62 | 63 | # construct regular surface mesh 64 | mesh = make_mesh(height, arc_height) 65 | 66 | # show the mesh with magnitude of jacobian as color 67 | show(mesh, "plot", arc_height_jac_magnitude, name="CAD Sensitivity w.r.t arc height
Nebula by Open Orion") 68 | 69 | ``` 70 | 71 | ![./assets/example1_arc_height.png](./assets/example1_arc_height.png) 72 | 73 | 74 | ## Current features 75 | 76 | * Basic Sketch 77 | - rect 78 | - lineTo 79 | - bsplineTo 80 | - polarLine 81 | - polyline 82 | * Extrude 83 | * Polyline 84 | * Bspline Operations 85 | - Skin 86 | - Evaluate 87 | 88 | 89 | ## Deficits 90 | All the project roadmap features are in the TODO file. Main thing to keep in mind are: 91 | * Clip operations 92 | - Can only have 1-2 intersections, this was to maintain uniform size for faster operations, if we see it fit will add a workaround later 93 | 94 | 95 | # Developement Setup 96 | ``` 97 | git clone https://github.com/OpenOrion/nebula.git 98 | cd nebula 99 | pip install -r requirements_dev.txt 100 | ``` 101 | 102 | # Help Wanted 103 | Please join the [Discord](https://discord.gg/H7qRauGkQ6) for project communications and collaboration. 104 | 105 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.cgns 2 | *.su2 3 | *.geo_unrolled 4 | *.msh 5 | su2/ 6 | cfd/ 7 | generated/ 8 | simulations/ 9 | paraview 10 | docs 11 | bin/CQ-editor* 12 | bin/CQ-editor/ 13 | *.curaprofile 14 | ./cad 15 | *.step 16 | *.geo 17 | *.stl 18 | *.out 19 | *.mod 20 | *.pkl 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | cover/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | db.sqlite3 83 | db.sqlite3-journal 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | 95 | # PyBuilder 96 | .pybuilder/ 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | # For a library or package, you might want to ignore these files since the code is 108 | # intended to run in multiple environments; otherwise, check them in: 109 | # .python-version 110 | 111 | # pipenv 112 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 113 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 114 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 115 | # install all needed dependencies. 116 | #Pipfile.lock 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # General 162 | .DS_Store 163 | .AppleDouble 164 | .LSOverride 165 | 166 | # Icon must end with two \r 167 | Icon 168 | 169 | 170 | # Thumbnails 171 | ._* 172 | 173 | # Files that might appear in the root of a volume 174 | .DocumentRevisions-V100 175 | .fseventsd 176 | .Spotlight-V100 177 | .TemporaryItems 178 | .Trashes 179 | .VolumeIcon.icns 180 | .com.apple.timemachine.donotpresent 181 | 182 | # Directories potentially created on remote AFP share 183 | .AppleDB 184 | .AppleDesktop 185 | Network Trash Folder 186 | Temporary Items 187 | .apdisk 188 | 189 | *~ 190 | 191 | # temporary files which can be created if a process still has a handle open of a deleted file 192 | .fuse_hidden* 193 | 194 | # KDE directory preferences 195 | .directory 196 | 197 | # Linux trash folder which might appear on any partition or disk 198 | .Trash-* 199 | 200 | # .nfs files are created when an open file is removed but is still being accessed 201 | .nfs* 202 | 203 | # Windows thumbnail cache files 204 | Thumbs.db 205 | Thumbs.db:encryptable 206 | ehthumbs.db 207 | ehthumbs_vista.db 208 | 209 | # Dump file 210 | *.stackdump 211 | 212 | # Folder config file 213 | [Dd]esktop.ini 214 | 215 | # Recycle Bin used on file shares 216 | $RECYCLE.BIN/ 217 | 218 | # Windows Installer files 219 | *.cab 220 | *.msi 221 | *.msix 222 | *.msm 223 | *.msp 224 | 225 | # Windows shortcuts 226 | *.lnk 227 | -------------------------------------------------------------------------------- /nebula/tools/edge.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from typing import Optional 3 | import jax.numpy as jnp 4 | from nebula.evaluators.bspline import BSplineEvaluator 5 | from nebula.prim.bspline_curves import BSplineCurves 6 | from nebula.helpers.clipper import Clipper 7 | from nebula.helpers.types import ArrayLike, Number 8 | from nebula.helpers.wire import WireHelper 9 | from nebula.prim.sparse import SparseArray 10 | from nebula.topology.edges import Edges 11 | from nebula.topology.wires import Wires 12 | 13 | 14 | class EdgeTool: 15 | @staticmethod 16 | @jax.jit 17 | def make_rect( 18 | x_len: Number, y_len: Number, origins: Optional[jnp.ndarray] = None, centered: bool = True 19 | ): 20 | """Creates a rectangle. 21 | 22 | :param x_len: length of the rectangle along x-axis 23 | :type x_len: float 24 | :param y_len: length of the rectangle along y-axis 25 | :type y_len: float 26 | :param axes: axes to create the rectangle on, defaults to XY 27 | :type axes: Axes, optional 28 | :param center: flag to create the rectangle at the center 29 | :type center: bool 30 | :return: rectangle array of size (num_axis, 4, 2, 3) 31 | :rtype: jnp.ndarray 32 | """ 33 | if origins is None: 34 | origins = jnp.zeros((1,3)) 35 | 36 | 37 | vertices = jnp.array( 38 | [ 39 | (0.0, 0.0, 0.0), 40 | (x_len, 0.0, 0.0), 41 | (x_len, y_len, 0.0), 42 | (0.0, y_len, 0.0), 43 | (0.0, 0.0, 0.0), 44 | ] 45 | ) 46 | 47 | # center the rectangle if requested 48 | vertices = jnp.where( 49 | centered, vertices - jnp.array([x_len / 2, y_len / 2, 0.0]), vertices 50 | ) 51 | return EdgeTool.make_polyline(vertices + origins) 52 | 53 | @staticmethod 54 | @jax.jit 55 | def make_polyline(vertices: jnp.ndarray): 56 | segments = jnp.stack([vertices[:-1], vertices[1:]], axis=1) 57 | return Edges.from_line_segments(segments) 58 | 59 | @staticmethod 60 | @jax.jit 61 | def make_line(start_vertex: jnp.ndarray, end_vertex: jnp.ndarray): 62 | line_segments = jnp.stack([start_vertex, end_vertex]) 63 | return Edges.from_line_segments(jnp.expand_dims(line_segments, axis=0)) 64 | 65 | @staticmethod 66 | @jax.jit 67 | def make_polar_line(start_vertex: jnp.ndarray, distance: Number, angle: Number): 68 | end_vertex = jnp.array( 69 | [ 70 | jnp.cos(jnp.radians(angle)) * distance, 71 | jnp.sin(jnp.radians(angle)) * distance, 72 | ] 73 | ) 74 | 75 | return EdgeTool.make_line(start_vertex, end_vertex) 76 | 77 | @staticmethod 78 | def make_bspline_curve( 79 | ctrl_pnts: jnp.ndarray, 80 | degree: jnp.ndarray, 81 | knots: Optional[ArrayLike] = None, 82 | ): 83 | if knots is None: 84 | knots = BSplineEvaluator.generate_clamped_knots(degree, len(ctrl_pnts)) 85 | knots = jnp.asarray(knots) 86 | 87 | return BSplineCurves( 88 | degree=jnp.array([degree], dtype=jnp.int32), 89 | ctrl_pnts=SparseArray.from_array(ctrl_pnts), 90 | knots=SparseArray.from_array(knots), 91 | ) 92 | 93 | @staticmethod 94 | def consolidate_wires(wires: Wires, pending_edges: Edges, validate: bool): 95 | """ 96 | Consolidates the pending edges with the existing wires. 97 | 98 | :param wires: The existing wires 99 | :type wires: Wires 100 | :param pending_edges: The pending edges 101 | :type pending_edges: Edges 102 | :param validate: Flag to validate the pending edges for sort order before consolidation. No validation is more performant. 103 | :type validate: bool 104 | 105 | """ 106 | if validate: 107 | pending_edge_result = pending_edges.get_sorted() 108 | pending_edges = pending_edge_result.edges 109 | is_clockwise = WireHelper.is_clockwise(pending_edge_result.segments[0], pending_edge_result.segments[-1]) 110 | if is_clockwise: 111 | pending_edges = pending_edges.flip_winding() 112 | 113 | clipper_wire = Wires.from_edges(pending_edges) 114 | if wires.count > 0: 115 | return Clipper.cut_polygons(wires, clipper_wire) 116 | else: 117 | return clipper_wire 118 | -------------------------------------------------------------------------------- /tests/test_bspline.py: -------------------------------------------------------------------------------- 1 | from nebula.evaluators.bspline import BSplineEvaluator 2 | import jax.numpy as jnp 3 | 4 | 5 | def test_bspline_curve(): 6 | u = jnp.linspace(0.0, 1.0, 10) 7 | actual = BSplineEvaluator.eval_curve( 8 | degree=jnp.array(2), 9 | ctrl_pnts=jnp.array([[1, 0, 0], [1, 1, 0], [0, 1, 0]]), 10 | u=u, 11 | knots=jnp.array([0, 0, 0, 1, 1, 1]), 12 | ) 13 | expected = jnp.array( 14 | [ 15 | [1.0, 0.0, 0.0], 16 | [0.9876543209876543, 0.20987654320987653, 0.0], 17 | [0.9506172839506173, 0.3950617283950617, 0.0], 18 | [0.8888888888888891, 0.5555555555555556, 0.0], 19 | [0.8024691358024691, 0.691358024691358, 0.0], 20 | [0.691358024691358, 0.8024691358024691, 0.0], 21 | [0.5555555555555556, 0.8888888888888888, 0.0], 22 | [0.3950617283950617, 0.9506172839506173, 0.0], 23 | [0.20987654320987661, 0.9876543209876544, 0.0], 24 | [0.0, 1.0, 0.0], 25 | ] 26 | ) 27 | 28 | assert jnp.allclose(actual, expected, atol=1e-3) 29 | 30 | 31 | def test_bspline_surface(): 32 | u = jnp.linspace(0.0, 1.0, 5) 33 | actual = BSplineEvaluator.eval_surface( 34 | u_degree=jnp.array(3), 35 | v_degree=jnp.array(3), 36 | ctrl_pnts=jnp.array( 37 | [ 38 | [-25.0, -25.0, -10.0], 39 | [-25.0, -15.0, -5.0], 40 | [-25.0, -5.0, 0.0], 41 | [-25.0, 5.0, 0.0], 42 | [-25.0, 15.0, -5.0], 43 | [-25.0, 25.0, -10.0], 44 | [-15.0, -25.0, -8.0], 45 | [-15.0, -15.0, -4.0], 46 | [-15.0, -5.0, -4.0], 47 | [-15.0, 5.0, -4.0], 48 | [-15.0, 15.0, -4.0], 49 | [-15.0, 25.0, -8.0], 50 | [-5.0, -25.0, -5.0], 51 | [-5.0, -15.0, -3.0], 52 | [-5.0, -5.0, -8.0], 53 | [-5.0, 5.0, -8.0], 54 | [-5.0, 15.0, -3.0], 55 | [-5.0, 25.0, -5.0], 56 | [5.0, -25.0, -3.0], 57 | [5.0, -15.0, -2.0], 58 | [5.0, -5.0, -8.0], 59 | [5.0, 5.0, -8.0], 60 | [5.0, 15.0, -2.0], 61 | [5.0, 25.0, -3.0], 62 | [15.0, -25.0, -8.0], 63 | [15.0, -15.0, -4.0], 64 | [15.0, -5.0, -4.0], 65 | [15.0, 5.0, -4.0], 66 | [15.0, 15.0, -4.0], 67 | [15.0, 25.0, -8.0], 68 | [25.0, -25.0, -10.0], 69 | [25.0, -15.0, -5.0], 70 | [25.0, -5.0, 2.0], 71 | [25.0, 5.0, 2.0], 72 | [25.0, 15.0, -5.0], 73 | [25.0, 25.0, -10.0], 74 | ] 75 | ).reshape(6, 6, 3), 76 | u_knots=jnp.array([0.0, 0.0, 0.0, 0.0, 0.33, 0.66, 1.0, 1.0, 1.0, 1.0]), 77 | v_knots=jnp.array([0.0, 0.0, 0.0, 0.0, 0.33, 0.66, 1.0, 1.0, 1.0, 1.0]), 78 | u=u, 79 | v=u, 80 | ) 81 | expected = jnp.array( 82 | [ 83 | [ 84 | [-25.0, -25.0, -10.0], 85 | [-25.000004, -9.077171, -2.397286], 86 | [-25.000002, 0.11256424, -0.3082978], 87 | [-25.000002, 9.309787, -2.4978478], 88 | [-25.0, 25.0, -10.0], 89 | ], 90 | [ 91 | [-9.077171, -25.000004, -6.2806444], 92 | [-9.077172, -9.077172, -4.8789206], 93 | [-9.077172, 0.11256425, -5.9172673], 94 | [-9.077171, 9.309788, -4.8489876], 95 | [-9.077171, 25.000004, -6.2806444], 96 | ], 97 | [ 98 | [0.11256424, -25.000002, -4.2381387], 99 | [0.1125643, -9.077172, -5.3740025], 100 | [0.1125643, 0.11256424, -7.434884], 101 | [0.11256424, 9.309788, -5.299428], 102 | [0.11256424, 25.000002, -4.2381387], 103 | ], 104 | [ 105 | [9.309787, -25.000002, -5.5793867], 106 | [9.309788, -9.077171, -4.6443815], 107 | [9.309788, 0.11256424, -5.7848625], 108 | [9.309787, 9.309787, -4.609164], 109 | [9.309787, 25.000002, -5.5793867], 110 | ], 111 | [ 112 | [25.0, -25.0, -10.0], 113 | [25.000004, -9.077171, -1.3277057], 114 | [25.000002, 0.11256424, 1.5683833], 115 | [25.000002, 9.309787, -1.4598912], 116 | [25.0, 25.0, -10.0], 117 | ], 118 | ], 119 | dtype=jnp.float32, 120 | ) 121 | 122 | assert jnp.allclose(actual, expected, atol=1e-3) 123 | -------------------------------------------------------------------------------- /nebula/tools/solid.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from nebula.helpers.types import Number 5 | from nebula.prim.axes import Axes 6 | from nebula.topology.faces import Faces 7 | from nebula.topology.shells import Shells 8 | from nebula.topology.solids import Solids 9 | from nebula.topology.wires import Wires 10 | 11 | 12 | class SolidTool: 13 | @staticmethod 14 | # @jax.jit 15 | def extrude_line_segments(segments: jnp.ndarray, translation: Number): 16 | """ 17 | Returns the segments for the extruded edge points 18 | 19 | :param segment: edge segments to extrude 20 | :type segment: jnp.ndarray 21 | :param translation: translation to extrude by 22 | :type translation: Number 23 | :return: extruded segments (num_edges, 2, 3) 24 | 25 | 26 | """ 27 | bottom_edge = segments 28 | top_edge = bottom_edge[::-1] + translation 29 | right_edge = jnp.stack([bottom_edge[-1], top_edge[0]]) 30 | left_edge = jnp.stack([top_edge[-1], bottom_edge[0]]) 31 | 32 | return jnp.stack([bottom_edge, right_edge, top_edge, left_edge]) 33 | 34 | @staticmethod 35 | @jax.jit 36 | def extruded_planar_faces( 37 | faces: Faces, translation: Number, is_interior: jnp.ndarray 38 | ): 39 | # Handle planar extrusion faces 40 | line_segments = faces.wires.edges.lines.get_segments() 41 | plane_segments = jnp.where( 42 | is_interior[:, None, None], 43 | jnp.flip(line_segments, (0, 1)), 44 | line_segments, 45 | ) 46 | 47 | extruded_plane_segments = jax.vmap( 48 | SolidTool.extrude_line_segments, in_axes=(0, None) 49 | )( 50 | plane_segments, translation 51 | ) # (num_edges, 4, 2, 3) 52 | 53 | return Faces.from_wires( 54 | Wires.from_line_segments(extruded_plane_segments), 55 | is_exterior=True, 56 | ) 57 | 58 | # TODO: this is a work in progress. 59 | # @staticmethod 60 | # def sweep(wires: Wires, path: Wires): 61 | # assert path.is_single, "Only one path is supported for sweep" 62 | # faces = Faces.from_wires(wires) 63 | # start_face = faces.flip_winding() 64 | 65 | 66 | # sample = jnp.linspace(0.1, 1.0, 10) 67 | 68 | # new_axes = Axes( 69 | # normals=path.dir_vec_at(sample), 70 | # origins=path.evaluate_at(sample), 71 | # ) 72 | 73 | # end_face = faces.project(new_axes) 74 | # end_face.wires.plot() 75 | 76 | 77 | # @staticmethod 78 | # def revolve(wires: Wires, angle: Number, axis: Axes): 79 | # return SolidTool.sweep(wires, Wires.make_circle(1.0, angle, axis)) 80 | 81 | @staticmethod 82 | def extrude(wires: Wires, distance: Number): 83 | """Extrudes a wire by a distance on axis. 84 | 85 | E-------F 86 | /| /| 87 | / | / | 88 | A--|----B | 89 | | H----|--G 90 | | / ^ / 91 | |/ |/ 92 | C--->---D 93 | 94 | :param wire: Wires to extrude 95 | :type wire: Wire 96 | :param distance: distance to extrude 97 | :type distance: float 98 | :return: new extruded Solids 99 | :rtype: Solids 100 | """ 101 | 102 | faces = Faces.from_wires(wires) 103 | 104 | faces_normals = faces.get_axis_normals() 105 | translation = faces_normals * distance 106 | 107 | start_face = faces.flip_winding() 108 | end_face = faces.translate(translation) 109 | 110 | # Handle planar extrusion faces 111 | line_segments = faces.wires.edges.lines.get_segments() 112 | planar_is_interior = faces.wires.is_interior[~faces.wires.edges.is_curve] 113 | plane_segments = line_segments.at[planar_is_interior].set( 114 | jnp.flip(line_segments[planar_is_interior], (0, 1)) 115 | ) 116 | extruded_plane_segments = jax.jit( 117 | jax.vmap(SolidTool.extrude_line_segments, in_axes=(0, None)) 118 | )(plane_segments, translation) 119 | 120 | exterior_extruded_plane_faces = Faces.from_wires( 121 | Wires.from_line_segments(extruded_plane_segments[~planar_is_interior]), 122 | is_exterior=True, 123 | ) 124 | interior_extruded_plane_faces = Faces.from_wires( 125 | Wires.from_line_segments(extruded_plane_segments[planar_is_interior]), 126 | is_exterior=True, 127 | ) 128 | 129 | # Handle non-planar extrusion faces 130 | start_curves = faces.wires.edges.curves 131 | end_curves = end_face.wires.edges.curves 132 | bspline_surface_faces = Faces.from_wires(Wires.skin((start_curves, end_curves))) 133 | 134 | exterior_faces = Faces.combine( 135 | (start_face, end_face, exterior_extruded_plane_faces, bspline_surface_faces) 136 | ) 137 | start_face_exterior_index = start_face.index[~wires.is_interior][0] 138 | exterior_shell_index = ( 139 | np.ones(len(exterior_faces.index), dtype=np.int32) * start_face_exterior_index 140 | ) 141 | interior_shell_index = np.repeat(wires.index[wires.is_interior], 4) 142 | shell_index = np.concatenate([exterior_shell_index, interior_shell_index]) 143 | 144 | new_faces = Faces.combine((exterior_faces, interior_extruded_plane_faces)) 145 | shells = Shells(faces=new_faces, index=shell_index) 146 | 147 | return Solids.from_shells(shells) 148 | -------------------------------------------------------------------------------- /nebula/prim/axes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | from functools import cached_property 3 | from typing import Literal 4 | import jax_dataclasses as jdc 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | 9 | @jax.jit 10 | def get_rotation_matrix(a: jnp.ndarray, b: jnp.ndarray): 11 | """Returns a rotation matrix that rotates vector a to vector b. 12 | 13 | :param a: vector a 14 | :type a: jnp.ndarray 15 | :param b: vector b 16 | :type b: jnp.ndarray 17 | :return: rotation matrix 18 | :rtype: jnp.ndarray 19 | """ 20 | # Tolerance for floating point errors 21 | eps = 1.0e-10 22 | 23 | # Normalize the vectors 24 | a = a / jnp.linalg.norm(a) 25 | b = b / jnp.linalg.norm(b) 26 | 27 | # dimension of the space and identity 28 | I = jnp.identity(a.size) 29 | 30 | # Get the cross product of the two vectors 31 | v = jnp.cross(a, b) 32 | 33 | # Get the dot product of the two vectors 34 | c = jnp.dot(a, b=b) 35 | 36 | # the cross product matrix of a vector to rotate around 37 | K = jnp.outer(b, a) - jnp.outer(a, b) 38 | 39 | return jnp.where( 40 | # same direction 41 | jnp.abs(c - 1.0) < eps, 42 | I, 43 | jnp.where( 44 | # opposite direction 45 | jnp.abs(c + 1.0) < eps, 46 | -I, 47 | # Rodrigues' formula 48 | I + K + (K @ K) / (1 + c), 49 | ), 50 | ) 51 | 52 | 53 | AxesString = Literal["XY", "YX", "XZ", "ZX", "YZ", "ZY"] 54 | 55 | 56 | 57 | 58 | @jdc.pytree_dataclass 59 | class Axes: 60 | normals: jnp.ndarray 61 | origins: jnp.ndarray = field(default_factory=lambda: jnp.array([[0.0, 0.0, 0.0]])) 62 | 63 | @cached_property 64 | def local_rotation_matrix(self): 65 | return jax.vmap(get_rotation_matrix, (0, None))(self.normals, XY.normals[0]) 66 | 67 | @cached_property 68 | def world_rotation_matrix(self): 69 | return jax.vmap(get_rotation_matrix, (None, 0))(XY.normals[0], self.normals) 70 | 71 | @property 72 | def count(self): 73 | return self.normals.shape[0] 74 | 75 | @property 76 | def local_origins(self): 77 | return self.to_local_coords(self.origins) 78 | 79 | @property 80 | def shape(self): 81 | return self.origins.shape 82 | 83 | def __add__(self, translation: jnp.ndarray): 84 | return Axes(origins=self.origins + translation, normals=self.normals) 85 | 86 | def to_local_coords(self, world_coords: jnp.ndarray): 87 | """Returns a rotation matrix that rotates the plane to the xy-plane. 88 | 89 | :param origin: origin of the plane 90 | :type origin: djnp.ndarray 91 | :param normal: normal of the plane 92 | :type normal: jnp.ndarray 93 | :param world_coords: world coordinates to rotate 94 | :type world_coords: jnp.ndarray 95 | :return: rotated coordinates 96 | :rtype: jnp.ndarray 97 | 98 | """ 99 | # Translation vector 100 | translation_vector = jnp.expand_dims(XY.origins[0] - self.origins, axis=1) 101 | 102 | # Rotate the points 103 | return jax.jit(jax.vmap(jnp.matmul, (None, 0)))( 104 | world_coords + translation_vector, self.local_rotation_matrix 105 | ) 106 | 107 | def to_world_coords(self, local_coords: jnp.ndarray): 108 | """Returns a rotation matrix that rotates the plane to the xy-plane. 109 | 110 | :param origin: origin of the plane 111 | :type origin: jnp.ndarray 112 | :param normal: normal of the plane 113 | :type normal: jnp.ndarray 114 | :param local_coords: local coordinates to rotate 115 | :type local_coords: jnp.ndarray 116 | :return: rotated coordinates 117 | :rtype: jnp.ndarray 118 | """ 119 | if local_coords.shape[-1] == 2: 120 | local_coords = jnp.concatenate( 121 | [local_coords, jnp.zeros((local_coords.shape[0], 1))], axis=-1 122 | ) 123 | 124 | # Translation vector 125 | translation_vector = jnp.expand_dims(self.origins - XY.origins[0], axis=1) 126 | 127 | # Rotate the points 128 | return ( 129 | jax.jit(jax.vmap(jnp.matmul, (None, 0)))( 130 | local_coords, self.world_rotation_matrix 131 | ) 132 | + translation_vector 133 | ) 134 | 135 | def __repr__(self) -> str: 136 | return f"Axes(origin={self.origins}, normal={self.normals})" 137 | 138 | @staticmethod 139 | def from_str(str: AxesString): 140 | if str == "XY": 141 | return Axes( 142 | origins=jnp.array([[0.0, 0.0, 0.0]]), normals=jnp.array([[0.0, 0.0, 1.0]]) 143 | ) 144 | elif str == "YX": 145 | return Axes( 146 | origins=jnp.array([[0.0, 0.0, 0.0]]), 147 | normals=jnp.array([[0.0, 0.0, -1.0]]), 148 | ) 149 | elif str == "XZ": 150 | return Axes( 151 | origins=jnp.array([[0.0, 0.0, 0.0]]), normals=jnp.array([[0.0, 1.0, 0.0]]) 152 | ) 153 | elif str == "ZX": 154 | return Axes( 155 | origins=jnp.array([[0.0, 0.0, 0.0]]), 156 | normals=jnp.array([[0.0, -1.0, 0.0]]), 157 | ) 158 | elif str == "YZ": 159 | return Axes( 160 | origins=jnp.array([[0.0, 0.0, 0.0]]), normals=jnp.array([[1.0, 0.0, 0.0]]) 161 | ) 162 | elif str == "ZY": 163 | return Axes( 164 | origins=jnp.array([[0.0, 0.0, 0.0]]), 165 | normals=jnp.array([[-1.0, 0.0, 0.0]]), 166 | ) 167 | 168 | 169 | XY = Axes( 170 | origins=jnp.array([[0.0, 0.0, 0.0]]), normals=jnp.array(object=[[0.0, 0.0, 1.0]]) 171 | ) 172 | -------------------------------------------------------------------------------- /nebula/prim/bspline_curves.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax_dataclasses as jdc 3 | from nebula.evaluators.bspline import BSplineEvaluator 4 | from nebula.prim.axes import Axes 5 | from nebula.prim import SparseArray, SparseIndexable 6 | from typing import Optional, Sequence, Union 7 | import jax.numpy as jnp 8 | import numpy as np 9 | 10 | 11 | @jdc.pytree_dataclass 12 | class BSplineCurves: 13 | degree: jnp.ndarray 14 | ctrl_pnts: SparseArray 15 | knots: SparseArray 16 | 17 | @staticmethod 18 | def empty(): 19 | return BSplineCurves( 20 | degree=jnp.empty((0,), dtype=np.int32), 21 | ctrl_pnts=SparseArray.empty((0, 3)), 22 | knots=SparseArray.empty(), 23 | ) 24 | 25 | 26 | # TODO: make this more efficient 27 | def evaluate(self, u: Optional[jnp.ndarray] = None): 28 | """ 29 | Evaluate the curves at the given u values. 30 | 31 | :param u: The u values to evaluate the curves at. If None then gives sements of start and end ctrl pnts 32 | :type u: jnp.ndarray 33 | :returns: The vertices of the curves. 34 | :rtype: jnp.ndarray 35 | 36 | """ 37 | if not self.ctrl_pnts.count: 38 | return jnp.empty((0, 2, 3)) 39 | curve_segments = [] 40 | for i in range(self.ctrl_pnts.count): 41 | curve = self.mask(i) 42 | if u is not None: 43 | vertices = BSplineEvaluator.eval_curve( 44 | curve.degree, curve.ctrl_pnts.val, u, curve.knots.val 45 | ) 46 | 47 | segments = jnp.stack([vertices[:-1], vertices[1:]], axis=1) 48 | else: 49 | segments = jnp.array([[curve.ctrl_pnts.val[0], curve.ctrl_pnts.val[-1]]]) 50 | curve_segments.append(segments) 51 | return jnp.concatenate(curve_segments, axis=0) 52 | 53 | def evaluate_at(self, u: jnp.ndarray, index: jnp.ndarray): 54 | curve = self.mask(index) 55 | return BSplineEvaluator.eval_curve( 56 | curve.degree, curve.ctrl_pnts.val, u, curve.knots.val 57 | ) 58 | 59 | # TODO: Add and test this later when converting planes to bspline surfaces 60 | @staticmethod 61 | @jax.jit 62 | def from_line_segment(segment: jnp.ndarray): 63 | # other segments are the connectors between top and bottom segments in quad 64 | ctrl_pnts = segment.reshape(-1, 3) 65 | line_knots = BSplineEvaluator.generate_line_knots() 66 | return BSplineCurves( 67 | jnp.array([1]), 68 | SparseArray.from_array(ctrl_pnts), 69 | SparseArray.from_array(line_knots), 70 | ) 71 | 72 | def get_segments(self): 73 | last_vertices = ( 74 | jnp.empty((self.count, 3)).at[self.ctrl_pnts.index].set(self.ctrl_pnts.val) 75 | ) 76 | 77 | first_vertices = ( 78 | jnp.empty((self.count, 3)) 79 | .at[self.ctrl_pnts.index[::-1]] 80 | .set(self.ctrl_pnts.val[::-1]) 81 | ) 82 | 83 | return jnp.stack([first_vertices, last_vertices], axis=1) 84 | 85 | 86 | def add( 87 | self, 88 | curves: "BSplineCurves", 89 | ): 90 | return BSplineCurves( 91 | degree=jnp.concatenate([self.degree, curves.degree]), 92 | ctrl_pnts=self.ctrl_pnts.add(curves.ctrl_pnts), 93 | knots=self.knots.add(curves.knots), 94 | ) 95 | 96 | 97 | @classmethod 98 | def combine(cls, all_curves: Sequence["BSplineCurves"]): 99 | """Combine curves into single 100 | 101 | :param curves: Topology to combine 102 | :type curves: Topology 103 | :return: Combined curves 104 | :rtype: Topology 105 | """ 106 | new_curves = cls(None) # type: ignore 107 | for curves in all_curves: 108 | new_curves.add(curves) 109 | return new_curves 110 | 111 | def clone( 112 | self, 113 | degree: Optional[jnp.ndarray] = None, 114 | ctrl_pnts: Optional[SparseArray] = None, 115 | knots: Optional[SparseArray] = None, 116 | ): 117 | return BSplineCurves( 118 | degree=degree if degree is not None else self.degree, 119 | ctrl_pnts=( 120 | SparseArray(ctrl_pnts.val, ctrl_pnts.index) 121 | if ctrl_pnts is not None 122 | else self.ctrl_pnts 123 | ), 124 | knots=( 125 | SparseArray(knots.val, knots.index) if knots is not None else self.knots 126 | ), 127 | ) 128 | 129 | # TODO: this is assuming it is already on the XY plane 130 | def project(self, curr_axes: Axes, new_axes: Axes): 131 | local_coords = curr_axes.to_local_coords(self.ctrl_pnts.val) 132 | return BSplineCurves( 133 | degree=jnp.repeat(self.degree[None, :], new_axes.count, axis=0).flatten(), 134 | ctrl_pnts=SparseArray( 135 | val=new_axes.to_world_coords(local_coords).reshape(-1, 3), 136 | index=SparseIndexable.expanded_index(self.ctrl_pnts.index, new_axes.count), 137 | ), 138 | knots=SparseArray( 139 | val=jnp.repeat(self.knots.val[None, :], new_axes.count, axis=0).flatten(), 140 | index=SparseIndexable.expanded_index(self.knots.index, new_axes.count), 141 | ), 142 | ) 143 | 144 | def translate(self, translation: jnp.ndarray): 145 | return self.clone( 146 | ctrl_pnts=SparseArray(self.ctrl_pnts.val + translation, index=self.ctrl_pnts.index), 147 | ) 148 | 149 | def mask(self, mask: Union[jnp.ndarray, np.ndarray, int]): 150 | index = np.arange(self.count)[mask] 151 | ctrl_pnts_mask = jnp.isin(self.ctrl_pnts.index, index) 152 | knots_mask = jnp.isin(self.knots.index, index) 153 | return BSplineCurves( 154 | degree=self.degree[mask], 155 | ctrl_pnts=self.ctrl_pnts.mask(ctrl_pnts_mask), 156 | knots=self.knots.mask(knots_mask), 157 | ) 158 | 159 | def reorder(self, index: np.ndarray): 160 | return BSplineCurves( 161 | degree=self.degree[index], 162 | ctrl_pnts=self.ctrl_pnts.reorder(index), 163 | knots=self.knots.reorder(index), 164 | ) 165 | 166 | def flip_winding(self, mask: Optional[jnp.ndarray] = None): 167 | return BSplineCurves( 168 | degree=self.degree[::-1], 169 | ctrl_pnts=self.ctrl_pnts.flip_winding(mask), 170 | knots=self.knots.flip_winding(mask), 171 | ) 172 | 173 | def __repr__(self) -> str: 174 | return f"BSplineCurves(count={self.count})" 175 | 176 | @property 177 | def count(self): 178 | return self.ctrl_pnts.count 179 | -------------------------------------------------------------------------------- /nebula/render/visualization.py: -------------------------------------------------------------------------------- 1 | import zmq 2 | import pickle 3 | from http.client import REQUEST_TIMEOUT 4 | import jax.numpy as jnp 5 | from typing import Literal, Optional, Union 6 | from dataclasses import dataclass, asdict 7 | from nebula.helpers.vector import VectorHelper 8 | from nebula.render.tesselation import Mesh, Tesselator 9 | from nebula.topology.solids import Solids 10 | 11 | 12 | @dataclass 13 | class BoundingBox: 14 | xmin: float 15 | xmax: float 16 | ymin: float 17 | ymax: float 18 | zmin: float 19 | zmax: float 20 | 21 | 22 | @dataclass 23 | class PartShape: 24 | vertices: list[list[float]] 25 | triangles: list[int] 26 | normals: list[list[float]] 27 | edges: list[list[list[float]]] 28 | 29 | 30 | @dataclass 31 | class Part: 32 | name: str 33 | id: str 34 | shape: PartShape 35 | type: str = "shapes" 36 | color: str = "#e8b024" 37 | renderback: bool = True 38 | 39 | 40 | @dataclass 41 | class Shapes: 42 | bb: Optional[BoundingBox] 43 | parts: list[Part] 44 | name: str = "Group" 45 | id: str = "/Group" 46 | loc: Optional[tuple[list[float], tuple[float, float, float, float]]] = None 47 | 48 | 49 | def get_part(mesh: Mesh, index: int = 0): 50 | return Part( 51 | name=f"Part_{index}", 52 | id=f"/Group/Part_{index}", 53 | shape=PartShape( 54 | vertices=mesh.vertices.tolist(), 55 | triangles=mesh.simplices.flatten().tolist(), 56 | normals=mesh.normals.tolist(), 57 | edges=mesh.edges.tolist(), 58 | ), 59 | ) 60 | 61 | 62 | def get_bounding_box(vertices: jnp.ndarray): 63 | return BoundingBox( 64 | xmin=float(vertices[:, 0].min()), 65 | xmax=float(vertices[:, 0].max()), 66 | ymin=float(vertices[:, 1].min()), 67 | ymax=float(vertices[:, 1].max()), 68 | zmin=float(vertices[:, 2].min()) or -1e-07, 69 | zmax=float(vertices[:, 2].max()) or 1e-07, 70 | ) 71 | 72 | 73 | DEFAULT_CONFIG = { 74 | "viewer": "", 75 | "anchor": "right", 76 | "theme": "light", 77 | "pinning": False, 78 | "angular_tolerance": 0.2, 79 | "deviation": 0.1, 80 | "edge_accuracy": None, 81 | "default_color": [232, 176, 36], 82 | "default_edge_color": "#707070", 83 | "optimal_bb": False, 84 | "render_normals": False, 85 | "render_edges": True, 86 | "render_mates": False, 87 | "parallel": False, 88 | "mate_scale": 1, 89 | "control": "trackball", 90 | "up": "Z", 91 | "axes": False, 92 | "axes0": False, 93 | "grid": [False, False, False], 94 | "ticks": 10, 95 | "ortho": True, 96 | "transparent": False, 97 | "black_edges": False, 98 | "ambient_intensity": 0.75, 99 | "direct_intensity": 0.15, 100 | "reset_camera": True, 101 | "show_parent": True, 102 | "show_bbox": False, 103 | "quaternion": None, 104 | "target": None, 105 | "zoom_speed": 1.0, 106 | "pan_speed": 1.0, 107 | "rotate_speed": 1.0, 108 | "collapse": 1, 109 | "tools": True, 110 | "timeit": False, 111 | "js_debug": False, 112 | "normal_len": 0, 113 | } 114 | ZMQ_PORT = 5555 115 | 116 | 117 | def connect(context): 118 | endpoint = f"tcp://localhost:{ZMQ_PORT}" 119 | socket = context.socket(zmq.REQ) 120 | socket.connect(endpoint) 121 | return socket 122 | 123 | 124 | def send(data): 125 | context = zmq.Context() 126 | socket = connect(context) 127 | 128 | msg = pickle.dumps(data, 4) 129 | print(" sending ... ", end="") 130 | socket.send(msg) 131 | 132 | retries_left = 3 133 | while True: 134 | if (socket.poll(REQUEST_TIMEOUT) & zmq.POLLIN) != 0: 135 | reply = socket.recv_json() 136 | 137 | if reply["result"] == "success": 138 | print("done") 139 | else: 140 | print("\n", reply["msg"]) 141 | break 142 | 143 | retries_left -= 1 144 | 145 | # Socket is confused. Close and remove it. 146 | socket.setsockopt(zmq.LINGER, 0) 147 | socket.close() 148 | if retries_left == 0: 149 | break 150 | 151 | print("Reconnecting to server…") 152 | # Create new connection 153 | socket = connect(context) 154 | 155 | print("Resending ...") 156 | socket.send(msg) 157 | 158 | 159 | def show( 160 | item: Union[Solids, Mesh], 161 | type: Literal["cad", "plot"] = "cad", 162 | intensity: Optional[jnp.ndarray] = None, 163 | name: Optional[str] = None, 164 | file_name: Optional[str] = None, 165 | **kwargs, 166 | ): 167 | mesh = Tesselator.get_mesh(item) if isinstance(item, Solids) else item 168 | 169 | if type == "cad": 170 | parts: list[Part] = [get_part(mesh)] 171 | 172 | shapes = Shapes( 173 | bb=get_bounding_box(mesh.vertices), 174 | parts=parts, 175 | ) 176 | 177 | states = {"/Group/Part_0": [1, 1]} 178 | data = { 179 | "data": dict(shapes=asdict(shapes), states=states), 180 | "type": "data", 181 | "config": DEFAULT_CONFIG, 182 | "count": len(parts), 183 | } 184 | send(data) 185 | 186 | else: 187 | import plotly.graph_objects as go 188 | 189 | Xe = [] 190 | Ye = [] 191 | Ze = [] 192 | for T in mesh.vertices[mesh.simplices]: 193 | Xe.extend([T[k % 3][0] for k in range(4)] + [None]) 194 | Ye.extend([T[k % 3][1] for k in range(4)] + [None]) 195 | Ze.extend([T[k % 3][2] for k in range(4)] + [None]) 196 | 197 | # Create a mesh object 198 | if intensity is not None: 199 | intensity = VectorHelper.normalize(intensity) 200 | plot_mesh = go.Mesh3d( 201 | x=mesh.vertices[:, 0], 202 | y=mesh.vertices[:, 1], 203 | z=mesh.vertices[:, 2], 204 | i=[face[0] for face in mesh.simplices], 205 | j=[face[1] for face in mesh.simplices], 206 | k=[face[2] for face in mesh.simplices], 207 | intensity=intensity, 208 | ) 209 | 210 | lines = go.Scatter3d( 211 | x=Xe, 212 | y=Ye, 213 | z=Ze, 214 | mode="lines", 215 | line=dict(color="rgb(70,70,70)", width=1), 216 | ) 217 | 218 | # Create a figure and add the mesh to it 219 | fig = go.Figure(data=[plot_mesh, lines]) 220 | 221 | # Update layout to remove axes and background and make the graph larger 222 | fig.update_layout( 223 | scene=dict( 224 | xaxis=dict(visible=False), 225 | yaxis=dict(visible=False), 226 | zaxis=dict(visible=False), 227 | bgcolor="rgba(0,0,0,0)", 228 | ), 229 | width=800, # Set the width of the graph to 800 pixels 230 | height=600, # Set the height of the graph to 600 pixels 231 | title=name, 232 | ) 233 | fig.update_layout(scene=dict(bgcolor="rgba(0,0,0,0)")) 234 | if file_name is not None: 235 | fig.write_image(file_name, scale=6, format="png", engine="kaleido") 236 | # Display the figure 237 | fig.show() 238 | -------------------------------------------------------------------------------- /nebula/cases/airfoil.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import cached_property 3 | from typing import Literal, Optional 4 | import plotly.graph_objects as go 5 | from nebula.evaluators.bspline import BSplineEvaluator, get_sampling 6 | import jax.numpy as jnp 7 | from nebula.tools.edge import EdgeTool 8 | 9 | 10 | def get_thickness_dist_ctrl_pnts( 11 | camber: jnp.ndarray, 12 | camber_normal: jnp.ndarray, 13 | thickness_dist: jnp.ndarray, 14 | thickness_sampling: jnp.ndarray, 15 | degree: jnp.ndarray, 16 | ): 17 | "get thickness distribution control points" 18 | camber_normal_thickness = BSplineEvaluator.eval_curve( 19 | degree, thickness_dist, thickness_sampling 20 | ) 21 | 22 | return jnp.concatenate( 23 | [ 24 | jnp.array([camber[0]]), 25 | camber + camber_normal * camber_normal_thickness, 26 | jnp.array([camber[-1]]), 27 | ] 28 | ) 29 | 30 | 31 | @dataclass 32 | class CamberThicknessAirfoil: 33 | "parametric airfoil using B-splines" 34 | 35 | inlet_angle: jnp.ndarray 36 | "inlet angle (rad)" 37 | 38 | outlet_angle: jnp.ndarray 39 | "outlet angle (rad)" 40 | 41 | upper_thick_prop: jnp.ndarray 42 | "upper thickness proportion to chord length (length)" 43 | 44 | lower_thick_prop: jnp.ndarray 45 | "lower thickness proportion to chord length (length)" 46 | 47 | leading_prop: jnp.ndarray 48 | "leading edge tangent line proportion [0.0-1.0] (dimensionless)" 49 | 50 | trailing_prop: jnp.ndarray 51 | "trailing edge tangent line proportion [0.0-1.0] (dimensionless)" 52 | 53 | chord_length: jnp.ndarray 54 | "chord length (length)" 55 | 56 | stagger_angle: Optional[jnp.ndarray] = None 57 | "stagger angle (rad)" 58 | 59 | num_samples: int = 50 60 | "number of samples" 61 | 62 | is_cosine_sampling: bool = True 63 | "use cosine sampling" 64 | 65 | leading_ctrl_pnt: jnp.ndarray = field( 66 | default_factory=lambda: jnp.array([0.0, 0.0, 0.0]) 67 | ) 68 | "leading control point (length)" 69 | 70 | angle_units: Literal["rad", "deg"] = "rad" 71 | "angle units" 72 | 73 | def __post_init__(self): 74 | if self.angle_units == "deg": 75 | self.inlet_angle = jnp.radians(self.inlet_angle) 76 | self.outlet_angle = jnp.radians(self.outlet_angle) 77 | 78 | if self.upper_thick_prop is not None: 79 | self.upper_thick_dist = [ 80 | self.chord_length * prop for prop in self.upper_thick_prop 81 | ] 82 | 83 | if self.lower_thick_prop is not None: 84 | self.lower_thick_dist = [ 85 | self.chord_length * prop for prop in self.lower_thick_prop 86 | ] 87 | 88 | if self.stagger_angle is None: 89 | self.stagger_angle = (self.inlet_angle + self.outlet_angle) / 2 90 | 91 | self.degree = jnp.array(3) 92 | self.num_thickness_dist_pnts = len(self.upper_thick_dist) + 4 93 | self.camber_knots = BSplineEvaluator.generate_clamped_knots(self.degree, 4) 94 | 95 | self.thickness_dist_sampling = jnp.linspace( 96 | 0, 1, self.num_thickness_dist_pnts, endpoint=True 97 | ) 98 | self.sampling = get_sampling(0.0, 1.0, self.num_samples, self.is_cosine_sampling) 99 | self.axial_chord_length = self.chord_length * jnp.cos(self.stagger_angle) 100 | self.height = self.chord_length * jnp.sin(self.stagger_angle) 101 | 102 | @cached_property 103 | def camber_ctrl_pnts(self): 104 | assert self.stagger_angle is not None, "stagger angle is not defined" 105 | p_le = jnp.array(self.leading_ctrl_pnt) 106 | 107 | p_te = p_le + jnp.array( 108 | [ 109 | self.chord_length * jnp.cos(self.stagger_angle), 110 | self.chord_length * jnp.sin(self.stagger_angle), 111 | 0.0, 112 | ] 113 | ) 114 | 115 | # leading edge tangent control point 116 | p1 = p_le + self.leading_prop * jnp.array( 117 | [ 118 | self.chord_length * jnp.cos(self.inlet_angle), 119 | self.chord_length * jnp.sin(self.inlet_angle), 120 | 0.0, 121 | ] 122 | ) 123 | 124 | # trailing edge tangent control point 125 | p2 = p_te - self.trailing_prop * jnp.array( 126 | [ 127 | self.chord_length * jnp.cos(self.outlet_angle), 128 | self.chord_length * jnp.sin(self.outlet_angle), 129 | 0.0, 130 | ] 131 | ) 132 | 133 | return jnp.vstack((p_le, p1, p2, p_te)) 134 | 135 | @cached_property 136 | def top_ctrl_pnts(self): 137 | "upper side bspline" 138 | assert ( 139 | self.upper_thick_dist is not None 140 | ), "upper thickness distribution is not defined" 141 | thickness_dist = jnp.vstack(self.upper_thick_dist) 142 | return get_thickness_dist_ctrl_pnts( 143 | self.camber_coords, 144 | self.camber_normal_coords, 145 | thickness_dist, 146 | self.thickness_dist_sampling, 147 | self.degree, 148 | ) 149 | 150 | @cached_property 151 | def bottom_ctrl_pnts(self): 152 | "lower side bspline" 153 | assert ( 154 | self.lower_thick_dist is not None 155 | ), "lower thickness distribution is not defined" 156 | thickness_dist = -jnp.vstack(self.lower_thick_dist) 157 | return get_thickness_dist_ctrl_pnts( 158 | self.camber_coords, 159 | self.camber_normal_coords, 160 | thickness_dist, 161 | self.thickness_dist_sampling, 162 | self.degree, 163 | ) 164 | 165 | @cached_property 166 | def camber_coords(self): 167 | "camber line coordinates" 168 | return BSplineEvaluator.eval_curve( 169 | self.degree, self.camber_ctrl_pnts, self.thickness_dist_sampling 170 | ) 171 | 172 | @cached_property 173 | def camber_normal_coords(self): 174 | "camber normal line coordinates" 175 | dy = jnp.gradient(self.camber_coords[:, 1]) 176 | dx = jnp.gradient(self.camber_coords[:, 0]) 177 | normal = jnp.vstack([-dy, dx, jnp.zeros(len(self.camber_coords))]).T 178 | return normal / jnp.linalg.norm(normal, axis=1)[:, None] 179 | 180 | def get_coords(self): 181 | "airfoil coordinates" 182 | top_coords = BSplineEvaluator.eval_curve( 183 | self.degree, self.top_ctrl_pnts, self.sampling 184 | ) 185 | bottom_coords = BSplineEvaluator.eval_curve( 186 | self.degree, self.bottom_ctrl_pnts, self.sampling 187 | ) 188 | 189 | return jnp.concatenate([top_coords[1:-1], bottom_coords[::-1]]) 190 | 191 | def get_edges(self): 192 | airfoil_top_edge = EdgeTool.make_bspline_curve(self.top_ctrl_pnts, self.degree) 193 | airfoil_bottom_edge = EdgeTool.make_bspline_curve( 194 | self.bottom_ctrl_pnts, self.degree 195 | ) 196 | return airfoil_top_edge, airfoil_bottom_edge 197 | 198 | def visualize( 199 | self, 200 | include_camber=True, 201 | include_camber_ctrl_pnts=False, 202 | filename: Optional[str] = None, 203 | ): 204 | fig = go.Figure(layout=go.Layout(title=go.layout.Title(text="Airfoil"))) 205 | if include_camber_ctrl_pnts: 206 | fig.add_trace( 207 | go.Scatter( 208 | x=self.camber_ctrl_pnts[:, 0], 209 | y=self.camber_ctrl_pnts[:, 1], 210 | name=f"Camber Control Points", 211 | ) 212 | ) 213 | 214 | if include_camber: 215 | camber_coords = self.camber_coords 216 | fig.add_trace( 217 | go.Scatter(x=camber_coords[:, 0], y=camber_coords[:, 1], name=f"Camber") 218 | ) 219 | 220 | coords = self.get_coords() 221 | fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], name=f"Airfoil")) 222 | 223 | fig.layout.yaxis.scaleanchor = "x" # type: ignore 224 | if filename: 225 | fig.write_image(filename, width=500, height=500) 226 | else: 227 | fig.show() 228 | 229 | 230 | -------------------------------------------------------------------------------- /nebula/evaluators/bspline.py: -------------------------------------------------------------------------------- 1 | 2 | from nebula.helpers.types import Number 3 | import jax 4 | import jax.numpy as jnp 5 | from typing import Callable, Optional 6 | 7 | 8 | SpanFunction = Callable[[int, jnp.ndarray, int, jnp.ndarray], jnp.ndarray] 9 | 10 | 11 | def get_degree(degree: Number): 12 | degree = jnp.asarray(degree) 13 | return degree[0] if len(degree.shape) > 0 else degree 14 | 15 | def get_sampling(start: Number, end: Number, num_points: int, is_cosine_sampling: bool = False): 16 | if is_cosine_sampling: 17 | beta = jnp.linspace(0.0,jnp.pi, num_points, endpoint=True) 18 | return 0.5*(1.0-jnp.cos(beta)) 19 | return jnp.linspace(start, end, num_points, endpoint=True) 20 | 21 | 22 | class BSplineEvaluator: 23 | @staticmethod 24 | def find_spans( 25 | degree: jnp.ndarray, 26 | knot_vector: jnp.ndarray, 27 | num_ctrlpts: int, 28 | knot_samples: jnp.ndarray, 29 | ): 30 | """Finds the span of a single knot over the knot vector using linear search. 31 | 32 | Alternative implementation for the Algorithm A2.1 from The NURBS Book by Piegl & Tiller. 33 | 34 | :param degree: degree, :math:`p` 35 | :type degree: jnp.ndarray, (1,) 36 | :param knot_vector: knot vector, :math:`U` 37 | :type knot_vector: torch.Tensor 38 | :param num_ctrlpts: number of control points, :math:`n + 1` 39 | :type num_ctrlpts: int 40 | :param knot: knot or parameter, :math:`u` 41 | :type knot: float 42 | :return: knot span 43 | :rtype: int 44 | """ 45 | span_start = degree + 1 46 | span_offset = jnp.sum( 47 | jnp.expand_dims(knot_samples, axis=-1) > knot_vector[span_start:], axis=-1 48 | ) 49 | span = jnp.clip(span_start + span_offset, a_max=num_ctrlpts) 50 | return span - 1 51 | 52 | @staticmethod 53 | def basis_functions( 54 | degree: Number, 55 | knot_vector: jnp.ndarray, 56 | span: jnp.ndarray, 57 | knot_samples: jnp.ndarray, 58 | ): 59 | """Computes the non-vanishing basis functions for a single parameter. 60 | 61 | Implementation of Algorithm A2.2 pg 70 from The NURBS Book by Piegl & Tiller. 62 | Uses recurrence to compute the basis functions, also known as Cox - de 63 | Boor recursion formula. 64 | 65 | :param degree: degree, :math:`p` 66 | :type degree: jnp.ndarray (1,) 67 | :param knot_vector: knot vector, :math:`U` 68 | :type knot_vector: list, tuple 69 | :param span: knot span, :math:`i` 70 | :type span: int 71 | :param knot: knot or parameter, :math:`u` 72 | :type knot: float 73 | :return: basis functions 74 | :rtype: list 75 | """ 76 | N = jnp.ones((degree + 1, len(knot_samples))) 77 | left = jnp.expand_dims(knot_samples, axis=0) - knot_vector[span + 1 - jnp.arange(degree + 1)[:, None]] 78 | right = knot_vector[span + jnp.arange(degree + 1)[:, None]] - jnp.expand_dims(knot_samples, axis=0) 79 | 80 | def inner_body_fun(r, init_value): 81 | j, saved, new_N = init_value 82 | temp = new_N[r] / (right[r + 1] + left[j - r]) 83 | next_N = new_N.at[r].set(saved + right[r + 1] * temp) 84 | saved = left[j - r] * temp 85 | return j, saved, next_N 86 | 87 | def outer_body_fun(j, N: jnp.ndarray): 88 | saved = jnp.zeros(len(knot_samples)) 89 | _, saved, N = jax.lax.fori_loop(0, j, inner_body_fun, (j, saved, N)) 90 | return N.at[j].set(saved) 91 | 92 | return jax.lax.fori_loop(1, degree+1, outer_body_fun, N) 93 | 94 | 95 | @staticmethod 96 | def generate_line_knots(): 97 | return BSplineEvaluator.generate_clamped_knots(1, 2) 98 | 99 | @staticmethod 100 | def generate_clamped_knots(degree: Number, num_ctrlpts: Number): 101 | """Generates a clamped knot vector. 102 | 103 | :param degree: non-zero degree of the curve 104 | :type degree: int 105 | :param num_ctrlpts: non-zero number of control points 106 | :type num_ctrlpts: int 107 | :return: clamped knot vector 108 | :rtype: Array 109 | """ 110 | # Number of repetitions at the start and end of the array 111 | num_repeat = degree 112 | # Number of knots in the middle 113 | num_segments = int(num_ctrlpts - (degree + 1)) 114 | 115 | return jnp.concatenate( 116 | ( 117 | jnp.zeros(num_repeat), 118 | jnp.linspace(0.0, 1.0, num_segments + 2), 119 | jnp.ones(num_repeat), 120 | ) 121 | ) 122 | 123 | @staticmethod 124 | def generate_unclamped_knots(degree: int, num_ctrlpts: int): 125 | """Generates a unclamped knot vector. 126 | 127 | :param degree: non-zero degree of the curve 128 | :type degree: int 129 | :param num_ctrlpts: non-zero number of control points 130 | :type num_ctrlpts: int 131 | :return: clamped knot vector 132 | :rtype: Array 133 | """ 134 | # Should conform the rule: m = n + p + 1 135 | return jnp.linspace(0.0, 1.0, degree + num_ctrlpts + 1) 136 | 137 | 138 | @staticmethod 139 | def eval_curve_pnt( 140 | degree: jnp.ndarray, ctrl_pnts: jnp.ndarray, basis: jnp.ndarray, span: jnp.ndarray 141 | ): 142 | dim = ctrl_pnts.shape[-1] 143 | if len(ctrl_pnts) < degree + 1: 144 | raise ValueError("Invalid size of control points for the given degree.") 145 | 146 | ctrl_pnt_slice = jax.lax.dynamic_slice( 147 | ctrl_pnts, (span - degree, 0), (1 + degree, dim) 148 | ) 149 | return jnp.sum(ctrl_pnt_slice * jnp.expand_dims(basis, axis=1), axis=0) 150 | 151 | @staticmethod 152 | def eval_curve( 153 | degree: jnp.ndarray, ctrl_pnts: jnp.ndarray, u: jnp.ndarray, knots: Optional[jnp.ndarray] = None 154 | ): 155 | degree = get_degree(degree) 156 | knots = ( 157 | BSplineEvaluator.generate_clamped_knots(degree, len(ctrl_pnts)) 158 | if knots is None 159 | else knots 160 | ) 161 | 162 | span = BSplineEvaluator.find_spans(degree, knots, len(ctrl_pnts), u) 163 | basis = BSplineEvaluator.basis_functions(degree, knots, span, u) 164 | 165 | return jax.vmap(BSplineEvaluator.eval_curve_pnt, in_axes=(None, None, 1, 0))( 166 | degree, ctrl_pnts, basis, span 167 | ) 168 | 169 | @staticmethod 170 | def eval_surface_pnt( 171 | u_degree: jnp.ndarray, 172 | v_degree: jnp.ndarray, 173 | ctrl_pnts: jnp.ndarray, 174 | basis_u: jnp.ndarray, 175 | basis_v: jnp.ndarray, 176 | span_u: jnp.ndarray, 177 | span_v: jnp.ndarray, 178 | ): 179 | ctrl_pnt_slice = jax.lax.dynamic_slice( 180 | ctrl_pnts, 181 | (span_u - u_degree, span_v - v_degree, 0), 182 | (1 + u_degree, 1 + v_degree, 3), 183 | ) 184 | 185 | eval_prev_pnt = jnp.sum(ctrl_pnt_slice * jnp.expand_dims(basis_v, axis=1), axis=1) 186 | return jnp.sum(eval_prev_pnt * jnp.expand_dims(basis_u, axis=1), axis=0) 187 | 188 | @staticmethod 189 | def eval_surface( 190 | u_degree: Number, 191 | v_degree: Number, 192 | ctrl_pnts: jnp.ndarray, 193 | u_knots: jnp.ndarray, 194 | v_knots: jnp.ndarray, 195 | u: jnp.ndarray, 196 | v: jnp.ndarray, 197 | ): 198 | u_degree = get_degree(u_degree) 199 | v_degree = get_degree(v_degree) 200 | 201 | assert ( 202 | ctrl_pnts.shape[0] >= u_degree + 1 203 | ), f"Number of curves should be at least {u_degree + 1}" 204 | assert ( 205 | ctrl_pnts.shape[1] >= v_degree + 1 206 | ), f"Number of control points should be at least {v_degree + 1}" 207 | 208 | span_u = BSplineEvaluator.find_spans( 209 | u_degree, u_knots, ctrl_pnts.shape[0], u 210 | ) 211 | basis_u = BSplineEvaluator.basis_functions(u_degree, u_knots, span_u, u) 212 | 213 | span_v = BSplineEvaluator.find_spans( 214 | v_degree, v_knots, ctrl_pnts.shape[1], v 215 | ) 216 | basis_v = BSplineEvaluator.basis_functions(v_degree, v_knots, span_v, v) 217 | 218 | return jax.vmap( 219 | jax.vmap( 220 | BSplineEvaluator.eval_surface_pnt, 221 | in_axes=(None, None, None, None, 1, None, 0), 222 | ), 223 | in_axes=(None, None, None, 1, None, 0, None), 224 | )(u_degree, v_degree, ctrl_pnts, basis_u, basis_v, span_u, span_v) 225 | 226 | -------------------------------------------------------------------------------- /nebula/render/tesselation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Sequence, Union 3 | import jax 4 | import numpy as np 5 | from scipy.spatial import Delaunay 6 | import jax.numpy as jnp 7 | from nebula.evaluators.bspline import BSplineEvaluator, get_sampling 8 | from nebula.prim.bspline_curves import BSplineCurves 9 | from nebula.helpers.wire import WireHelper 10 | from nebula.topology.edges import DEFAULT_CURVE_SAMPLES 11 | from nebula.topology.faces import Faces 12 | from nebula.topology.solids import Solids 13 | from nebula.topology.wires import Wires 14 | from nebula.workplane import Workplane 15 | 16 | 17 | @dataclass 18 | class Mesh: 19 | vertices: jnp.ndarray 20 | simplices: jnp.ndarray 21 | normals: jnp.ndarray 22 | edges: jnp.ndarray 23 | 24 | @staticmethod 25 | def empty(): 26 | return Mesh( 27 | vertices=jnp.empty((0, 3)), 28 | simplices=jnp.empty((0, 3), dtype=jnp.int32), 29 | normals=jnp.empty((0, 3)), 30 | edges=jnp.empty((0, 2, 3)), 31 | ) 32 | 33 | @staticmethod 34 | def combine(meshes: Sequence["Mesh"]): 35 | 36 | vertices = jnp.concatenate([mesh.vertices for mesh in meshes]) 37 | simplices = jnp.concatenate( 38 | [ 39 | mesh.simplices + (len(meshes[i - 1].vertices) if i > 0 else 0) 40 | for i, mesh in enumerate(meshes) 41 | ] 42 | ) 43 | normals = jnp.concatenate([mesh.normals for mesh in meshes]) 44 | edges = jnp.concatenate([mesh.edges for mesh in meshes]) 45 | return Mesh(vertices, simplices, normals, edges) 46 | 47 | 48 | class Tesselator: 49 | @staticmethod 50 | def get_mesh(solids: Solids): 51 | faces = solids.shells.faces 52 | planar_mesh = Tesselator.get_planar_mesh(faces.mask(faces.wires.is_planar)) 53 | non_planar_mesh = Tesselator.get_non_planar_mesh( 54 | faces.mask(~faces.wires.is_planar) 55 | ) 56 | return Mesh.combine([planar_mesh, non_planar_mesh]) 57 | 58 | @staticmethod 59 | def get_differentiable_mesh(target: Union[Workplane, Solids]): 60 | solids = target.base_solids if isinstance(target, Workplane) else target 61 | faces = solids.shells.faces 62 | planar_mesh = Tesselator.get_differentiable_planar_mesh( 63 | faces.mask(faces.wires.is_planar) 64 | ) 65 | non_planar_mesh = Tesselator.get_non_planar_mesh( 66 | faces.mask(~faces.wires.is_planar) 67 | ) 68 | return Mesh.combine([planar_mesh, non_planar_mesh]) 69 | 70 | @staticmethod 71 | def get_trimmed_simplices( 72 | vertices: jnp.ndarray, 73 | simplices: jnp.ndarray, 74 | exterior_segments: jnp.ndarray, 75 | interior_segments: jnp.ndarray, 76 | ): 77 | triangles = vertices[simplices] 78 | triangle_centers = jnp.mean(triangles, axis=1) 79 | is_exterior_triangle = jax.vmap(WireHelper.contains_vertex, in_axes=(None, 0))( 80 | exterior_segments, triangle_centers 81 | ) 82 | is_interior_triangle = jax.vmap(WireHelper.contains_vertex, in_axes=(None, 0))( 83 | interior_segments, triangle_centers 84 | ) 85 | return simplices[is_exterior_triangle & ~is_interior_triangle] 86 | 87 | @staticmethod 88 | def get_tri_normal(tri_vertices: Union[jnp.ndarray, np.ndarray]) -> jnp.ndarray: 89 | # Calculate two vectors along the edges of the triangle 90 | normal = jnp.cross( 91 | tri_vertices[1] - tri_vertices[0], tri_vertices[2] - tri_vertices[0] 92 | ) 93 | unit_normal = normal / jnp.linalg.norm(normal, axis=0, keepdims=True) 94 | return unit_normal 95 | 96 | @staticmethod 97 | def get_differentiable_planar_mesh(faces: Faces, curve_samples: int = 50): 98 | is_quad_face = ((faces.get_num_edges() - faces.get_num_curves()) == 4).repeat(4) 99 | quad_faces = faces.mask(is_quad_face) 100 | bspline_faces = Faces.from_wires( 101 | Wires.convert_to_bspline(quad_faces.wires), is_exterior=True 102 | ) 103 | bspline_face_mesh = Tesselator.get_non_planar_mesh(bspline_faces, curve_samples) 104 | return bspline_face_mesh 105 | 106 | @staticmethod 107 | def get_planar_mesh(faces: Faces, curve_samples: int = 50): 108 | vertices = jnp.empty((0, 3)) 109 | simplices = jnp.empty((0, 3), dtype=jnp.int32) 110 | normals = jnp.empty((0, 3)) 111 | 112 | # TODO: replace this with jax implementation of Delaunay triangulation 113 | for i in range(faces.count): 114 | face = faces.mask(faces.index == i) 115 | assert face.wires.is_planar.all() == True, "Only planar faces supported" 116 | 117 | axis = face.get_axes() 118 | axis_normal = axis.normals[0] 119 | 120 | # triangulate face 121 | segments = face.wires.edges.evaluate(curve_samples, is_cosine_sampling=True) 122 | local_segments = axis.to_local_coords(segments)[0] 123 | 124 | delaunay = Delaunay(local_segments[:, 0, 0:2]) 125 | 126 | tri_local_vertices, tri_simplices = ( 127 | jnp.array(delaunay.points), 128 | jnp.array(delaunay.simplices), 129 | ) 130 | tri_vertices = axis.to_world_coords(tri_local_vertices)[0] 131 | 132 | # Flip triangle normals to correct direction as plane 133 | tri_normal = Tesselator.get_tri_normal( 134 | tri_vertices[tri_simplices[0]], 135 | ) 136 | if not (tri_normal == axis_normal).all(): 137 | tri_simplices = jnp.flip(tri_simplices, axis=1) 138 | 139 | # Trim triangles that are not interior 140 | is_interior = face.wires.get_segment_is_interior(curve_samples) 141 | trimmed_simplices = Tesselator.get_trimmed_simplices( 142 | tri_local_vertices, 143 | tri_simplices, 144 | local_segments[~is_interior], 145 | local_segments[is_interior], 146 | ) 147 | new_normals = axis.normals.repeat(len(vertices), axis=0) 148 | 149 | simplices = jnp.concatenate([simplices, trimmed_simplices + len(vertices)]) 150 | vertices = jnp.concatenate([vertices, tri_vertices]) 151 | 152 | normals = jnp.concatenate([normals, new_normals]) 153 | 154 | return Mesh( 155 | vertices, 156 | simplices, 157 | normals, 158 | faces.edges.evaluate(curve_samples), 159 | ) 160 | 161 | @staticmethod 162 | @jax.jit 163 | def get_bspline_tri_index(i: jnp.ndarray, j: jnp.ndarray, num_v: jnp.ndarray): 164 | index = jnp.array( 165 | [ 166 | j + (i * num_v), 167 | j + ((i + 1) * num_v), 168 | j + 1 + ((i + 1) * num_v), 169 | j + 1 + (i * num_v), 170 | ], 171 | dtype=jnp.int32, 172 | ) 173 | # + curr_index 174 | 175 | return jax.vmap( 176 | lambda tri_index, i: jnp.array( 177 | # [tri_index[i + 1], tri_index[i], tri_index[0]], 178 | [tri_index[0], tri_index[i], tri_index[i + 1]], 179 | dtype=jnp.int32, 180 | ), 181 | in_axes=(None, 0), 182 | )(index, jnp.arange(1, 3)) 183 | 184 | @staticmethod 185 | def get_non_planar_mesh( 186 | faces: Faces, 187 | num_u: int = 20, 188 | num_v: int = 20, 189 | curve_samples: int = DEFAULT_CURVE_SAMPLES, 190 | ): 191 | mesh = Mesh.empty() 192 | if len(faces.wires.is_planar) > 0: 193 | # TODO: get rid of for loop here 194 | for i in range(faces.count): 195 | face = faces.mask(faces.index == i) 196 | assert ( 197 | faces.wires.is_planar.all() == False 198 | ), "Only non-planar faces supported" 199 | i, j = jnp.arange(num_u - 1), jnp.arange(num_v - 1) 200 | 201 | u = get_sampling(0.0, 1.0, num_u) 202 | v = get_sampling(0.0, 1.0, num_v) 203 | vertices = face.bspline_surfaces.evaluate(u, v) 204 | 205 | simplicies = jax.vmap( 206 | jax.vmap( 207 | Tesselator.get_bspline_tri_index, 208 | in_axes=(None, 0, None), 209 | ), 210 | in_axes=(0, None, None), 211 | )(i, j, jnp.array(num_v)).reshape(-1, 3) 212 | 213 | u_curve = get_sampling(0.0, 1.0, curve_samples) 214 | line_segments = faces.wires.edges.curves.evaluate(u_curve) 215 | # faces.wires.plot() 216 | # TODO: fix these normals frp, Bspline derivatives later 217 | normals = jax.vmap(Tesselator.get_tri_normal, in_axes=(0,))( 218 | vertices[simplicies] 219 | ) 220 | 221 | new_mesh = Mesh(vertices, simplicies, normals, line_segments) 222 | mesh = Mesh.combine([mesh, new_mesh]) 223 | return mesh 224 | -------------------------------------------------------------------------------- /nebula/workplane.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from typing import Optional, Union 3 | from nebula.helpers.types import ArrayLike, CoordLike, Number 4 | from nebula.helpers.wire import WireHelper 5 | from nebula.tools.solid import SolidTool 6 | from nebula.tools.edge import EdgeTool 7 | from nebula.prim.axes import Axes, AxesString 8 | from nebula.topology.edges import Edges 9 | from nebula.topology.solids import Solids 10 | from nebula.topology.wires import Wires 11 | import jax_dataclasses as jdc 12 | 13 | 14 | @jdc.pytree_dataclass 15 | class Workplane: 16 | axes: Axes 17 | "The axes of the workplane" 18 | 19 | base_solids: Solids = jdc.field(default_factory=Solids.empty) 20 | "The base solids of the workplane" 21 | 22 | pending_edges: Edges = jdc.field(default_factory=Edges.empty) 23 | "The pending edges of the workplane" 24 | 25 | pending_wires: Wires = jdc.field(default_factory=Wires.empty) 26 | "The pending wires of the workplane" 27 | 28 | trace_vertex: jnp.ndarray = jdc.field( 29 | default_factory=lambda: jnp.array([0, 0, 0], dtype=jnp.float32) 30 | ) 31 | "The trace vertex of the workplane" 32 | 33 | @staticmethod 34 | def init( 35 | axes: Union[AxesString, Axes] = "XY", 36 | ): 37 | "Initialize the workplane" 38 | axes = axes if isinstance(axes, Axes) else Axes.from_str(axes) 39 | workplane = Workplane(axes) 40 | return workplane 41 | 42 | def clone( 43 | self, 44 | axes: Optional[Axes] = None, 45 | base_solids: Optional[Solids] = None, 46 | pending_edges: Optional[Edges] = None, 47 | pending_wires: Optional[Wires] = None, 48 | trace_vertex: Optional[jnp.ndarray] = None, 49 | ): 50 | "Clone the workplane" 51 | return Workplane( 52 | axes if axes is not None else self.axes, 53 | base_solids if base_solids is not None else self.base_solids, 54 | pending_edges if pending_edges is not None else self.pending_edges, 55 | pending_wires if pending_wires is not None else self.pending_wires, 56 | trace_vertex if trace_vertex is not None else self.trace_vertex, 57 | ) 58 | 59 | def plot(self): 60 | self.base_solids.shells.faces.wires.plot() 61 | 62 | @property 63 | def axis_locked(self): 64 | return self.pending_wires.count > 0 65 | 66 | def consolidateWires(self, validate: bool = False): 67 | """Consolidate pending wires into a single wire 68 | 69 | :param validate: Validate the edges for order 70 | 71 | """ 72 | pending_wires = EdgeTool.consolidate_wires( 73 | self.pending_wires, self.pending_edges, validate 74 | ) 75 | # pending_wires = Wires.empty() 76 | pending_edges = Edges.empty() 77 | trace_vertex = jnp.array([0, 0, 0]) 78 | return self.clone( 79 | pending_edges=pending_edges, 80 | pending_wires=pending_wires, 81 | trace_vertex=trace_vertex, 82 | ) 83 | 84 | def moveTo(self, x: Number = 0.0, y: Number = 0.0): 85 | "Move to the specified point, without drawing" 86 | trace_vertex = jnp.array([x, y, 0]) 87 | return self.clone(trace_vertex=trace_vertex) 88 | 89 | def move(self, xDist: Number = 0.0, yDist: Number = 0.0): 90 | "Move the specified distance from the current point, without drawing" 91 | trace_vertex = self.trace_vertex + jnp.array([xDist, yDist, 0]) 92 | return self.clone(trace_vertex=trace_vertex) 93 | 94 | def polarLineTo(self, distance: Number, angle: Number, forConstruction: bool = False): 95 | """ 96 | Make a line from the current point to the given polar coordinates 97 | 98 | Useful if it is more convenient to specify the end location rather than 99 | the distance and angle from the current point 100 | 101 | :param distance: distance of the end of the line from the origin 102 | :param angle: angle of the vector to the end of the line with the x-axis 103 | :return: the Workplane object with the current point at the end of the new line 104 | """ 105 | edge = EdgeTool.make_polar_line(self.trace_vertex, distance, angle) 106 | 107 | new_workplane = self.moveTo( 108 | edge.lines.vertices[-1][0], edge.lines.vertices[-1][1] 109 | ) 110 | if not forConstruction: 111 | new_pending_edges = self.pending_edges.add(edge) 112 | new_workplane = new_workplane.clone(pending_edges=new_pending_edges) 113 | return new_workplane 114 | 115 | def box(self, length: Number, width: Number, height: Number, centered=True): 116 | "Make a box for each item on the stack" 117 | return self.rect(length, width).extrude(height) 118 | 119 | def lineTo(self, x: Number, y: Number, forConstruction=False): 120 | "Make a line from the current point to the provided point" 121 | edge = EdgeTool.make_line(self.trace_vertex, jnp.array([x, y, 0])) 122 | 123 | new_workplane = self.moveTo(x, y) 124 | if not forConstruction: 125 | new_pending_edges = self.pending_edges.add(edge) 126 | new_workplane = new_workplane.clone(pending_edges=new_pending_edges) 127 | return new_workplane 128 | 129 | def rect(self, xLen: Number, yLen: Number, centered=True, forConstruction=False): 130 | "Make a rectangle for each item on the stack" 131 | edges = EdgeTool.make_rect(xLen, yLen, self.trace_vertex, centered=centered) 132 | if not forConstruction: 133 | new_pending_edges = self.pending_edges.add(edges) 134 | new_workplane = self.clone(pending_edges=new_pending_edges).consolidateWires( 135 | validate=False 136 | ) 137 | 138 | return new_workplane 139 | 140 | # def show(self): 141 | 142 | def polyline( 143 | self, 144 | vertices: CoordLike, 145 | forConstruction: bool = False, 146 | includeCurrent: bool = False, 147 | ): 148 | "Create a polyline from a list of points" 149 | 150 | vertices = WireHelper.to_3d_vertices(vertices) 151 | 152 | if includeCurrent: 153 | vertices = jnp.concatenate( 154 | [jnp.expand_dims(self.trace_vertex, axis=0), vertices] 155 | ) 156 | 157 | edges = EdgeTool.make_polyline(vertices) 158 | new_workplane = self.moveTo(vertices[-1][0], vertices[-1][1]) 159 | if not forConstruction: 160 | new_pending_edges = self.pending_edges.add(edges) 161 | new_workplane = new_workplane.clone(pending_edges=new_pending_edges) 162 | return new_workplane 163 | 164 | def bspline( 165 | self, 166 | ctrl_pnts: CoordLike, 167 | degree: Number = 3, 168 | knots: Optional[ArrayLike] = None, 169 | includeCurrent: bool = False, 170 | forConstruction: bool = False, 171 | ): 172 | ctrl_pnts = WireHelper.to_3d_vertices(ctrl_pnts) 173 | if includeCurrent: 174 | ctrl_pnts = jnp.concatenate( 175 | [jnp.expand_dims(self.trace_vertex, axis=0), ctrl_pnts] 176 | ) 177 | 178 | curve = EdgeTool.make_bspline_curve(ctrl_pnts, jnp.array(degree), knots) 179 | edges = Edges.from_bspline_curves(curve) 180 | new_workplane = self.moveTo(ctrl_pnts[-1][0], ctrl_pnts[-1][1]) 181 | if not forConstruction: 182 | new_pending_edges = self.pending_edges.add(edges) 183 | new_workplane = new_workplane.clone(pending_edges=new_pending_edges) 184 | return new_workplane 185 | 186 | def close(self, validate: bool = True): 187 | """End construction, and attempt to build a closed wire 188 | 189 | If the wire is already closed, nothing happens 190 | 191 | :param validate: Validate the edges for order 192 | """ 193 | assert self.pending_edges.count > 0, "No segments to close" 194 | 195 | start_point = self.pending_edges.vertices[0] 196 | is_closed = jnp.allclose(self.trace_vertex, start_point, atol=1e-3) 197 | assert ~is_closed, "Wire already closed" 198 | 199 | return self.lineTo(start_point[0], start_point[1]).consolidateWires(validate) 200 | 201 | def extrude( 202 | self, 203 | until: Number, 204 | ): 205 | "Use all un-extruded wires in the parent chain to create a prismatic solid" 206 | projected_wires = self.pending_wires.project(self.axes) 207 | extruded_solids = SolidTool.extrude(projected_wires, until) 208 | base_solids = self.base_solids.add(extruded_solids) 209 | return Workplane(self.axes, base_solids) 210 | 211 | # TODO: implement this later 212 | # def sweep(self, path: Union[Wires, Edges, "Workplane"]): 213 | # "Use all un-extruded wires in the parent chain to create a swept solid" 214 | # if isinstance(path, Workplane): 215 | # path = path.pending_wires 216 | # elif isinstance(path, Edges): 217 | # path = Wires.from_edges(path) 218 | 219 | # projected_wires = self.pending_wires.project(self.axes) 220 | # swept_solids = SolidTool.sweep(projected_wires, path) 221 | # self.base_solids = self.base_solids.add(swept_solids) 222 | # self.reset() 223 | 224 | # return self 225 | -------------------------------------------------------------------------------- /nebula/helpers/clipper.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import jax_dataclasses as jdc 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | from nebula.helpers.intersection import Intersection 7 | from nebula.helpers.wire import WireHelper 8 | from nebula.topology.edges import Edges 9 | from nebula.topology.wires import Wires 10 | 11 | 12 | @jdc.pytree_dataclass 13 | class PolygonContainsResult: 14 | subject_in_clipper: jnp.ndarray 15 | clipper_in_subject: jnp.ndarray 16 | 17 | 18 | @jdc.pytree_dataclass 19 | class PolygonSplitResult: 20 | split_subject_segments: jnp.ndarray 21 | split_clipper_segments: jnp.ndarray 22 | subject_intersected: jnp.ndarray 23 | clipper_intersected: jnp.ndarray 24 | 25 | 26 | @jdc.pytree_dataclass 27 | class WireSplitResult: 28 | subject_segments: jnp.ndarray 29 | clipper_segments: jnp.ndarray 30 | unsplit_wires: Wires 31 | 32 | 33 | @jdc.pytree_dataclass 34 | class PolygonClipResult: 35 | clip_edges: Edges 36 | unclipped_wires: Wires 37 | 38 | 39 | class Clipper: 40 | @staticmethod 41 | def cut_polygons( 42 | subject_wires: Wires, clipper: Union[Wires, Edges], sort_edges: bool = True 43 | ): 44 | cutter_union_result = Clipper.union( 45 | subject_wires.mask(subject_wires.is_interior), 46 | clipper, 47 | ) 48 | 49 | intersection_result = Clipper.intersect( 50 | subject_wires.mask(~subject_wires.is_interior), 51 | cutter_union_result.clip_edges, 52 | ) 53 | clip_edges = intersection_result.clip_edges 54 | if sort_edges: 55 | clip_edges = clip_edges.get_sorted().edges 56 | 57 | intersection_wires = Wires.from_edges(clip_edges) 58 | return Wires.combine( 59 | [ 60 | intersection_wires, 61 | cutter_union_result.unclipped_wires, 62 | intersection_result.unclipped_wires, 63 | ] 64 | ) 65 | 66 | @staticmethod 67 | def intersect(subject_wires: Wires, clipper: Union[Wires, Edges]): 68 | # Split subject and clipper edges at intersection points 69 | split_result = Clipper.split_wires(subject_wires, clipper) 70 | 71 | if subject_wires.edges.curves.count: 72 | raise NotImplementedError("Intersecting curves is not supported yet.") 73 | 74 | new_subject_segments = ( 75 | split_result.subject_segments 76 | if len(split_result.subject_segments) 77 | else subject_wires.edges.lines.get_segments() 78 | ) 79 | 80 | # Get which subject and clipper edges are in the other polygon 81 | polygon_contains = Clipper.contains_polygons( 82 | new_subject_segments, split_result.clipper_segments 83 | ) 84 | # TODO: check for orientation of clipper 85 | filtered_subject_segments, filtered_clipper_segments = ( 86 | new_subject_segments[~polygon_contains.subject_in_clipper], 87 | split_result.clipper_segments[polygon_contains.clipper_in_subject], 88 | ) 89 | 90 | 91 | if len(split_result.subject_segments): 92 | filtered_clipper_segments = jnp.flip(filtered_clipper_segments, (0,1)) 93 | clip_edges = Edges.from_line_segments( 94 | jnp.concatenate( 95 | [ 96 | filtered_subject_segments, 97 | filtered_clipper_segments, 98 | ] 99 | ), 100 | ) 101 | return PolygonClipResult( 102 | clip_edges=clip_edges, 103 | unclipped_wires=split_result.unsplit_wires, 104 | ) 105 | 106 | return PolygonClipResult( 107 | clip_edges=Edges.from_line_segments(filtered_subject_segments), 108 | unclipped_wires=Wires.from_line_segments( 109 | filtered_clipper_segments, is_interior=True 110 | ), 111 | ) 112 | 113 | @staticmethod 114 | def union(subject_wires: Wires, clipper: Union[Wires, Edges]): 115 | # Split subject and clipper edges at intersection points 116 | split_result = Clipper.split_wires(subject_wires, clipper) 117 | 118 | # Get which subject and clipper edges are in the other polygon 119 | polygon_contains = Clipper.contains_polygons( 120 | split_result.subject_segments, split_result.clipper_segments 121 | ) 122 | 123 | clip_edges = Edges.from_line_segments( 124 | jnp.concatenate( 125 | [ 126 | split_result.subject_segments[~polygon_contains.subject_in_clipper], 127 | split_result.clipper_segments[~polygon_contains.clipper_in_subject], 128 | ] 129 | ), 130 | ) 131 | return PolygonClipResult( 132 | clip_edges=clip_edges, 133 | unclipped_wires=split_result.unsplit_wires, 134 | ) 135 | 136 | # TODO: make segments have same count for jit 137 | @staticmethod 138 | @jax.jit 139 | def contains_polygons( 140 | subject_segments: jnp.ndarray, 141 | clipper_segments: jnp.ndarray, 142 | ): 143 | subject_in_clipper = jax.vmap(WireHelper.contains_segment, in_axes=(None, 0))( 144 | clipper_segments, subject_segments 145 | ) 146 | 147 | clipper_in_subject = jax.vmap(WireHelper.contains_segment, in_axes=(None, 0))( 148 | subject_segments, clipper_segments 149 | ) 150 | 151 | return PolygonContainsResult( 152 | subject_in_clipper=subject_in_clipper, clipper_in_subject=clipper_in_subject 153 | ) 154 | 155 | @staticmethod 156 | def split_polygon(subject_segments: jnp.ndarray, clipper_segments: jnp.ndarray): 157 | intersections = jax.vmap( 158 | jax.vmap(Intersection.line_segment_intersection, in_axes=(None, 0)), 159 | in_axes=(0, None), 160 | )(subject_segments, clipper_segments) 161 | # find which subject and clipper edges intersect 162 | subject_intersected = jnp.any(intersections.is_intersection, axis=(1,)) 163 | clipper_intersected = jnp.any(intersections.is_intersection, axis=(0,)) 164 | 165 | # split subject edges at intersection points 166 | subject_split_segments = Clipper.split_segments( 167 | subject_segments[subject_intersected], 168 | intersections.intersected_vertices[intersections.is_intersection], 169 | ) 170 | 171 | # Split clipper edges at intersection points 172 | clipper_split_segments = Clipper.split_segments( 173 | clipper_segments[clipper_intersected], 174 | jnp.swapaxes(intersections.intersected_vertices, 0, 1)[ 175 | jnp.swapaxes(intersections.is_intersection, 0, 1) 176 | ], 177 | ) 178 | 179 | return PolygonSplitResult( 180 | split_subject_segments=subject_split_segments, 181 | split_clipper_segments=clipper_split_segments, 182 | subject_intersected=subject_intersected, 183 | clipper_intersected=clipper_intersected, 184 | ) 185 | 186 | @staticmethod 187 | def split_wires(subject_wires: Wires, clipper: Union[Wires, Edges]): 188 | clipper_edges = clipper.edges if isinstance(clipper, Wires) else clipper 189 | if subject_wires.edges.curves.count or clipper_edges.curves.count: 190 | raise NotImplementedError("Intersecting curves is not supported yet.") 191 | subject_segments = subject_wires.edges.lines.get_segments() 192 | clipper_segments = clipper_edges.lines.get_segments() 193 | split_result = Clipper.split_polygon(subject_segments, clipper_segments) 194 | 195 | # Exclude all subject wires that are not affected from final subject segements 196 | is_split_wire = ( 197 | jnp.zeros(subject_wires.count) 198 | .at[subject_wires.index] 199 | .add(split_result.subject_intersected.astype(jnp.int32))[subject_wires.index] 200 | .astype(bool) 201 | ) 202 | unsplit_wires = subject_wires.mask(~is_split_wire) 203 | unsplit_affected_wire_segments = subject_segments[ 204 | ~split_result.subject_intersected & is_split_wire 205 | ] 206 | 207 | new_subject_segments = jnp.concatenate( 208 | [unsplit_affected_wire_segments, split_result.split_subject_segments] 209 | ) 210 | 211 | new_clipper_segments = jnp.concatenate( 212 | [ 213 | clipper_segments[~split_result.clipper_intersected], 214 | split_result.split_clipper_segments, 215 | ] 216 | ) 217 | 218 | return WireSplitResult( 219 | subject_segments=new_subject_segments, 220 | clipper_segments=new_clipper_segments, 221 | unsplit_wires=unsplit_wires, 222 | ) 223 | 224 | # TODO: make segments have same count for jit 225 | @staticmethod 226 | @jax.jit 227 | def split_segments(segments: jnp.ndarray, intersection_vertices: jnp.ndarray): 228 | fun_split_edges = lambda edge_vertices, intersection_point: jnp.array( 229 | [ 230 | [edge_vertices[0], intersection_point], 231 | [intersection_point, edge_vertices[1]], 232 | ] 233 | ) # (2, 3) 234 | 235 | # Split subject edges at intersection points 236 | segment_groups = jax.vmap(fun_split_edges, in_axes=(0, 0))( 237 | segments, 238 | intersection_vertices, 239 | ) 240 | return ( 241 | jnp.concatenate(segment_groups) 242 | if len(segment_groups) > 0 243 | else jnp.empty((0, 2, 3)) 244 | ) 245 | -------------------------------------------------------------------------------- /nebula/topology/edges.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union 2 | import jax_dataclasses as jdc 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from nebula.evaluators.bspline import get_sampling 6 | from nebula.prim import BSplineCurves 7 | from nebula.prim import Lines 8 | import numpy.typing as npt 9 | from nebula.helpers.wire import WireHelper 10 | from nebula.prim.axes import Axes 11 | from nebula.topology.topology import Topology 12 | 13 | DEFAULT_CURVE_SAMPLES = 20 14 | 15 | 16 | @jdc.pytree_dataclass 17 | class SortedEdgeResult: 18 | edges: "Edges" 19 | index: jdc.Static[npt.NDArray[np.int32]] 20 | segments: jnp.ndarray 21 | 22 | 23 | @jdc.pytree_dataclass 24 | class Edges(Topology): 25 | lines: Lines 26 | curves: BSplineCurves 27 | is_curve: jdc.Static[npt.NDArray[np.bool_]] 28 | 29 | def add( 30 | self, 31 | edges: Union["Edges", Lines, BSplineCurves, jnp.ndarray], 32 | reorder_index: bool = False, 33 | ): 34 | """ 35 | Add a new edge to the topology. 36 | 37 | :param properties: Either new vertices of the edges or new Edges to add. 38 | :type a: jnp.ndarray 39 | 40 | :returns: The indices of the new edges. 41 | :rtype: jnp.ndarray 42 | """ 43 | # Add new unqiue vertices to the list of vertices 44 | if isinstance(edges, Edges): 45 | lines = self.lines.add(edges.lines) 46 | curves = self.curves.add(edges.curves) 47 | is_curve = np.concatenate([self.is_curve, edges.is_curve]) 48 | 49 | else: 50 | if isinstance(edges, BSplineCurves): 51 | curves = self.curves.add(edges) 52 | lines = self.lines 53 | else: 54 | if isinstance(edges, jnp.ndarray): 55 | edges = Lines.from_segments(edges) 56 | lines = self.lines.add(edges) 57 | curves = self.curves 58 | 59 | new_indices = np.arange(edges.count, dtype=np.int32) 60 | is_curve = np.concatenate( 61 | [self.is_curve, np.full(len(new_indices), isinstance(edges, BSplineCurves))] 62 | ) 63 | 64 | return Edges( 65 | lines, 66 | curves, 67 | is_curve, 68 | ) 69 | 70 | @staticmethod 71 | def empty(): 72 | return Edges( 73 | lines=Lines.empty(), 74 | curves=BSplineCurves.empty(), 75 | is_curve=np.empty((0,), dtype=bool), 76 | ) 77 | 78 | @staticmethod 79 | def from_line_segments(segments: jnp.ndarray): 80 | """Create edges from segments 81 | 82 | :param segments: segments to create edges from 83 | :type segments: jnp.ndarray 84 | :return: Edges 85 | :rtype: Edges 86 | """ 87 | lines = Lines.from_segments(segments) 88 | return Edges( 89 | lines=lines, 90 | curves=BSplineCurves.empty(), 91 | is_curve=np.zeros(lines.count, dtype=bool), 92 | ) 93 | 94 | @staticmethod 95 | def from_bspline_curves( 96 | bspline_curves: BSplineCurves, 97 | ): 98 | """Create edges from bspline curves 99 | 100 | :param bspline_curves: bspline curves 101 | :type bspline_curves: BSplineCurves 102 | :return: Edges 103 | :rtype: Edges 104 | """ 105 | new_edges = Edges.empty() 106 | return new_edges.add(bspline_curves) 107 | 108 | def clone( 109 | self, 110 | lines: Optional[Lines] = None, 111 | curves: Optional[BSplineCurves] = None, 112 | is_curve: Optional[np.ndarray] = None, 113 | ): 114 | return Edges( 115 | lines=lines if lines is not None else self.lines, 116 | curves=curves if curves is not None else self.curves, 117 | is_curve=is_curve if is_curve is not None else self.is_curve, 118 | ) 119 | 120 | @property 121 | def vertices(self): 122 | return jnp.concatenate( 123 | [self.lines.get_segments()[:, 0], self.curves.ctrl_pnts.val] 124 | ) 125 | 126 | # TODO: make this simpler and more efficient 127 | def get_sorted(self): 128 | segments = self.get_segments() 129 | sorted_segments = jnp.zeros((len(segments), 2, 3)) 130 | sorting_order = np.zeros(self.index.shape[0], dtype=int) 131 | 132 | order_index = np.arange(self.count) 133 | for i in range(0, len(segments)): 134 | if i == 0: 135 | sorted_segments = sorted_segments.at[i].set(segments[0]) 136 | sorting_order[i] = self.index[0] 137 | else: 138 | previous_segment_end = sorted_segments[i - 1][1] 139 | # check which start segments match the previous segment's end 140 | sort_mask = jnp.all( 141 | jnp.abs(segments[:, 0] - previous_segment_end) <= 1e-5, axis=1 142 | ) 143 | sorted_segments = sorted_segments.at[i].set(segments[sort_mask][0]) 144 | sorting_order[i] = order_index[sort_mask][0] 145 | 146 | sorted_is_curve = self.is_curve[sorting_order] 147 | sorted_index = self.index[sorting_order] 148 | 149 | new_edges = self.clone( 150 | lines=self.lines.reorder(sorted_index[~sorted_is_curve]), 151 | curves=self.curves.reorder(sorted_index[sorted_is_curve]), 152 | is_curve=sorted_is_curve, 153 | ) 154 | 155 | return SortedEdgeResult(new_edges, sorted_index, sorted_segments) 156 | # if return_index: 157 | # return self.clone(), self.index 158 | # return self.clone() 159 | 160 | # def get_segments(self): 161 | # line_segments = self.lines.get_segments() 162 | # curve_segments = self.curves.get_segments() 163 | 164 | # total_size = line_segments.shape[0] + curve_segments.shape[0] 165 | # padded_curve_segments = jnp.pad( 166 | # curve_segments, ((0, total_size - curve_segments.shape[0]), (0, 0), (0, 0)) 167 | # ) 168 | # padded_line_segments = jnp.pad( 169 | # line_segments, ((0, total_size - line_segments.shape[0]), (0, 0), (0, 0)) 170 | # ) 171 | 172 | # return jnp.where( 173 | # self.is_curve[:, None, None], padded_curve_segments, padded_line_segments 174 | # ) 175 | 176 | def get_segments(self): 177 | line_segments = self.lines.get_segments() 178 | curve_segments = self.curves.get_segments() 179 | 180 | return ( 181 | jnp.empty((self.count, 2, 3)) 182 | .at[~self.is_curve] 183 | .set(line_segments) 184 | .at[self.is_curve] 185 | .set(curve_segments) 186 | ) 187 | 188 | @property 189 | def index(self): 190 | """ 191 | Generates edge index for each component, use mask self.is_curve to mask off each index 192 | """ 193 | index = np.empty(self.count, dtype=np.int32) 194 | index[~self.is_curve] = np.arange(self.lines.count) 195 | index[self.is_curve] = np.arange(self.curves.count) 196 | return index 197 | 198 | @property 199 | def count(self): 200 | return len(self.is_curve) 201 | 202 | def evaluate( 203 | self, 204 | curve_samples: Optional[int] = DEFAULT_CURVE_SAMPLES, 205 | is_cosine_sampling: bool = False, 206 | ): 207 | """ 208 | Evaluate edges, not in order 209 | 210 | :param curve_samples: Number of samples to use for curves, 211 | if None then uses ctrl pnt start and end as segments, defaults to DEFAULT_CURVE_SAMPLES 212 | :type curve_samples: Optional[int], optional 213 | :return: Segments 214 | :rtype: jnp.ndarray 215 | """ 216 | u = ( 217 | get_sampling(0.0, 1.0, curve_samples + 1, is_cosine_sampling) 218 | if curve_samples 219 | else None 220 | ) 221 | curve_segments = self.curves.evaluate(u) 222 | line_segments = self.lines.get_segments() 223 | 224 | return jnp.concatenate([line_segments, curve_segments], axis=0) 225 | 226 | def dir_vec_at(self, u: jnp.ndarray, index: jnp.ndarray, eps=1e-5): 227 | u_next = jnp.where(u >= 1.0, u, u + eps) 228 | u = jnp.where(u >= 1.0, u_next - eps, u) 229 | 230 | start_vertices = self.evaluate_at(u, index) 231 | end_vertices = self.evaluate_at(u_next, index) 232 | 233 | return end_vertices - start_vertices 234 | 235 | def evaluate_at(self, u: jnp.ndarray, index: jnp.ndarray): 236 | """ 237 | Evaluate the edge at the given u values. 238 | 239 | :param u: The u values to evaluate the edge at. 240 | :type u: jnp.ndarray 241 | :param index: The index of the edge to evaluate. 242 | :type index: jnp.ndarray 243 | :return: The vertices of the edge. 244 | :rtype: jnp.ndarray 245 | """ 246 | if self.is_curve[index]: 247 | return self.curves.evaluate_at(u, self.index[index]) 248 | return self.lines.evaluate_at(u, self.index[index]) 249 | 250 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]): 251 | """Get edges from mask 252 | 253 | :param mask: Mask for edges 254 | :type mask: Union[jnp.ndarray, np.ndarray] 255 | :return: Edges 256 | :rtype: Edges 257 | """ 258 | 259 | line_mask = mask[~self.is_curve] 260 | curve_mask = mask[self.is_curve] 261 | 262 | return self.clone( 263 | self.lines.mask(line_mask), 264 | self.curves.mask(curve_mask), 265 | self.is_curve[mask], 266 | ) 267 | 268 | def project(self, curr_axes: Axes, new_axes: Axes): 269 | new_lines = self.lines.project(curr_axes, new_axes) 270 | new_curves = self.curves.project(curr_axes, new_axes) 271 | is_curve = np.repeat(self.is_curve[None, :], new_axes.count, axis=0).flatten() 272 | return Edges(lines=new_lines, curves=new_curves, is_curve=is_curve) 273 | 274 | def translate(self, translation: jnp.ndarray): 275 | line_translation = ( 276 | translation if len(translation.shape) == 1 else translation[~self.is_curve] 277 | ) 278 | 279 | curve_translation = ( 280 | translation if len(translation.shape) == 1 else translation[self.is_curve] 281 | ) 282 | 283 | new_lines = self.lines.translate(line_translation) 284 | new_curves = self.curves.translate(curve_translation) 285 | return self.clone(new_lines, new_curves) 286 | 287 | def flip_winding(self, mask: Optional[jnp.ndarray] = None): 288 | line_mask = mask[~self.is_curve] if mask is not None else None 289 | curve_mask = mask[self.is_curve] if mask is not None else None 290 | flipped_edges = self.clone( 291 | self.lines.flip_winding(line_mask), 292 | self.curves.flip_winding(curve_mask), 293 | self.is_curve[::-1], 294 | ) 295 | 296 | return flipped_edges 297 | 298 | def __repr__(self) -> str: 299 | return f"Edges(count={self.count})" 300 | -------------------------------------------------------------------------------- /nebula/topology/wires.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, Sequence, Union 2 | import jax 3 | import jax_dataclasses as jdc 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import numpy.typing as npt 7 | from nebula.helpers.types import Number 8 | from nebula.helpers.wire import WireHelper 9 | from nebula.prim.axes import Axes 10 | from nebula.topology.topology import SparseIndexable, Topology 11 | from nebula.topology.edges import DEFAULT_CURVE_SAMPLES, BSplineCurves, Edges 12 | 13 | 14 | class BoundingBox: 15 | def __init__(self, min: jnp.ndarray, max: jnp.ndarray) -> None: 16 | self.min = min 17 | self.max = max 18 | 19 | @staticmethod 20 | def from_points(points: jnp.ndarray): 21 | flattened_points = points.reshape(-1, 3) 22 | return BoundingBox( 23 | jnp.min(flattened_points, axis=0), jnp.max(flattened_points, axis=0) 24 | ) 25 | 26 | def contains(self, point: jnp.ndarray): 27 | return jnp.all(point >= self.min, axis=-1) & jnp.all(point <= self.max, axis=-1) 28 | 29 | def __repr__(self) -> str: 30 | return f"BoundingBox(min={self.min}, max={self.max})" 31 | 32 | 33 | class WireLike(Topology, Protocol): 34 | edges: Edges 35 | 36 | @classmethod 37 | def combine(cls, topologies: Sequence["Topology"], reorder_index: bool = True): 38 | """Combine topologies into single 39 | 40 | :param topologies: Topology to combine 41 | :type topologies: Topology 42 | :return: Combined topologies 43 | :rtype: Topology 44 | """ 45 | new_topology = cls.empty() # type: ignore 46 | for topology in topologies: 47 | new_topology = new_topology.add(topology, reorder_index) 48 | 49 | return new_topology 50 | 51 | def get_segment_index(self, curve_samples: Optional[int] = None): 52 | if curve_samples is None: 53 | curve_samples = 0 54 | # Get the wire segment index including repeats for curve segments 55 | return jnp.concatenate( 56 | [ 57 | self.index[~self.edges.is_curve], 58 | self.index[self.edges.is_curve].repeat(curve_samples, axis=0), 59 | ] 60 | ) 61 | 62 | def get_axis_normals(self) -> jnp.ndarray: 63 | """Compute the normals of the wire axes. Only works for planar wires.""" 64 | line_segments = self.edges.get_segments() 65 | # Get the first and last edges 66 | first_edges = ( 67 | jnp.zeros(shape=(self.count, 2, 3)) 68 | .at[self.index[::-1]] 69 | .set(line_segments[::-1]) 70 | ) 71 | last_edges = jnp.zeros(shape=(self.count, 2, 3)).at[self.index].set(line_segments) 72 | 73 | return jax.vmap(WireHelper.get_normal, in_axes=(0, 0))(first_edges, last_edges) 74 | 75 | def get_axes(self): 76 | """Compute the axis of the faces.""" 77 | return Axes(normals=self.get_axis_normals()) 78 | 79 | def get_lengths(self, curve_samples=DEFAULT_CURVE_SAMPLES): 80 | segments = self.edges.evaluate(curve_samples) 81 | segment_index = self.get_segment_index(curve_samples) 82 | 83 | start_vertices = segments[:, 0] 84 | end_vertices = segments[:, 1] 85 | 86 | lengths = jnp.linalg.norm(end_vertices - start_vertices) 87 | return jnp.zeros(self.count).at[segment_index].add(lengths) 88 | 89 | def get_edge_index(self, u: Number): 90 | if isinstance(u, (int, float)): 91 | u = jnp.array([u]) 92 | wire_lengths = self.get_lengths() 93 | length_ratios = wire_lengths / jnp.sum(wire_lengths) 94 | 95 | wire_ratios = jnp.cumsum(length_ratios) 96 | # TODO: check this for correctness 97 | # return jnp.arange(len(wire_ratios))[(wire_ratios <= u)][-1] 98 | return jnp.array(0) 99 | 100 | def dir_vec_at(self, u: Number): 101 | if isinstance(u, (int, float)): 102 | u = jnp.array([u]) 103 | 104 | edge_index = self.get_edge_index(u) 105 | # TODO: add this later 106 | # edge_u = (u - length_ratios[edge_index]) / length_ratios[edge_index] 107 | return self.edges.dir_vec_at(u, edge_index) 108 | 109 | def evaluate_at(self, u: Number): 110 | if isinstance(u, (int, float)): 111 | u = jnp.array([u]) 112 | 113 | edge_index = self.get_edge_index(u) 114 | 115 | # TODO: add this later 116 | # edge_u = (u - length_ratios[edge_index]) / length_ratios[edge_index] 117 | return self.edges.evaluate_at(u, edge_index) 118 | 119 | def get_centroids(self, curve_samples=DEFAULT_CURVE_SAMPLES): 120 | """Compute the centers of the faces.""" 121 | segments = self.edges.evaluate(curve_samples) 122 | segment_index = self.get_segment_index(curve_samples) 123 | num_segments = jnp.expand_dims( 124 | (jnp.zeros(self.count).at[segment_index].add(values=jnp.ones(len(segments)))), 125 | axis=1, 126 | ) 127 | 128 | start_vertices = segments[:, 0] 129 | 130 | # Sum vertices and divide by number of vertices 131 | vertex_sum = ( 132 | jnp.zeros(shape=(self.count, 3)).at[segment_index].add(start_vertices) 133 | ) # (num_wires, 3) 134 | centers = vertex_sum / num_segments 135 | return centers 136 | 137 | def get_num_edges(self): 138 | """Compute the number of edges in each wire.""" 139 | return ( 140 | jnp.zeros(shape=self.count) 141 | .at[self.index] 142 | .add(values=jnp.ones(shape=self.edges.count)) 143 | ) 144 | 145 | def get_num_curves(self): 146 | """Compute the number of edges in each wire.""" 147 | return ( 148 | jnp.zeros(shape=self.count) 149 | .at[self.index] 150 | .add(values=self.edges.is_curve.astype(jnp.int32)) 151 | ) 152 | 153 | def get_bounding_box(self): 154 | start_vertices = self.edges.evaluate()[:, 0] 155 | 156 | max_coords = ( 157 | jnp.zeros(shape=(self.count, 3)).at[self.index].max(start_vertices) 158 | ) # (num_wires, 3) 159 | 160 | min_coords = ( 161 | jnp.zeros(shape=(self.count, 3)).at[self.index].min(start_vertices) 162 | ) # (num_wires, 3) 163 | 164 | return BoundingBox(min_coords, max_coords) 165 | 166 | 167 | @jdc.pytree_dataclass 168 | class Wires(WireLike): 169 | edges: Edges 170 | index: jdc.Static[npt.NDArray[np.int32]] 171 | is_interior: jdc.Static[npt.NDArray[np.bool_]] 172 | is_planar: jdc.Static[npt.NDArray[np.bool_]] 173 | 174 | def get_segment_is_interior(self, curve_samples: int): 175 | # Get the wire segment index including repeats for curve segments 176 | return jnp.concatenate( 177 | [ 178 | self.is_interior[~self.edges.is_curve], 179 | self.is_interior[self.edges.is_curve].repeat(curve_samples, axis=0), 180 | ] 181 | ) 182 | 183 | @property 184 | def is_single(self): 185 | return self.count == 1 186 | 187 | # TODO: make this more efficient 188 | @staticmethod 189 | def convert_to_bspline(wires: "Wires"): 190 | """ 191 | :param wires: Wires to convert to bspline if they are planar and quads 192 | :type wires: Wires 193 | :return: BSpline curves 194 | :rtype: BSplineCurves 195 | """ 196 | eval_wires = Wires.empty() 197 | for i in range(wires.count): 198 | if wires.is_planar[i]: 199 | curr_wire = wires.mask(wires.index == i) 200 | segments = curr_wire.edges.get_segments() 201 | assert segments.shape[0] == 4, "Only quads are supported" 202 | top_curve = BSplineCurves.from_line_segment(segments[0]) 203 | bottom_curve = BSplineCurves.from_line_segment(segments[2][::-1]) 204 | bspline_plane_wire = Wires.skin([top_curve, bottom_curve]) 205 | else: 206 | bspline_plane_wire = wires.mask(wires.index == i) 207 | eval_wires = eval_wires.add(bspline_plane_wire) 208 | return eval_wires 209 | 210 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]): 211 | """Get wires from mask 212 | 213 | :param mask: Mask for wires 214 | :type mask: Union[jnp.ndarray, np.ndarray] 215 | :return: Masked wires 216 | :rtype: Wires 217 | """ 218 | wires = Wires( 219 | edges=self.edges.mask(mask), 220 | index=Topology.reorder_index(self.index[mask]), 221 | is_interior=self.is_interior[mask], 222 | is_planar=self.is_planar[mask], 223 | ) 224 | assert len(wires.index) == wires.edges.count, "index and node count mismatch" 225 | return wires 226 | 227 | @staticmethod 228 | def empty(): 229 | return Wires( 230 | edges=Edges.empty(), 231 | index=np.empty((0,), dtype=np.int32), 232 | is_interior=np.empty((0,), dtype=bool), 233 | is_planar=np.empty((0,), dtype=bool), 234 | ) 235 | 236 | @staticmethod 237 | def skin(sections: Sequence[BSplineCurves]): 238 | """Create Wires from bspline surfaces 239 | 240 | :param bspline_surfaces: bspline surfaces 241 | :type bspline_surfaces: BSplineSurfaces 242 | :return: Wires 243 | :rtype: Wires 244 | """ 245 | new_curves = BSplineCurves.empty() 246 | index = np.empty((0,), dtype=np.int32) 247 | num_section_pnts = sections[0].ctrl_pnts.count 248 | for section in sections: 249 | assert section.ctrl_pnts.count == num_section_pnts, "section size mismatch" 250 | index = np.concatenate([index, np.arange(section.count, dtype=np.int32)]) 251 | new_curves = new_curves.add(section) 252 | 253 | edges = Edges.from_bspline_curves(new_curves) 254 | num_edges = edges.count 255 | 256 | return Wires( 257 | edges=edges, 258 | index=index, 259 | is_interior=np.full(num_edges, False), 260 | is_planar=np.full(num_edges, False), 261 | ) 262 | 263 | @staticmethod 264 | def from_edges(edges: Edges): 265 | """Create Wires from edges 266 | 267 | :param edges: edges 268 | :type edges: Edges 269 | :return: Wires 270 | :rtype: Wires 271 | """ 272 | num_edges = edges.count 273 | return Wires( 274 | edges=edges, 275 | index=np.full(num_edges, 0, dtype=np.int32), 276 | is_interior=np.full(num_edges, False), 277 | is_planar=np.full(num_edges, True), 278 | ) 279 | 280 | @staticmethod 281 | def from_line_segments( 282 | segments: jnp.ndarray, 283 | is_interior: bool = False, 284 | ): 285 | """Create Wires from edge segments for each wire (num_wires, num_edges, 2, 3) 286 | 287 | :param edge_segments: edge segments for each wire (num_wires, num_edges, 2, 3) 288 | :type edge_segments: jnp.ndarray 289 | :return: Wires 290 | :rtype: Wires 291 | """ 292 | if len(segments.shape) == 3: 293 | segments = jnp.expand_dims(segments, axis=0) 294 | 295 | edges = Edges.from_line_segments(segments) 296 | 297 | index = ( 298 | # create new indices for each wire 299 | np.arange(0, len(segments)) 300 | # repeat for each edge 301 | .repeat(segments.shape[1], axis=0) 302 | ) 303 | 304 | return Wires( 305 | edges=edges, 306 | index=index, 307 | is_interior=np.full(edges.count, is_interior), 308 | is_planar=np.full(edges.count, True), 309 | ) 310 | 311 | def add(self, wires: "Wires", reorder_index: bool = False): 312 | is_interior = np.concatenate([self.is_interior, wires.is_interior]) 313 | is_planar = np.concatenate([self.is_planar, wires.is_planar]) 314 | edges = self.edges.add(wires.edges) 315 | index = self.add_indices(wires.index, reorder_index) 316 | assert len(index) == edges.count, "index and edges count mismatch" 317 | 318 | return Wires(edges, index, is_interior, is_planar) 319 | 320 | def clone( 321 | self, 322 | edges: Optional[Edges] = None, 323 | index: Optional[np.ndarray] = None, 324 | is_interior: Optional[np.ndarray] = None, 325 | is_planar: Optional[np.ndarray] = None, 326 | ): 327 | edges = edges if edges is not None else self.edges 328 | return Wires( 329 | edges=edges, 330 | index=index if index is not None else self.index, 331 | is_interior=np.full( 332 | edges.count, is_interior if is_interior is not None else False 333 | ), 334 | is_planar=np.full(edges.count, is_planar if is_planar is not None else True), 335 | ) 336 | 337 | def project(self, new_axes: Axes): 338 | """ 339 | Project the wires onto the given axes from XY plane. Assumes all wires are on XY plane. 340 | :param axes: Axes to project on 341 | :type axes: Axes 342 | :return: Projected wires 343 | :rtype: Wires 344 | """ 345 | curr_axes = self.get_axes() 346 | new_edges = self.edges.project(curr_axes, new_axes) 347 | index = SparseIndexable.expanded_index(self.index, new_axes.count) 348 | is_interior = np.repeat( 349 | self.is_interior[None, :], new_axes.count, axis=0 350 | ).flatten() 351 | is_planar = np.repeat(self.is_planar[None, :], new_axes.count, axis=0).flatten() 352 | 353 | return Wires(new_edges, index, is_interior, is_planar) 354 | 355 | def translate(self, translation: jnp.ndarray): 356 | """Translate the wires by a given translation. 357 | 358 | :param translation: The translation to apply. 359 | :type translation: jnp.ndarray 360 | """ 361 | if len(translation.shape) == 1: 362 | new_edges = self.edges.translate(translation) 363 | elif len(translation.shape) == 2: 364 | new_edges = self.edges.translate(translation[self.index]) 365 | else: 366 | raise ValueError( 367 | f"Translation must be of shape (3,) or (num_wires, 3), got {translation.shape}" 368 | ) 369 | return self.clone(new_edges) 370 | 371 | def wind(self, is_clockwise: bool = False): 372 | normals = self.get_axis_normals() 373 | # checks if winding order is clockwise, if so it needs to be flipped to counter clockwise 374 | flip_mask = jnp.sum(normals, axis=1) < 0 375 | if is_clockwise: 376 | # if we want it to be clockwise, it is the inverse of the flip mask 377 | flip_mask = ~flip_mask 378 | new_flipped = self.flip_winding(flip_mask) 379 | return new_flipped 380 | 381 | def flip_winding(self, mask: Optional[jnp.ndarray] = None): 382 | """Returns the wires with fliped winding order. 383 | 384 | :param edges: Wires to flip 385 | :type wires: Wires 386 | :return: Wire with flipped winding order 387 | :rtype: jnp.ndarray 388 | """ 389 | # assign to new Wires 390 | edge_flip_mask = mask[self.index] if mask is not None else None 391 | new_edges = self.edges.flip_winding(edge_flip_mask) 392 | new_wires = self.clone( 393 | new_edges, self.index[::-1], is_interior=self.is_interior[::-1] 394 | ) 395 | 396 | return new_wires 397 | 398 | @property 399 | def num_interior(self): 400 | return jnp.sum(jnp.any(self.is_interior)) 401 | 402 | def __repr__(self) -> str: 403 | return f"Wires(count={self.count}, edges={self.edges}, num_interior={self.num_interior})" 404 | 405 | def plot(self, is_3d=True): 406 | """Plotly plot the wires.""" 407 | 408 | import plotly.graph_objects as go 409 | 410 | x = [] 411 | y = [] 412 | z = [] 413 | segments = self.edges.evaluate() 414 | for segment in segments: 415 | x += [*segment[:, 0], None] 416 | y += [*segment[:, 1], None] 417 | z += [*segment[:, 2], None] 418 | if is_3d: 419 | fig = go.Figure( 420 | go.Scatter3d( 421 | x=x, 422 | y=y, 423 | z=z, 424 | mode="lines", 425 | line=dict(color="blue", width=2), 426 | ) 427 | ) 428 | fig.update_scenes(aspectmode="data") 429 | 430 | else: 431 | fig = go.Figure( 432 | go.Scatter( 433 | x=x, 434 | y=y, 435 | mode="lines", 436 | line=dict(color="blue", width=2), 437 | ) 438 | ) 439 | # equal scale 440 | fig.update_xaxes(scaleanchor="y", scaleratio=1) 441 | fig.update_yaxes(scaleanchor="x", scaleratio=1) 442 | fig.show() 443 | --------------------------------------------------------------------------------